Compare BIRCH and MiniBatchKMeans

This example compares the timing of BIRCH (with and without the global clustering step) and MiniBatchKMeans on a synthetic dataset having 25,000 samples and 2 features generated using make_blobs.

Both MiniBatchKMeans and BIRCH are very scalable algorithms and could run efficiently on hundreds of thousands or even millions of datapoints. We chose to limit the dataset size of this example in the interest of keeping our Continuous Integration resource usage reasonable but the interested reader might enjoy editing this script to rerun it with a larger value for n_samples.

If n_clusters is set to None, the data is reduced from 25,000 samples to a set of 158 clusters. This can be viewed as a preprocessing step before the final (global) clustering step that further reduces these 158 clusters to 100 clusters.

BIRCH without global clustering, BIRCH with global clustering, MiniBatchKMeans
/home/runner/work/scikit-learn/scikit-learn/sklearn/cluster/_birch.py:677: UserWarning:

Some metric_kwargs have been passed ({'Y_norm_squared': array([143.50454616, 440.28785945, 531.67097269, 647.35149727,
       968.33594414, 515.20083068, 770.95496401,  68.6666749 ,
       346.97724523, 302.99906981,  57.42460786, 591.09408586,
       165.45688258, 298.73641304, 431.77978904, 150.67636202,
       776.99102443, 344.14701893, 635.16564778, 298.75076245,
       201.60356162,  11.74547698, 529.47972857, 106.99412216,
       185.73155616, 200.8668397 , 492.50690482,  39.04899534,
        11.10443208, 489.79942035, 537.74789609, 456.23192575,
       528.53923395, 340.93019585, 489.79923515, 639.36987808,
       537.09472195, 301.87332381,  58.22884994, 492.7319812 ,
       583.88355467, 285.19998592, 108.89612064, 157.38162385,
       200.41776247, 761.07185361, 999.55716426, 322.80600392,
       781.523434  , 448.86981855, 441.58274076, 297.49217172,
       613.76930293, 157.59700141, 879.94256221, 636.39290568,
        13.03826574, 197.50130281,  61.43734061,   5.83103503,
        22.51215929, 878.72068449, 296.63233986, 536.4261413 ,
       490.63055753, 334.03596486, 156.10106179, 587.86995096,
       315.63809387, 105.08988492, 450.722817  , 974.98834856,
       358.47558868, 762.51429959, 785.35442707, 633.15310054,
       445.83845118, 188.46824669, 293.883547  , 199.55700317,
       278.4276957 , 655.87570702,  63.43294819, 505.52222923,
        58.38454793, 774.22624472, 164.48332596, 583.46117289,
       109.97680564,  13.39847235,  23.4317738 , 635.2136768 ,
       305.91572402, 298.1590872 , 451.45584429, 527.24914771,
       442.03226164,   7.52031107, 781.96123567, 974.72961985,
       631.81200861, 203.71231759,  60.1959924 , 155.82739196,
       339.60605579, 553.49430467, 438.52050773, 207.29748009,
       213.28620389, 468.86221306,  58.52169268, 106.77029661,
       351.1630945 , 531.02555852, 296.733965  , 154.20663434,
       489.29496317, 296.80643931, 756.85395733, 442.89064367,
       444.05913519, 225.53874781, 671.42246884, 492.64758942,
       541.69572073,  11.08201729, 104.36140765, 462.49997432,
       155.51581727, 421.87117195, 814.81066407, 368.58659446,
       220.56152005, 548.93306499, 641.688439  , 586.47667732,
        58.02516019, 298.4997049 , 323.1330857 , 197.36617526,
       972.45548202, 348.21080096,  62.46501099, 490.00254817,
       783.95781475, 351.09137395, 304.35278   , 156.14431967,
       485.9247294 ,  58.14137327, 158.06635335, 629.55977946,
       301.69806493, 612.80974738, 538.57255339, 175.4864581 ,
       791.29761213, 185.46960351])}) but aren't usable for this case (EuclideanArgKmin64) and will be ignored.

BIRCH without global clustering as the final step took 0.65 seconds
n_clusters : 158
/home/runner/work/scikit-learn/scikit-learn/sklearn/cluster/_birch.py:677: UserWarning:

Some metric_kwargs have been passed ({'Y_norm_squared': array([143.50454616, 440.28785945, 531.67097269, 647.35149727,
       968.33594414, 515.20083068, 770.95496401,  68.6666749 ,
       346.97724523, 302.99906981,  57.42460786, 591.09408586,
       165.45688258, 298.73641304, 431.77978904, 150.67636202,
       776.99102443, 344.14701893, 635.16564778, 298.75076245,
       201.60356162,  11.74547698, 529.47972857, 106.99412216,
       185.73155616, 200.8668397 , 492.50690482,  39.04899534,
        11.10443208, 489.79942035, 537.74789609, 456.23192575,
       528.53923395, 340.93019585, 489.79923515, 639.36987808,
       537.09472195, 301.87332381,  58.22884994, 492.7319812 ,
       583.88355467, 285.19998592, 108.89612064, 157.38162385,
       200.41776247, 761.07185361, 999.55716426, 322.80600392,
       781.523434  , 448.86981855, 441.58274076, 297.49217172,
       613.76930293, 157.59700141, 879.94256221, 636.39290568,
        13.03826574, 197.50130281,  61.43734061,   5.83103503,
        22.51215929, 878.72068449, 296.63233986, 536.4261413 ,
       490.63055753, 334.03596486, 156.10106179, 587.86995096,
       315.63809387, 105.08988492, 450.722817  , 974.98834856,
       358.47558868, 762.51429959, 785.35442707, 633.15310054,
       445.83845118, 188.46824669, 293.883547  , 199.55700317,
       278.4276957 , 655.87570702,  63.43294819, 505.52222923,
        58.38454793, 774.22624472, 164.48332596, 583.46117289,
       109.97680564,  13.39847235,  23.4317738 , 635.2136768 ,
       305.91572402, 298.1590872 , 451.45584429, 527.24914771,
       442.03226164,   7.52031107, 781.96123567, 974.72961985,
       631.81200861, 203.71231759,  60.1959924 , 155.82739196,
       339.60605579, 553.49430467, 438.52050773, 207.29748009,
       213.28620389, 468.86221306,  58.52169268, 106.77029661,
       351.1630945 , 531.02555852, 296.733965  , 154.20663434,
       489.29496317, 296.80643931, 756.85395733, 442.89064367,
       444.05913519, 225.53874781, 671.42246884, 492.64758942,
       541.69572073,  11.08201729, 104.36140765, 462.49997432,
       155.51581727, 421.87117195, 814.81066407, 368.58659446,
       220.56152005, 548.93306499, 641.688439  , 586.47667732,
        58.02516019, 298.4997049 , 323.1330857 , 197.36617526,
       972.45548202, 348.21080096,  62.46501099, 490.00254817,
       783.95781475, 351.09137395, 304.35278   , 156.14431967,
       485.9247294 ,  58.14137327, 158.06635335, 629.55977946,
       301.69806493, 612.80974738, 538.57255339, 175.4864581 ,
       791.29761213, 185.46960351])}) but aren't usable for this case (EuclideanArgKmin64) and will be ignored.

BIRCH with global clustering as the final step took 0.65 seconds
n_clusters : 100
Time taken to run MiniBatchKMeans 0.21 seconds

# Authors: Manoj Kumar <manojkumarsivaraj334@gmail.com
#          Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
# License: BSD 3 clause

from joblib import cpu_count
from itertools import cycle
from time import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors

from sklearn.cluster import Birch, MiniBatchKMeans
from sklearn.datasets import make_blobs


# Generate centers for the blobs so that it forms a 10 X 10 grid.
xx = np.linspace(-22, 22, 10)
yy = np.linspace(-22, 22, 10)
xx, yy = np.meshgrid(xx, yy)
n_centers = np.hstack((np.ravel(xx)[:, np.newaxis], np.ravel(yy)[:, np.newaxis]))

# Generate blobs to do a comparison between MiniBatchKMeans and BIRCH.
X, y = make_blobs(n_samples=25000, centers=n_centers, random_state=0)

# Use all colors that matplotlib provides by default.
colors_ = cycle(colors.cnames.keys())

fig = plt.figure(figsize=(12, 4))
fig.subplots_adjust(left=0.04, right=0.98, bottom=0.1, top=0.9)

# Compute clustering with BIRCH with and without the final clustering step
# and plot.
birch_models = [
    Birch(threshold=1.7, n_clusters=None),
    Birch(threshold=1.7, n_clusters=100),
]
final_step = ["without global clustering", "with global clustering"]

for ind, (birch_model, info) in enumerate(zip(birch_models, final_step)):
    t = time()
    birch_model.fit(X)
    print("BIRCH %s as the final step took %0.2f seconds" % (info, (time() - t)))

    # Plot result
    labels = birch_model.labels_
    centroids = birch_model.subcluster_centers_
    n_clusters = np.unique(labels).size
    print("n_clusters : %d" % n_clusters)

    ax = fig.add_subplot(1, 3, ind + 1)
    for this_centroid, k, col in zip(centroids, range(n_clusters), colors_):
        mask = labels == k
        ax.scatter(X[mask, 0], X[mask, 1], c="w", edgecolor=col, marker=".", alpha=0.5)
        if birch_model.n_clusters is None:
            ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25)
    ax.set_ylim([-25, 25])
    ax.set_xlim([-25, 25])
    ax.set_autoscaley_on(False)
    ax.set_title("BIRCH %s" % info)

# Compute clustering with MiniBatchKMeans.
mbk = MiniBatchKMeans(
    init="k-means++",
    n_clusters=100,
    batch_size=256 * cpu_count(),
    n_init=10,
    max_no_improvement=10,
    verbose=0,
    random_state=0,
)
t0 = time()
mbk.fit(X)
t_mini_batch = time() - t0
print("Time taken to run MiniBatchKMeans %0.2f seconds" % t_mini_batch)
mbk_means_labels_unique = np.unique(mbk.labels_)

ax = fig.add_subplot(1, 3, 3)
for this_centroid, k, col in zip(mbk.cluster_centers_, range(n_clusters), colors_):
    mask = mbk.labels_ == k
    ax.scatter(X[mask, 0], X[mask, 1], marker=".", c="w", edgecolor=col, alpha=0.5)
    ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25)
ax.set_xlim([-25, 25])
ax.set_ylim([-25, 25])
ax.set_title("MiniBatchKMeans")
ax.set_autoscaley_on(False)
plt.show()

Total running time of the script: ( 0 minutes 3.708 seconds)

Gallery generated by Sphinx-Gallery