-
Notifications
You must be signed in to change notification settings - Fork 23
Wrapper for GMM #137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Wrapper for GMM #137
Conversation
…ts because of library restrictions.
…ded methods, but i checked method sample.
…dition. It requires manual entry of params because we cant garantee that sklearn's from_samples and gmr's going to be the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 26 out of 26 changed files in this pull request and generated 1 comment.
Comments suppressed due to low confidence (1)
bamt/utils/gmm_wrapper.py:160
- The assignment of given_values using indexing [0] may unintentionally drop dimensions when conditioning on multiple variables. Consider using np.array(given_values) without the [0] indexing so that the full array is preserved.
given_values = np.array(given_values)[0]
|
Обновил исходя из PR review, добавил комменты, докстринги и остальное |
|
Разве что надо убрать gmr теперь из requirements и pyproject.toml |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NB: I didn't run these changes on my own!
Low-Level Observations
- The low-level implementation looks solid overall.
- A few commented-out lines should be removed to clean things up.
- Exception handling is clear and well-structured.
- Computation is efficient and numerically stable — e.g., good use of
np.linalg.solveinstead of inverse matrix operations.
High-Level Issues
Redundant Validations in from_samples
There's a manual check to ensure X.shape[0] >= 2, like:
if X.shape[0] < 2:
raise ValueError("Need at least 2 samples...")But this check already exists within sklearn.GaussianMixture through:
Snippet from fit_predict
X = validate_data(self, X, dtype=[np.float64, np.float32], ensure_min_samples=2)Which internally uses:
if ensure_min_samples > 0:
n_samples = _num_samples(array)
if n_samples < ensure_min_samples:
raise ValueError(
"Found array with %d sample(s) (shape=%s) while a"
" minimum of %d is required%s."
% (n_samples, array.shape, ensure_min_samples, context)
)Recommendation: These checks are redundant and can be removed or delegated to sklearn.
manual_init Duplication
Instead of rewriting the initialization logic, it's more maintainable to reuse sklearn's internal _initialize method. For reference, here's the original:
def _initialize(self, X, resp):
"""Initialization of the Gaussian mixture parameters."""
n_samples, _ = X.shape
weights, means, covariances = None, None, None
if resp is not None:
weights, means, covariances = _estimate_gaussian_parameters(
X, resp, self.reg_covar, self.covariance_type
)
if self.weights_init is None:
weights /= n_samples
self.weights_ = weights if self.weights_init is None else self.weights_init
self.means_ = means if self.means_init is None else self.means_init
if self.precisions_init is None:
self.covariances_ = covariances
self.precisions_cholesky_ = _compute_precision_cholesky(
covariances, self.covariance_type
)
else:
self.precisions_cholesky_ = _compute_precision_cholesky_from_precisions(
self.precisions_init, self.covariance_type
)Recommendation: Subclass GaussianMixture and override only what you need, e.g.:
from sklearn.mixture import GaussianMixture
class CustomGMM(GaussianMixture):
def _initialize(self, X, resp):
super()._initialize(X, resp)
# Add any custom behavior hereThis avoids code duplication and benefits from future improvements in scikit-learn.
Testing Strategy
- Using
gmras a reference is acceptable for transitional verification. - However, the long-term goal is to remove
gmrfrombamt, which means the testing strategy should evolve.
Suggestions:
- Shift from comparison-based testing to behavioral testing:
- Check that likelihood improves after training.
- Verify cluster consistency on toy datasets.
- Add integration tests on synthetic datasets with known structure.
Action Items
- Remove commented-out lines and any redundant stability checks.
- Refactor by subclassing
sklearn.GaussianMixtureand override only required methods. - Redesign tests to focus on correctness and invariants instead of output comparison with
gmr.
|
@jrzkaminski убирать тогда надо аккуратно -- перенести например в группу тестирования |
|
Removed redundant validation from from_samples() method and cleaned up outdated comments. All tests pass This will allow us to fully drop dependency on gmr and transition to robust behavioral testing (log-likelihood checks, cluster validation). |
Drop-in replacement for GMM from gmr
Implements:
.from_samples(...)
.sample(...)
.predict(...)
.to_responsibilities(...)
.condition(...)
Carefully replicates gmr behavior, including:
Manual initialization when n_samples < 2
Safe handling of degenerate weights (e.g., nan or negative)
Adds warnings instead of exceptions for better debug experience
This pull request includes several changes across multiple files to improve code readability and consistency. The most important changes involve adding missing imports, reformatting code for better readability, and updating string formatting.
Code Readability Improvements:
bamt/networks/base.py: Reformatted import statements and long lines to improve readability. [1] [2] [3] [4] [5]bamt/nodes/conditional_mixture_gaussian_node.py: Reformatted long lines and updated thechooseandpredictmethods for better readability. [1] [2] [3] [4]Consistency Updates:
bamt/external/pyitlib/DiscreteRandomVariableUtils.py: Updated string formatting to use double quotes and reformatted long lines for consistency. [1] [2] [3] [4]bamt/nodes/conditional_logit_node.py: Reformatted long lines and updated string formatting for consistency.Minor Additions:
Bug Fixes:
bamt/external/pyitlib/DiscreteRandomVariableUtils.py: Fixed assertion statements and improved error handling in_estimate_probabilitiesfunction. [1] [2] [3]Import Updates:
bamt/nodes/conditional_mixture_gaussian_node.py: Replaced import fromgmrtobamt.utils.gmm_wrapperforGMM.