.. _sphx_glr_auto_examples_miscellaneous_plot_set_output.py:

================================
Introducing the `set_output` API
================================

.. currentmodule:: sklearn

This example will demonstrate the `set_output` API to configure transformers to
output pandas DataFrames. `set_output` can be configured per estimator by calling
the `set_output` method or globally by setting
`set_config(transform_output="pandas")`. For details, see
`SLEP018 `__.

.. GENERATED FROM PYTHON SOURCE LINES 16-17

First, we load the iris dataset as a DataFrame to demonstrate the `set_output`
API.

.. GENERATED FROM PYTHON SOURCE LINES 17-24

.. code-block:: Python

    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split

    X, y = load_iris(as_frame=True, return_X_y=True)
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0)
    X_train.head()


.. raw:: html
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)
60 5.0 2.0 3.5 1.0
1 4.9 3.0 1.4 0.2
8 4.4 2.9 1.4 0.2
93 5.0 2.3 3.3 1.0
106 4.9 2.5 4.5 1.7

.. GENERATED FROM PYTHON SOURCE LINES 25-27 To configure an estimator such as :class:`preprocessing.StandardScaler` to return DataFrames, call `set_output`. This feature requires pandas to be installed. .. GENERATED FROM PYTHON SOURCE LINES 27-36 .. code-block:: Python from sklearn.preprocessing import StandardScaler scaler = StandardScaler().set_output(transform="pandas") scaler.fit(X_train) X_test_scaled = scaler.transform(X_test) X_test_scaled.head() .. raw:: html
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)
39 -0.894264 0.798301 -1.271411 -1.327605
12 -1.244466 -0.086944 -1.327407 -1.459074
48 -0.660797 1.462234 -1.271411 -1.327605
23 -0.894264 0.576989 -1.159419 -0.933197
81 -0.427329 -1.414810 -0.039497 -0.275851

.. GENERATED FROM PYTHON SOURCE LINES 37-38 `set_output` can be called after `fit` to configure `transform` after the fact. .. GENERATED FROM PYTHON SOURCE LINES 38-48 .. code-block:: Python scaler2 = StandardScaler() scaler2.fit(X_train) X_test_np = scaler2.transform(X_test) print(f"Default output type: {type(X_test_np).__name__}") scaler2.set_output(transform="pandas") X_test_df = scaler2.transform(X_test) print(f"Configured pandas output type: {type(X_test_df).__name__}") .. rst-class:: sphx-glr-script-out .. code-block:: none Default output type: ndarray Configured pandas output type: DataFrame .. GENERATED FROM PYTHON SOURCE LINES 49-51 In a :class:`pipeline.Pipeline`, `set_output` configures all steps to output DataFrames. .. GENERATED FROM PYTHON SOURCE LINES 51-61 .. code-block:: Python from sklearn.feature_selection import SelectPercentile from sklearn.linear_model import LogisticRegression from sklearn.pipeline import make_pipeline clf = make_pipeline( StandardScaler(), SelectPercentile(percentile=75), LogisticRegression() ) clf.set_output(transform="pandas") clf.fit(X_train, y_train) .. raw:: html
Pipeline(steps=[('standardscaler', StandardScaler()),
                    ('selectpercentile', SelectPercentile(percentile=75)),
                    ('logisticregression', LogisticRegression())])
.. GENERATED FROM PYTHON SOURCE LINES 62-64 Each transformer in the pipeline is configured to return DataFrames. This means that the final logistic regression step contains the feature names of the input. .. GENERATED FROM PYTHON SOURCE LINES 64-66 .. code-block:: Python clf[-1].feature_names_in_ .. rst-class:: sphx-glr-script-out .. code-block:: none array(['sepal length (cm)', 'petal length (cm)', 'petal width (cm)'], dtype=object) .. GENERATED FROM PYTHON SOURCE LINES 67-69 .. note:: If one uses the method `set_params`, the transformer will be replaced by a new one with the default output format. .. GENERATED FROM PYTHON SOURCE LINES 69-73 .. code-block:: Python clf.set_params(standardscaler=StandardScaler()) clf.fit(X_train, y_train) clf[-1].feature_names_in_ .. rst-class:: sphx-glr-script-out .. code-block:: none array(['x0', 'x2', 'x3'], dtype=object) .. GENERATED FROM PYTHON SOURCE LINES 74-76 To keep the intended behavior, use `set_output` on the new transformer beforehand .. GENERATED FROM PYTHON SOURCE LINES 76-81 .. code-block:: Python scaler = StandardScaler().set_output(transform="pandas") clf.set_params(standardscaler=scaler) clf.fit(X_train, y_train) clf[-1].feature_names_in_ .. rst-class:: sphx-glr-script-out .. code-block:: none array(['sepal length (cm)', 'petal length (cm)', 'petal width (cm)'], dtype=object) .. GENERATED FROM PYTHON SOURCE LINES 82-84 Next we load the titanic dataset to demonstrate `set_output` with :class:`compose.ColumnTransformer` and heterogeneous data. .. GENERATED FROM PYTHON SOURCE LINES 84-89 .. code-block:: Python from sklearn.datasets import fetch_openml X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y) .. GENERATED FROM PYTHON SOURCE LINES 90-92 The `set_output` API can be configured globally by using :func:`set_config` and setting `transform_output` to `"pandas"`. .. GENERATED FROM PYTHON SOURCE LINES 92-118 .. code-block:: Python from sklearn import set_config from sklearn.compose import ColumnTransformer from sklearn.impute import SimpleImputer from sklearn.preprocessing import OneHotEncoder, StandardScaler set_config(transform_output="pandas") num_pipe = make_pipeline(SimpleImputer(), StandardScaler()) num_cols = ["age", "fare"] ct = ColumnTransformer( ( ("numerical", num_pipe, num_cols), ( "categorical", OneHotEncoder( sparse_output=False, drop="if_binary", handle_unknown="ignore" ), ["embarked", "sex", "pclass"], ), ), verbose_feature_names_out=False, ) clf = make_pipeline(ct, SelectPercentile(percentile=50), LogisticRegression()) clf.fit(X_train, y_train) clf.score(X_test, y_test) .. rst-class:: sphx-glr-script-out .. code-block:: none 0.7621951219512195 .. GENERATED FROM PYTHON SOURCE LINES 119-121 With the global configuration, all transformers output DataFrames. This allows us to easily plot the logistic regression coefficients with the corresponding feature names. .. GENERATED FROM PYTHON SOURCE LINES 121-127 .. code-block:: Python import pandas as pd log_reg = clf[-1] coef = pd.Series(log_reg.coef_.ravel(), index=log_reg.feature_names_in_) _ = coef.sort_values().plot.barh() .. image-sg:: /auto_examples/miscellaneous/images/sphx_glr_plot_set_output_001.png :alt: plot set output :srcset: /auto_examples/miscellaneous/images/sphx_glr_plot_set_output_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 128-130 In order to demonstrate the :func:`config_context` functionality below, let us first reset `transform_output` to its default value. .. GENERATED FROM PYTHON SOURCE LINES 130-132 .. code-block:: Python set_config(transform_output="default") .. GENERATED FROM PYTHON SOURCE LINES 133-137 When configuring the output type with :func:`config_context` the configuration at the time when `transform` or `fit_transform` are called is what counts. Setting these only when you construct or fit the transformer has no effect. .. GENERATED FROM PYTHON SOURCE LINES 137-142 .. code-block:: Python from sklearn import config_context scaler = StandardScaler() scaler.fit(X_train[num_cols]) .. raw:: html
.. GENERATED FROM PYTHON SOURCE LINES 143-148 .. code-block:: Python with config_context(transform_output="pandas"): # the output of transform will be a Pandas DataFrame X_test_scaled = scaler.transform(X_test[num_cols]) X_test_scaled.head() .. raw:: html
age fare
1088 0.151101 -0.479229
1001 NaN -0.188153
660 -0.393297 -0.263234
657 -1.975455 -0.263234
285 2.532843 3.546068

.. GENERATED FROM PYTHON SOURCE LINES 149-150

outside of the context manager, the output will be a NumPy array

.. GENERATED FROM PYTHON SOURCE LINES 150-152

.. code-block:: Python

    X_test_scaled = scaler.transform(X_test[num_cols])
    X_test_scaled[:5]




.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    array([[ 0.1511007 , -0.47922861],
           [        nan, -0.18815268],
           [-0.39329747, -0.26323428],
           [-1.97545464, -0.26323428],
           [ 2.53284267,  3.54606834]])