.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/inspection/plot_permutation_importance_multicollinear.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code or to run this example in your browser via JupyterLite or Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_inspection_plot_permutation_importance_multicollinear.py: ================================================================= Permutation Importance with Multicollinear or Correlated Features ================================================================= In this example, we compute the :func:`~sklearn.inspection.permutation_importance` of the features to a trained :class:`~sklearn.ensemble.RandomForestClassifier` using the :ref:`breast_cancer_dataset`. The model can easily get about 97% accuracy on a test dataset. Because this dataset contains multicollinear features, the permutation importance shows that none of the features are important, in contradiction with the high test accuracy. We demo a possible approach to handling multicollinearity, which consists of hierarchical clustering on the features' Spearman rank-order correlations, picking a threshold, and keeping a single feature from each cluster. .. note:: See also :ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py` .. GENERATED FROM PYTHON SOURCE LINES 25-29 Random Forest Feature Importance on Breast Cancer Data ------------------------------------------------------ First, we define a function to ease the plotting: .. GENERATED FROM PYTHON SOURCE LINES 29-45 .. code-block:: Python from sklearn.inspection import permutation_importance def plot_permutation_importance(clf, X, y, ax): result = permutation_importance(clf, X, y, n_repeats=10, random_state=42, n_jobs=2) perm_sorted_idx = result.importances_mean.argsort() ax.boxplot( result.importances[perm_sorted_idx].T, vert=False, labels=X.columns[perm_sorted_idx], ) ax.axvline(x=0, color="k", linestyle="--") return ax .. GENERATED FROM PYTHON SOURCE LINES 46-48 We then train a :class:`~sklearn.ensemble.RandomForestClassifier` on the :ref:`breast_cancer_dataset` and evaluate its accuracy on a test set: .. GENERATED FROM PYTHON SOURCE LINES 48-59 .. code-block:: Python from sklearn.datasets import load_breast_cancer from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split X, y = load_breast_cancer(return_X_y=True, as_frame=True) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) clf = RandomForestClassifier(n_estimators=100, random_state=42) clf.fit(X_train, y_train) print(f"Baseline accuracy on test data: {clf.score(X_test, y_test):.2}") .. rst-class:: sphx-glr-script-out .. code-block:: none Baseline accuracy on test data: 0.97 .. GENERATED FROM PYTHON SOURCE LINES 60-63 Next, we plot the tree based feature importance and the permutation importance. The permutation importance is calculated on the training set to show how much the model relies on each feature during training. .. GENERATED FROM PYTHON SOURCE LINES 63-81 .. code-block:: Python import matplotlib.pyplot as plt import numpy as np import pandas as pd mdi_importances = pd.Series(clf.feature_importances_, index=X_train.columns) tree_importance_sorted_idx = np.argsort(clf.feature_importances_) tree_indices = np.arange(0, len(clf.feature_importances_)) + 0.5 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8)) mdi_importances.sort_values().plot.barh(ax=ax1) ax1.set_xlabel("Gini importance") plot_permutation_importance(clf, X_train, y_train, ax2) ax2.set_xlabel("Decrease in accuracy score") fig.suptitle( "Impurity-based vs. permutation importances on multicollinear features (train set)" ) _ = fig.tight_layout() .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_multicollinear_001.png :alt: Impurity-based vs. permutation importances on multicollinear features (train set) :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_multicollinear_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 82-98 The plot on the left shows the Gini importance of the model. As the scikit-learn implementation of :class:`~sklearn.ensemble.RandomForestClassifier` uses a random subsets of :math:`\sqrt{n_\text{features}}` features at each split, it is able to dilute the dominance of any single correlated feature. As a result, the individual feature importance may be distributed more evenly among the correlated features. Since the features have large cardinality and the classifier is non-overfitted, we can relatively trust those values. The permutation importance on the right plot shows that permuting a feature drops the accuracy by at most `0.012`, which would suggest that none of the features are important. This is in contradiction with the high test accuracy computed as baseline: some feature must be important. Similarly, the change in accuracy score computed on the test set appears to be driven by chance: .. GENERATED FROM PYTHON SOURCE LINES 98-105 .. code-block:: Python fig, ax = plt.subplots(figsize=(7, 6)) plot_permutation_importance(clf, X_test, y_test, ax) ax.set_title("Permutation Importances on multicollinear features\n(test set)") ax.set_xlabel("Decrease in accuracy score") _ = ax.figure.tight_layout() .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_multicollinear_002.png :alt: Permutation Importances on multicollinear features (test set) :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_multicollinear_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 106-120 Nevertheless, one can still compute a meaningful permutation importance in the presence of correlated features, as demonstrated in the following section. Handling Multicollinear Features -------------------------------- When features are collinear, permuting one feature has little effect on the models performance because it can get the same information from a correlated feature. Note that this is not the case for all predictive models and depends on their underlying implementation. One way to handle multicollinear features is by performing hierarchical clustering on the Spearman rank-order correlations, picking a threshold, and keeping a single feature from each cluster. First, we plot a heatmap of the correlated features: .. GENERATED FROM PYTHON SOURCE LINES 120-147 .. code-block:: Python from scipy.cluster import hierarchy from scipy.spatial.distance import squareform from scipy.stats import spearmanr fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8)) corr = spearmanr(X).correlation # Ensure the correlation matrix is symmetric corr = (corr + corr.T) / 2 np.fill_diagonal(corr, 1) # We convert the correlation matrix to a distance matrix before performing # hierarchical clustering using Ward's linkage. distance_matrix = 1 - np.abs(corr) dist_linkage = hierarchy.ward(squareform(distance_matrix)) dendro = hierarchy.dendrogram( dist_linkage, labels=X.columns.to_list(), ax=ax1, leaf_rotation=90 ) dendro_idx = np.arange(0, len(dendro["ivl"])) ax2.imshow(corr[dendro["leaves"], :][:, dendro["leaves"]]) ax2.set_xticks(dendro_idx) ax2.set_yticks(dendro_idx) ax2.set_xticklabels(dendro["ivl"], rotation="vertical") ax2.set_yticklabels(dendro["ivl"]) _ = fig.tight_layout() .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_multicollinear_003.png :alt: plot permutation importance multicollinear :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_multicollinear_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 148-153 Next, we manually pick a threshold by visual inspection of the dendrogram to group our features into clusters and choose a feature from each cluster to keep, select those features from our dataset, and train a new random forest. The test accuracy of the new random forest did not change much compared to the random forest trained on the complete dataset. .. GENERATED FROM PYTHON SOURCE LINES 153-172 .. code-block:: Python from collections import defaultdict cluster_ids = hierarchy.fcluster(dist_linkage, 1, criterion="distance") cluster_id_to_feature_ids = defaultdict(list) for idx, cluster_id in enumerate(cluster_ids): cluster_id_to_feature_ids[cluster_id].append(idx) selected_features = [v[0] for v in cluster_id_to_feature_ids.values()] selected_features_names = X.columns[selected_features] X_train_sel = X_train[selected_features_names] X_test_sel = X_test[selected_features_names] clf_sel = RandomForestClassifier(n_estimators=100, random_state=42) clf_sel.fit(X_train_sel, y_train) print( "Baseline accuracy on test data with features removed:" f" {clf_sel.score(X_test_sel, y_test):.2}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none Baseline accuracy on test data with features removed: 0.97 .. GENERATED FROM PYTHON SOURCE LINES 173-175 We can finally explore the permutation importance of the selected subset of features: .. GENERATED FROM PYTHON SOURCE LINES 175-182 .. code-block:: Python fig, ax = plt.subplots(figsize=(7, 6)) plot_permutation_importance(clf_sel, X_test_sel, y_test, ax) ax.set_title("Permutation Importances on selected subset of features\n(test set)") ax.set_xlabel("Decrease in accuracy score") ax.figure.tight_layout() plt.show() .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_multicollinear_004.png :alt: Permutation Importances on selected subset of features (test set) :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_multicollinear_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 3.843 seconds) .. _sphx_glr_download_auto_examples_inspection_plot_permutation_importance_multicollinear.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/inspection/plot_permutation_importance_multicollinear.ipynb :alt: Launch binder :width: 150 px .. container:: lite-badge .. image:: images/jupyterlite_badge_logo.svg :target: ../../lite/lab/?path=auto_examples/inspection/plot_permutation_importance_multicollinear.ipynb :alt: Launch JupyterLite :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_permutation_importance_multicollinear.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_permutation_importance_multicollinear.py ` .. include:: plot_permutation_importance_multicollinear.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_