Hinweis
Gehe zum Ende, um den vollständigen Beispielcode herunterzuladen oder dieses Beispiel über JupyterLite oder Binder in deinem Browser auszuführen.
Benutzerdefinierte Refit-Strategie für Grid Search mit Kreuzvalidierung#
Dieses Beispiel zeigt, wie ein Klassifikator mittels Kreuzvalidierung optimiert wird, was mit dem Objekt GridSearchCV auf einem Entwicklungssatz durchgeführt wird, der nur die Hälfte der verfügbaren gelabelten Daten umfasst.
Die Leistung der ausgewählten Hyperparameter und des trainierten Modells wird anschließend auf einem dedizierten Evaluationssatz gemessen, der während des Modellauswahlschritts nicht verwendet wurde.
Weitere Details zu Werkzeugen für die Modellauswahl finden Sie in den Abschnitten zur Kreuzvalidierung: Bewertung der Schätzleistung und zur Optimierung der Hyperparameter eines Schätzers.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
Der Datensatz#
Wir arbeiten mit dem digits Datensatz. Das Ziel ist die Klassifizierung von Bildern handschriftlicher Ziffern. Wir transformieren das Problem in eine binäre Klassifikation, um es einfacher verständlich zu machen: Das Ziel ist zu identifizieren, ob eine Ziffer eine 8 ist oder nicht.
from sklearn import datasets
digits = datasets.load_digits()
Um einen Klassifikator auf Bildern zu trainieren, müssen wir sie zu Vektoren abflachen. Jedes Bild mit 8 x 8 Pixeln muss in einen Vektor mit 64 Pixeln umgewandelt werden. Somit erhalten wir ein finales Datenarray der Form (n_images, n_pixels).
n_samples = len(digits.images)
X = digits.images.reshape((n_samples, -1))
y = digits.target == 8
print(
f"The number of images is {X.shape[0]} and each image contains {X.shape[1]} pixels"
)
The number of images is 1797 and each image contains 64 pixels
Wie in der Einleitung dargestellt, werden die Daten in einen Trainings- und einen Testsatz gleicher Größe aufgeteilt.
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)
Definieren unserer Grid-Search-Strategie#
Wir wählen einen Klassifikator aus, indem wir die besten Hyperparameter auf den Folds des Trainingssatzes suchen. Dazu müssen wir die Scores definieren, um den besten Kandidaten auszuwählen.
scores = ["precision", "recall"]
Wir können auch eine Funktion definieren, die dem Parameter refit der GridSearchCV-Instanz übergeben wird. Diese implementiert die benutzerdefinierte Strategie zur Auswahl des besten Kandidaten aus dem Attribut cv_results_ des GridSearchCV. Sobald der Kandidat ausgewählt ist, wird er automatisch vom GridSearchCV-Objekt neu angepasst.
Hierbei ist die Strategie, die Modelle, die in Bezug auf Präzision und Recall am besten sind, zu einer engeren Auswahl zusammenzufassen. Aus den ausgewählten Modellen wählen wir schließlich das schnellste Modell bei der Vorhersage aus. Beachten Sie, dass diese benutzerdefinierten Auswahlmöglichkeiten rein willkürlich sind.
import pandas as pd
def print_dataframe(filtered_cv_results):
"""Pretty print for filtered dataframe"""
for mean_precision, std_precision, mean_recall, std_recall, params in zip(
filtered_cv_results["mean_test_precision"],
filtered_cv_results["std_test_precision"],
filtered_cv_results["mean_test_recall"],
filtered_cv_results["std_test_recall"],
filtered_cv_results["params"],
):
print(
f"precision: {mean_precision:0.3f} (±{std_precision:0.03f}),"
f" recall: {mean_recall:0.3f} (±{std_recall:0.03f}),"
f" for {params}"
)
print()
def refit_strategy(cv_results):
"""Define the strategy to select the best estimator.
The strategy defined here is to filter-out all results below a precision threshold
of 0.98, rank the remaining by recall and keep all models with one standard
deviation of the best by recall. Once these models are selected, we can select the
fastest model to predict.
Parameters
----------
cv_results : dict of numpy (masked) ndarrays
CV results as returned by the `GridSearchCV`.
Returns
-------
best_index : int
The index of the best estimator as it appears in `cv_results`.
"""
# print the info about the grid-search for the different scores
precision_threshold = 0.98
cv_results_ = pd.DataFrame(cv_results)
print("All grid-search results:")
print_dataframe(cv_results_)
# Filter-out all results below the threshold
high_precision_cv_results = cv_results_[
cv_results_["mean_test_precision"] > precision_threshold
]
print(f"Models with a precision higher than {precision_threshold}:")
print_dataframe(high_precision_cv_results)
high_precision_cv_results = high_precision_cv_results[
[
"mean_score_time",
"mean_test_recall",
"std_test_recall",
"mean_test_precision",
"std_test_precision",
"rank_test_recall",
"rank_test_precision",
"params",
]
]
# Select the most performant models in terms of recall
# (within 1 sigma from the best)
best_recall_std = high_precision_cv_results["mean_test_recall"].std()
best_recall = high_precision_cv_results["mean_test_recall"].max()
best_recall_threshold = best_recall - best_recall_std
high_recall_cv_results = high_precision_cv_results[
high_precision_cv_results["mean_test_recall"] > best_recall_threshold
]
print(
"Out of the previously selected high precision models, we keep all the\n"
"the models within one standard deviation of the highest recall model:"
)
print_dataframe(high_recall_cv_results)
# From the best candidates, select the fastest model to predict
fastest_top_recall_high_precision_index = high_recall_cv_results[
"mean_score_time"
].idxmin()
print(
"\nThe selected final model is the fastest to predict out of the previously\n"
"selected subset of best models based on precision and recall.\n"
"Its scoring time is:\n\n"
f"{high_recall_cv_results.loc[fastest_top_recall_high_precision_index]}"
)
return fastest_top_recall_high_precision_index
Optimieren von Hyperparametern#
Nachdem wir unsere Strategie zur Auswahl des besten Modells definiert haben, definieren wir die Werte der Hyperparameter und erstellen die Grid-Search-Instanz
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
tuned_parameters = [
{"kernel": ["rbf"], "gamma": [1e-3, 1e-4], "C": [1, 10, 100, 1000]},
{"kernel": ["linear"], "C": [1, 10, 100, 1000]},
]
grid_search = GridSearchCV(
SVC(), tuned_parameters, scoring=scores, refit=refit_strategy
)
grid_search.fit(X_train, y_train)
All grid-search results:
precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.257 (±0.061), for {'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 0.968 (±0.039), recall: 0.780 (±0.083), for {'C': 10, 'gamma': 0.0001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 0.905 (±0.058), recall: 0.889 (±0.074), for {'C': 100, 'gamma': 0.0001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 0.904 (±0.058), recall: 0.890 (±0.073), for {'C': 1000, 'gamma': 0.0001, 'kernel': 'rbf'}
precision: 0.695 (±0.073), recall: 0.743 (±0.065), for {'C': 1, 'kernel': 'linear'}
precision: 0.643 (±0.066), recall: 0.757 (±0.066), for {'C': 10, 'kernel': 'linear'}
precision: 0.611 (±0.028), recall: 0.744 (±0.044), for {'C': 100, 'kernel': 'linear'}
precision: 0.618 (±0.039), recall: 0.744 (±0.044), for {'C': 1000, 'kernel': 'linear'}
Models with a precision higher than 0.98:
precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.257 (±0.061), for {'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'}
Out of the previously selected high precision models, we keep all the
the models within one standard deviation of the highest recall model:
precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'}
The selected final model is the fastest to predict out of the previously
selected subset of best models based on precision and recall.
Its scoring time is:
mean_score_time 0.005081
mean_test_recall 0.877206
std_test_recall 0.069196
mean_test_precision 1.0
std_test_precision 0.0
rank_test_recall 3
rank_test_precision 1
params {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
Name: 4, dtype: object
Die von der Grid-Search mit unserer benutzerdefinierten Strategie ausgewählten Parameter sind
grid_search.best_params_
{'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
Schließlich bewerten wir das feinabgestimmte Modell auf dem zurückbehaltenen Evaluationsdatensatz: Das Objekt grid_search **wurde automatisch** mit den Parametern, die von unserer benutzerdefinierten Refit-Strategie ausgewählt wurden, auf dem gesamten Trainingsdatensatz **neu angepasst**.
Wir können den Klassifikationsbericht verwenden, um Standard-Klassifikationsmetriken auf dem zurückbehaltenen Satz zu berechnen
from sklearn.metrics import classification_report
y_pred = grid_search.predict(X_test)
print(classification_report(y_test, y_pred))
precision recall f1-score support
False 0.99 1.00 0.99 807
True 1.00 0.87 0.93 92
accuracy 0.99 899
macro avg 0.99 0.93 0.96 899
weighted avg 0.99 0.99 0.99 899
Hinweis
Das Problem ist zu einfach: Das Plateau der Hyperparameter ist zu flach und das Ausgabemodell ist dasselbe für Präzision und Recall bei Gleichstand in der Qualität.
Gesamtlaufzeit des Skripts: (0 Minuten 10,200 Sekunden)
Verwandte Beispiele
Modellkomplexität und kreuzvalidierter Score ausbalancieren
Vergleich von zufälliger Suche und Gitter-Suche zur Hyperparameter-Schätzung
Rekursive Merkmalseliminierung mit Kreuzvalidierung