.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/compose/plot_transformed_target.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_compose_plot_transformed_target.py>`
        to download the full example code or to run this example in your browser via JupyterLite or Binder

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_compose_plot_transformed_target.py:


======================================================
Effect of transforming the targets in regression model
======================================================

In this example, we give an overview of
:class:`~sklearn.compose.TransformedTargetRegressor`. We use two examples
to illustrate the benefit of transforming the targets before learning a linear
regression model. The first example uses synthetic data while the second
example is based on the Ames housing data set.

.. GENERATED FROM PYTHON SOURCE LINES 13-19

.. code-block:: Python


    # Author: Guillaume Lemaitre <guillaume.lemaitre@inria.fr>
    # License: BSD 3 clause

    print(__doc__)








.. GENERATED FROM PYTHON SOURCE LINES 20-34

Synthetic example
##################

 A synthetic random regression dataset is generated. The targets ``y`` are
 modified by:

   1. translating all targets such that all entries are
      non-negative (by adding the absolute value of the lowest ``y``) and
   2. applying an exponential function to obtain non-linear
      targets which cannot be fitted using a simple linear model.

 Therefore, a logarithmic (`np.log1p`) and an exponential function
 (`np.expm1`) will be used to transform the targets before training a linear
 regression model and using it for prediction.

.. GENERATED FROM PYTHON SOURCE LINES 34-42

.. code-block:: Python

    import numpy as np

    from sklearn.datasets import make_regression

    X, y = make_regression(n_samples=10_000, noise=100, random_state=0)
    y = np.expm1((y + abs(y.min())) / 200)
    y_trans = np.log1p(y)








.. GENERATED FROM PYTHON SOURCE LINES 43-45

Below we plot the probability density functions of the target
before and after applying the logarithmic functions.

.. GENERATED FROM PYTHON SOURCE LINES 45-67

.. code-block:: Python

    import matplotlib.pyplot as plt

    from sklearn.model_selection import train_test_split

    f, (ax0, ax1) = plt.subplots(1, 2)

    ax0.hist(y, bins=100, density=True)
    ax0.set_xlim([0, 2000])
    ax0.set_ylabel("Probability")
    ax0.set_xlabel("Target")
    ax0.set_title("Target distribution")

    ax1.hist(y_trans, bins=100, density=True)
    ax1.set_ylabel("Probability")
    ax1.set_xlabel("Target")
    ax1.set_title("Transformed target distribution")

    f.suptitle("Synthetic data", y=1.05)
    plt.tight_layout()

    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)




.. image-sg:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_001.png
   :alt: Synthetic data, Target distribution, Transformed target distribution
   :srcset: /auto_examples/compose/images/sphx_glr_plot_transformed_target_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 68-73

At first, a linear model will be applied on the original targets. Due to the
non-linearity, the model trained will not be precise during
prediction. Subsequently, a logarithmic function is used to linearize the
targets, allowing better prediction even with a similar linear model as
reported by the median absolute error (MedAE).

.. GENERATED FROM PYTHON SOURCE LINES 73-83

.. code-block:: Python

    from sklearn.metrics import median_absolute_error, r2_score


    def compute_score(y_true, y_pred):
        return {
            "R2": f"{r2_score(y_true, y_pred):.3f}",
            "MedAE": f"{median_absolute_error(y_true, y_pred):.3f}",
        }









.. GENERATED FROM PYTHON SOURCE LINES 84-124

.. code-block:: Python

    from sklearn.compose import TransformedTargetRegressor
    from sklearn.linear_model import RidgeCV
    from sklearn.metrics import PredictionErrorDisplay

    f, (ax0, ax1) = plt.subplots(1, 2, sharey=True)

    ridge_cv = RidgeCV().fit(X_train, y_train)
    y_pred_ridge = ridge_cv.predict(X_test)

    ridge_cv_with_trans_target = TransformedTargetRegressor(
        regressor=RidgeCV(), func=np.log1p, inverse_func=np.expm1
    ).fit(X_train, y_train)
    y_pred_ridge_with_trans_target = ridge_cv_with_trans_target.predict(X_test)

    PredictionErrorDisplay.from_predictions(
        y_test,
        y_pred_ridge,
        kind="actual_vs_predicted",
        ax=ax0,
        scatter_kwargs={"alpha": 0.5},
    )
    PredictionErrorDisplay.from_predictions(
        y_test,
        y_pred_ridge_with_trans_target,
        kind="actual_vs_predicted",
        ax=ax1,
        scatter_kwargs={"alpha": 0.5},
    )

    # Add the score in the legend of each axis
    for ax, y_pred in zip([ax0, ax1], [y_pred_ridge, y_pred_ridge_with_trans_target]):
        for name, score in compute_score(y_test, y_pred).items():
            ax.plot([], [], " ", label=f"{name}={score}")
        ax.legend(loc="upper left")

    ax0.set_title("Ridge regression \n without target transformation")
    ax1.set_title("Ridge regression \n with target transformation")
    f.suptitle("Synthetic data", y=1.05)
    plt.tight_layout()




.. image-sg:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_002.png
   :alt: Synthetic data, Ridge regression   without target transformation, Ridge regression   with target transformation
   :srcset: /auto_examples/compose/images/sphx_glr_plot_transformed_target_002.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 125-131

Real-world data set
####################

 In a similar manner, the Ames housing data set is used to show the impact
 of transforming the targets before learning a model. In this example, the
 target to be predicted is the selling price of each house.

.. GENERATED FROM PYTHON SOURCE LINES 131-145

.. code-block:: Python

    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import quantile_transform

    ames = fetch_openml(name="house_prices", as_frame=True)
    # Keep only numeric columns
    X = ames.data.select_dtypes(np.number)
    # Remove columns with NaN or Inf values
    X = X.drop(columns=["LotFrontage", "GarageYrBlt", "MasVnrArea"])
    # Let the price be in k$
    y = ames.target / 1000
    y_trans = quantile_transform(
        y.to_frame(), n_quantiles=900, output_distribution="normal", copy=True
    ).squeeze()








.. GENERATED FROM PYTHON SOURCE LINES 146-149

A :class:`~sklearn.preprocessing.QuantileTransformer` is used to normalize
the target distribution before applying a
:class:`~sklearn.linear_model.RidgeCV` model.

.. GENERATED FROM PYTHON SOURCE LINES 149-164

.. code-block:: Python

    f, (ax0, ax1) = plt.subplots(1, 2)

    ax0.hist(y, bins=100, density=True)
    ax0.set_ylabel("Probability")
    ax0.set_xlabel("Target")
    ax0.set_title("Target distribution")

    ax1.hist(y_trans, bins=100, density=True)
    ax1.set_ylabel("Probability")
    ax1.set_xlabel("Target")
    ax1.set_title("Transformed target distribution")

    f.suptitle("Ames housing data: selling price", y=1.05)
    plt.tight_layout()




.. image-sg:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_003.png
   :alt: Ames housing data: selling price, Target distribution, Transformed target distribution
   :srcset: /auto_examples/compose/images/sphx_glr_plot_transformed_target_003.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 165-167

.. code-block:: Python

    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)








.. GENERATED FROM PYTHON SOURCE LINES 168-175

The effect of the transformer is weaker than on the synthetic data. However,
the transformation results in an increase in :math:`R^2` and large decrease
of the MedAE. The residual plot (predicted target - true target vs predicted
target) without target transformation takes on a curved, 'reverse smile'
shape due to residual values that vary depending on the value of predicted
target. With target transformation, the shape is more linear indicating
better model fit.

.. GENERATED FROM PYTHON SOURCE LINES 175-234

.. code-block:: Python

    from sklearn.preprocessing import QuantileTransformer

    f, (ax0, ax1) = plt.subplots(2, 2, sharey="row", figsize=(6.5, 8))

    ridge_cv = RidgeCV().fit(X_train, y_train)
    y_pred_ridge = ridge_cv.predict(X_test)

    ridge_cv_with_trans_target = TransformedTargetRegressor(
        regressor=RidgeCV(),
        transformer=QuantileTransformer(n_quantiles=900, output_distribution="normal"),
    ).fit(X_train, y_train)
    y_pred_ridge_with_trans_target = ridge_cv_with_trans_target.predict(X_test)

    # plot the actual vs predicted values
    PredictionErrorDisplay.from_predictions(
        y_test,
        y_pred_ridge,
        kind="actual_vs_predicted",
        ax=ax0[0],
        scatter_kwargs={"alpha": 0.5},
    )
    PredictionErrorDisplay.from_predictions(
        y_test,
        y_pred_ridge_with_trans_target,
        kind="actual_vs_predicted",
        ax=ax0[1],
        scatter_kwargs={"alpha": 0.5},
    )

    # Add the score in the legend of each axis
    for ax, y_pred in zip([ax0[0], ax0[1]], [y_pred_ridge, y_pred_ridge_with_trans_target]):
        for name, score in compute_score(y_test, y_pred).items():
            ax.plot([], [], " ", label=f"{name}={score}")
        ax.legend(loc="upper left")

    ax0[0].set_title("Ridge regression \n without target transformation")
    ax0[1].set_title("Ridge regression \n with target transformation")

    # plot the residuals vs the predicted values
    PredictionErrorDisplay.from_predictions(
        y_test,
        y_pred_ridge,
        kind="residual_vs_predicted",
        ax=ax1[0],
        scatter_kwargs={"alpha": 0.5},
    )
    PredictionErrorDisplay.from_predictions(
        y_test,
        y_pred_ridge_with_trans_target,
        kind="residual_vs_predicted",
        ax=ax1[1],
        scatter_kwargs={"alpha": 0.5},
    )
    ax1[0].set_title("Ridge regression \n without target transformation")
    ax1[1].set_title("Ridge regression \n with target transformation")

    f.suptitle("Ames housing data: selling price", y=1.05)
    plt.tight_layout()
    plt.show()



.. image-sg:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_004.png
   :alt: Ames housing data: selling price, Ridge regression   without target transformation, Ridge regression   with target transformation, Ridge regression   without target transformation, Ridge regression   with target transformation
   :srcset: /auto_examples/compose/images/sphx_glr_plot_transformed_target_004.png
   :class: sphx-glr-single-img






.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 1.402 seconds)


.. _sphx_glr_download_auto_examples_compose_plot_transformed_target.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: binder-badge

      .. image:: images/binder_badge_logo.svg
        :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.4.X?urlpath=lab/tree/notebooks/auto_examples/compose/plot_transformed_target.ipynb
        :alt: Launch binder
        :width: 150 px

    .. container:: lite-badge

      .. image:: images/jupyterlite_badge_logo.svg
        :target: ../../lite/lab/?path=auto_examples/compose/plot_transformed_target.ipynb
        :alt: Launch JupyterLite
        :width: 150 px

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_transformed_target.ipynb <plot_transformed_target.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_transformed_target.py <plot_transformed_target.py>`


.. include:: plot_transformed_target.recommendations


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_