.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/inspection/plot_permutation_importance.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via JupyterLite or Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_inspection_plot_permutation_importance.py: ================================================================ Permutation Importance vs Random Forest Feature Importance (MDI) ================================================================ In this example, we will compare the impurity-based feature importance of :class:`~sklearn.ensemble.RandomForestClassifier` with the permutation importance on the titanic dataset using :func:`~sklearn.inspection.permutation_importance`. We will show that the impurity-based feature importance can inflate the importance of numerical features. Furthermore, the impurity-based feature importance of random forests suffers from being computed on statistics derived from the training dataset: the importances can be high even for features that are not predictive of the target variable, as long as the model has the capacity to use them to overfit. This example shows how to use Permutation Importances as an alternative that can mitigate those limitations. .. rubric:: References * :doi:`L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001. <10.1023/A:1010933404324>` .. GENERATED FROM PYTHON SOURCE LINES 27-31 .. code-block:: Python # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause .. GENERATED FROM PYTHON SOURCE LINES 32-44 Data Loading and Feature Engineering ------------------------------------ Let's use pandas to load a copy of the titanic dataset. The following shows how to apply separate preprocessing on numerical and categorical features. We further include two random variables that are not correlated in any way with the target variable (``survived``): - ``random_num`` is a high cardinality numerical variable (as many unique values as records). - ``random_cat`` is a low cardinality categorical variable (3 possible values). .. GENERATED FROM PYTHON SOURCE LINES 44-60 .. code-block:: Python import numpy as np from sklearn.datasets import fetch_openml from sklearn.model_selection import train_test_split X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True) rng = np.random.RandomState(seed=42) X["random_cat"] = rng.randint(3, size=X.shape[0]) X["random_num"] = rng.randn(X.shape[0]) categorical_columns = ["pclass", "sex", "embarked", "random_cat"] numerical_columns = ["age", "sibsp", "parch", "fare", "random_num"] X = X[categorical_columns + numerical_columns] X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42) .. GENERATED FROM PYTHON SOURCE LINES 61-68 We define a predictive model based on a random forest. Therefore, we will make the following preprocessing steps: - use :class:`~sklearn.preprocessing.OrdinalEncoder` to encode the categorical features; - use :class:`~sklearn.impute.SimpleImputer` to fill missing values for numerical features using a mean strategy. .. GENERATED FROM PYTHON SOURCE LINES 68-95 .. code-block:: Python from sklearn.compose import ColumnTransformer from sklearn.ensemble import RandomForestClassifier from sklearn.impute import SimpleImputer from sklearn.pipeline import Pipeline from sklearn.preprocessing import OrdinalEncoder categorical_encoder = OrdinalEncoder( handle_unknown="use_encoded_value", unknown_value=-1, encoded_missing_value=-1 ) numerical_pipe = SimpleImputer(strategy="mean") preprocessing = ColumnTransformer( [ ("cat", categorical_encoder, categorical_columns), ("num", numerical_pipe, numerical_columns), ], verbose_feature_names_out=False, ) rf = Pipeline( [ ("preprocess", preprocessing), ("classifier", RandomForestClassifier(random_state=42)), ] ) rf.fit(X_train, y_train) .. raw:: html
Pipeline(steps=[('preprocess',
                     ColumnTransformer(transformers=[('cat',
                                                      OrdinalEncoder(encoded_missing_value=-1,
                                                                     handle_unknown='use_encoded_value',
                                                                     unknown_value=-1),
                                                      ['pclass', 'sex', 'embarked',
                                                       'random_cat']),
                                                     ('num', SimpleImputer(),
                                                      ['age', 'sibsp', 'parch',
                                                       'fare', 'random_num'])],
                                       verbose_feature_names_out=False)),
                    ('classifier', RandomForestClassifier(random_state=42))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 96-116 Accuracy of the Model --------------------- Prior to inspecting the feature importances, it is important to check that the model predictive performance is high enough. Indeed there would be little interest of inspecting the important features of a non-predictive model. Here one can observe that the train accuracy is very high (the forest model has enough capacity to completely memorize the training set) but it can still generalize well enough to the test set thanks to the built-in bagging of random forests. It might be possible to trade some accuracy on the training set for a slightly better accuracy on the test set by limiting the capacity of the trees (for instance by setting ``min_samples_leaf=5`` or ``min_samples_leaf=10``) so as to limit overfitting while not introducing too much underfitting. However let's keep our high capacity random forest model for now so as to illustrate some pitfalls with feature importance on variables with many unique values. .. GENERATED FROM PYTHON SOURCE LINES 116-120 .. code-block:: Python print(f"RF train accuracy: {rf.score(X_train, y_train):.3f}") print(f"RF test accuracy: {rf.score(X_test, y_test):.3f}") .. rst-class:: sphx-glr-script-out .. code-block:: none RF train accuracy: 1.000 RF test accuracy: 0.814 .. GENERATED FROM PYTHON SOURCE LINES 121-142 Tree's Feature Importance from Mean Decrease in Impurity (MDI) -------------------------------------------------------------- The impurity-based feature importance ranks the numerical features to be the most important features. As a result, the non-predictive ``random_num`` variable is ranked as one of the most important features! This problem stems from two limitations of impurity-based feature importances: - impurity-based importances are biased towards high cardinality features; - impurity-based importances are computed on training set statistics and therefore do not reflect the ability of feature to be useful to make predictions that generalize to the test set (when the model has enough capacity). The bias towards high cardinality features explains why the `random_num` has a really large importance in comparison with `random_cat` while we would expect both random features to have a null importance. The fact that we use training set statistics explains why both the `random_num` and `random_cat` features have a non-null importance. .. GENERATED FROM PYTHON SOURCE LINES 142-150 .. code-block:: Python import pandas as pd feature_names = rf[:-1].get_feature_names_out() mdi_importances = pd.Series( rf[-1].feature_importances_, index=feature_names ).sort_values(ascending=True) .. GENERATED FROM PYTHON SOURCE LINES 151-155 .. code-block:: Python ax = mdi_importances.plot.barh() ax.set_title("Random Forest Feature Importances (MDI)") ax.figure.tight_layout() .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_001.png :alt: Random Forest Feature Importances (MDI) :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 156-164 As an alternative, the permutation importances of ``rf`` are computed on a held out test set. This shows that the low cardinality categorical feature, `sex` and `pclass` are the most important feature. Indeed, permuting the values of these features will lead to most decrease in accuracy score of the model on the test set. Also note that both random features have very low importances (close to 0) as expected. .. GENERATED FROM PYTHON SOURCE LINES 164-181 .. code-block:: Python from sklearn.inspection import permutation_importance result = permutation_importance( rf, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2 ) sorted_importances_idx = result.importances_mean.argsort() importances = pd.DataFrame( result.importances[sorted_importances_idx].T, columns=X.columns[sorted_importances_idx], ) ax = importances.plot.box(vert=False, whis=10) ax.set_title("Permutation Importances (test set)") ax.axvline(x=0, color="k", linestyle="--") ax.set_xlabel("Decrease in accuracy score") ax.figure.tight_layout() .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_002.png :alt: Permutation Importances (test set) :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 182-187 It is also possible to compute the permutation importances on the training set. This reveals that `random_num` and `random_cat` get a significantly higher importance ranking than when computed on the test set. The difference between those two plots is a confirmation that the RF model has enough capacity to use that random numerical and categorical features to overfit. .. GENERATED FROM PYTHON SOURCE LINES 187-202 .. code-block:: Python result = permutation_importance( rf, X_train, y_train, n_repeats=10, random_state=42, n_jobs=2 ) sorted_importances_idx = result.importances_mean.argsort() importances = pd.DataFrame( result.importances[sorted_importances_idx].T, columns=X.columns[sorted_importances_idx], ) ax = importances.plot.box(vert=False, whis=10) ax.set_title("Permutation Importances (train set)") ax.axvline(x=0, color="k", linestyle="--") ax.set_xlabel("Decrease in accuracy score") ax.figure.tight_layout() .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_003.png :alt: Permutation Importances (train set) :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 203-205 We can further retry the experiment by limiting the capacity of the trees to overfit by setting `min_samples_leaf` at 20 data points. .. GENERATED FROM PYTHON SOURCE LINES 205-207 .. code-block:: Python rf.set_params(classifier__min_samples_leaf=20).fit(X_train, y_train) .. raw:: html
Pipeline(steps=[('preprocess',
                     ColumnTransformer(transformers=[('cat',
                                                      OrdinalEncoder(encoded_missing_value=-1,
                                                                     handle_unknown='use_encoded_value',
                                                                     unknown_value=-1),
                                                      ['pclass', 'sex', 'embarked',
                                                       'random_cat']),
                                                     ('num', SimpleImputer(),
                                                      ['age', 'sibsp', 'parch',
                                                       'fare', 'random_num'])],
                                       verbose_feature_names_out=False)),
                    ('classifier',
                     RandomForestClassifier(min_samples_leaf=20, random_state=42))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 208-211 Observing the accuracy score on the training and testing set, we observe that the two metrics are very similar now. Therefore, our model is not overfitting anymore. We can then check the permutation importances with this new model. .. GENERATED FROM PYTHON SOURCE LINES 211-214 .. code-block:: Python print(f"RF train accuracy: {rf.score(X_train, y_train):.3f}") print(f"RF test accuracy: {rf.score(X_test, y_test):.3f}") .. rst-class:: sphx-glr-script-out .. code-block:: none RF train accuracy: 0.810 RF test accuracy: 0.832 .. GENERATED FROM PYTHON SOURCE LINES 215-223 .. code-block:: Python train_result = permutation_importance( rf, X_train, y_train, n_repeats=10, random_state=42, n_jobs=2 ) test_results = permutation_importance( rf, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2 ) sorted_importances_idx = train_result.importances_mean.argsort() .. GENERATED FROM PYTHON SOURCE LINES 224-233 .. code-block:: Python train_importances = pd.DataFrame( train_result.importances[sorted_importances_idx].T, columns=X.columns[sorted_importances_idx], ) test_importances = pd.DataFrame( test_results.importances[sorted_importances_idx].T, columns=X.columns[sorted_importances_idx], ) .. GENERATED FROM PYTHON SOURCE LINES 234-241 .. code-block:: Python for name, importances in zip(["train", "test"], [train_importances, test_importances]): ax = importances.plot.box(vert=False, whis=10) ax.set_title(f"Permutation Importances ({name} set)") ax.set_xlabel("Decrease in accuracy score") ax.axvline(x=0, color="k", linestyle="--") ax.figure.tight_layout() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_004.png :alt: Permutation Importances (train set) :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_004.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_005.png :alt: Permutation Importances (test set) :srcset: /auto_examples/inspection/images/sphx_glr_plot_permutation_importance_005.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 242-246 Now, we can observe that on both sets, the `random_num` and `random_cat` features have a lower importance compared to the overfitting random forest. However, the conclusions regarding the importance of the other features are still valid. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 5.424 seconds) .. _sphx_glr_download_auto_examples_inspection_plot_permutation_importance.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/inspection/plot_permutation_importance.ipynb :alt: Launch binder :width: 150 px .. container:: lite-badge .. image:: images/jupyterlite_badge_logo.svg :target: ../../lite/lab/index.html?path=auto_examples/inspection/plot_permutation_importance.ipynb :alt: Launch JupyterLite :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_permutation_importance.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_permutation_importance.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_permutation_importance.zip ` .. include:: plot_permutation_importance.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_