.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/compose/plot_column_transformer.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code or to run this example in your browser via JupyterLite or Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_compose_plot_column_transformer.py: ================================================== Column Transformer with Heterogeneous Data Sources ================================================== Datasets can often contain components that require different feature extraction and processing pipelines. This scenario might occur when: 1. your dataset consists of heterogeneous data types (e.g. raster images and text captions), 2. your dataset is stored in a :class:`pandas.DataFrame` and different columns require different processing pipelines. This example demonstrates how to use :class:`~sklearn.compose.ColumnTransformer` on a dataset containing different types of features. The choice of features is not particularly helpful, but serves to illustrate the technique. .. GENERATED FROM PYTHON SOURCE LINES 20-37 .. code-block:: Python # Author: Matt Terry # # License: BSD 3 clause import numpy as np from sklearn.compose import ColumnTransformer from sklearn.datasets import fetch_20newsgroups from sklearn.decomposition import PCA from sklearn.feature_extraction import DictVectorizer from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics import classification_report from sklearn.pipeline import Pipeline from sklearn.preprocessing import FunctionTransformer from sklearn.svm import LinearSVC .. GENERATED FROM PYTHON SOURCE LINES 38-46 20 newsgroups dataset --------------------- We will use the :ref:`20 newsgroups dataset <20newsgroups_dataset>`, which comprises posts from newsgroups on 20 topics. This dataset is split into train and test subsets based on messages posted before and after a specific date. We will only use posts from 2 categories to speed up running time. .. GENERATED FROM PYTHON SOURCE LINES 46-63 .. code-block:: Python categories = ["sci.med", "sci.space"] X_train, y_train = fetch_20newsgroups( random_state=1, subset="train", categories=categories, remove=("footers", "quotes"), return_X_y=True, ) X_test, y_test = fetch_20newsgroups( random_state=1, subset="test", categories=categories, remove=("footers", "quotes"), return_X_y=True, ) .. GENERATED FROM PYTHON SOURCE LINES 64-66 Each feature comprises meta information about that post, such as the subject, and the body of the news post. .. GENERATED FROM PYTHON SOURCE LINES 66-69 .. code-block:: Python print(X_train[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none From: mccall@mksol.dseg.ti.com (fred j mccall 575-3539) Subject: Re: Metric vs English Article-I.D.: mksol.1993Apr6.131900.8407 Organization: Texas Instruments Inc Lines: 31 American, perhaps, but nothing military about it. I learned (mostly) slugs when we talked English units in high school physics and while the teacher was an ex-Navy fighter jock the book certainly wasn't produced by the military. [Poundals were just too flinking small and made the math come out funny; sort of the same reason proponents of SI give for using that.] -- "Insisting on perfect safety is for people who don't have the balls to live in the real world." -- Mary Shafer, NASA Ames Dryden .. GENERATED FROM PYTHON SOURCE LINES 70-79 Creating transformers --------------------- First, we would like a transformer that extracts the subject and body of each post. Since this is a stateless transformation (does not require state information from training data), we can define a function that performs the data transformation then use :class:`~sklearn.preprocessing.FunctionTransformer` to create a scikit-learn transformer. .. GENERATED FROM PYTHON SOURCE LINES 79-105 .. code-block:: Python def subject_body_extractor(posts): # construct object dtype array with two columns # first column = 'subject' and second column = 'body' features = np.empty(shape=(len(posts), 2), dtype=object) for i, text in enumerate(posts): # temporary variable `_` stores '\n\n' headers, _, body = text.partition("\n\n") # store body text in second column features[i, 1] = body prefix = "Subject:" sub = "" # save text after 'Subject:' in first column for line in headers.split("\n"): if line.startswith(prefix): sub = line[len(prefix) :] break features[i, 0] = sub return features subject_body_transformer = FunctionTransformer(subject_body_extractor) .. GENERATED FROM PYTHON SOURCE LINES 106-108 We will also create a transformer that extracts the length of the text and the number of sentences. .. GENERATED FROM PYTHON SOURCE LINES 108-116 .. code-block:: Python def text_stats(posts): return [{"length": len(text), "num_sentences": text.count(".")} for text in posts] text_stats_transformer = FunctionTransformer(text_stats) .. GENERATED FROM PYTHON SOURCE LINES 117-126 Classification pipeline ----------------------- The pipeline below extracts the subject and body from each post using ``SubjectBodyExtractor``, producing a (n_samples, 2) array. This array is then used to compute standard bag-of-words features for the subject and body as well as text length and number of sentences on the body, using ``ColumnTransformer``. We combine them, with weights, then train a classifier on the combined set of features. .. GENERATED FROM PYTHON SOURCE LINES 126-181 .. code-block:: Python pipeline = Pipeline( [ # Extract subject & body ("subjectbody", subject_body_transformer), # Use ColumnTransformer to combine the subject and body features ( "union", ColumnTransformer( [ # bag-of-words for subject (col 0) ("subject", TfidfVectorizer(min_df=50), 0), # bag-of-words with decomposition for body (col 1) ( "body_bow", Pipeline( [ ("tfidf", TfidfVectorizer()), ("best", PCA(n_components=50, svd_solver="arpack")), ] ), 1, ), # Pipeline for pulling text stats from post's body ( "body_stats", Pipeline( [ ( "stats", text_stats_transformer, ), # returns a list of dicts ( "vect", DictVectorizer(), ), # list of dicts -> feature matrix ] ), 1, ), ], # weight above ColumnTransformer features transformer_weights={ "subject": 0.8, "body_bow": 0.5, "body_stats": 1.0, }, ), ), # Use a SVC classifier on the combined features ("svc", LinearSVC(dual=False)), ], verbose=True, ) .. GENERATED FROM PYTHON SOURCE LINES 182-184 Finally, we fit our pipeline on the training data and use it to predict topics for ``X_test``. Performance metrics of our pipeline are then printed. .. GENERATED FROM PYTHON SOURCE LINES 184-188 .. code-block:: Python pipeline.fit(X_train, y_train) y_pred = pipeline.predict(X_test) print("Classification report:\n\n{}".format(classification_report(y_test, y_pred))) .. rst-class:: sphx-glr-script-out .. code-block:: none [Pipeline] ....... (step 1 of 3) Processing subjectbody, total= 0.0s [Pipeline] ............. (step 2 of 3) Processing union, total= 0.4s [Pipeline] ............... (step 3 of 3) Processing svc, total= 0.0s Classification report: precision recall f1-score support 0 0.84 0.87 0.86 396 1 0.87 0.84 0.85 394 accuracy 0.86 790 macro avg 0.86 0.86 0.86 790 weighted avg 0.86 0.86 0.86 790 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 2.717 seconds) .. _sphx_glr_download_auto_examples_compose_plot_column_transformer.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.4.X?urlpath=lab/tree/notebooks/auto_examples/compose/plot_column_transformer.ipynb :alt: Launch binder :width: 150 px .. container:: lite-badge .. image:: images/jupyterlite_badge_logo.svg :target: ../../lite/lab/?path=auto_examples/compose/plot_column_transformer.ipynb :alt: Launch JupyterLite :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_column_transformer.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_column_transformer.py ` .. include:: plot_column_transformer.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_