.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/mixture/plot_gmm_selection.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download_auto_examples_mixture_plot_gmm_selection.py>` to download the full example code or to run this example in your browser via JupyterLite or Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_mixture_plot_gmm_selection.py: ================================ Gaussian Mixture Model Selection ================================ This example shows that model selection can be performed with Gaussian Mixture Models (GMM) using :ref:`information-theory criteria <aic_bic>`. Model selection concerns both the covariance type and the number of components in the model. In this case, both the Akaike Information Criterion (AIC) and the Bayes Information Criterion (BIC) provide the right result, but we only demo the latter as BIC is better suited to identify the true model among a set of candidates. Unlike Bayesian procedures, such inferences are prior-free. .. GENERATED FROM PYTHON SOURCE LINES 18-25 Data generation --------------- We generate two components (each one containing `n_samples`) by randomly sampling the standard normal distribution as returned by `numpy.random.randn`. One component is kept spherical yet shifted and re-scaled. The other one is deformed to have a more general covariance matrix. .. GENERATED FROM PYTHON SOURCE LINES 25-36 .. code-block:: Python import numpy as np n_samples = 500 np.random.seed(0) C = np.array([[0.0, -0.1], [1.7, 0.4]]) component_1 = np.dot(np.random.randn(n_samples, 2), C) # general component_2 = 0.7 * np.random.randn(n_samples, 2) + np.array([-4, 1]) # spherical X = np.concatenate([component_1, component_2]) .. GENERATED FROM PYTHON SOURCE LINES 37-38 We can visualize the different components: .. GENERATED FROM PYTHON SOURCE LINES 38-47 .. code-block:: Python import matplotlib.pyplot as plt plt.scatter(component_1[:, 0], component_1[:, 1], s=0.8) plt.scatter(component_2[:, 0], component_2[:, 1], s=0.8) plt.title("Gaussian Mixture components") plt.axis("equal") plt.show() .. image-sg:: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_001.png :alt: Gaussian Mixture components :srcset: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 48-67 Model training and selection ---------------------------- We vary the number of components from 1 to 6 and the type of covariance parameters to use: - `"full"`: each component has its own general covariance matrix. - `"tied"`: all components share the same general covariance matrix. - `"diag"`: each component has its own diagonal covariance matrix. - `"spherical"`: each component has its own single variance. We score the different models and keep the best model (the lowest BIC). This is done by using :class:`~sklearn.model_selection.GridSearchCV` and a user-defined score function which returns the negative BIC score, as :class:`~sklearn.model_selection.GridSearchCV` is designed to **maximize** a score (maximizing the negative BIC is equivalent to minimizing the BIC). The best set of parameters and estimator are stored in `best_parameters_` and `best_estimator_`, respectively. .. GENERATED FROM PYTHON SOURCE LINES 67-87 .. code-block:: Python from sklearn.mixture import GaussianMixture from sklearn.model_selection import GridSearchCV def gmm_bic_score(estimator, X): """Callable to pass to GridSearchCV that will use the BIC score.""" # Make it negative since GridSearchCV expects a score to maximize return -estimator.bic(X) param_grid = { "n_components": range(1, 7), "covariance_type": ["spherical", "tied", "diag", "full"], } grid_search = GridSearchCV( GaussianMixture(), param_grid=param_grid, scoring=gmm_bic_score ) grid_search.fit(X) .. raw:: html <div class="output_subarea output_html rendered_html output_result"> <style>#sk-container-id-28 { /* Definition of color scheme common for light and dark mode */ --sklearn-color-text: black; --sklearn-color-line: gray; /* Definition of color scheme for unfitted estimators */ --sklearn-color-unfitted-level-0: #fff5e6; --sklearn-color-unfitted-level-1: #f6e4d2; --sklearn-color-unfitted-level-2: #ffe0b3; --sklearn-color-unfitted-level-3: chocolate; /* Definition of color scheme for fitted estimators */ --sklearn-color-fitted-level-0: #f0f8ff; --sklearn-color-fitted-level-1: #d4ebff; --sklearn-color-fitted-level-2: #b3dbfd; --sklearn-color-fitted-level-3: cornflowerblue; /* Specific color for light theme */ --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black))); --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white))); --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black))); --sklearn-color-icon: #696969; @media (prefers-color-scheme: dark) { /* Redefinition of color scheme for dark theme */ --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white))); --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111))); --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white))); --sklearn-color-icon: #878787; } } #sk-container-id-28 { color: var(--sklearn-color-text); } #sk-container-id-28 pre { padding: 0; } #sk-container-id-28 input.sk-hidden--visually { border: 0; clip: rect(1px 1px 1px 1px); clip: rect(1px, 1px, 1px, 1px); height: 1px; margin: -1px; overflow: hidden; padding: 0; position: absolute; width: 1px; } #sk-container-id-28 div.sk-dashed-wrapped { border: 1px dashed var(--sklearn-color-line); margin: 0 0.4em 0.5em 0.4em; box-sizing: border-box; padding-bottom: 0.4em; background-color: var(--sklearn-color-background); } #sk-container-id-28 div.sk-container { /* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */ display: inline-block !important; position: relative; } #sk-container-id-28 div.sk-text-repr-fallback { display: none; } div.sk-parallel-item, div.sk-serial, div.sk-item { /* draw centered vertical line to link estimators */ background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background)); background-size: 2px 100%; background-repeat: no-repeat; background-position: center center; } /* Parallel-specific style estimator block */ #sk-container-id-28 div.sk-parallel-item::after { content: ""; width: 100%; border-bottom: 2px solid var(--sklearn-color-text-on-default-background); flex-grow: 1; } #sk-container-id-28 div.sk-parallel { display: flex; align-items: stretch; justify-content: center; background-color: var(--sklearn-color-background); position: relative; } #sk-container-id-28 div.sk-parallel-item { display: flex; flex-direction: column; } #sk-container-id-28 div.sk-parallel-item:first-child::after { align-self: flex-end; width: 50%; } #sk-container-id-28 div.sk-parallel-item:last-child::after { align-self: flex-start; width: 50%; } #sk-container-id-28 div.sk-parallel-item:only-child::after { width: 0; } /* Serial-specific style estimator block */ #sk-container-id-28 div.sk-serial { display: flex; flex-direction: column; align-items: center; background-color: var(--sklearn-color-background); padding-right: 1em; padding-left: 1em; } /* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is clickable and can be expanded/collapsed. - Pipeline and ColumnTransformer use this feature and define the default style - Estimators will overwrite some part of the style using the `sk-estimator` class */ /* Pipeline and ColumnTransformer style (default) */ #sk-container-id-28 div.sk-toggleable { /* Default theme specific background. It is overwritten whether we have a specific estimator or a Pipeline/ColumnTransformer */ background-color: var(--sklearn-color-background); } /* Toggleable label */ #sk-container-id-28 label.sk-toggleable__label { cursor: pointer; display: block; width: 100%; margin-bottom: 0; padding: 0.5em; box-sizing: border-box; text-align: center; } #sk-container-id-28 label.sk-toggleable__label-arrow:before { /* Arrow on the left of the label */ content: "▸"; float: left; margin-right: 0.25em; color: var(--sklearn-color-icon); } #sk-container-id-28 label.sk-toggleable__label-arrow:hover:before { color: var(--sklearn-color-text); } /* Toggleable content - dropdown */ #sk-container-id-28 div.sk-toggleable__content { max-height: 0; max-width: 0; overflow: hidden; text-align: left; /* unfitted */ background-color: var(--sklearn-color-unfitted-level-0); } #sk-container-id-28 div.sk-toggleable__content.fitted { /* fitted */ background-color: var(--sklearn-color-fitted-level-0); } #sk-container-id-28 div.sk-toggleable__content pre { margin: 0.2em; border-radius: 0.25em; color: var(--sklearn-color-text); /* unfitted */ background-color: var(--sklearn-color-unfitted-level-0); } #sk-container-id-28 div.sk-toggleable__content.fitted pre { /* unfitted */ background-color: var(--sklearn-color-fitted-level-0); } #sk-container-id-28 input.sk-toggleable__control:checked~div.sk-toggleable__content { /* Expand drop-down */ max-height: 200px; max-width: 100%; overflow: auto; } #sk-container-id-28 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before { content: "▾"; } /* Pipeline/ColumnTransformer-specific style */ #sk-container-id-28 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label { color: var(--sklearn-color-text); background-color: var(--sklearn-color-unfitted-level-2); } #sk-container-id-28 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label { background-color: var(--sklearn-color-fitted-level-2); } /* Estimator-specific style */ /* Colorize estimator box */ #sk-container-id-28 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label { /* unfitted */ background-color: var(--sklearn-color-unfitted-level-2); } #sk-container-id-28 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label { /* fitted */ background-color: var(--sklearn-color-fitted-level-2); } #sk-container-id-28 div.sk-label label.sk-toggleable__label, #sk-container-id-28 div.sk-label label { /* The background is the default theme color */ color: var(--sklearn-color-text-on-default-background); } /* On hover, darken the color of the background */ #sk-container-id-28 div.sk-label:hover label.sk-toggleable__label { color: var(--sklearn-color-text); background-color: var(--sklearn-color-unfitted-level-2); } /* Label box, darken color on hover, fitted */ #sk-container-id-28 div.sk-label.fitted:hover label.sk-toggleable__label.fitted { color: var(--sklearn-color-text); background-color: var(--sklearn-color-fitted-level-2); } /* Estimator label */ #sk-container-id-28 div.sk-label label { font-family: monospace; font-weight: bold; display: inline-block; line-height: 1.2em; } #sk-container-id-28 div.sk-label-container { text-align: center; } /* Estimator-specific */ #sk-container-id-28 div.sk-estimator { font-family: monospace; border: 1px dotted var(--sklearn-color-border-box); border-radius: 0.25em; box-sizing: border-box; margin-bottom: 0.5em; /* unfitted */ background-color: var(--sklearn-color-unfitted-level-0); } #sk-container-id-28 div.sk-estimator.fitted { /* fitted */ background-color: var(--sklearn-color-fitted-level-0); } /* on hover */ #sk-container-id-28 div.sk-estimator:hover { /* unfitted */ background-color: var(--sklearn-color-unfitted-level-2); } #sk-container-id-28 div.sk-estimator.fitted:hover { /* fitted */ background-color: var(--sklearn-color-fitted-level-2); } /* Specification for estimator info (e.g. "i" and "?") */ /* Common style for "i" and "?" */ .sk-estimator-doc-link, a:link.sk-estimator-doc-link, a:visited.sk-estimator-doc-link { float: right; font-size: smaller; line-height: 1em; font-family: monospace; background-color: var(--sklearn-color-background); border-radius: 1em; height: 1em; width: 1em; text-decoration: none !important; margin-left: 1ex; /* unfitted */ border: var(--sklearn-color-unfitted-level-1) 1pt solid; color: var(--sklearn-color-unfitted-level-1); } .sk-estimator-doc-link.fitted, a:link.sk-estimator-doc-link.fitted, a:visited.sk-estimator-doc-link.fitted { /* fitted */ border: var(--sklearn-color-fitted-level-1) 1pt solid; color: var(--sklearn-color-fitted-level-1); } /* On hover */ div.sk-estimator:hover .sk-estimator-doc-link:hover, .sk-estimator-doc-link:hover, div.sk-label-container:hover .sk-estimator-doc-link:hover, .sk-estimator-doc-link:hover { /* unfitted */ background-color: var(--sklearn-color-unfitted-level-3); color: var(--sklearn-color-background); text-decoration: none; } div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover, .sk-estimator-doc-link.fitted:hover, div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover, .sk-estimator-doc-link.fitted:hover { /* fitted */ background-color: var(--sklearn-color-fitted-level-3); color: var(--sklearn-color-background); text-decoration: none; } /* Span, style for the box shown on hovering the info icon */ .sk-estimator-doc-link span { display: none; z-index: 9999; position: relative; font-weight: normal; right: .2ex; padding: .5ex; margin: .5ex; width: min-content; min-width: 20ex; max-width: 50ex; color: var(--sklearn-color-text); box-shadow: 2pt 2pt 4pt #999; /* unfitted */ background: var(--sklearn-color-unfitted-level-0); border: .5pt solid var(--sklearn-color-unfitted-level-3); } .sk-estimator-doc-link.fitted span { /* fitted */ background: var(--sklearn-color-fitted-level-0); border: var(--sklearn-color-fitted-level-3); } .sk-estimator-doc-link:hover span { display: block; } /* "?"-specific style due to the `<a>` HTML tag */ #sk-container-id-28 a.estimator_doc_link { float: right; font-size: 1rem; line-height: 1em; font-family: monospace; background-color: var(--sklearn-color-background); border-radius: 1rem; height: 1rem; width: 1rem; text-decoration: none; /* unfitted */ color: var(--sklearn-color-unfitted-level-1); border: var(--sklearn-color-unfitted-level-1) 1pt solid; } #sk-container-id-28 a.estimator_doc_link.fitted { /* fitted */ border: var(--sklearn-color-fitted-level-1) 1pt solid; color: var(--sklearn-color-fitted-level-1); } /* On hover */ #sk-container-id-28 a.estimator_doc_link:hover { /* unfitted */ background-color: var(--sklearn-color-unfitted-level-3); color: var(--sklearn-color-background); text-decoration: none; } #sk-container-id-28 a.estimator_doc_link.fitted:hover { /* fitted */ background-color: var(--sklearn-color-fitted-level-3); } </style><div id="sk-container-id-28" class="sk-top-container"><div class="sk-text-repr-fallback"><pre>GridSearchCV(estimator=GaussianMixture(), param_grid={'covariance_type': ['spherical', 'tied', 'diag', 'full'], 'n_components': range(1, 7)}, scoring=<function gmm_bic_score at 0x7fdebf689820>)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-109" type="checkbox" ><label for="sk-estimator-id-109" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted"> GridSearchCV<a class="sk-estimator-doc-link fitted" rel="noreferrer" target="_blank" href="https://scikit-learn.org/1.4/modules/generated/sklearn.model_selection.GridSearchCV.html">?<span>Documentation for GridSearchCV</span></a><span class="sk-estimator-doc-link fitted">i<span>Fitted</span></span></label><div class="sk-toggleable__content fitted"><pre>GridSearchCV(estimator=GaussianMixture(), param_grid={'covariance_type': ['spherical', 'tied', 'diag', 'full'], 'n_components': range(1, 7)}, scoring=<function gmm_bic_score at 0x7fdebf689820>)</pre></div> </div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-110" type="checkbox" ><label for="sk-estimator-id-110" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted">estimator: GaussianMixture</label><div class="sk-toggleable__content fitted"><pre>GaussianMixture()</pre></div> </div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-111" type="checkbox" ><label for="sk-estimator-id-111" class="sk-toggleable__label fitted sk-toggleable__label-arrow fitted"> GaussianMixture<a class="sk-estimator-doc-link fitted" rel="noreferrer" target="_blank" href="https://scikit-learn.org/1.4/modules/generated/sklearn.mixture.GaussianMixture.html">?<span>Documentation for GaussianMixture</span></a></label><div class="sk-toggleable__content fitted"><pre>GaussianMixture()</pre></div> </div></div></div></div></div></div></div></div></div> </div> <br /> <br /> .. GENERATED FROM PYTHON SOURCE LINES 88-94 Plot the BIC scores ------------------- To ease the plotting we can create a `pandas.DataFrame` from the results of the cross-validation done by the grid search. We re-inverse the sign of the BIC score to show the effect of minimizing it. .. GENERATED FROM PYTHON SOURCE LINES 94-110 .. code-block:: Python import pandas as pd df = pd.DataFrame(grid_search.cv_results_)[ ["param_n_components", "param_covariance_type", "mean_test_score"] ] df["mean_test_score"] = -df["mean_test_score"] df = df.rename( columns={ "param_n_components": "Number of components", "param_covariance_type": "Type of covariance", "mean_test_score": "BIC score", } ) df.sort_values(by="BIC score").head() .. raw:: html <div class="output_subarea output_html rendered_html output_result"> <div> <style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </style> <table border="1" class="dataframe"> <thead> <tr style="text-align: right;"> <th></th> <th>Number of components</th> <th>Type of covariance</th> <th>BIC score</th> </tr> </thead> <tbody> <tr> <th>19</th> <td>2</td> <td>full</td> <td>1046.829429</td> </tr> <tr> <th>20</th> <td>3</td> <td>full</td> <td>1084.038689</td> </tr> <tr> <th>21</th> <td>4</td> <td>full</td> <td>1114.517272</td> </tr> <tr> <th>22</th> <td>5</td> <td>full</td> <td>1148.512281</td> </tr> <tr> <th>23</th> <td>6</td> <td>full</td> <td>1179.977890</td> </tr> </tbody> </table> </div> </div> <br /> <br /> .. GENERATED FROM PYTHON SOURCE LINES 111-122 .. code-block:: Python import seaborn as sns sns.catplot( data=df, kind="bar", x="Number of components", y="BIC score", hue="Type of covariance", ) plt.show() .. image-sg:: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_002.png :alt: plot gmm selection :srcset: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 123-139 In the present case, the model with 2 components and full covariance (which corresponds to the true generative model) has the lowest BIC score and is therefore selected by the grid search. Plot the best model ------------------- We plot an ellipse to show each Gaussian component of the selected model. For such purpose, one needs to find the eigenvalues of the covariance matrices as returned by the `covariances_` attribute. The shape of such matrices depends on the `covariance_type`: - `"full"`: (`n_components`, `n_features`, `n_features`) - `"tied"`: (`n_features`, `n_features`) - `"diag"`: (`n_components`, `n_features`) - `"spherical"`: (`n_components`,) .. GENERATED FROM PYTHON SOURCE LINES 139-174 .. code-block:: Python from matplotlib.patches import Ellipse from scipy import linalg color_iter = sns.color_palette("tab10", 2)[::-1] Y_ = grid_search.predict(X) fig, ax = plt.subplots() for i, (mean, cov, color) in enumerate( zip( grid_search.best_estimator_.means_, grid_search.best_estimator_.covariances_, color_iter, ) ): v, w = linalg.eigh(cov) if not np.any(Y_ == i): continue plt.scatter(X[Y_ == i, 0], X[Y_ == i, 1], 0.8, color=color) angle = np.arctan2(w[0][1], w[0][0]) angle = 180.0 * angle / np.pi # convert to degrees v = 2.0 * np.sqrt(2.0) * np.sqrt(v) ellipse = Ellipse(mean, v[0], v[1], angle=180.0 + angle, color=color) ellipse.set_clip_box(fig.bbox) ellipse.set_alpha(0.5) ax.add_artist(ellipse) plt.title( f"Selected GMM: {grid_search.best_params_['covariance_type']} model, " f"{grid_search.best_params_['n_components']} components" ) plt.axis("equal") plt.show() .. image-sg:: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_003.png :alt: Selected GMM: full model, 2 components :srcset: /auto_examples/mixture/images/sphx_glr_plot_gmm_selection_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.314 seconds) .. _sphx_glr_download_auto_examples_mixture_plot_gmm_selection.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.4.X?urlpath=lab/tree/notebooks/auto_examples/mixture/plot_gmm_selection.ipynb :alt: Launch binder :width: 150 px .. container:: lite-badge .. image:: images/jupyterlite_badge_logo.svg :target: ../../lite/lab/?path=auto_examples/mixture/plot_gmm_selection.ipynb :alt: Launch JupyterLite :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_gmm_selection.ipynb <plot_gmm_selection.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_gmm_selection.py <plot_gmm_selection.py>` .. include:: plot_gmm_selection.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_