ROC-Kurve mit Visualisierungs-API#

Scikit-learn definiert eine einfache API zum Erstellen von Visualisierungen für maschinelles Lernen. Die Hauptmerkmale dieser API sind die schnelle Anzeige und visuelle Anpassungen ohne Neuberechnung. In diesem Beispiel zeigen wir, wie die Visualisierungs-API durch den Vergleich von ROC-Kurven verwendet wird.

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

Daten laden und einen SVC trainieren#

Zuerst laden wir den Wein-Datensatz und konvertieren ihn in ein binäres Klassifizierungsproblem. Dann trainieren wir einen Support Vector Classifier auf einem Trainingsdatensatz.

import matplotlib.pyplot as plt

from sklearn.datasets import load_wine
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import RocCurveDisplay
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC

X, y = load_wine(return_X_y=True)
y = y == 2

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)
SVC(random_state=42)
In einer Jupyter-Umgebung führen Sie diese Zelle bitte erneut aus, um die HTML-Darstellung anzuzeigen, oder vertrauen Sie dem Notebook.
Auf GitHub kann die HTML-Darstellung nicht gerendert werden. Versuchen Sie bitte, diese Seite mit nbviewer.org zu laden.


ROC-Kurve plotten#

Als nächstes plotten wir die ROC-Kurve mit einem einzigen Aufruf von sklearn.metrics.RocCurveDisplay.from_estimator. Das zurückgegebene svc_disp Objekt ermöglicht uns die weitere Verwendung der bereits berechneten ROC-Kurve für den SVC in zukünftigen Plots.

svc_disp = RocCurveDisplay.from_estimator(svc, X_test, y_test)
plt.show()
plot roc curve visualization api

Random Forest trainieren und ROC-Kurve plotten#

Wir trainieren einen Random Forest Classifier und erstellen einen Plot, der ihn mit der SVC-ROC-Kurve vergleicht. Beachten Sie, wie svc_disp plot verwendet, um die SVC-ROC-Kurve zu plotten, ohne die Werte der ROC-Kurve selbst neu zu berechnen. Darüber hinaus übergeben wir alpha=0.8 an die Plot-Funktionen, um die Alpha-Werte der Kurven anzupassen.

rfc = RandomForestClassifier(n_estimators=10, random_state=42)
rfc.fit(X_train, y_train)
ax = plt.gca()
rfc_disp = RocCurveDisplay.from_estimator(
    rfc, X_test, y_test, ax=ax, curve_kwargs=dict(alpha=0.8)
)
svc_disp.plot(ax=ax, curve_kwargs=dict(alpha=0.8))
plt.show()
plot roc curve visualization api

Gesamtlaufzeit des Skripts: (0 Minuten 0,146 Sekunden)

Verwandte Beispiele

Receiver Operating Characteristic (ROC) mit Kreuzvalidierung

Receiver Operating Characteristic (ROC) mit Kreuzvalidierung

Release Highlights für scikit-learn 0.22

Release Highlights für scikit-learn 0.22

Multiklassen-Receiver Operating Characteristic (ROC)

Multiklassen-Receiver Operating Characteristic (ROC)

Detection Error Tradeoff (DET) Kurve

Detection Error Tradeoff (DET) Kurve

Galerie generiert von Sphinx-Gallery