Classification of text documents using sparse features

This is an example showing how scikit-learn can be used to classify documents by topics using a bag-of-words approach. This example uses a scipy.sparse matrix to store the features and demonstrates various classifiers that can efficiently handle sparse matrices.

The dataset used in this example is the 20 newsgroups dataset. It will be automatically downloaded, then cached.

# Author: Peter Prettenhofer <peter.prettenhofer@gmail.com>
#         Olivier Grisel <olivier.grisel@ensta.org>
#         Mathieu Blondel <mathieu@mblondel.org>
#         Lars Buitinck
# License: BSD 3 clause

Configuration options for the analysis

# If True, we use `HashingVectorizer`, otherwise we use a `TfidfVectorizer`
USE_HASHING = False

# Number of features used by `HashingVectorizer`
N_FEATURES = 2**16

# Optional feature selection: either False, or an integer: the number of
# features to select
SELECT_CHI2 = False

Load data from the training set

Let’s load data from the newsgroups dataset which comprises around 18000 newsgroups posts on 20 topics split in two subsets: one for training (or development) and the other one for testing (or for performance evaluation).

from sklearn.datasets import fetch_20newsgroups

categories = [
    "alt.atheism",
    "talk.religion.misc",
    "comp.graphics",
    "sci.space",
]

data_train = fetch_20newsgroups(
    subset="train", categories=categories, shuffle=True, random_state=42
)

data_test = fetch_20newsgroups(
    subset="test", categories=categories, shuffle=True, random_state=42
)
print("data loaded")

# order of labels in `target_names` can be different from `categories`
target_names = data_train.target_names


def size_mb(docs):
    return sum(len(s.encode("utf-8")) for s in docs) / 1e6


data_train_size_mb = size_mb(data_train.data)
data_test_size_mb = size_mb(data_test.data)

print(
    "%d documents - %0.3fMB (training set)" % (len(data_train.data), data_train_size_mb)
)
print("%d documents - %0.3fMB (test set)" % (len(data_test.data), data_test_size_mb))
print("%d categories" % len(target_names))

Out:

data loaded
2034 documents - 3.980MB (training set)
1353 documents - 2.867MB (test set)
4 categories

Vectorize the training and test data

split a training set and a test set

y_train, y_test = data_train.target, data_test.target

Extracting features from the training data using a sparse vectorizer

from time import time

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_extraction.text import HashingVectorizer

t0 = time()

if USE_HASHING:
    vectorizer = HashingVectorizer(
        stop_words="english", alternate_sign=False, n_features=N_FEATURES
    )
    X_train = vectorizer.transform(data_train.data)
else:
    vectorizer = TfidfVectorizer(sublinear_tf=True, max_df=0.5, stop_words="english")
    X_train = vectorizer.fit_transform(data_train.data)
duration = time() - t0
print("done in %fs at %0.3fMB/s" % (duration, data_train_size_mb / duration))
print("n_samples: %d, n_features: %d" % X_train.shape)

Out:

done in 0.390181s at 10.199MB/s
n_samples: 2034, n_features: 33809

Extracting features from the test data using the same vectorizer

t0 = time()
X_test = vectorizer.transform(data_test.data)
duration = time() - t0
print("done in %fs at %0.3fMB/s" % (duration, data_test_size_mb / duration))
print("n_samples: %d, n_features: %d" % X_test.shape)

Out:

done in 0.244998s at 11.704MB/s
n_samples: 1353, n_features: 33809

mapping from integer feature name to original token string

if USE_HASHING:
    feature_names = None
else:
    feature_names = vectorizer.get_feature_names_out()

Keeping only the best features

from sklearn.feature_selection import SelectKBest, chi2

if SELECT_CHI2:
    print("Extracting %d best features by a chi-squared test" % SELECT_CHI2)
    t0 = time()
    ch2 = SelectKBest(chi2, k=SELECT_CHI2)
    X_train = ch2.fit_transform(X_train, y_train)
    X_test = ch2.transform(X_test)
    if feature_names is not None:
        # keep selected feature names
        feature_names = feature_names[ch2.get_support()]
    print("done in %fs" % (time() - t0))
    print()

Benchmark classifiers

First we define small benchmarking utilities

import numpy as np
from sklearn import metrics
from sklearn.utils.extmath import density


def trim(s):
    """Trim string to fit on terminal (assuming 80-column display)"""
    return s if len(s) <= 80 else s[:77] + "..."


def benchmark(clf):
    print("_" * 80)
    print("Training: ")
    print(clf)
    t0 = time()
    clf.fit(X_train, y_train)
    train_time = time() - t0
    print("train time: %0.3fs" % train_time)

    t0 = time()
    pred = clf.predict(X_test)
    test_time = time() - t0
    print("test time:  %0.3fs" % test_time)

    score = metrics.accuracy_score(y_test, pred)
    print("accuracy:   %0.3f" % score)

    if hasattr(clf, "coef_"):
        print("dimensionality: %d" % clf.coef_.shape[1])
        print("density: %f" % density(clf.coef_))

        if feature_names is not None:
            print("top 10 keywords per class:")
            for i, label in enumerate(target_names):
                top10 = np.argsort(clf.coef_[i])[-10:]
                print(trim("%s: %s" % (label, " ".join(feature_names[top10]))))
        print()

    print("classification report:")
    print(metrics.classification_report(y_test, pred, target_names=target_names))

    print("confusion matrix:")
    print(metrics.confusion_matrix(y_test, pred))

    print()
    clf_descr = str(clf).split("(")[0]
    return clf_descr, score, train_time, test_time

We now train and test the datasets with 15 different classification models and get performance results for each model.

from sklearn.feature_selection import SelectFromModel
from sklearn.linear_model import RidgeClassifier
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from sklearn.linear_model import SGDClassifier
from sklearn.linear_model import Perceptron
from sklearn.linear_model import PassiveAggressiveClassifier
from sklearn.naive_bayes import BernoulliNB, ComplementNB, MultinomialNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neighbors import NearestCentroid
from sklearn.ensemble import RandomForestClassifier


results = []
for clf, name in (
    (RidgeClassifier(tol=1e-2, solver="sag"), "Ridge Classifier"),
    (Perceptron(max_iter=50), "Perceptron"),
    (PassiveAggressiveClassifier(max_iter=50), "Passive-Aggressive"),
    (KNeighborsClassifier(n_neighbors=10), "kNN"),
    (RandomForestClassifier(), "Random forest"),
):
    print("=" * 80)
    print(name)
    results.append(benchmark(clf))

for penalty in ["l2", "l1"]:
    print("=" * 80)
    print("%s penalty" % penalty.upper())
    # Train Liblinear model
    results.append(benchmark(LinearSVC(penalty=penalty, dual=False, tol=1e-3)))

    # Train SGD model
    results.append(benchmark(SGDClassifier(alpha=0.0001, max_iter=50, penalty=penalty)))

# Train SGD with Elastic Net penalty
print("=" * 80)
print("Elastic-Net penalty")
results.append(
    benchmark(SGDClassifier(alpha=0.0001, max_iter=50, penalty="elasticnet"))
)

# Train NearestCentroid without threshold
print("=" * 80)
print("NearestCentroid (aka Rocchio classifier)")
results.append(benchmark(NearestCentroid()))

# Train sparse Naive Bayes classifiers
print("=" * 80)
print("Naive Bayes")
results.append(benchmark(MultinomialNB(alpha=0.01)))
results.append(benchmark(BernoulliNB(alpha=0.01)))
results.append(benchmark(ComplementNB(alpha=0.1)))

print("=" * 80)
print("LinearSVC with L1-based feature selection")
# The smaller C, the stronger the regularization.
# The more regularization, the more sparsity.
results.append(
    benchmark(
        Pipeline(
            [
                (
                    "feature_selection",
                    SelectFromModel(LinearSVC(penalty="l1", dual=False, tol=1e-3)),
                ),
                ("classification", LinearSVC(penalty="l2")),
            ]
        )
    )
)

Out:

================================================================================
Ridge Classifier
________________________________________________________________________________
Training:
RidgeClassifier(solver='sag', tol=0.01)
/home/circleci/project/sklearn/linear_model/_ridge.py:830: UserWarning: "sag" solver requires many iterations to fit an intercept with sparse inputs. Either set the solver to "auto" or "sparse_cg", or set a low "tol" and a high "max_iter" (especially if inputs are not standardized).
  warnings.warn(
train time: 0.166s
test time:  0.001s
accuracy:   0.897
dimensionality: 33809
density: 1.000000
top 10 keywords per class:
alt.atheism: atheist osrhe wingate god okcforum caltech islamic atheism keith...
comp.graphics: animation video looking card hi 3d thanks file image graphics
sci.space: dc flight shuttle launch pat moon sci orbit nasa space
talk.religion.misc: jesus mitre hudson morality biblical 2000 beast mr fbi ch...

classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.87      0.83      0.85       319
     comp.graphics       0.90      0.98      0.94       389
         sci.space       0.96      0.94      0.95       394
talk.religion.misc       0.83      0.78      0.80       251

          accuracy                           0.90      1353
         macro avg       0.89      0.88      0.89      1353
      weighted avg       0.90      0.90      0.90      1353

confusion matrix:
[[266   9   7  37]
 [  1 381   4   3]
 [  0  22 372   0]
 [ 40  10   6 195]]

================================================================================
Perceptron
________________________________________________________________________________
Training:
Perceptron(max_iter=50)
train time: 0.015s
test time:  0.001s
accuracy:   0.888
dimensionality: 33809
density: 0.255302
top 10 keywords per class:
alt.atheism: wingate osrhe freedom lippard alt thing cobb atheists atheism keith
comp.graphics: siggraph code fractal comp mpeg library pc animation sphere gr...
sci.space: bruce wpi solar sci funding moon orbit planets dc space
talk.religion.misc: god morality hudson beast sword fbi 2000 order mr christian

classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.86      0.80      0.83       319
     comp.graphics       0.90      0.97      0.94       389
         sci.space       0.95      0.93      0.94       394
talk.religion.misc       0.79      0.80      0.79       251

          accuracy                           0.89      1353
         macro avg       0.88      0.88      0.88      1353
      weighted avg       0.89      0.89      0.89      1353

confusion matrix:
[[256   7   8  48]
 [  0 379   4   6]
 [  7  21 366   0]
 [ 33  12   6 200]]

================================================================================
Passive-Aggressive
________________________________________________________________________________
Training:
PassiveAggressiveClassifier(max_iter=50)
train time: 0.028s
test time:  0.001s
accuracy:   0.903
dimensionality: 33809
density: 0.708428
top 10 keywords per class:
alt.atheism: charley cobb osrhe okcforum caltech atheist islamic keith atheis...
comp.graphics: 3d 42 computer code animation windows tiff file image graphics
sci.space: pat alaska shuttle launch sci dc nasa moon orbit space
talk.religion.misc: abortion 666 hudson christians fbi mr morality beast 2000...

classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.86      0.84      0.85       319
     comp.graphics       0.93      0.98      0.95       389
         sci.space       0.95      0.95      0.95       394
talk.religion.misc       0.83      0.80      0.81       251

          accuracy                           0.90      1353
         macro avg       0.89      0.89      0.89      1353
      weighted avg       0.90      0.90      0.90      1353

confusion matrix:
[[269   5   9  36]
 [  0 380   4   5]
 [  3  17 373   1]
 [ 39   7   5 200]]

================================================================================
kNN
________________________________________________________________________________
Training:
KNeighborsClassifier(n_neighbors=10)
train time: 0.001s
test time:  0.147s
accuracy:   0.858
classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.78      0.90      0.84       319
     comp.graphics       0.89      0.89      0.89       389
         sci.space       0.90      0.91      0.90       394
talk.religion.misc       0.86      0.67      0.75       251

          accuracy                           0.86      1353
         macro avg       0.86      0.84      0.85      1353
      weighted avg       0.86      0.86      0.86      1353

confusion matrix:
[[287   3  11  18]
 [ 14 348  19   8]
 [  7  26 359   2]
 [ 59  13  12 167]]

================================================================================
Random forest
________________________________________________________________________________
Training:
RandomForestClassifier()
train time: 1.163s
test time:  0.078s
accuracy:   0.838
classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.83      0.73      0.78       319
     comp.graphics       0.80      0.97      0.88       389
         sci.space       0.93      0.89      0.91       394
talk.religion.misc       0.78      0.69      0.73       251

          accuracy                           0.84      1353
         macro avg       0.83      0.82      0.82      1353
      weighted avg       0.84      0.84      0.84      1353

confusion matrix:
[[234  29   9  47]
 [  1 377  10   1]
 [  3  41 349   1]
 [ 43  25   9 174]]

================================================================================
L2 penalty
________________________________________________________________________________
Training:
LinearSVC(dual=False, tol=0.001)
train time: 0.074s
test time:  0.001s
accuracy:   0.900
dimensionality: 33809
density: 1.000000
top 10 keywords per class:
alt.atheism: rushdie osrhe atheist wingate okcforum caltech islamic atheism k...
comp.graphics: code 42 video hi animation thanks 3d file image graphics
sci.space: planets dc pat shuttle launch sci moon nasa orbit space
talk.religion.misc: abortion hudson 666 biblical 2000 morality mr beast fbi c...

classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.87      0.83      0.85       319
     comp.graphics       0.91      0.98      0.95       389
         sci.space       0.96      0.95      0.95       394
talk.religion.misc       0.83      0.79      0.81       251

          accuracy                           0.90      1353
         macro avg       0.89      0.89      0.89      1353
      weighted avg       0.90      0.90      0.90      1353

confusion matrix:
[[266   7   8  38]
 [  2 381   3   3]
 [  1  20 373   0]
 [ 38   9   6 198]]

________________________________________________________________________________
Training:
SGDClassifier(max_iter=50)
train time: 0.022s
test time:  0.001s
accuracy:   0.899
dimensionality: 33809
density: 0.570514
top 10 keywords per class:
alt.atheism: charley rushdie cobb caltech okcforum wingate islamic keith athe...
comp.graphics: tiff video hi 42 code 3d animation file image graphics
sci.space: shuttle pat planets launch sci dc moon nasa orbit space
talk.religion.misc: hudson abortion 666 biblical beast mr morality fbi 2000 c...

classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.87      0.83      0.85       319
     comp.graphics       0.92      0.97      0.94       389
         sci.space       0.95      0.94      0.95       394
talk.religion.misc       0.81      0.81      0.81       251

          accuracy                           0.90      1353
         macro avg       0.89      0.89      0.89      1353
      weighted avg       0.90      0.90      0.90      1353

confusion matrix:
[[264   5   9  41]
 [  2 377   5   5]
 [  2  20 371   1]
 [ 34   8   5 204]]

================================================================================
L1 penalty
________________________________________________________________________________
Training:
LinearSVC(dual=False, penalty='l1', tol=0.001)
train time: 0.187s
test time:  0.001s
accuracy:   0.873
dimensionality: 33809
density: 0.005561
top 10 keywords per class:
alt.atheism: benedikt rice rushdie wingate islamic bmd atheism wwc keith athe...
comp.graphics: sphere virtual 42 files windows hi image 3d 3do graphics
sci.space: pat henry sunrise rockets dc launch flight moon orbit space
talk.religion.misc: hudson thyagi biblical 2000 abortion kendig hare mitre ch...

classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.85      0.75      0.80       319
     comp.graphics       0.89      0.97      0.93       389
         sci.space       0.94      0.94      0.94       394
talk.religion.misc       0.76      0.78      0.77       251

          accuracy                           0.87      1353
         macro avg       0.86      0.86      0.86      1353
      weighted avg       0.87      0.87      0.87      1353

confusion matrix:
[[238  14  11  56]
 [  0 378   7   4]
 [  2  22 369   1]
 [ 39  12   4 196]]

________________________________________________________________________________
Training:
SGDClassifier(max_iter=50, penalty='l1')
train time: 0.084s
test time:  0.001s
accuracy:   0.885
dimensionality: 33809
density: 0.023019
top 10 keywords per class:
alt.atheism: charley atheist rice psilink wingate rushdie keith islamic athei...
comp.graphics: windows 3d video animation files hi pov 3do image graphics
sci.space: sunrise sci nasa flight launch dc pat moon orbit space
talk.religion.misc: 2000 abortion 666 homosexuality biblical hudson mr beast ...

classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.86      0.79      0.82       319
     comp.graphics       0.93      0.96      0.95       389
         sci.space       0.93      0.96      0.94       394
talk.religion.misc       0.77      0.77      0.77       251

          accuracy                           0.89      1353
         macro avg       0.87      0.87      0.87      1353
      weighted avg       0.88      0.89      0.88      1353

confusion matrix:
[[252   6  12  49]
 [  1 375   8   5]
 [  2  11 378   3]
 [ 39  11   8 193]]

================================================================================
Elastic-Net penalty
________________________________________________________________________________
Training:
SGDClassifier(max_iter=50, penalty='elasticnet')
train time: 0.125s
test time:  0.002s
accuracy:   0.902
dimensionality: 33809
density: 0.187731
top 10 keywords per class:
alt.atheism: okcforum caltech cobb atheist rushdie wingate islamic keith athe...
comp.graphics: video points hi 3do 42 3d file animation image graphics
sci.space: shuttle pat launch planets sci dc nasa moon orbit space
talk.religion.misc: order abortion 666 biblical morality mr 2000 fbi beast ch...

classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.86      0.85      0.85       319
     comp.graphics       0.92      0.98      0.95       389
         sci.space       0.96      0.94      0.95       394
talk.religion.misc       0.83      0.78      0.81       251

          accuracy                           0.90      1353
         macro avg       0.89      0.89      0.89      1353
      weighted avg       0.90      0.90      0.90      1353

confusion matrix:
[[270   6   9  34]
 [  0 381   3   5]
 [  1  20 372   1]
 [ 42   7   5 197]]

================================================================================
NearestCentroid (aka Rocchio classifier)
________________________________________________________________________________
Training:
NearestCentroid()
train time: 0.004s
test time:  0.002s
accuracy:   0.855
classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.88      0.69      0.77       319
     comp.graphics       0.84      0.97      0.90       389
         sci.space       0.96      0.92      0.94       394
talk.religion.misc       0.72      0.79      0.75       251

          accuracy                           0.86      1353
         macro avg       0.85      0.84      0.84      1353
      weighted avg       0.86      0.86      0.85      1353

confusion matrix:
[[219  25   5  70]
 [  1 379   5   4]
 [  1  30 361   2]
 [ 29  19   5 198]]

================================================================================
Naive Bayes
________________________________________________________________________________
Training:
MultinomialNB(alpha=0.01)
train time: 0.003s
test time:  0.001s
accuracy:   0.899
classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.85      0.87      0.86       319
     comp.graphics       0.95      0.95      0.95       389
         sci.space       0.92      0.95      0.94       394
talk.religion.misc       0.86      0.77      0.81       251

          accuracy                           0.90      1353
         macro avg       0.89      0.89      0.89      1353
      weighted avg       0.90      0.90      0.90      1353

confusion matrix:
[[279   2   8  30]
 [  2 369  16   2]
 [  3  15 376   0]
 [ 45   4   9 193]]

________________________________________________________________________________
Training:
BernoulliNB(alpha=0.01)
train time: 0.004s
test time:  0.003s
accuracy:   0.884
classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.83      0.88      0.86       319
     comp.graphics       0.88      0.96      0.92       389
         sci.space       0.94      0.91      0.92       394
talk.religion.misc       0.87      0.73      0.79       251

          accuracy                           0.88      1353
         macro avg       0.88      0.87      0.87      1353
      weighted avg       0.88      0.88      0.88      1353

confusion matrix:
[[282   9   3  25]
 [  1 373  13   2]
 [  5  31 358   0]
 [ 50  10   8 183]]

________________________________________________________________________________
Training:
ComplementNB(alpha=0.1)
train time: 0.003s
test time:  0.001s
accuracy:   0.911
classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.85      0.89      0.87       319
     comp.graphics       0.95      0.97      0.96       389
         sci.space       0.94      0.97      0.95       394
talk.religion.misc       0.88      0.75      0.81       251

          accuracy                           0.91      1353
         macro avg       0.90      0.90      0.90      1353
      weighted avg       0.91      0.91      0.91      1353

confusion matrix:
[[284   2   8  25]
 [  2 379   7   1]
 [  0  13 381   0]
 [ 49   5   9 188]]

================================================================================
LinearSVC with L1-based feature selection
________________________________________________________________________________
Training:
Pipeline(steps=[('feature_selection',
                 SelectFromModel(estimator=LinearSVC(dual=False, penalty='l1',
                                                     tol=0.001))),
                ('classification', LinearSVC())])
train time: 0.190s
test time:  0.002s
accuracy:   0.880
classification report:
                    precision    recall  f1-score   support

       alt.atheism       0.84      0.80      0.82       319
     comp.graphics       0.91      0.96      0.93       389
         sci.space       0.93      0.95      0.94       394
talk.religion.misc       0.81      0.76      0.78       251

          accuracy                           0.88      1353
         macro avg       0.87      0.87      0.87      1353
      weighted avg       0.88      0.88      0.88      1353

confusion matrix:
[[254  11  13  41]
 [  2 374   9   4]
 [  2  18 373   1]
 [ 44   9   8 190]]

Add plots

The bar plot indicates the accuracy, training time (normalized) and test time (normalized) of each classifier.

import matplotlib.pyplot as plt

indices = np.arange(len(results))

results = [[x[i] for x in results] for i in range(4)]

clf_names, score, training_time, test_time = results
training_time = np.array(training_time) / np.max(training_time)
test_time = np.array(test_time) / np.max(test_time)

plt.figure(figsize=(12, 8))
plt.title("Score")
plt.barh(indices, score, 0.2, label="score", color="navy")
plt.barh(indices + 0.3, training_time, 0.2, label="training time", color="c")
plt.barh(indices + 0.6, test_time, 0.2, label="test time", color="darkorange")
plt.yticks(())
plt.legend(loc="best")
plt.subplots_adjust(left=0.25)
plt.subplots_adjust(top=0.95)
plt.subplots_adjust(bottom=0.05)

for i, c in zip(indices, clf_names):
    plt.text(-0.3, i, c)

plt.show()
Score

Total running time of the script: ( 0 minutes 3.679 seconds)

Gallery generated by Sphinx-Gallery