Plot multinomial and One-vs-Rest Logistic Regression#

Plot decision surface of multinomial and One-vs-Rest Logistic Regression. The hyperplanes corresponding to the three One-vs-Rest (OVR) classifiers are represented by the dashed lines.

```training score : 0.995 (multinomial)
training score : 0.976 (ovr)
```

```# Authors: Tom Dupre la Tour <tom.dupre-la-tour@m4x.org>

import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import make_blobs
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier

# make 3-class dataset for classification
centers = [[-5, 0], [0, 1.5], [5, -1]]
X, y = make_blobs(n_samples=1000, centers=centers, random_state=40)
transformation = [[0.4, 0.2], [-0.4, 1.2]]
X = np.dot(X, transformation)

for multi_class in ("multinomial", "ovr"):
clf = LogisticRegression(solver="sag", max_iter=100, random_state=42)
if multi_class == "ovr":
clf = OneVsRestClassifier(clf)
clf.fit(X, y)

# print the training scores
print("training score : %.3f (%s)" % (clf.score(X, y), multi_class))

_, ax = plt.subplots()
DecisionBoundaryDisplay.from_estimator(
clf, X, response_method="predict", cmap=plt.cm.Paired, ax=ax
)
plt.title("Decision surface of LogisticRegression (%s)" % multi_class)
plt.axis("tight")

# Plot also the training points
colors = "bry"
for i, color in zip(clf.classes_, colors):
idx = np.where(y == i)
plt.scatter(X[idx, 0], X[idx, 1], c=color, edgecolor="black", s=20)

# Plot the three one-against-all classifiers
xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()
if multi_class == "ovr":
coef = np.concatenate([est.coef_ for est in clf.estimators_])
intercept = np.concatenate([est.intercept_ for est in clf.estimators_])
else:
coef = clf.coef_
intercept = clf.intercept_

def plot_hyperplane(c, color):
def line(x0):
return (-(x0 * coef[c, 0]) - intercept[c]) / coef[c, 1]

plt.plot([xmin, xmax], [line(xmin), line(xmax)], ls="--", color=color)

for i, color in zip(clf.classes_, colors):
plot_hyperplane(i, color)

plt.show()
```

Total running time of the script: (0 minutes 0.216 seconds)

Related examples

Plot multi-class SGD on the iris dataset

Plot multi-class SGD on the iris dataset

Multiclass sparse logistic regression on 20newgroups

Multiclass sparse logistic regression on 20newgroups

SVM Tie Breaking Example

SVM Tie Breaking Example

Plot classification probability

Plot classification probability

Gallery generated by Sphinx-Gallery