Caching nächster Nachbarn#

Dieses Beispiel zeigt, wie die k nächsten Nachbarn vorab berechnet werden können, bevor sie in KNeighborsClassifier verwendet werden. KNeighborsClassifier kann die nächsten Nachbarn intern berechnen, aber die Vorab-Berechnung kann mehrere Vorteile haben, wie z. B. feinere Parameterkontrolle, Caching für mehrfache Verwendung oder benutzerdefinierte Implementierungen.

Hier nutzen wir die Caching-Eigenschaft von Pipelines, um den Graphen der nächsten Nachbarn zwischen mehreren Fits von KNeighborsClassifier zu cachen. Der erste Aufruf ist langsam, da er den Nachbargraphen berechnet, während nachfolgende Aufrufe schneller sind, da sie den Graphen nicht neu berechnen müssen. Hier sind die Dauern gering, da der Datensatz klein ist, aber der Gewinn kann bei größeren Datensätzen oder bei einer großen Anzahl von zu durchsuchenden Parametern erheblicher sein.

Classification accuracy, Fit time (with caching)
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

from tempfile import TemporaryDirectory

import matplotlib.pyplot as plt

from sklearn.datasets import load_digits
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier, KNeighborsTransformer
from sklearn.pipeline import Pipeline

X, y = load_digits(return_X_y=True)
n_neighbors_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]

# The transformer computes the nearest neighbors graph using the maximum number
# of neighbors necessary in the grid search. The classifier model filters the
# nearest neighbors graph as required by its own n_neighbors parameter.
graph_model = KNeighborsTransformer(n_neighbors=max(n_neighbors_list), mode="distance")
classifier_model = KNeighborsClassifier(metric="precomputed")

# Note that we give `memory` a directory to cache the graph computation
# that will be used several times when tuning the hyperparameters of the
# classifier.
with TemporaryDirectory(prefix="sklearn_graph_cache_") as tmpdir:
    full_model = Pipeline(
        steps=[("graph", graph_model), ("classifier", classifier_model)], memory=tmpdir
    )

    param_grid = {"classifier__n_neighbors": n_neighbors_list}
    grid_model = GridSearchCV(full_model, param_grid)
    grid_model.fit(X, y)

# Plot the results of the grid search.
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].errorbar(
    x=n_neighbors_list,
    y=grid_model.cv_results_["mean_test_score"],
    yerr=grid_model.cv_results_["std_test_score"],
)
axes[0].set(xlabel="n_neighbors", title="Classification accuracy")
axes[1].errorbar(
    x=n_neighbors_list,
    y=grid_model.cv_results_["mean_fit_time"],
    yerr=grid_model.cv_results_["std_fit_time"],
    color="r",
)
axes[1].set(xlabel="n_neighbors", title="Fit time (with caching)")
fig.tight_layout()
plt.show()

Gesamtlaufzeit des Skripts: (0 Minuten 0,671 Sekunden)

Verwandte Beispiele

Vergleich von Nächsten Nachbarn mit und ohne Neighborhood Components Analysis

Vergleich von Nächsten Nachbarn mit und ohne Neighborhood Components Analysis

Nearest Neighbors Klassifikation

Nearest Neighbors Klassifikation

Annähernde nächste Nachbarn in TSNE

Annähernde nächste Nachbarn in TSNE

Release Highlights für scikit-learn 0.22

Release Highlights für scikit-learn 0.22

Galerie generiert von Sphinx-Gallery