.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/text/plot_document_classification_20newsgroups.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_text_plot_document_classification_20newsgroups.py: ====================================================== Classification of text documents using sparse features ====================================================== This is an example showing how scikit-learn can be used to classify documents by topics using a bag-of-words approach. This example uses a scipy.sparse matrix to store the features and demonstrates various classifiers that can efficiently handle sparse matrices. The dataset used in this example is the 20 newsgroups dataset. It will be automatically downloaded, then cached. .. GENERATED FROM PYTHON SOURCE LINES 15-23 .. code-block:: default # Author: Peter Prettenhofer # Olivier Grisel # Mathieu Blondel # Lars Buitinck # License: BSD 3 clause .. GENERATED FROM PYTHON SOURCE LINES 24-26 Configuration options for the analysis -------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 26-38 .. code-block:: default # If True, we use `HashingVectorizer`, otherwise we use a `TfidfVectorizer` USE_HASHING = False # Number of features used by `HashingVectorizer` N_FEATURES = 2**16 # Optional feature selection: either False, or an integer: the number of # features to select SELECT_CHI2 = False .. GENERATED FROM PYTHON SOURCE LINES 39-44 Load data from the training set ------------------------------------ Let's load data from the newsgroups dataset which comprises around 18000 newsgroups posts on 20 topics split in two subsets: one for training (or development) and the other one for testing (or for performance evaluation). .. GENERATED FROM PYTHON SOURCE LINES 44-79 .. code-block:: default from sklearn.datasets import fetch_20newsgroups categories = [ "alt.atheism", "talk.religion.misc", "comp.graphics", "sci.space", ] data_train = fetch_20newsgroups( subset="train", categories=categories, shuffle=True, random_state=42 ) data_test = fetch_20newsgroups( subset="test", categories=categories, shuffle=True, random_state=42 ) print("data loaded") # order of labels in `target_names` can be different from `categories` target_names = data_train.target_names def size_mb(docs): return sum(len(s.encode("utf-8")) for s in docs) / 1e6 data_train_size_mb = size_mb(data_train.data) data_test_size_mb = size_mb(data_test.data) print( "%d documents - %0.3fMB (training set)" % (len(data_train.data), data_train_size_mb) ) print("%d documents - %0.3fMB (test set)" % (len(data_test.data), data_test_size_mb)) print("%d categories" % len(target_names)) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none data loaded 2034 documents - 3.980MB (training set) 1353 documents - 2.867MB (test set) 4 categories .. GENERATED FROM PYTHON SOURCE LINES 80-84 Vectorize the training and test data ------------------------------------- split a training set and a test set .. GENERATED FROM PYTHON SOURCE LINES 84-86 .. code-block:: default y_train, y_test = data_train.target, data_test.target .. GENERATED FROM PYTHON SOURCE LINES 87-88 Extracting features from the training data using a sparse vectorizer .. GENERATED FROM PYTHON SOURCE LINES 88-107 .. code-block:: default from time import time from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import HashingVectorizer t0 = time() if USE_HASHING: vectorizer = HashingVectorizer( stop_words="english", alternate_sign=False, n_features=N_FEATURES ) X_train = vectorizer.transform(data_train.data) else: vectorizer = TfidfVectorizer(sublinear_tf=True, max_df=0.5, stop_words="english") X_train = vectorizer.fit_transform(data_train.data) duration = time() - t0 print("done in %fs at %0.3fMB/s" % (duration, data_train_size_mb / duration)) print("n_samples: %d, n_features: %d" % X_train.shape) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none done in 0.447097s at 8.901MB/s n_samples: 2034, n_features: 33809 .. GENERATED FROM PYTHON SOURCE LINES 108-109 Extracting features from the test data using the same vectorizer .. GENERATED FROM PYTHON SOURCE LINES 109-115 .. code-block:: default t0 = time() X_test = vectorizer.transform(data_test.data) duration = time() - t0 print("done in %fs at %0.3fMB/s" % (duration, data_test_size_mb / duration)) print("n_samples: %d, n_features: %d" % X_test.shape) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none done in 0.262186s at 10.937MB/s n_samples: 1353, n_features: 33809 .. GENERATED FROM PYTHON SOURCE LINES 116-117 mapping from integer feature name to original token string .. GENERATED FROM PYTHON SOURCE LINES 117-122 .. code-block:: default if USE_HASHING: feature_names = None else: feature_names = vectorizer.get_feature_names_out() .. GENERATED FROM PYTHON SOURCE LINES 123-124 Keeping only the best features .. GENERATED FROM PYTHON SOURCE LINES 124-139 .. code-block:: default from sklearn.feature_selection import SelectKBest, chi2 if SELECT_CHI2: print("Extracting %d best features by a chi-squared test" % SELECT_CHI2) t0 = time() ch2 = SelectKBest(chi2, k=SELECT_CHI2) X_train = ch2.fit_transform(X_train, y_train) X_test = ch2.transform(X_test) if feature_names is not None: # keep selected feature names feature_names = feature_names[ch2.get_support()] print("done in %fs" % (time() - t0)) print() .. GENERATED FROM PYTHON SOURCE LINES 140-144 Benchmark classifiers ------------------------------------ First we define small benchmarking utilities .. GENERATED FROM PYTHON SOURCE LINES 144-193 .. code-block:: default import numpy as np from sklearn import metrics from sklearn.utils.extmath import density def trim(s): """Trim string to fit on terminal (assuming 80-column display)""" return s if len(s) <= 80 else s[:77] + "..." def benchmark(clf): print("_" * 80) print("Training: ") print(clf) t0 = time() clf.fit(X_train, y_train) train_time = time() - t0 print("train time: %0.3fs" % train_time) t0 = time() pred = clf.predict(X_test) test_time = time() - t0 print("test time: %0.3fs" % test_time) score = metrics.accuracy_score(y_test, pred) print("accuracy: %0.3f" % score) if hasattr(clf, "coef_"): print("dimensionality: %d" % clf.coef_.shape[1]) print("density: %f" % density(clf.coef_)) if feature_names is not None: print("top 10 keywords per class:") for i, label in enumerate(target_names): top10 = np.argsort(clf.coef_[i])[-10:] print(trim("%s: %s" % (label, " ".join(feature_names[top10])))) print() print("classification report:") print(metrics.classification_report(y_test, pred, target_names=target_names)) print("confusion matrix:") print(metrics.confusion_matrix(y_test, pred)) print() clf_descr = str(clf).split("(")[0] return clf_descr, score, train_time, test_time .. GENERATED FROM PYTHON SOURCE LINES 194-196 We now train and test the datasets with 15 different classification models and get performance results for each model. .. GENERATED FROM PYTHON SOURCE LINES 196-268 .. code-block:: default from sklearn.feature_selection import SelectFromModel from sklearn.linear_model import RidgeClassifier from sklearn.pipeline import Pipeline from sklearn.svm import LinearSVC from sklearn.linear_model import SGDClassifier from sklearn.linear_model import Perceptron from sklearn.linear_model import PassiveAggressiveClassifier from sklearn.naive_bayes import BernoulliNB, ComplementNB, MultinomialNB from sklearn.neighbors import KNeighborsClassifier from sklearn.neighbors import NearestCentroid from sklearn.ensemble import RandomForestClassifier results = [] for clf, name in ( (RidgeClassifier(tol=1e-2, solver="sag"), "Ridge Classifier"), (Perceptron(max_iter=50), "Perceptron"), (PassiveAggressiveClassifier(max_iter=50), "Passive-Aggressive"), (KNeighborsClassifier(n_neighbors=10), "kNN"), (RandomForestClassifier(), "Random forest"), ): print("=" * 80) print(name) results.append(benchmark(clf)) for penalty in ["l2", "l1"]: print("=" * 80) print("%s penalty" % penalty.upper()) # Train Liblinear model results.append(benchmark(LinearSVC(penalty=penalty, dual=False, tol=1e-3))) # Train SGD model results.append(benchmark(SGDClassifier(alpha=0.0001, max_iter=50, penalty=penalty))) # Train SGD with Elastic Net penalty print("=" * 80) print("Elastic-Net penalty") results.append( benchmark(SGDClassifier(alpha=0.0001, max_iter=50, penalty="elasticnet")) ) # Train NearestCentroid without threshold print("=" * 80) print("NearestCentroid (aka Rocchio classifier)") results.append(benchmark(NearestCentroid())) # Train sparse Naive Bayes classifiers print("=" * 80) print("Naive Bayes") results.append(benchmark(MultinomialNB(alpha=0.01))) results.append(benchmark(BernoulliNB(alpha=0.01))) results.append(benchmark(ComplementNB(alpha=0.1))) print("=" * 80) print("LinearSVC with L1-based feature selection") # The smaller C, the stronger the regularization. # The more regularization, the more sparsity. results.append( benchmark( Pipeline( [ ( "feature_selection", SelectFromModel(LinearSVC(penalty="l1", dual=False, tol=1e-3)), ), ("classification", LinearSVC(penalty="l2")), ] ) ) ) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none ================================================================================ Ridge Classifier ________________________________________________________________________________ Training: RidgeClassifier(solver='sag', tol=0.01) /home/circleci/project/sklearn/linear_model/_ridge.py:830: UserWarning: "sag" solver requires many iterations to fit an intercept with sparse inputs. Either set the solver to "auto" or "sparse_cg", or set a low "tol" and a high "max_iter" (especially if inputs are not standardized). warnings.warn( train time: 0.180s test time: 0.001s accuracy: 0.897 dimensionality: 33809 density: 1.000000 top 10 keywords per class: alt.atheism: atheist osrhe wingate god okcforum caltech islamic atheism keith... comp.graphics: animation video looking card hi 3d thanks file image graphics sci.space: dc flight shuttle launch pat moon sci orbit nasa space talk.religion.misc: jesus mitre hudson morality biblical 2000 beast mr fbi ch... classification report: precision recall f1-score support alt.atheism 0.87 0.83 0.85 319 comp.graphics 0.90 0.98 0.94 389 sci.space 0.96 0.94 0.95 394 talk.religion.misc 0.83 0.78 0.80 251 accuracy 0.90 1353 macro avg 0.89 0.88 0.89 1353 weighted avg 0.90 0.90 0.90 1353 confusion matrix: [[266 9 7 37] [ 1 381 4 3] [ 0 22 372 0] [ 40 10 6 195]] ================================================================================ Perceptron ________________________________________________________________________________ Training: Perceptron(max_iter=50) train time: 0.016s test time: 0.002s accuracy: 0.888 dimensionality: 33809 density: 0.255302 top 10 keywords per class: alt.atheism: wingate osrhe freedom lippard alt thing cobb atheists atheism keith comp.graphics: siggraph code fractal comp mpeg library pc animation sphere gr... sci.space: bruce wpi solar sci funding moon orbit planets dc space talk.religion.misc: god morality hudson beast sword fbi 2000 order mr christian classification report: precision recall f1-score support alt.atheism 0.86 0.80 0.83 319 comp.graphics 0.90 0.97 0.94 389 sci.space 0.95 0.93 0.94 394 talk.religion.misc 0.79 0.80 0.79 251 accuracy 0.89 1353 macro avg 0.88 0.88 0.88 1353 weighted avg 0.89 0.89 0.89 1353 confusion matrix: [[256 7 8 48] [ 0 379 4 6] [ 7 21 366 0] [ 33 12 6 200]] ================================================================================ Passive-Aggressive ________________________________________________________________________________ Training: PassiveAggressiveClassifier(max_iter=50) train time: 0.030s test time: 0.002s accuracy: 0.903 dimensionality: 33809 density: 0.708428 top 10 keywords per class: alt.atheism: charley cobb osrhe okcforum caltech atheist islamic keith atheis... comp.graphics: 3d 42 computer code animation windows tiff file image graphics sci.space: pat alaska shuttle launch sci dc nasa moon orbit space talk.religion.misc: abortion 666 hudson christians fbi mr morality beast 2000... classification report: precision recall f1-score support alt.atheism 0.86 0.84 0.85 319 comp.graphics 0.93 0.98 0.95 389 sci.space 0.95 0.95 0.95 394 talk.religion.misc 0.83 0.80 0.81 251 accuracy 0.90 1353 macro avg 0.89 0.89 0.89 1353 weighted avg 0.90 0.90 0.90 1353 confusion matrix: [[269 5 9 36] [ 0 380 4 5] [ 3 17 373 1] [ 39 7 5 200]] ================================================================================ kNN ________________________________________________________________________________ Training: KNeighborsClassifier(n_neighbors=10) train time: 0.001s test time: 0.158s accuracy: 0.858 classification report: precision recall f1-score support alt.atheism 0.78 0.90 0.84 319 comp.graphics 0.89 0.89 0.89 389 sci.space 0.90 0.91 0.90 394 talk.religion.misc 0.86 0.67 0.75 251 accuracy 0.86 1353 macro avg 0.86 0.84 0.85 1353 weighted avg 0.86 0.86 0.86 1353 confusion matrix: [[287 3 11 18] [ 14 348 19 8] [ 7 26 359 2] [ 59 13 12 167]] ================================================================================ Random forest ________________________________________________________________________________ Training: RandomForestClassifier() train time: 1.277s test time: 0.079s accuracy: 0.838 classification report: precision recall f1-score support alt.atheism 0.83 0.73 0.78 319 comp.graphics 0.80 0.97 0.88 389 sci.space 0.93 0.89 0.91 394 talk.religion.misc 0.78 0.69 0.73 251 accuracy 0.84 1353 macro avg 0.83 0.82 0.82 1353 weighted avg 0.84 0.84 0.84 1353 confusion matrix: [[234 29 9 47] [ 1 377 10 1] [ 3 41 349 1] [ 43 25 9 174]] ================================================================================ L2 penalty ________________________________________________________________________________ Training: LinearSVC(dual=False, tol=0.001) train time: 0.077s test time: 0.001s accuracy: 0.900 dimensionality: 33809 density: 1.000000 top 10 keywords per class: alt.atheism: rushdie osrhe atheist wingate okcforum caltech islamic atheism k... comp.graphics: code 42 video hi animation thanks 3d file image graphics sci.space: planets dc pat shuttle launch sci moon nasa orbit space talk.religion.misc: abortion hudson 666 biblical 2000 morality mr beast fbi c... classification report: precision recall f1-score support alt.atheism 0.87 0.83 0.85 319 comp.graphics 0.91 0.98 0.95 389 sci.space 0.96 0.95 0.95 394 talk.religion.misc 0.83 0.79 0.81 251 accuracy 0.90 1353 macro avg 0.89 0.89 0.89 1353 weighted avg 0.90 0.90 0.90 1353 confusion matrix: [[266 7 8 38] [ 2 381 3 3] [ 1 20 373 0] [ 38 9 6 198]] ________________________________________________________________________________ Training: SGDClassifier(max_iter=50) train time: 0.023s test time: 0.002s accuracy: 0.899 dimensionality: 33809 density: 0.570514 top 10 keywords per class: alt.atheism: charley rushdie cobb caltech okcforum wingate islamic keith athe... comp.graphics: tiff video hi 42 code 3d animation file image graphics sci.space: shuttle pat planets launch sci dc moon nasa orbit space talk.religion.misc: hudson abortion 666 biblical beast mr morality fbi 2000 c... classification report: precision recall f1-score support alt.atheism 0.87 0.83 0.85 319 comp.graphics 0.92 0.97 0.94 389 sci.space 0.95 0.94 0.95 394 talk.religion.misc 0.81 0.81 0.81 251 accuracy 0.90 1353 macro avg 0.89 0.89 0.89 1353 weighted avg 0.90 0.90 0.90 1353 confusion matrix: [[264 5 9 41] [ 2 377 5 5] [ 2 20 371 1] [ 34 8 5 204]] ================================================================================ L1 penalty ________________________________________________________________________________ Training: LinearSVC(dual=False, penalty='l1', tol=0.001) train time: 0.205s test time: 0.001s accuracy: 0.873 dimensionality: 33809 density: 0.005561 top 10 keywords per class: alt.atheism: benedikt rice rushdie wingate islamic bmd atheism wwc keith athe... comp.graphics: sphere virtual 42 files windows hi image 3d 3do graphics sci.space: pat henry sunrise rockets dc launch flight moon orbit space talk.religion.misc: hudson thyagi biblical 2000 abortion kendig hare mitre ch... classification report: precision recall f1-score support alt.atheism 0.85 0.75 0.80 319 comp.graphics 0.89 0.97 0.93 389 sci.space 0.94 0.94 0.94 394 talk.religion.misc 0.76 0.78 0.77 251 accuracy 0.87 1353 macro avg 0.86 0.86 0.86 1353 weighted avg 0.87 0.87 0.87 1353 confusion matrix: [[238 14 11 56] [ 0 378 7 4] [ 2 22 369 1] [ 39 12 4 196]] ________________________________________________________________________________ Training: SGDClassifier(max_iter=50, penalty='l1') train time: 0.087s test time: 0.002s accuracy: 0.885 dimensionality: 33809 density: 0.023019 top 10 keywords per class: alt.atheism: charley atheist rice psilink wingate rushdie keith islamic athei... comp.graphics: windows 3d video animation files hi pov 3do image graphics sci.space: sunrise sci nasa flight launch dc pat moon orbit space talk.religion.misc: 2000 abortion 666 homosexuality biblical hudson mr beast ... classification report: precision recall f1-score support alt.atheism 0.86 0.79 0.82 319 comp.graphics 0.93 0.96 0.95 389 sci.space 0.93 0.96 0.94 394 talk.religion.misc 0.77 0.77 0.77 251 accuracy 0.89 1353 macro avg 0.87 0.87 0.87 1353 weighted avg 0.88 0.89 0.88 1353 confusion matrix: [[252 6 12 49] [ 1 375 8 5] [ 2 11 378 3] [ 39 11 8 193]] ================================================================================ Elastic-Net penalty ________________________________________________________________________________ Training: SGDClassifier(max_iter=50, penalty='elasticnet') train time: 0.129s test time: 0.001s accuracy: 0.902 dimensionality: 33809 density: 0.187731 top 10 keywords per class: alt.atheism: okcforum caltech cobb atheist rushdie wingate islamic keith athe... comp.graphics: video points hi 3do 42 3d file animation image graphics sci.space: shuttle pat launch planets sci dc nasa moon orbit space talk.religion.misc: order abortion 666 biblical morality mr 2000 fbi beast ch... classification report: precision recall f1-score support alt.atheism 0.86 0.85 0.85 319 comp.graphics 0.92 0.98 0.95 389 sci.space 0.96 0.94 0.95 394 talk.religion.misc 0.83 0.78 0.81 251 accuracy 0.90 1353 macro avg 0.89 0.89 0.89 1353 weighted avg 0.90 0.90 0.90 1353 confusion matrix: [[270 6 9 34] [ 0 381 3 5] [ 1 20 372 1] [ 42 7 5 197]] ================================================================================ NearestCentroid (aka Rocchio classifier) ________________________________________________________________________________ Training: NearestCentroid() train time: 0.005s test time: 0.002s accuracy: 0.855 classification report: precision recall f1-score support alt.atheism 0.88 0.69 0.77 319 comp.graphics 0.84 0.97 0.90 389 sci.space 0.96 0.92 0.94 394 talk.religion.misc 0.72 0.79 0.75 251 accuracy 0.86 1353 macro avg 0.85 0.84 0.84 1353 weighted avg 0.86 0.86 0.85 1353 confusion matrix: [[219 25 5 70] [ 1 379 5 4] [ 1 30 361 2] [ 29 19 5 198]] ================================================================================ Naive Bayes ________________________________________________________________________________ Training: MultinomialNB(alpha=0.01) train time: 0.004s test time: 0.001s accuracy: 0.899 classification report: precision recall f1-score support alt.atheism 0.85 0.87 0.86 319 comp.graphics 0.95 0.95 0.95 389 sci.space 0.92 0.95 0.94 394 talk.religion.misc 0.86 0.77 0.81 251 accuracy 0.90 1353 macro avg 0.89 0.89 0.89 1353 weighted avg 0.90 0.90 0.90 1353 confusion matrix: [[279 2 8 30] [ 2 369 16 2] [ 3 15 376 0] [ 45 4 9 193]] ________________________________________________________________________________ Training: BernoulliNB(alpha=0.01) train time: 0.005s test time: 0.004s accuracy: 0.884 classification report: precision recall f1-score support alt.atheism 0.83 0.88 0.86 319 comp.graphics 0.88 0.96 0.92 389 sci.space 0.94 0.91 0.92 394 talk.religion.misc 0.87 0.73 0.79 251 accuracy 0.88 1353 macro avg 0.88 0.87 0.87 1353 weighted avg 0.88 0.88 0.88 1353 confusion matrix: [[282 9 3 25] [ 1 373 13 2] [ 5 31 358 0] [ 50 10 8 183]] ________________________________________________________________________________ Training: ComplementNB(alpha=0.1) train time: 0.004s test time: 0.001s accuracy: 0.911 classification report: precision recall f1-score support alt.atheism 0.85 0.89 0.87 319 comp.graphics 0.95 0.97 0.96 389 sci.space 0.94 0.97 0.95 394 talk.religion.misc 0.88 0.75 0.81 251 accuracy 0.91 1353 macro avg 0.90 0.90 0.90 1353 weighted avg 0.91 0.91 0.91 1353 confusion matrix: [[284 2 8 25] [ 2 379 7 1] [ 0 13 381 0] [ 49 5 9 188]] ================================================================================ LinearSVC with L1-based feature selection ________________________________________________________________________________ Training: Pipeline(steps=[('feature_selection', SelectFromModel(estimator=LinearSVC(dual=False, penalty='l1', tol=0.001))), ('classification', LinearSVC())]) train time: 0.205s test time: 0.003s accuracy: 0.880 classification report: precision recall f1-score support alt.atheism 0.84 0.80 0.82 319 comp.graphics 0.91 0.96 0.93 389 sci.space 0.93 0.95 0.94 394 talk.religion.misc 0.81 0.76 0.78 251 accuracy 0.88 1353 macro avg 0.87 0.87 0.87 1353 weighted avg 0.88 0.88 0.88 1353 confusion matrix: [[254 11 13 41] [ 2 374 9 4] [ 2 18 373 1] [ 44 9 8 190]] .. GENERATED FROM PYTHON SOURCE LINES 269-273 Add plots ------------------------------------ The bar plot indicates the accuracy, training time (normalized) and test time (normalized) of each classifier. .. GENERATED FROM PYTHON SOURCE LINES 273-298 .. code-block:: default import matplotlib.pyplot as plt indices = np.arange(len(results)) results = [[x[i] for x in results] for i in range(4)] clf_names, score, training_time, test_time = results training_time = np.array(training_time) / np.max(training_time) test_time = np.array(test_time) / np.max(test_time) plt.figure(figsize=(12, 8)) plt.title("Score") plt.barh(indices, score, 0.2, label="score", color="navy") plt.barh(indices + 0.3, training_time, 0.2, label="training time", color="c") plt.barh(indices + 0.6, test_time, 0.2, label="test time", color="darkorange") plt.yticks(()) plt.legend(loc="best") plt.subplots_adjust(left=0.25) plt.subplots_adjust(top=0.95) plt.subplots_adjust(bottom=0.05) for i, c in zip(indices, clf_names): plt.text(-0.3, i, c) plt.show() .. image-sg:: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_001.png :alt: Score :srcset: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 3.995 seconds) .. _sphx_glr_download_auto_examples_text_plot_document_classification_20newsgroups.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.1.X?urlpath=lab/tree/notebooks/auto_examples/text/plot_document_classification_20newsgroups.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_document_classification_20newsgroups.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_document_classification_20newsgroups.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_