Note
Click here to download the full example code or to run this example in your browser via Binder
Approximate nearest neighbors in TSNE¶
This example presents how to chain KNeighborsTransformer and TSNE in a pipeline.
It also shows how to wrap the packages annoy
and nmslib
to replace
KNeighborsTransformer and perform approximate nearest neighbors. These packages
can be installed with pip install annoy nmslib
.
Note: In KNeighborsTransformer we use the definition which includes each
training point as its own neighbor in the count of n_neighbors
, and for
compatibility reasons, one extra neighbor is computed when mode == 'distance'
.
Please note that we do the same in the proposed wrappers.
Sample output:
Benchmarking on MNIST_2000:
---------------------------
AnnoyTransformer: 0.305 sec
NMSlibTransformer: 0.144 sec
KNeighborsTransformer: 0.090 sec
TSNE with AnnoyTransformer: 2.818 sec
TSNE with NMSlibTransformer: 2.592 sec
TSNE with KNeighborsTransformer: 2.338 sec
TSNE with internal NearestNeighbors: 2.364 sec
Benchmarking on MNIST_10000:
----------------------------
AnnoyTransformer: 2.874 sec
NMSlibTransformer: 1.098 sec
KNeighborsTransformer: 1.264 sec
TSNE with AnnoyTransformer: 16.118 sec
TSNE with NMSlibTransformer: 15.281 sec
TSNE with KNeighborsTransformer: 15.400 sec
TSNE with internal NearestNeighbors: 15.573 sec
Note that the prediction speed KNeighborsTransformer was optimized in scikit-learn 1.1 and therefore approximate methods are not necessarily faster because computing the index takes time and can nullify the gains obtained at prediction time.
# Author: Tom Dupre la Tour
#
# License: BSD 3 clause
import time
import sys
try:
import annoy
except ImportError:
print("The package 'annoy' is required to run this example.")
sys.exit()
try:
import nmslib
except ImportError:
print("The package 'nmslib' is required to run this example.")
sys.exit()
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.neighbors import KNeighborsTransformer
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.datasets import fetch_openml
from sklearn.pipeline import make_pipeline
from sklearn.manifold import TSNE
from sklearn.utils import shuffle
class NMSlibTransformer(TransformerMixin, BaseEstimator):
"""Wrapper for using nmslib as sklearn's KNeighborsTransformer"""
def __init__(self, n_neighbors=5, metric="euclidean", method="sw-graph", n_jobs=1):
self.n_neighbors = n_neighbors
self.method = method
self.metric = metric
self.n_jobs = n_jobs
def fit(self, X):
self.n_samples_fit_ = X.shape[0]
# see more metric in the manual
# https://github.com/nmslib/nmslib/tree/master/manual
space = {
"euclidean": "l2",
"cosine": "cosinesimil",
"l1": "l1",
"l2": "l2",
}[self.metric]
self.nmslib_ = nmslib.init(method=self.method, space=space)
self.nmslib_.addDataPointBatch(X)
self.nmslib_.createIndex()
return self
def transform(self, X):
n_samples_transform = X.shape[0]
# For compatibility reasons, as each sample is considered as its own
# neighbor, one extra neighbor will be computed.
n_neighbors = self.n_neighbors + 1
results = self.nmslib_.knnQueryBatch(X, k=n_neighbors, num_threads=self.n_jobs)
indices, distances = zip(*results)
indices, distances = np.vstack(indices), np.vstack(distances)
indptr = np.arange(0, n_samples_transform * n_neighbors + 1, n_neighbors)
kneighbors_graph = csr_matrix(
(distances.ravel(), indices.ravel(), indptr),
shape=(n_samples_transform, self.n_samples_fit_),
)
return kneighbors_graph
class AnnoyTransformer(TransformerMixin, BaseEstimator):
"""Wrapper for using annoy.AnnoyIndex as sklearn's KNeighborsTransformer"""
def __init__(self, n_neighbors=5, metric="euclidean", n_trees=10, search_k=-1):
self.n_neighbors = n_neighbors
self.n_trees = n_trees
self.search_k = search_k
self.metric = metric
def fit(self, X):
self.n_samples_fit_ = X.shape[0]
self.annoy_ = annoy.AnnoyIndex(X.shape[1], metric=self.metric)
for i, x in enumerate(X):
self.annoy_.add_item(i, x.tolist())
self.annoy_.build(self.n_trees)
return self
def transform(self, X):
return self._transform(X)
def fit_transform(self, X, y=None):
return self.fit(X)._transform(X=None)
def _transform(self, X):
"""As `transform`, but handles X is None for faster `fit_transform`."""
n_samples_transform = self.n_samples_fit_ if X is None else X.shape[0]
# For compatibility reasons, as each sample is considered as its own
# neighbor, one extra neighbor will be computed.
n_neighbors = self.n_neighbors + 1
indices = np.empty((n_samples_transform, n_neighbors), dtype=int)
distances = np.empty((n_samples_transform, n_neighbors))
if X is None:
for i in range(self.annoy_.get_n_items()):
ind, dist = self.annoy_.get_nns_by_item(
i, n_neighbors, self.search_k, include_distances=True
)
indices[i], distances[i] = ind, dist
else:
for i, x in enumerate(X):
indices[i], distances[i] = self.annoy_.get_nns_by_vector(
x.tolist(), n_neighbors, self.search_k, include_distances=True
)
indptr = np.arange(0, n_samples_transform * n_neighbors + 1, n_neighbors)
kneighbors_graph = csr_matrix(
(distances.ravel(), indices.ravel(), indptr),
shape=(n_samples_transform, self.n_samples_fit_),
)
return kneighbors_graph
def test_transformers():
"""Test that AnnoyTransformer and KNeighborsTransformer give same results"""
X = np.random.RandomState(42).randn(10, 2)
knn = KNeighborsTransformer()
Xt0 = knn.fit_transform(X)
ann = AnnoyTransformer()
Xt1 = ann.fit_transform(X)
nms = NMSlibTransformer()
Xt2 = nms.fit_transform(X)
assert_array_almost_equal(Xt0.toarray(), Xt1.toarray(), decimal=5)
assert_array_almost_equal(Xt0.toarray(), Xt2.toarray(), decimal=5)
def load_mnist(n_samples):
"""Load MNIST, shuffle the data, and return only n_samples."""
mnist = fetch_openml("mnist_784", as_frame=False)
X, y = shuffle(mnist.data, mnist.target, random_state=2)
return X[:n_samples] / 255, y[:n_samples]
def run_benchmark():
datasets = [
("MNIST_2000", load_mnist(n_samples=2000)),
("MNIST_10000", load_mnist(n_samples=10000)),
]
n_iter = 500
perplexity = 30
metric = "euclidean"
# TSNE requires a certain number of neighbors which depends on the
# perplexity parameter.
# Add one since we include each sample as its own neighbor.
n_neighbors = int(3.0 * perplexity + 1) + 1
tsne_params = dict(
init="random", # pca not supported for sparse matrices
perplexity=perplexity,
method="barnes_hut",
random_state=42,
n_iter=n_iter,
learning_rate="auto",
)
transformers = [
("AnnoyTransformer", AnnoyTransformer(n_neighbors=n_neighbors, metric=metric)),
(
"NMSlibTransformer",
NMSlibTransformer(n_neighbors=n_neighbors, metric=metric),
),
(
"KNeighborsTransformer",
KNeighborsTransformer(
n_neighbors=n_neighbors, mode="distance", metric=metric
),
),
(
"TSNE with AnnoyTransformer",
make_pipeline(
AnnoyTransformer(n_neighbors=n_neighbors, metric=metric),
TSNE(metric="precomputed", **tsne_params),
),
),
(
"TSNE with NMSlibTransformer",
make_pipeline(
NMSlibTransformer(n_neighbors=n_neighbors, metric=metric),
TSNE(metric="precomputed", **tsne_params),
),
),
(
"TSNE with KNeighborsTransformer",
make_pipeline(
KNeighborsTransformer(
n_neighbors=n_neighbors, mode="distance", metric=metric
),
TSNE(metric="precomputed", **tsne_params),
),
),
("TSNE with internal NearestNeighbors", TSNE(metric=metric, **tsne_params)),
]
# init the plot
nrows = len(datasets)
ncols = np.sum([1 for name, model in transformers if "TSNE" in name])
fig, axes = plt.subplots(
nrows=nrows, ncols=ncols, squeeze=False, figsize=(5 * ncols, 4 * nrows)
)
axes = axes.ravel()
i_ax = 0
for dataset_name, (X, y) in datasets:
msg = "Benchmarking on %s:" % dataset_name
print("\n%s\n%s" % (msg, "-" * len(msg)))
for transformer_name, transformer in transformers:
start = time.time()
Xt = transformer.fit_transform(X)
duration = time.time() - start
# print the duration report
longest = np.max([len(name) for name, model in transformers])
whitespaces = " " * (longest - len(transformer_name))
print("%s: %s%.3f sec" % (transformer_name, whitespaces, duration))
# plot TSNE embedding which should be very similar across methods
if "TSNE" in transformer_name:
axes[i_ax].set_title(transformer_name + "\non " + dataset_name)
axes[i_ax].scatter(
Xt[:, 0],
Xt[:, 1],
c=y.astype(np.int32),
alpha=0.2,
cmap=plt.cm.viridis,
)
axes[i_ax].xaxis.set_major_formatter(NullFormatter())
axes[i_ax].yaxis.set_major_formatter(NullFormatter())
axes[i_ax].axis("tight")
i_ax += 1
fig.tight_layout()
plt.show()
if __name__ == "__main__":
test_transformers()
run_benchmark()
Total running time of the script: ( 0 minutes 0.000 seconds)