Multiklassen-Receiver Operating Characteristic (ROC)#

Dieses Beispiel beschreibt die Verwendung der Metrik Receiver Operating Characteristic (ROC) zur Bewertung der Qualität von Multiklassen-Klassifikatoren.

ROC-Kurven zeigen typischerweise die True Positive Rate (TPR) auf der Y-Achse und die False Positive Rate (FPR) auf der X-Achse. Das bedeutet, dass die obere linke Ecke des Diagramms der "ideale" Punkt ist – eine FPR von Null und eine TPR von Eins. Dies ist nicht sehr realistisch, bedeutet aber, dass eine größere Fläche unter der Kurve (AUC) normalerweise besser ist. Die "Steilheit" von ROC-Kurven ist ebenfalls wichtig, da es ideal ist, die TPR zu maximieren und gleichzeitig die FPR zu minimieren.

ROC-Kurven werden typischerweise in der binären Klassifizierung verwendet, wo TPR und FPR eindeutig definiert werden können. Im Fall der Multiklassen-Klassifizierung wird ein Begriff von TPR oder FPR erst nach Binarisierung der Ausgabe erhalten. Dies kann auf 2 verschiedene Arten erfolgen:

  • das One-vs-Rest-Schema vergleicht jede Klasse mit allen anderen (angenommen als eine);

  • das One-vs-One-Schema vergleicht jede eindeutige paarweise Kombination von Klassen.

In diesem Beispiel untersuchen wir beide Schemata und demonstrieren die Konzepte von Mikro- und Makromittelung als verschiedene Wege, die Informationen der Multiklassen-ROC-Kurven zusammenzufassen.

Hinweis

Siehe Receiver Operating Characteristic (ROC) mit Kreuzvalidierung für eine Erweiterung dieses Beispiels, bei der die Varianz der ROC-Kurven und ihre jeweiligen AUC geschätzt werden.

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

Daten laden und vorbereiten#

Wir importieren das Iris-Pflanzendatenset, das 3 Klassen enthält, die jeweils einem Typ von Irispflanze entsprechen. Eine Klasse ist linear von den anderen beiden trennbar; letztere sind **nicht** linear voneinander trennbar.

Hier binarisieren wir die Ausgabe und fügen verrauschte Merkmale hinzu, um das Problem zu erschweren.

import numpy as np

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

iris = load_iris()
target_names = iris.target_names
X, y = iris.data, iris.target
y = iris.target_names[y]

random_state = np.random.RandomState(0)
n_samples, n_features = X.shape
n_classes = len(np.unique(y))
X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1)
(
    X_train,
    X_test,
    y_train,
    y_test,
) = train_test_split(X, y, test_size=0.5, stratify=y, random_state=0)

Wir trainieren ein LogisticRegression-Modell, das dank der Verwendung der multinominalen Formulierung natürliche Unterstützung für Multiklassenprobleme bietet.

from sklearn.linear_model import LogisticRegression

classifier = LogisticRegression()
y_score = classifier.fit(X_train, y_train).predict_proba(X_test)

One-vs-Rest Multiklassen-ROC#

Die One-vs-the-Rest (OvR) Multiklassenstrategie, auch bekannt als one-vs-all, besteht darin, eine ROC-Kurve für jede der n_classes zu berechnen. In jedem Schritt wird eine gegebene Klasse als positive Klasse betrachtet und die restlichen Klassen werden als negative Klasse in einem Block behandelt.

Hinweis

Man sollte die OvR-Strategie, die zur **Bewertung** von Multiklassen-Klassifikatoren verwendet wird, nicht mit der OvR-Strategie verwechseln, die zum **Trainieren** eines Multiklassen-Klassifikators durch Anpassen einer Reihe von binären Klassifikatoren verwendet wird (zum Beispiel über den Meta-Schätzer OneVsRestClassifier). Die OvR-ROC-Bewertung kann verwendet werden, um beliebige Klassifikationsmodelle zu untersuchen, unabhängig davon, wie sie trainiert wurden (siehe Multiklassen- und Multi-Output-Algorithmen).

In diesem Abschnitt verwenden wir einen LabelBinarizer, um das Ziel durch One-Hot-Encoding im OvR-Verfahren zu binarisieren. Das bedeutet, dass das Ziel der Form (n_samples,) auf ein Ziel der Form (n_samples, n_classes) abgebildet wird.

from sklearn.preprocessing import LabelBinarizer

label_binarizer = LabelBinarizer().fit(y_train)
y_onehot_test = label_binarizer.transform(y_test)
y_onehot_test.shape  # (n_samples, n_classes)
(75, 3)

Wir können auch leicht die Kodierung einer bestimmten Klasse überprüfen

label_binarizer.transform(["virginica"])
array([[0, 0, 1]])

ROC-Kurve, die eine bestimmte Klasse zeigt#

Im folgenden Diagramm zeigen wir die resultierende ROC-Kurve, wenn die Iris-Blumen entweder als "virginica" (class_id=2) oder "nicht-virginica" (die restlichen) betrachtet werden.

class_of_interest = "virginica"
class_id = np.flatnonzero(label_binarizer.classes_ == class_of_interest)[0]
class_id
np.int64(2)
import matplotlib.pyplot as plt

from sklearn.metrics import RocCurveDisplay

display = RocCurveDisplay.from_predictions(
    y_onehot_test[:, class_id],
    y_score[:, class_id],
    name=f"{class_of_interest} vs the rest",
    curve_kwargs=dict(color="darkorange"),
    plot_chance_level=True,
    despine=True,
)
_ = display.ax_.set(
    xlabel="False Positive Rate",
    ylabel="True Positive Rate",
    title="One-vs-Rest ROC curves:\nVirginica vs (Setosa & Versicolor)",
)
One-vs-Rest ROC curves: Virginica vs (Setosa & Versicolor)

ROC-Kurve unter Verwendung der Mikro-gemittelten OvR#

Mikromittelung aggregiert die Beiträge aller Klassen (unter Verwendung von numpy.ravel), um die durchschnittlichen Metriken wie folgt zu berechnen:

\(TPR=\frac{\sum_{c}TP_c}{\sum_{c}(TP_c + FN_c)}\) ;

\(FPR=\frac{\sum_{c}FP_c}{\sum_{c}(FP_c + TN_c)}\) .

Wir können kurz die Auswirkung von numpy.ravel demonstrieren

print(f"y_score:\n{y_score[0:2, :]}")
print()
print(f"y_score.ravel():\n{y_score[0:2, :].ravel()}")
y_score:
[[0.38 0.05 0.57]
 [0.07 0.28 0.65]]

y_score.ravel():
[0.38 0.05 0.57 0.07 0.28 0.65]

In einem Multiklassen-Klassifizierungsaufbau mit stark unausgeglichenen Klassen ist die Mikromittelung der Makromittelung vorzuziehen. In solchen Fällen kann alternativ eine gewichtete Makromittelung verwendet werden, die hier nicht gezeigt wird.

display = RocCurveDisplay.from_predictions(
    y_onehot_test.ravel(),
    y_score.ravel(),
    name="micro-average OvR",
    curve_kwargs=dict(color="darkorange"),
    plot_chance_level=True,
    despine=True,
)
_ = display.ax_.set(
    xlabel="False Positive Rate",
    ylabel="True Positive Rate",
    title="Micro-averaged One-vs-Rest\nReceiver Operating Characteristic",
)
Micro-averaged One-vs-Rest Receiver Operating Characteristic

In dem Fall, in dem das Hauptinteresse nicht das Diagramm, sondern der ROC-AUC-Wert selbst ist, können wir den im Diagramm angezeigten Wert mit roc_auc_score reproduzieren.

from sklearn.metrics import roc_auc_score

micro_roc_auc_ovr = roc_auc_score(
    y_test,
    y_score,
    multi_class="ovr",
    average="micro",
)

print(f"Micro-averaged One-vs-Rest ROC AUC score:\n{micro_roc_auc_ovr:.2f}")
Micro-averaged One-vs-Rest ROC AUC score:
0.77

Dies ist äquivalent zur Berechnung der ROC-Kurve mit roc_curve und dann der Fläche unter der Kurve mit auc für die verflachten wahren und vorhergesagten Klassen.

from sklearn.metrics import auc, roc_curve

# store the fpr, tpr, and roc_auc for all averaging strategies
fpr, tpr, roc_auc = dict(), dict(), dict()
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_onehot_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

print(f"Micro-averaged One-vs-Rest ROC AUC score:\n{roc_auc['micro']:.2f}")
Micro-averaged One-vs-Rest ROC AUC score:
0.77

Hinweis

Standardmäßig fügt die Berechnung der ROC-Kurve einen einzelnen Punkt beim maximalen Falsch-Positiv-Rate hinzu, indem sie lineare Interpolation und die McClish-Korrektur verwendet [Analyzing a portion of the ROC curve Med Decis Making. 1989 Jul-Sep; 9(3):190-5.].

ROC-Kurve unter Verwendung der OvR-Makromittelung#

Das Erhalten der Makromittelung erfordert die unabhängige Berechnung der Metrik für jede Klasse und anschließende Mittelung über diese, wodurch allen Klassen a priori die gleiche Bedeutung beigemessen wird. Wir aggregieren zunächst die wahren/falschen Positivraten pro Klasse

\(TPR=\frac{1}{C}\sum_{c}\frac{TP_c}{TP_c + FN_c}\) ;

\(FPR=\frac{1}{C}\sum_{c}\frac{FP_c}{FP_c + TN_c}\) .

wobei C die Gesamtzahl der Klassen ist.

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_onehot_test[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

fpr_grid = np.linspace(0.0, 1.0, 1000)

# Interpolate all ROC curves at these points
mean_tpr = np.zeros_like(fpr_grid)

for i in range(n_classes):
    mean_tpr += np.interp(fpr_grid, fpr[i], tpr[i])  # linear interpolation

# Average it and compute AUC
mean_tpr /= n_classes

fpr["macro"] = fpr_grid
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

print(f"Macro-averaged One-vs-Rest ROC AUC score:\n{roc_auc['macro']:.2f}")
Macro-averaged One-vs-Rest ROC AUC score:
0.78

Diese Berechnung ist äquivalent zum einfachen Aufruf von

macro_roc_auc_ovr = roc_auc_score(
    y_test,
    y_score,
    multi_class="ovr",
    average="macro",
)

print(f"Macro-averaged One-vs-Rest ROC AUC score:\n{macro_roc_auc_ovr:.2f}")
Macro-averaged One-vs-Rest ROC AUC score:
0.78

Alle OvR-ROC-Kurven zusammen plotten#

from itertools import cycle

fig, ax = plt.subplots(figsize=(6, 6))

plt.plot(
    fpr["micro"],
    tpr["micro"],
    label=f"micro-average ROC curve (AUC = {roc_auc['micro']:.2f})",
    color="deeppink",
    linestyle=":",
    linewidth=4,
)

plt.plot(
    fpr["macro"],
    tpr["macro"],
    label=f"macro-average ROC curve (AUC = {roc_auc['macro']:.2f})",
    color="navy",
    linestyle=":",
    linewidth=4,
)

colors = cycle(["aqua", "darkorange", "cornflowerblue"])
for class_id, color in zip(range(n_classes), colors):
    RocCurveDisplay.from_predictions(
        y_onehot_test[:, class_id],
        y_score[:, class_id],
        name=f"ROC curve for {target_names[class_id]}",
        curve_kwargs=dict(color=color),
        ax=ax,
        plot_chance_level=(class_id == 2),
        despine=True,
    )

_ = ax.set(
    xlabel="False Positive Rate",
    ylabel="True Positive Rate",
    title="Extension of Receiver Operating Characteristic\nto One-vs-Rest multiclass",
)
Extension of Receiver Operating Characteristic to One-vs-Rest multiclass

One-vs-One Multiklassen-ROC#

Die One-vs-One (OvO) Multiklassenstrategie besteht darin, einen Klassifikator pro Klassenpaar anzupassen. Da dies das Trainieren von n_classes * (n_classes - 1) / 2 Klassifikatoren erfordert, ist diese Methode aufgrund ihrer O(n_classes ^2) Komplexität normalerweise langsamer als One-vs-Rest.

In diesem Abschnitt demonstrieren wir die makrogemittelte AUC mit dem OvO-Schema für die 3 möglichen Kombinationen im Iris-Pflanzendatenset: "setosa" vs "versicolor", "versicolor" vs "virginica" und "virginica" vs "setosa". Beachten Sie, dass die Mikromittelung für das OvO-Schema nicht definiert ist.

ROC-Kurve unter Verwendung der OvO-Makromittelung#

Im OvO-Schema besteht der erste Schritt darin, alle möglichen eindeutigen Paar-Kombinationen zu identifizieren. Die Berechnung der Scores erfolgt, indem eines der Elemente in einem gegebenen Paar als positive Klasse und das andere Element als negative Klasse behandelt wird, dann der Score durch Invertieren der Rollen und Mittelung beider Scores neu berechnet wird.

from itertools import combinations

pair_list = list(combinations(np.unique(y), 2))
print(pair_list)
[(np.str_('setosa'), np.str_('versicolor')), (np.str_('setosa'), np.str_('virginica')), (np.str_('versicolor'), np.str_('virginica'))]
pair_scores = []
mean_tpr = dict()

for ix, (label_a, label_b) in enumerate(pair_list):
    a_mask = y_test == label_a
    b_mask = y_test == label_b
    ab_mask = np.logical_or(a_mask, b_mask)

    a_true = a_mask[ab_mask]
    b_true = b_mask[ab_mask]

    idx_a = np.flatnonzero(label_binarizer.classes_ == label_a)[0]
    idx_b = np.flatnonzero(label_binarizer.classes_ == label_b)[0]

    fpr_a, tpr_a, _ = roc_curve(a_true, y_score[ab_mask, idx_a])
    fpr_b, tpr_b, _ = roc_curve(b_true, y_score[ab_mask, idx_b])

    mean_tpr[ix] = np.zeros_like(fpr_grid)
    mean_tpr[ix] += np.interp(fpr_grid, fpr_a, tpr_a)
    mean_tpr[ix] += np.interp(fpr_grid, fpr_b, tpr_b)
    mean_tpr[ix] /= 2
    mean_score = auc(fpr_grid, mean_tpr[ix])
    pair_scores.append(mean_score)

    fig, ax = plt.subplots(figsize=(6, 6))
    plt.plot(
        fpr_grid,
        mean_tpr[ix],
        label=f"Mean {label_a} vs {label_b} (AUC = {mean_score:.2f})",
        linestyle=":",
        linewidth=4,
    )
    RocCurveDisplay.from_predictions(
        a_true,
        y_score[ab_mask, idx_a],
        ax=ax,
        name=f"{label_a} as positive class",
    )
    RocCurveDisplay.from_predictions(
        b_true,
        y_score[ab_mask, idx_b],
        ax=ax,
        name=f"{label_b} as positive class",
        plot_chance_level=True,
        despine=True,
    )
    ax.set(
        xlabel="False Positive Rate",
        ylabel="True Positive Rate",
        title=f"{target_names[idx_a]} vs {label_b} ROC curves",
    )

print(f"Macro-averaged One-vs-One ROC AUC score:\n{np.average(pair_scores):.2f}")
  • setosa vs versicolor ROC curves
  • setosa vs virginica ROC curves
  • versicolor vs virginica ROC curves
Macro-averaged One-vs-One ROC AUC score:
0.78

Man kann auch bestätigen, dass die von uns "von Hand" berechnete Makromittelung mit der implementierten Option average="macro" der Funktion roc_auc_score äquivalent ist.

macro_roc_auc_ovo = roc_auc_score(
    y_test,
    y_score,
    multi_class="ovo",
    average="macro",
)

print(f"Macro-averaged One-vs-One ROC AUC score:\n{macro_roc_auc_ovo:.2f}")
Macro-averaged One-vs-One ROC AUC score:
0.78

Alle OvO-ROC-Kurven zusammen plotten#

ovo_tpr = np.zeros_like(fpr_grid)

fig, ax = plt.subplots(figsize=(6, 6))
for ix, (label_a, label_b) in enumerate(pair_list):
    ovo_tpr += mean_tpr[ix]
    ax.plot(
        fpr_grid,
        mean_tpr[ix],
        label=f"Mean {label_a} vs {label_b} (AUC = {pair_scores[ix]:.2f})",
    )

ovo_tpr /= sum(1 for pair in enumerate(pair_list))

ax.plot(
    fpr_grid,
    ovo_tpr,
    label=f"One-vs-One macro-average (AUC = {macro_roc_auc_ovo:.2f})",
    linestyle=":",
    linewidth=4,
)
ax.plot([0, 1], [0, 1], "k--", label="Chance level (AUC = 0.5)")
_ = ax.set(
    xlabel="False Positive Rate",
    ylabel="True Positive Rate",
    title="Extension of Receiver Operating Characteristic\nto One-vs-One multiclass",
    aspect="equal",
    xlim=(-0.01, 1.01),
    ylim=(-0.01, 1.01),
)
Extension of Receiver Operating Characteristic to One-vs-One multiclass

Wir bestätigen, dass die Klassen "versicolor" und "virginica" von einem linearen Klassifikator nicht gut unterschieden werden. Beachten Sie, dass die ROC-AUC-Punktzahl von "virginica" gegen den Rest (0,77) zwischen den OvO-ROC-AUC-Punktzahlen für "versicolor" gegen "virginica" (0,64) und "setosa" gegen "virginica" (0,90) liegt. Tatsächlich liefert die OvO-Strategie zusätzliche Informationen über die Verwechslung zwischen einem Klassenpaar, auf Kosten des Rechenaufwands bei einer großen Anzahl von Klassen.

Die OvO-Strategie wird empfohlen, wenn der Benutzer hauptsächlich daran interessiert ist, eine bestimmte Klasse oder eine Teilmenge von Klassen korrekt zu identifizieren, während die globale Leistung eines Klassifikators immer noch durch eine gegebene Durchschnittsstrategie zusammengefasst werden kann.

Bei der Arbeit mit unausgeglichenen Datensätzen ist die Wahl der geeigneten Metrik, die auf dem Geschäftskontext oder dem behandelten Problem basiert, entscheidend. Es ist auch wichtig, eine geeignete Mittelungsmethode (Mikro vs. Makro) entsprechend dem gewünschten Ergebnis auszuwählen

  • Die Mikromittelung aggregiert Metriken über alle Instanzen und behandelt jede einzelne Instanz gleich, unabhängig von ihrer Klasse. Dieser Ansatz ist nützlich bei der Bewertung der Gesamtleistung, aber beachten Sie, dass er bei unausgeglichenen Datensätzen von der Mehrheitsklasse dominiert werden kann.

  • Die Makromittelung berechnet Metriken für jede Klasse unabhängig und bildet dann den Durchschnitt, wodurch jeder Klasse die gleiche Gewichtung gegeben wird. Dies ist besonders nützlich, wenn Sie möchten, dass unterrepräsentierte Klassen genauso wichtig sind wie stark bevölkerte Klassen.

Gesamtlaufzeit des Skripts: (0 Minuten 0,577 Sekunden)

Verwandte Beispiele

Receiver Operating Characteristic (ROC) mit Kreuzvalidierung

Receiver Operating Characteristic (ROC) mit Kreuzvalidierung

Detection Error Tradeoff (DET) Kurve

Detection Error Tradeoff (DET) Kurve

ROC-Kurve mit Visualisierungs-API

ROC-Kurve mit Visualisierungs-API

Visualisierungen mit Display-Objekten

Visualisierungen mit Display-Objekten

Galerie generiert von Sphinx-Gallery