.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/model_selection/plot_successive_halving_heatmap.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_model_selection_plot_successive_halving_heatmap.py>`
        to download the full example code or to run this example in your browser via JupyterLite or Binder

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_model_selection_plot_successive_halving_heatmap.py:


Comparison between grid search and successive halving
=====================================================

This example compares the parameter search performed by
:class:`~sklearn.model_selection.HalvingGridSearchCV` and
:class:`~sklearn.model_selection.GridSearchCV`.

.. GENERATED FROM PYTHON SOURCE LINES 10-22

.. code-block:: default


    from time import time

    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd

    from sklearn import datasets
    from sklearn.experimental import enable_halving_search_cv  # noqa
    from sklearn.model_selection import GridSearchCV, HalvingGridSearchCV
    from sklearn.svm import SVC








.. GENERATED FROM PYTHON SOURCE LINES 23-27

We first define the parameter space for an :class:`~sklearn.svm.SVC`
estimator, and compute the time required to train a
:class:`~sklearn.model_selection.HalvingGridSearchCV` instance, as well as a
:class:`~sklearn.model_selection.GridSearchCV` instance.

.. GENERATED FROM PYTHON SOURCE LINES 27-49

.. code-block:: default


    rng = np.random.RandomState(0)
    X, y = datasets.make_classification(n_samples=1000, random_state=rng)

    gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
    Cs = [1, 10, 100, 1e3, 1e4, 1e5]
    param_grid = {"gamma": gammas, "C": Cs}

    clf = SVC(random_state=rng)

    tic = time()
    gsh = HalvingGridSearchCV(
        estimator=clf, param_grid=param_grid, factor=2, random_state=rng
    )
    gsh.fit(X, y)
    gsh_time = time() - tic

    tic = time()
    gs = GridSearchCV(estimator=clf, param_grid=param_grid)
    gs.fit(X, y)
    gs_time = time() - tic








.. GENERATED FROM PYTHON SOURCE LINES 50-51

We now plot heatmaps for both search estimators.

.. GENERATED FROM PYTHON SOURCE LINES 51-119

.. code-block:: default



    def make_heatmap(ax, gs, is_sh=False, make_cbar=False):
        """Helper to make a heatmap."""
        results = pd.DataFrame(gs.cv_results_)
        results[["param_C", "param_gamma"]] = results[["param_C", "param_gamma"]].astype(
            np.float64
        )
        if is_sh:
            # SH dataframe: get mean_test_score values for the highest iter
            scores_matrix = results.sort_values("iter").pivot_table(
                index="param_gamma",
                columns="param_C",
                values="mean_test_score",
                aggfunc="last",
            )
        else:
            scores_matrix = results.pivot(
                index="param_gamma", columns="param_C", values="mean_test_score"
            )

        im = ax.imshow(scores_matrix)

        ax.set_xticks(np.arange(len(Cs)))
        ax.set_xticklabels(["{:.0E}".format(x) for x in Cs])
        ax.set_xlabel("C", fontsize=15)

        ax.set_yticks(np.arange(len(gammas)))
        ax.set_yticklabels(["{:.0E}".format(x) for x in gammas])
        ax.set_ylabel("gamma", fontsize=15)

        # Rotate the tick labels and set their alignment.
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

        if is_sh:
            iterations = results.pivot_table(
                index="param_gamma", columns="param_C", values="iter", aggfunc="max"
            ).values
            for i in range(len(gammas)):
                for j in range(len(Cs)):
                    ax.text(
                        j,
                        i,
                        iterations[i, j],
                        ha="center",
                        va="center",
                        color="w",
                        fontsize=20,
                    )

        if make_cbar:
            fig.subplots_adjust(right=0.8)
            cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
            fig.colorbar(im, cax=cbar_ax)
            cbar_ax.set_ylabel("mean_test_score", rotation=-90, va="bottom", fontsize=15)


    fig, axes = plt.subplots(ncols=2, sharey=True)
    ax1, ax2 = axes

    make_heatmap(ax1, gsh, is_sh=True)
    make_heatmap(ax2, gs, make_cbar=True)

    ax1.set_title("Successive Halving\ntime = {:.3f}s".format(gsh_time), fontsize=15)
    ax2.set_title("GridSearch\ntime = {:.3f}s".format(gs_time), fontsize=15)

    plt.show()




.. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_successive_halving_heatmap_001.png
   :alt: Successive Halving time = 1.377s, GridSearch time = 5.596s
   :srcset: /auto_examples/model_selection/images/sphx_glr_plot_successive_halving_heatmap_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 120-130

The heatmaps show the mean test score of the parameter combinations for an
:class:`~sklearn.svm.SVC` instance. The
:class:`~sklearn.model_selection.HalvingGridSearchCV` also shows the
iteration at which the combinations where last used. The combinations marked
as ``0`` were only evaluated at the first iteration, while the ones with
``5`` are the parameter combinations that are considered the best ones.

We can see that the :class:`~sklearn.model_selection.HalvingGridSearchCV`
class is able to find parameter combinations that are just as accurate as
:class:`~sklearn.model_selection.GridSearchCV`, in much less time.


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 7.134 seconds)


.. _sphx_glr_download_auto_examples_model_selection_plot_successive_halving_heatmap.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example


    .. container:: binder-badge

      .. image:: images/binder_badge_logo.svg
        :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.3.X?urlpath=lab/tree/notebooks/auto_examples/model_selection/plot_successive_halving_heatmap.ipynb
        :alt: Launch binder
        :width: 150 px



    .. container:: lite-badge

      .. image:: images/jupyterlite_badge_logo.svg
        :target: ../../lite/lab/?path=auto_examples/model_selection/plot_successive_halving_heatmap.ipynb
        :alt: Launch JupyterLite
        :width: 150 px

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_successive_halving_heatmap.py <plot_successive_halving_heatmap.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_successive_halving_heatmap.ipynb <plot_successive_halving_heatmap.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_