Hinweis
Gehen Sie zum Ende, um den vollständigen Beispielcode herunterzuladen oder dieses Beispiel über JupyterLite oder Binder in Ihrem Browser auszuführen.
ROC-Kurve mit Visualisierungs-API#
Scikit-learn definiert eine einfache API zum Erstellen von Visualisierungen für maschinelles Lernen. Die Hauptmerkmale dieser API sind die schnelle Anzeige und visuelle Anpassungen ohne Neuberechnung. In diesem Beispiel zeigen wir, wie die Visualisierungs-API durch den Vergleich von ROC-Kurven verwendet wird.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
Daten laden und einen SVC trainieren#
Zuerst laden wir den Wein-Datensatz und konvertieren ihn in ein binäres Klassifizierungsproblem. Dann trainieren wir einen Support Vector Classifier auf einem Trainingsdatensatz.
import matplotlib.pyplot as plt
from sklearn.datasets import load_wine
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import RocCurveDisplay
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
X, y = load_wine(return_X_y=True)
y = y == 2
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)
ROC-Kurve plotten#
Als nächstes plotten wir die ROC-Kurve mit einem einzigen Aufruf von sklearn.metrics.RocCurveDisplay.from_estimator. Das zurückgegebene svc_disp Objekt ermöglicht uns die weitere Verwendung der bereits berechneten ROC-Kurve für den SVC in zukünftigen Plots.
svc_disp = RocCurveDisplay.from_estimator(svc, X_test, y_test)
plt.show()

Random Forest trainieren und ROC-Kurve plotten#
Wir trainieren einen Random Forest Classifier und erstellen einen Plot, der ihn mit der SVC-ROC-Kurve vergleicht. Beachten Sie, wie svc_disp plot verwendet, um die SVC-ROC-Kurve zu plotten, ohne die Werte der ROC-Kurve selbst neu zu berechnen. Darüber hinaus übergeben wir alpha=0.8 an die Plot-Funktionen, um die Alpha-Werte der Kurven anzupassen.
rfc = RandomForestClassifier(n_estimators=10, random_state=42)
rfc.fit(X_train, y_train)
ax = plt.gca()
rfc_disp = RocCurveDisplay.from_estimator(
rfc, X_test, y_test, ax=ax, curve_kwargs=dict(alpha=0.8)
)
svc_disp.plot(ax=ax, curve_kwargs=dict(alpha=0.8))
plt.show()

Gesamtlaufzeit des Skripts: (0 Minuten 0,146 Sekunden)
Verwandte Beispiele
Receiver Operating Characteristic (ROC) mit Kreuzvalidierung
Multiklassen-Receiver Operating Characteristic (ROC)