.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/model_selection/plot_grid_search_digits.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via JupyterLite or Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_model_selection_plot_grid_search_digits.py: ============================================================ Custom refit strategy of a grid search with cross-validation ============================================================ This examples shows how a classifier is optimized by cross-validation, which is done using the :class:`~sklearn.model_selection.GridSearchCV` object on a development set that comprises only half of the available labeled data. The performance of the selected hyper-parameters and trained model is then measured on a dedicated evaluation set that was not used during the model selection step. More details on tools available for model selection can be found in the sections on :ref:`cross_validation` and :ref:`grid_search`. .. GENERATED FROM PYTHON SOURCE LINES 17-21 .. code-block:: Python # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause .. GENERATED FROM PYTHON SOURCE LINES 22-29 The dataset ----------- We will work with the `digits` dataset. The goal is to classify handwritten digits images. We transform the problem into a binary classification for easier understanding: the goal is to identify whether a digit is `8` or not. .. GENERATED FROM PYTHON SOURCE LINES 29-33 .. code-block:: Python from sklearn import datasets digits = datasets.load_digits() .. GENERATED FROM PYTHON SOURCE LINES 34-37 In order to train a classifier on images, we need to flatten them into vectors. Each image of 8 by 8 pixels needs to be transformed to a vector of 64 pixels. Thus, we will get a final data array of shape `(n_images, n_pixels)`. .. GENERATED FROM PYTHON SOURCE LINES 37-44 .. code-block:: Python n_samples = len(digits.images) X = digits.images.reshape((n_samples, -1)) y = digits.target == 8 print( f"The number of images is {X.shape[0]} and each image contains {X.shape[1]} pixels" ) .. rst-class:: sphx-glr-script-out .. code-block:: none The number of images is 1797 and each image contains 64 pixels .. GENERATED FROM PYTHON SOURCE LINES 45-47 As presented in the introduction, the data will be split into a training and a testing set of equal size. .. GENERATED FROM PYTHON SOURCE LINES 47-51 .. code-block:: Python from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0) .. GENERATED FROM PYTHON SOURCE LINES 52-58 Define our grid-search strategy ------------------------------- We will select a classifier by searching the best hyper-parameters on folds of the training set. To do this, we need to define the scores to select the best candidate. .. GENERATED FROM PYTHON SOURCE LINES 58-61 .. code-block:: Python scores = ["precision", "recall"] .. GENERATED FROM PYTHON SOURCE LINES 62-73 We can also define a function to be passed to the `refit` parameter of the :class:`~sklearn.model_selection.GridSearchCV` instance. It will implement the custom strategy to select the best candidate from the `cv_results_` attribute of the :class:`~sklearn.model_selection.GridSearchCV`. Once the candidate is selected, it is automatically refitted by the :class:`~sklearn.model_selection.GridSearchCV` instance. Here, the strategy is to short-list the models which are the best in terms of precision and recall. From the selected models, we finally select the fastest model at predicting. Notice that these custom choices are completely arbitrary. .. GENERATED FROM PYTHON SOURCE LINES 73-170 .. code-block:: Python import pandas as pd def print_dataframe(filtered_cv_results): """Pretty print for filtered dataframe""" for mean_precision, std_precision, mean_recall, std_recall, params in zip( filtered_cv_results["mean_test_precision"], filtered_cv_results["std_test_precision"], filtered_cv_results["mean_test_recall"], filtered_cv_results["std_test_recall"], filtered_cv_results["params"], ): print( f"precision: {mean_precision:0.3f} (±{std_precision:0.03f})," f" recall: {mean_recall:0.3f} (±{std_recall:0.03f})," f" for {params}" ) print() def refit_strategy(cv_results): """Define the strategy to select the best estimator. The strategy defined here is to filter-out all results below a precision threshold of 0.98, rank the remaining by recall and keep all models with one standard deviation of the best by recall. Once these models are selected, we can select the fastest model to predict. Parameters ---------- cv_results : dict of numpy (masked) ndarrays CV results as returned by the `GridSearchCV`. Returns ------- best_index : int The index of the best estimator as it appears in `cv_results`. """ # print the info about the grid-search for the different scores precision_threshold = 0.98 cv_results_ = pd.DataFrame(cv_results) print("All grid-search results:") print_dataframe(cv_results_) # Filter-out all results below the threshold high_precision_cv_results = cv_results_[ cv_results_["mean_test_precision"] > precision_threshold ] print(f"Models with a precision higher than {precision_threshold}:") print_dataframe(high_precision_cv_results) high_precision_cv_results = high_precision_cv_results[ [ "mean_score_time", "mean_test_recall", "std_test_recall", "mean_test_precision", "std_test_precision", "rank_test_recall", "rank_test_precision", "params", ] ] # Select the most performant models in terms of recall # (within 1 sigma from the best) best_recall_std = high_precision_cv_results["mean_test_recall"].std() best_recall = high_precision_cv_results["mean_test_recall"].max() best_recall_threshold = best_recall - best_recall_std high_recall_cv_results = high_precision_cv_results[ high_precision_cv_results["mean_test_recall"] > best_recall_threshold ] print( "Out of the previously selected high precision models, we keep all the\n" "the models within one standard deviation of the highest recall model:" ) print_dataframe(high_recall_cv_results) # From the best candidates, select the fastest model to predict fastest_top_recall_high_precision_index = high_recall_cv_results[ "mean_score_time" ].idxmin() print( "\nThe selected final model is the fastest to predict out of the previously\n" "selected subset of best models based on precision and recall.\n" "Its scoring time is:\n\n" f"{high_recall_cv_results.loc[fastest_top_recall_high_precision_index]}" ) return fastest_top_recall_high_precision_index .. GENERATED FROM PYTHON SOURCE LINES 171-176 Tuning hyper-parameters ----------------------- Once we defined our strategy to select the best model, we define the values of the hyper-parameters and create the grid-search instance: .. GENERATED FROM PYTHON SOURCE LINES 177-190 .. code-block:: Python from sklearn.model_selection import GridSearchCV from sklearn.svm import SVC tuned_parameters = [ {"kernel": ["rbf"], "gamma": [1e-3, 1e-4], "C": [1, 10, 100, 1000]}, {"kernel": ["linear"], "C": [1, 10, 100, 1000]}, ] grid_search = GridSearchCV( SVC(), tuned_parameters, scoring=scores, refit=refit_strategy ) grid_search.fit(X_train, y_train) .. rst-class:: sphx-glr-script-out .. code-block:: none All grid-search results: precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.257 (±0.061), for {'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'} precision: 0.968 (±0.039), recall: 0.780 (±0.083), for {'C': 10, 'gamma': 0.0001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'} precision: 0.905 (±0.058), recall: 0.889 (±0.074), for {'C': 100, 'gamma': 0.0001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'} precision: 0.904 (±0.058), recall: 0.890 (±0.073), for {'C': 1000, 'gamma': 0.0001, 'kernel': 'rbf'} precision: 0.695 (±0.073), recall: 0.743 (±0.065), for {'C': 1, 'kernel': 'linear'} precision: 0.643 (±0.066), recall: 0.757 (±0.066), for {'C': 10, 'kernel': 'linear'} precision: 0.611 (±0.028), recall: 0.744 (±0.044), for {'C': 100, 'kernel': 'linear'} precision: 0.618 (±0.039), recall: 0.744 (±0.044), for {'C': 1000, 'kernel': 'linear'} Models with a precision higher than 0.98: precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.257 (±0.061), for {'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'} Out of the previously selected high precision models, we keep all the the models within one standard deviation of the highest recall model: precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'} The selected final model is the fastest to predict out of the previously selected subset of best models based on precision and recall. Its scoring time is: mean_score_time 0.005403 mean_test_recall 0.877206 std_test_recall 0.069196 mean_test_precision 1.0 std_test_precision 0.0 rank_test_recall 3 rank_test_precision 1 params {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'} Name: 6, dtype: object .. raw:: html
GridSearchCV(estimator=SVC(),
                 param_grid=[{'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001],
                              'kernel': ['rbf']},
                             {'C': [1, 10, 100, 1000], 'kernel': ['linear']}],
                 refit=<function refit_strategy at 0x7fdec2fd7670>,
                 scoring=['precision', 'recall'])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 191-192 The parameters selected by the grid-search with our custom strategy are: .. GENERATED FROM PYTHON SOURCE LINES 193-195 .. code-block:: Python grid_search.best_params_ .. rst-class:: sphx-glr-script-out .. code-block:: none {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'} .. GENERATED FROM PYTHON SOURCE LINES 196-202 Finally, we evaluate the fine-tuned model on the left-out evaluation set: the `grid_search` object **has automatically been refit** on the full training set with the parameters selected by our custom refit strategy. We can use the classification report to compute standard classification metrics on the left-out set: .. GENERATED FROM PYTHON SOURCE LINES 203-208 .. code-block:: Python from sklearn.metrics import classification_report y_pred = grid_search.predict(X_test) print(classification_report(y_test, y_pred)) .. rst-class:: sphx-glr-script-out .. code-block:: none precision recall f1-score support False 0.99 1.00 0.99 807 True 1.00 0.87 0.93 92 accuracy 0.99 899 macro avg 0.99 0.93 0.96 899 weighted avg 0.99 0.99 0.99 899 .. GENERATED FROM PYTHON SOURCE LINES 209-212 .. note:: The problem is too easy: the hyperparameter plateau is too flat and the output model is the same for precision and recall with ties in quality. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 9.787 seconds) .. _sphx_glr_download_auto_examples_model_selection_plot_grid_search_digits.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.6.X?urlpath=lab/tree/notebooks/auto_examples/model_selection/plot_grid_search_digits.ipynb :alt: Launch binder :width: 150 px .. container:: lite-badge .. image:: images/jupyterlite_badge_logo.svg :target: ../../lite/lab/index.html?path=auto_examples/model_selection/plot_grid_search_digits.ipynb :alt: Launch JupyterLite :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_grid_search_digits.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_grid_search_digits.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_grid_search_digits.zip ` .. include:: plot_grid_search_digits.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_