缓存最近的邻居#

此示例演示了如何在KNeighborsClassifier中使用k个最近邻居之前预先计算它们。KNeighborsClassifier可以在内部计算最近的邻居,但预计算它们可以有几个好处,例如更好的参数控制、高速缓存以供多次使用或自定义实现。

在这里,我们使用管道的缓存属性来缓存KNeighborsClassifier的多个集合之间的最近邻居图。第一次调用很慢,因为它计算邻居图,而后续调用更快,因为它们不需要重新计算图。这里的持续时间很小,因为数据集很小,但当数据集变得更大或当要搜索的参数网格很大时,收益可能会更大。

Classification accuracy, Fit time (with caching)
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

from tempfile import TemporaryDirectory

import matplotlib.pyplot as plt

from sklearn.datasets import load_digits
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier, KNeighborsTransformer
from sklearn.pipeline import Pipeline

X, y = load_digits(return_X_y=True)
n_neighbors_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]

# The transformer computes the nearest neighbors graph using the maximum number
# of neighbors necessary in the grid search. The classifier model filters the
# nearest neighbors graph as required by its own n_neighbors parameter.
graph_model = KNeighborsTransformer(n_neighbors=max(n_neighbors_list), mode="distance")
classifier_model = KNeighborsClassifier(metric="precomputed")

# Note that we give `memory` a directory to cache the graph computation
# that will be used several times when tuning the hyperparameters of the
# classifier.
with TemporaryDirectory(prefix="sklearn_graph_cache_") as tmpdir:
    full_model = Pipeline(
        steps=[("graph", graph_model), ("classifier", classifier_model)], memory=tmpdir
    )

    param_grid = {"classifier__n_neighbors": n_neighbors_list}
    grid_model = GridSearchCV(full_model, param_grid)
    grid_model.fit(X, y)

# Plot the results of the grid search.
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].errorbar(
    x=n_neighbors_list,
    y=grid_model.cv_results_["mean_test_score"],
    yerr=grid_model.cv_results_["std_test_score"],
)
axes[0].set(xlabel="n_neighbors", title="Classification accuracy")
axes[1].errorbar(
    x=n_neighbors_list,
    y=grid_model.cv_results_["mean_fit_time"],
    yerr=grid_model.cv_results_["std_fit_time"],
    color="r",
)
axes[1].set(xlabel="n_neighbors", title="Fit time (with caching)")
fig.tight_layout()
plt.show()

Total running time of the script: (0分1.137秒)

相关实例

比较使用和不使用邻居成分分析的最近邻居

Comparing Nearest Neighbors with and without Neighborhood Components Analysis

最近邻分类

Nearest Neighbors Classification

TSNE的大约最近邻居

Approximate nearest neighbors in TSNE

有结构和不有结构的集聚

Agglomerative clustering with and without structure

Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io> _