.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/ensemble/plot_gradient_boosting_quantile.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_ensemble_plot_gradient_boosting_quantile.py>`
        to download the full example code or to run this example in your browser via JupyterLite or Binder

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_ensemble_plot_gradient_boosting_quantile.py:


=====================================================
Prediction Intervals for Gradient Boosting Regression
=====================================================

This example shows how quantile regression can be used to create prediction
intervals.

.. GENERATED FROM PYTHON SOURCE LINES 12-14

Generate some data for a synthetic regression problem by applying the
function f to uniformly sampled random inputs.

.. GENERATED FROM PYTHON SOURCE LINES 14-28

.. code-block:: default

    import numpy as np

    from sklearn.model_selection import train_test_split


    def f(x):
        """The function to predict."""
        return x * np.sin(x)


    rng = np.random.RandomState(42)
    X = np.atleast_2d(rng.uniform(0, 10.0, size=1000)).T
    expected_y = f(X).ravel()








.. GENERATED FROM PYTHON SOURCE LINES 29-38

To make the problem interesting, we generate observations of the target y as
the sum of a deterministic term computed by the function f and a random noise
term that follows a centered `log-normal
<https://en.wikipedia.org/wiki/Log-normal_distribution>`_. To make this even
more interesting we consider the case where the amplitude of the noise
depends on the input variable x (heteroscedastic noise).

The lognormal distribution is non-symmetric and long tailed: observing large
outliers is likely but it is impossible to observe small outliers.

.. GENERATED FROM PYTHON SOURCE LINES 38-42

.. code-block:: default

    sigma = 0.5 + X.ravel() / 10
    noise = rng.lognormal(sigma=sigma) - np.exp(sigma**2 / 2)
    y = expected_y + noise








.. GENERATED FROM PYTHON SOURCE LINES 43-44

Split into train, test datasets:

.. GENERATED FROM PYTHON SOURCE LINES 44-46

.. code-block:: default

    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)








.. GENERATED FROM PYTHON SOURCE LINES 47-59

Fitting non-linear quantile and least squares regressors
--------------------------------------------------------

Fit gradient boosting models trained with the quantile loss and
alpha=0.05, 0.5, 0.95.

The models obtained for alpha=0.05 and alpha=0.95 produce a 90% confidence
interval (95% - 5% = 90%).

The model trained with alpha=0.5 produces a regression of the median: on
average, there should be the same number of target observations above and
below the predicted values.

.. GENERATED FROM PYTHON SOURCE LINES 59-74

.. code-block:: default

    from sklearn.ensemble import GradientBoostingRegressor
    from sklearn.metrics import mean_pinball_loss, mean_squared_error

    all_models = {}
    common_params = dict(
        learning_rate=0.05,
        n_estimators=200,
        max_depth=2,
        min_samples_leaf=9,
        min_samples_split=9,
    )
    for alpha in [0.05, 0.5, 0.95]:
        gbr = GradientBoostingRegressor(loss="quantile", alpha=alpha, **common_params)
        all_models["q %1.2f" % alpha] = gbr.fit(X_train, y_train)








.. GENERATED FROM PYTHON SOURCE LINES 75-82

Notice that :class:`~sklearn.ensemble.HistGradientBoostingRegressor` is much
faster than :class:`~sklearn.ensemble.GradientBoostingRegressor` starting with
intermediate datasets (`n_samples >= 10_000`), which is not the case of the
present example.

For the sake of comparison, we also fit a baseline model trained with the
usual (mean) squared error (MSE).

.. GENERATED FROM PYTHON SOURCE LINES 82-85

.. code-block:: default

    gbr_ls = GradientBoostingRegressor(loss="squared_error", **common_params)
    all_models["mse"] = gbr_ls.fit(X_train, y_train)








.. GENERATED FROM PYTHON SOURCE LINES 86-88

Create an evenly spaced evaluation set of input values spanning the [0, 10]
range.

.. GENERATED FROM PYTHON SOURCE LINES 88-90

.. code-block:: default

    xx = np.atleast_2d(np.linspace(0, 10, 1000)).T








.. GENERATED FROM PYTHON SOURCE LINES 91-94

Plot the true conditional mean function f, the predictions of the conditional
mean (loss equals squared error), the conditional median and the conditional
90% interval (from 5th to 95th conditional percentiles).

.. GENERATED FROM PYTHON SOURCE LINES 94-117

.. code-block:: default

    import matplotlib.pyplot as plt

    y_pred = all_models["mse"].predict(xx)
    y_lower = all_models["q 0.05"].predict(xx)
    y_upper = all_models["q 0.95"].predict(xx)
    y_med = all_models["q 0.50"].predict(xx)

    fig = plt.figure(figsize=(10, 10))
    plt.plot(xx, f(xx), "g:", linewidth=3, label=r"$f(x) = x\,\sin(x)$")
    plt.plot(X_test, y_test, "b.", markersize=10, label="Test observations")
    plt.plot(xx, y_med, "r-", label="Predicted median")
    plt.plot(xx, y_pred, "r-", label="Predicted mean")
    plt.plot(xx, y_upper, "k-")
    plt.plot(xx, y_lower, "k-")
    plt.fill_between(
        xx.ravel(), y_lower, y_upper, alpha=0.4, label="Predicted 90% interval"
    )
    plt.xlabel("$x$")
    plt.ylabel("$f(x)$")
    plt.ylim(-10, 25)
    plt.legend(loc="upper left")
    plt.show()




.. image-sg:: /auto_examples/ensemble/images/sphx_glr_plot_gradient_boosting_quantile_001.png
   :alt: plot gradient boosting quantile
   :srcset: /auto_examples/ensemble/images/sphx_glr_plot_gradient_boosting_quantile_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 118-133

Comparing the predicted median with the predicted mean, we note that the
median is on average below the mean as the noise is skewed towards high
values (large outliers). The median estimate also seems to be smoother
because of its natural robustness to outliers.

Also observe that the inductive bias of gradient boosting trees is
unfortunately preventing our 0.05 quantile to fully capture the sinoisoidal
shape of the signal, in particular around x=8. Tuning hyper-parameters can
reduce this effect as shown in the last part of this notebook.

Analysis of the error metrics
-----------------------------

Measure the models with :func:`~sklearn.metrics.mean_squared_error` and
:func:`~sklearn.metrics.mean_pinball_loss` metrics on the training dataset.

.. GENERATED FROM PYTHON SOURCE LINES 133-152

.. code-block:: default

    import pandas as pd


    def highlight_min(x):
        x_min = x.min()
        return ["font-weight: bold" if v == x_min else "" for v in x]


    results = []
    for name, gbr in sorted(all_models.items()):
        metrics = {"model": name}
        y_pred = gbr.predict(X_train)
        for alpha in [0.05, 0.5, 0.95]:
            metrics["pbl=%1.2f" % alpha] = mean_pinball_loss(y_train, y_pred, alpha=alpha)
        metrics["MSE"] = mean_squared_error(y_train, y_pred)
        results.append(metrics)

    pd.DataFrame(results).set_index("model").style.apply(highlight_min)






.. raw:: html

    <div class="output_subarea output_html rendered_html output_result">
    <style type="text/css">
    #T_97857_row0_col3, #T_97857_row1_col0, #T_97857_row2_col1, #T_97857_row3_col2 {
      font-weight: bold;
    }
    </style>
    <table id="T_97857">
      <thead>
        <tr>
          <th class="blank level0" >&nbsp;</th>
          <th id="T_97857_level0_col0" class="col_heading level0 col0" >pbl=0.05</th>
          <th id="T_97857_level0_col1" class="col_heading level0 col1" >pbl=0.50</th>
          <th id="T_97857_level0_col2" class="col_heading level0 col2" >pbl=0.95</th>
          <th id="T_97857_level0_col3" class="col_heading level0 col3" >MSE</th>
        </tr>
        <tr>
          <th class="index_name level0" >model</th>
          <th class="blank col0" >&nbsp;</th>
          <th class="blank col1" >&nbsp;</th>
          <th class="blank col2" >&nbsp;</th>
          <th class="blank col3" >&nbsp;</th>
        </tr>
      </thead>
      <tbody>
        <tr>
          <th id="T_97857_level0_row0" class="row_heading level0 row0" >mse</th>
          <td id="T_97857_row0_col0" class="data row0 col0" >0.715413</td>
          <td id="T_97857_row0_col1" class="data row0 col1" >0.715413</td>
          <td id="T_97857_row0_col2" class="data row0 col2" >0.715413</td>
          <td id="T_97857_row0_col3" class="data row0 col3" >7.750348</td>
        </tr>
        <tr>
          <th id="T_97857_level0_row1" class="row_heading level0 row1" >q 0.05</th>
          <td id="T_97857_row1_col0" class="data row1 col0" >0.127128</td>
          <td id="T_97857_row1_col1" class="data row1 col1" >1.253445</td>
          <td id="T_97857_row1_col2" class="data row1 col2" >2.379763</td>
          <td id="T_97857_row1_col3" class="data row1 col3" >18.933253</td>
        </tr>
        <tr>
          <th id="T_97857_level0_row2" class="row_heading level0 row2" >q 0.50</th>
          <td id="T_97857_row2_col0" class="data row2 col0" >0.305438</td>
          <td id="T_97857_row2_col1" class="data row2 col1" >0.622811</td>
          <td id="T_97857_row2_col2" class="data row2 col2" >0.940184</td>
          <td id="T_97857_row2_col3" class="data row2 col3" >9.827917</td>
        </tr>
        <tr>
          <th id="T_97857_level0_row3" class="row_heading level0 row3" >q 0.95</th>
          <td id="T_97857_row3_col0" class="data row3 col0" >3.909909</td>
          <td id="T_97857_row3_col1" class="data row3 col1" >2.145957</td>
          <td id="T_97857_row3_col2" class="data row3 col2" >0.382005</td>
          <td id="T_97857_row3_col3" class="data row3 col3" >28.667219</td>
        </tr>
      </tbody>
    </table>

    </div>
    <br />
    <br />

.. GENERATED FROM PYTHON SOURCE LINES 153-168

One column shows all models evaluated by the same metric. The minimum number
on a column should be obtained when the model is trained and measured with
the same metric. This should be always the case on the training set if the
training converged.

Note that because the target distribution is asymmetric, the expected
conditional mean and conditional median are significantly different and
therefore one could not use the squared error model get a good estimation of
the conditional median nor the converse.

If the target distribution were symmetric and had no outliers (e.g. with a
Gaussian noise), then median estimator and the least squares estimator would
have yielded similar predictions.

We then do the same on the test set.

.. GENERATED FROM PYTHON SOURCE LINES 168-180

.. code-block:: default

    results = []
    for name, gbr in sorted(all_models.items()):
        metrics = {"model": name}
        y_pred = gbr.predict(X_test)
        for alpha in [0.05, 0.5, 0.95]:
            metrics["pbl=%1.2f" % alpha] = mean_pinball_loss(y_test, y_pred, alpha=alpha)
        metrics["MSE"] = mean_squared_error(y_test, y_pred)
        results.append(metrics)

    pd.DataFrame(results).set_index("model").style.apply(highlight_min)







.. raw:: html

    <div class="output_subarea output_html rendered_html output_result">
    <style type="text/css">
    #T_81adb_row1_col0, #T_81adb_row2_col1, #T_81adb_row2_col3, #T_81adb_row3_col2 {
      font-weight: bold;
    }
    </style>
    <table id="T_81adb">
      <thead>
        <tr>
          <th class="blank level0" >&nbsp;</th>
          <th id="T_81adb_level0_col0" class="col_heading level0 col0" >pbl=0.05</th>
          <th id="T_81adb_level0_col1" class="col_heading level0 col1" >pbl=0.50</th>
          <th id="T_81adb_level0_col2" class="col_heading level0 col2" >pbl=0.95</th>
          <th id="T_81adb_level0_col3" class="col_heading level0 col3" >MSE</th>
        </tr>
        <tr>
          <th class="index_name level0" >model</th>
          <th class="blank col0" >&nbsp;</th>
          <th class="blank col1" >&nbsp;</th>
          <th class="blank col2" >&nbsp;</th>
          <th class="blank col3" >&nbsp;</th>
        </tr>
      </thead>
      <tbody>
        <tr>
          <th id="T_81adb_level0_row0" class="row_heading level0 row0" >mse</th>
          <td id="T_81adb_row0_col0" class="data row0 col0" >0.917281</td>
          <td id="T_81adb_row0_col1" class="data row0 col1" >0.767498</td>
          <td id="T_81adb_row0_col2" class="data row0 col2" >0.617715</td>
          <td id="T_81adb_row0_col3" class="data row0 col3" >6.692901</td>
        </tr>
        <tr>
          <th id="T_81adb_level0_row1" class="row_heading level0 row1" >q 0.05</th>
          <td id="T_81adb_row1_col0" class="data row1 col0" >0.144204</td>
          <td id="T_81adb_row1_col1" class="data row1 col1" >1.245961</td>
          <td id="T_81adb_row1_col2" class="data row1 col2" >2.347717</td>
          <td id="T_81adb_row1_col3" class="data row1 col3" >15.648026</td>
        </tr>
        <tr>
          <th id="T_81adb_level0_row2" class="row_heading level0 row2" >q 0.50</th>
          <td id="T_81adb_row2_col0" class="data row2 col0" >0.412021</td>
          <td id="T_81adb_row2_col1" class="data row2 col1" >0.607752</td>
          <td id="T_81adb_row2_col2" class="data row2 col2" >0.803483</td>
          <td id="T_81adb_row2_col3" class="data row2 col3" >5.874771</td>
        </tr>
        <tr>
          <th id="T_81adb_level0_row3" class="row_heading level0 row3" >q 0.95</th>
          <td id="T_81adb_row3_col0" class="data row3 col0" >4.354394</td>
          <td id="T_81adb_row3_col1" class="data row3 col1" >2.355445</td>
          <td id="T_81adb_row3_col2" class="data row3 col2" >0.356497</td>
          <td id="T_81adb_row3_col3" class="data row3 col3" >34.852774</td>
        </tr>
      </tbody>
    </table>

    </div>
    <br />
    <br />

.. GENERATED FROM PYTHON SOURCE LINES 181-201

Errors are higher meaning the models slightly overfitted the data. It still
shows that the best test metric is obtained when the model is trained by
minimizing this same metric.

Note that the conditional median estimator is competitive with the squared
error estimator in terms of MSE on the test set: this can be explained by
the fact the squared error estimator is very sensitive to large outliers
which can cause significant overfitting. This can be seen on the right hand
side of the previous plot. The conditional median estimator is biased
(underestimation for this asymmetric noise) but is also naturally robust to
outliers and overfits less.

Calibration of the confidence interval
--------------------------------------

We can also evaluate the ability of the two extreme quantile estimators at
producing a well-calibrated conditional 90%-confidence interval.

To do this we can compute the fraction of observations that fall between the
predictions:

.. GENERATED FROM PYTHON SOURCE LINES 201-211

.. code-block:: default

    def coverage_fraction(y, y_low, y_high):
        return np.mean(np.logical_and(y >= y_low, y <= y_high))


    coverage_fraction(
        y_train,
        all_models["q 0.05"].predict(X_train),
        all_models["q 0.95"].predict(X_train),
    )





.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    0.9



.. GENERATED FROM PYTHON SOURCE LINES 212-214

On the training set the calibration is very close to the expected coverage
value for a 90% confidence interval.

.. GENERATED FROM PYTHON SOURCE LINES 214-219

.. code-block:: default

    coverage_fraction(
        y_test, all_models["q 0.05"].predict(X_test), all_models["q 0.95"].predict(X_test)
    )






.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    0.868



.. GENERATED FROM PYTHON SOURCE LINES 220-237

On the test set, the estimated confidence interval is slightly too narrow.
Note, however, that we would need to wrap those metrics in a cross-validation
loop to assess their variability under data resampling.

Tuning the hyper-parameters of the quantile regressors
------------------------------------------------------

In the plot above, we observed that the 5th percentile regressor seems to
underfit and could not adapt to sinusoidal shape of the signal.

The hyper-parameters of the model were approximately hand-tuned for the
median regressor and there is no reason that the same hyper-parameters are
suitable for the 5th percentile regressor.

To confirm this hypothesis, we tune the hyper-parameters of a new regressor
of the 5th percentile by selecting the best model parameters by
cross-validation on the pinball loss with alpha=0.05:

.. GENERATED FROM PYTHON SOURCE LINES 239-269

.. code-block:: default

    from sklearn.experimental import enable_halving_search_cv  # noqa
    from sklearn.model_selection import HalvingRandomSearchCV
    from sklearn.metrics import make_scorer
    from pprint import pprint

    param_grid = dict(
        learning_rate=[0.05, 0.1, 0.2],
        max_depth=[2, 5, 10],
        min_samples_leaf=[1, 5, 10, 20],
        min_samples_split=[5, 10, 20, 30, 50],
    )
    alpha = 0.05
    neg_mean_pinball_loss_05p_scorer = make_scorer(
        mean_pinball_loss,
        alpha=alpha,
        greater_is_better=False,  # maximize the negative loss
    )
    gbr = GradientBoostingRegressor(loss="quantile", alpha=alpha, random_state=0)
    search_05p = HalvingRandomSearchCV(
        gbr,
        param_grid,
        resource="n_estimators",
        max_resources=250,
        min_resources=50,
        scoring=neg_mean_pinball_loss_05p_scorer,
        n_jobs=2,
        random_state=0,
    ).fit(X_train, y_train)
    pprint(search_05p.best_params_)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    {'learning_rate': 0.2,
     'max_depth': 2,
     'min_samples_leaf': 20,
     'min_samples_split': 10,
     'n_estimators': 150}




.. GENERATED FROM PYTHON SOURCE LINES 270-278

We observe that the hyper-parameters that were hand-tuned for the median
regressor are in the same range as the hyper-parameters suitable for the 5th
percentile regressor.

Let's now tune the hyper-parameters for the 95th percentile regressor. We
need to redefine the `scoring` metric used to select the best model, along
with adjusting the alpha parameter of the inner gradient boosting estimator
itself:

.. GENERATED FROM PYTHON SOURCE LINES 278-293

.. code-block:: default

    from sklearn.base import clone

    alpha = 0.95
    neg_mean_pinball_loss_95p_scorer = make_scorer(
        mean_pinball_loss,
        alpha=alpha,
        greater_is_better=False,  # maximize the negative loss
    )
    search_95p = clone(search_05p).set_params(
        estimator__alpha=alpha,
        scoring=neg_mean_pinball_loss_95p_scorer,
    )
    search_95p.fit(X_train, y_train)
    pprint(search_95p.best_params_)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    {'learning_rate': 0.05,
     'max_depth': 2,
     'min_samples_leaf': 5,
     'min_samples_split': 20,
     'n_estimators': 150}




.. GENERATED FROM PYTHON SOURCE LINES 294-302

The result shows that the hyper-parameters for the 95th percentile regressor
identified by the search procedure are roughly in the same range as the hand-
tuned hyper-parameters for the median regressor and the hyper-parameters
identified by the search procedure for the 5th percentile regressor. However,
the hyper-parameter searches did lead to an improved 90% confidence interval
that is comprised by the predictions of those two tuned quantile regressors.
Note that the prediction of the upper 95th percentile has a much coarser shape
than the prediction of the lower 5th percentile because of the outliers:

.. GENERATED FROM PYTHON SOURCE LINES 302-320

.. code-block:: default

    y_lower = search_05p.predict(xx)
    y_upper = search_95p.predict(xx)

    fig = plt.figure(figsize=(10, 10))
    plt.plot(xx, f(xx), "g:", linewidth=3, label=r"$f(x) = x\,\sin(x)$")
    plt.plot(X_test, y_test, "b.", markersize=10, label="Test observations")
    plt.plot(xx, y_upper, "k-")
    plt.plot(xx, y_lower, "k-")
    plt.fill_between(
        xx.ravel(), y_lower, y_upper, alpha=0.4, label="Predicted 90% interval"
    )
    plt.xlabel("$x$")
    plt.ylabel("$f(x)$")
    plt.ylim(-10, 25)
    plt.legend(loc="upper left")
    plt.title("Prediction with tuned hyper-parameters")
    plt.show()




.. image-sg:: /auto_examples/ensemble/images/sphx_glr_plot_gradient_boosting_quantile_002.png
   :alt: Prediction with tuned hyper-parameters
   :srcset: /auto_examples/ensemble/images/sphx_glr_plot_gradient_boosting_quantile_002.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 321-326

The plot looks qualitatively better than for the untuned models, especially
for the shape of the of lower quantile.

We now quantitatively evaluate the joint-calibration of the pair of
estimators:

.. GENERATED FROM PYTHON SOURCE LINES 326-327

.. code-block:: default

    coverage_fraction(y_train, search_05p.predict(X_train), search_95p.predict(X_train))




.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    0.9026666666666666



.. GENERATED FROM PYTHON SOURCE LINES 328-329

.. code-block:: default

    coverage_fraction(y_test, search_05p.predict(X_test), search_95p.predict(X_test))




.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    0.796



.. GENERATED FROM PYTHON SOURCE LINES 330-335

The calibration of the tuned pair is sadly not better on the test set: the
width of the estimated confidence interval is still too narrow.

Again, we would need to wrap this study in a cross-validation loop to
better assess the variability of those estimates.


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 9.230 seconds)


.. _sphx_glr_download_auto_examples_ensemble_plot_gradient_boosting_quantile.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example


    .. container:: binder-badge

      .. image:: images/binder_badge_logo.svg
        :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.3.X?urlpath=lab/tree/notebooks/auto_examples/ensemble/plot_gradient_boosting_quantile.ipynb
        :alt: Launch binder
        :width: 150 px



    .. container:: lite-badge

      .. image:: images/jupyterlite_badge_logo.svg
        :target: ../../lite/lab/?path=auto_examples/ensemble/plot_gradient_boosting_quantile.ipynb
        :alt: Launch JupyterLite
        :width: 150 px

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_gradient_boosting_quantile.py <plot_gradient_boosting_quantile.py>`

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_gradient_boosting_quantile.ipynb <plot_gradient_boosting_quantile.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_