.. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_preprocessing_plot_function_transformer.py: ========================================================= Using FunctionTransformer to select columns ========================================================= Shows how to use a function transformer in a pipeline. If you know your dataset's first principle component is irrelevant for a classification task, you can use the FunctionTransformer to select all but the first column of the PCA transformed data. .. rst-class:: sphx-glr-horizontal * .. image:: /auto_examples/preprocessing/images/sphx_glr_plot_function_transformer_001.png :alt: plot function transformer :class: sphx-glr-multi-img * .. image:: /auto_examples/preprocessing/images/sphx_glr_plot_function_transformer_002.png :alt: plot function transformer :class: sphx-glr-multi-img .. code-block:: default import matplotlib.pyplot as plt import numpy as np from sklearn.model_selection import train_test_split from sklearn.decomposition import PCA from sklearn.pipeline import make_pipeline from sklearn.preprocessing import FunctionTransformer def _generate_vector(shift=0.5, noise=15): return np.arange(1000) + (np.random.rand(1000) - shift) * noise def generate_dataset(): """ This dataset is two lines with a slope ~ 1, where one has a y offset of ~100 """ return np.vstack(( np.vstack(( _generate_vector(), _generate_vector() + 100, )).T, np.vstack(( _generate_vector(), _generate_vector(), )).T, )), np.hstack((np.zeros(1000), np.ones(1000))) def all_but_first_column(X): return X[:, 1:] def drop_first_component(X, y): """ Create a pipeline with PCA and the column selector and use it to transform the dataset. """ pipeline = make_pipeline( PCA(), FunctionTransformer(all_but_first_column), ) X_train, X_test, y_train, y_test = train_test_split(X, y) pipeline.fit(X_train, y_train) return pipeline.transform(X_test), y_test if __name__ == '__main__': X, y = generate_dataset() lw = 0 plt.figure() plt.scatter(X[:, 0], X[:, 1], c=y, lw=lw) plt.figure() X_transformed, y_transformed = drop_first_component(*generate_dataset()) plt.scatter( X_transformed[:, 0], np.zeros(len(X_transformed)), c=y_transformed, lw=lw, s=60 ) plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.117 seconds) .. _sphx_glr_download_auto_examples_preprocessing_plot_function_transformer.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: binder-badge .. image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/0.23.X?urlpath=lab/tree/notebooks/auto_examples/preprocessing/plot_function_transformer.ipynb :width: 150 px .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_function_transformer.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_function_transformer.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_