TunedThresholdClassifierCV#

class sklearn.model_selection.TunedThresholdClassifierCV(estimator, *, scoring='balanced_accuracy', response_method='auto', thresholds=100, cv=None, refit=True, n_jobs=None, random_state=None, store_cv_results=False)[Quelle]#

Klassifikator, der die Entscheidungsschwelle mittels Kreuzvalidierung nachjustiert.

Dieser Estimator optimiert nachträglich den Entscheidungsschwellenwert (Cut-off-Punkt), der zur Umwandlung von Wahrscheinlichkeitsschätzungen (d.h. der Ausgabe von predict_proba) oder Entscheidungswerten (d.h. der Ausgabe von decision_function) in eine Klassenbezeichnung verwendet wird. Die Optimierung erfolgt durch Maximierung einer binären Metrik, möglicherweise eingeschränkt durch eine andere Metrik.

Lesen Sie mehr im Benutzerhandbuch.

Hinzugefügt in Version 1.5.

Parameter:
estimatorSchätzer-Instanz

Der Klassifikator, trainiert oder nicht, für den wir den Entscheidungsschwellenwert optimieren möchten, der während predict verwendet wird.

scoringstr oder aufrufbar, Standard="balanced_accuracy"

Die zu optimierende Zielmetrik. Kann sein

  • str: Zeichenkette, die einer Bewertungsfunktion für binäre Klassifizierung zugeordnet ist, siehe String-Namen für Bewerter für Optionen.

  • callable: Ein Scorer-Callable-Objekt (z. B. Funktion) mit der Signatur scorer(estimator, X, y). Siehe Callable Scorer für Details.

response_method{“auto”, “decision_function”, “predict_proba”}, Standard="auto"

Methoden des Klassifikators estimator, die sich auf die Entscheidungsfunktion beziehen, für die wir einen Schwellenwert finden möchten. Es kann sein

  • wenn "auto", wird versucht, für jeden Klassifikator "predict_proba" oder "decision_function" in dieser Reihenfolge aufzurufen.

  • andernfalls eine der Optionen "predict_proba" oder "decision_function". Wenn die Methode nicht vom Klassifikator implementiert wird, wird ein Fehler ausgelöst.

thresholdsint oder array-ähnlich, Standard=100

Die Anzahl der Entscheidungsschwellenwerte, die bei der Diskretisierung der Ausgabe der Klassifikator-Methode verwendet werden sollen. Übergeben Sie ein Array-ähnliches Objekt, um die zu verwendenden Schwellenwerte manuell anzugeben.

cvint, float, Kreuzvalidierungsgenerator, iterierbar oder "prefit", Standard=None

Bestimmt die Kreuzvalidierungsstrategie zum Trainieren des Klassifikators. Mögliche Eingaben für cv sind

  • None, um die standardmäßige 5-fache stratifizierte K-Falt-Kreuzvalidierung zu verwenden;

  • Eine ganze Zahl zur Angabe der Anzahl von Faltungen in einer stratifizierten k-Faltung;

  • Eine Gleitkommazahl zur Angabe einer einzelnen Shuffle-Split. Die Gleitkommazahl sollte im Bereich (0, 1) liegen und die Größe des Validierungsdatensatzes darstellen;

  • Ein Objekt, das als Kreuzvalidierungsgenerator verwendet wird;

  • Ein Iterator, der Train/Test-Splits liefert;

  • "prefit", um die Kreuzvalidierung zu umgehen.

Siehe Benutzerhandbuch für die verschiedenen Kreuzvalidierungsstrategien, die hier verwendet werden können.

Warnung

Die Verwendung von cv="prefit" und die Übergabe desselben Datensatzes für das Training von estimator und die Optimierung des Cut-off-Punktes ist anfällig für unerwünschte Überanpassung. Siehe Überlegungen zur Modell-Nachbildung und Kreuzvalidierung für ein Beispiel.

Diese Option sollte nur verwendet werden, wenn der Satz, der zum Trainieren von estimator verwendet wird, sich von demjenigen unterscheidet, der zur Optimierung des Cut-off-Punktes verwendet wird (durch Aufruf von TunedThresholdClassifierCV.fit).

refitbool, Standard=True

Ob der Klassifikator nach dem Finden des Entscheidungsschwellenwerts auf dem gesamten Trainingsdatensatz neu trainiert werden soll oder nicht. Beachten Sie, dass das Erzwingen von refit=False bei einer Kreuzvalidierung mit mehr als einem Split einen Fehler auslöst. Ebenso löst refit=True in Verbindung mit cv="prefit" einen Fehler aus.

n_jobsint, default=None

Die Anzahl der parallel auszuführenden Jobs. Wenn cv eine Kreuzvalidierungsstrategie darstellt, erfolgt das Trainieren und Bewerten auf jeder Datenaufteilung parallel. None bedeutet 1, sofern nicht in einem joblib.parallel_backend-Kontext. -1 bedeutet Verwendung aller Prozessoren. Siehe Glossar für weitere Details.

random_stateint, RandomState-Instanz oder None, default=None

Steuert die Zufälligkeit der Kreuzvalidierung, wenn cv ein Gleitkommazahl ist. Siehe Glossar.

store_cv_resultsbool, Standardwert=False

Ob alle während des Kreuzvalidierungsprozesses berechneten Bewertungen und Schwellenwerte gespeichert werden sollen.

Attribute:
estimator_Instanz des Estimators

Der trainierte Klassifikator, der bei der Vorhersage verwendet wird.

best_threshold_float

Der neue Entscheidungsschwellenwert.

best_score_float oder None

Die optimale Bewertung der Zielmetrik, ausgewertet an best_threshold_.

cv_results_dict oder None

Ein Wörterbuch, das die während des Kreuzvalidierungsprozesses berechneten Bewertungen und Schwellenwerte enthält. Nur vorhanden, wenn store_cv_results=True. Die Schlüssel sind "thresholds" und "scores".

classes_ndarray der Form (n_classes,)

Klassenbezeichnungen.

n_features_in_int

Anzahl der während fit gesehenen Merkmale. Nur definiert, wenn der zugrunde liegende Schätzer ein solches Attribut nach dem Training bereitstellt.

feature_names_in_ndarray mit Form (n_features_in_,)

Namen von Features, die während fit gesehen wurden. Nur definiert, wenn der zugrunde liegende Estimator ein solches Attribut nach dem Anpassen exponiert.

Siehe auch

sklearn.model_selection.FixedThresholdClassifier

Klassifikator, der einen konstanten Schwellenwert verwendet.

sklearn.calibration.CalibratedClassifierCV

Estimator, der Wahrscheinlichkeiten kalibriert.

Beispiele

>>> from sklearn.datasets import make_classification
>>> from sklearn.ensemble import RandomForestClassifier
>>> from sklearn.metrics import classification_report
>>> from sklearn.model_selection import TunedThresholdClassifierCV, train_test_split
>>> X, y = make_classification(
...     n_samples=1_000, weights=[0.9, 0.1], class_sep=0.8, random_state=42
... )
>>> X_train, X_test, y_train, y_test = train_test_split(
...     X, y, stratify=y, random_state=42
... )
>>> classifier = RandomForestClassifier(random_state=0).fit(X_train, y_train)
>>> print(classification_report(y_test, classifier.predict(X_test)))
              precision    recall  f1-score   support

           0       0.94      0.99      0.96       224
           1       0.80      0.46      0.59        26

    accuracy                           0.93       250
   macro avg       0.87      0.72      0.77       250
weighted avg       0.93      0.93      0.92       250

>>> classifier_tuned = TunedThresholdClassifierCV(
...     classifier, scoring="balanced_accuracy"
... ).fit(X_train, y_train)
>>> print(
...     f"Cut-off point found at {classifier_tuned.best_threshold_:.3f}"
... )
Cut-off point found at 0.342
>>> print(classification_report(y_test, classifier_tuned.predict(X_test)))
              precision    recall  f1-score   support

           0       0.96      0.95      0.96       224
           1       0.61      0.65      0.63        26

    accuracy                           0.92       250
   macro avg       0.78      0.80      0.79       250
weighted avg       0.92      0.92      0.92       250
decision_function(X)[Quelle]#

Entscheidungsfunktion für Stichproben in X unter Verwendung des trainierten Estimators.

Parameter:
X{array-like, sparse matrix} der Form (n_samples, n_features)

Trainingsvektoren, wobei n_samples die Anzahl der Stichproben und n_features die Anzahl der Merkmale ist.

Gibt zurück:
decisionsndarray der Form (n_samples,)

Die vom trainierten Estimator berechnete Entscheidungsfunktion.

fit(X, y, **params)[Quelle]#

Trainiert den Klassifikator.

Parameter:
X{array-like, sparse matrix} der Form (n_samples, n_features)

Trainingsdaten.

yarray-like von Form (n_samples,)

Zielwerte.

**paramsdict

Parameter, die an die fit-Methode des zugrunde liegenden Klassifikators übergeben werden.

Gibt zurück:
selfobject

Gibt eine Instanz von self zurück.

get_metadata_routing()[Quelle]#

Holt das Metadaten-Routing dieses Objekts.

Bitte prüfen Sie im Benutzerhandbuch, wie der Routing-Mechanismus funktioniert.

Gibt zurück:
routingMetadataRouter

Ein MetadataRouter, der die Routing-Informationen kapselt.

get_params(deep=True)[Quelle]#

Holt Parameter für diesen Schätzer.

Parameter:
deepbool, default=True

Wenn True, werden die Parameter für diesen Schätzer und die enthaltenen Unterobjekte, die Schätzer sind, zurückgegeben.

Gibt zurück:
paramsdict

Parameternamen, zugeordnet ihren Werten.

predict(X)[Quelle]#

Sagen Sie das Ziel neuer Stichproben voraus.

Parameter:
X{array-like, sparse matrix} der Form (n_samples, n_features)

Die Stichproben, wie sie von estimator.predict akzeptiert werden.

Gibt zurück:
class_labelsndarray der Form (n_samples,)

Die vorhergesagte Klasse.

predict_log_proba(X)[Quelle]#

Vorhersagt der logarithmischen Klassenwahrscheinlichkeiten für X unter Verwendung des trainierten Estimators.

Parameter:
X{array-like, sparse matrix} der Form (n_samples, n_features)

Trainingsvektoren, wobei n_samples die Anzahl der Stichproben und n_features die Anzahl der Merkmale ist.

Gibt zurück:
log_probabilitiesndarray der Form (n_samples, n_classes)

Die logarithmischen Klassenwahrscheinlichkeiten der Eingabestichproben.

predict_proba(X)[Quelle]#

Vorhersagt der Klassenwahrscheinlichkeiten für X unter Verwendung des trainierten Estimators.

Parameter:
X{array-like, sparse matrix} der Form (n_samples, n_features)

Trainingsvektoren, wobei n_samples die Anzahl der Stichproben und n_features die Anzahl der Merkmale ist.

Gibt zurück:
probabilitiesndarray der Form (n_samples, n_classes)

Die Klassenwahrscheinlichkeiten der Eingabemuster.

score(X, y, sample_weight=None)[Quelle]#

Gibt die Genauigkeit für die bereitgestellten Daten und Bezeichnungen zurück.

Bei der Multi-Label-Klassifizierung ist dies die Subset-Genauigkeit, eine strenge Metrik, da für jede Stichprobe verlangt wird, dass jede Label-Menge korrekt vorhergesagt wird.

Parameter:
Xarray-like der Form (n_samples, n_features)

Teststichproben.

yarray-like der Form (n_samples,) oder (n_samples, n_outputs)

Wahre Bezeichnungen für X.

sample_weightarray-like der Form (n_samples,), Standardwert=None

Stichprobengewichte.

Gibt zurück:
scorefloat

Mittlere Genauigkeit von self.predict(X) in Bezug auf y.

set_params(**params)[Quelle]#

Setzt die Parameter dieses Schätzers.

Die Methode funktioniert sowohl bei einfachen Schätzern als auch bei verschachtelten Objekten (wie Pipeline). Letztere haben Parameter der Form <component>__<parameter>, so dass es möglich ist, jede Komponente eines verschachtelten Objekts zu aktualisieren.

Parameter:
**paramsdict

Schätzer-Parameter.

Gibt zurück:
selfestimator instance

Schätzer-Instanz.

set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') TunedThresholdClassifierCV[Quelle]#

Konfiguriert, ob Metadaten für die score-Methode angefordert werden sollen.

Beachten Sie, dass diese Methode nur relevant ist, wenn dieser Schätzer als Unter-Schätzer innerhalb eines Meta-Schätzers verwendet wird und Metadaten-Routing mit enable_metadata_routing=True aktiviert ist (siehe sklearn.set_config). Bitte lesen Sie das Benutzerhandbuch, um zu erfahren, wie der Routing-Mechanismus funktioniert.

Die Optionen für jeden Parameter sind

  • True: Metadaten werden angefordert und, falls vorhanden, an score übergeben. Die Anforderung wird ignoriert, wenn keine Metadaten vorhanden sind.

  • False: Metadaten werden nicht angefordert und der Meta-Schätzer übergibt sie nicht an score.

  • None: Metadaten werden nicht angefordert und der Meta-Schätzer löst einen Fehler aus, wenn der Benutzer sie bereitstellt.

  • str: Metadaten sollten mit diesem Alias an den Meta-Schätzer übergeben werden und nicht mit dem ursprünglichen Namen.

Der Standardwert (sklearn.utils.metadata_routing.UNCHANGED) behält die bestehende Anforderung bei. Dies ermöglicht es Ihnen, die Anforderung für einige Parameter zu ändern und für andere nicht.

Hinzugefügt in Version 1.3.

Parameter:
sample_weightstr, True, False, oder None, Standardwert=sklearn.utils.metadata_routing.UNCHANGED

Metadaten-Routing für den Parameter sample_weight in score.

Gibt zurück:
selfobject

Das aktualisierte Objekt.