TunedThresholdClassifierCV#
- class sklearn.model_selection.TunedThresholdClassifierCV(estimator, *, scoring='balanced_accuracy', response_method='auto', thresholds=100, cv=None, refit=True, n_jobs=None, random_state=None, store_cv_results=False)[Quelle]#
Klassifikator, der die Entscheidungsschwelle mittels Kreuzvalidierung nachjustiert.
Dieser Estimator optimiert nachträglich den Entscheidungsschwellenwert (Cut-off-Punkt), der zur Umwandlung von Wahrscheinlichkeitsschätzungen (d.h. der Ausgabe von
predict_proba) oder Entscheidungswerten (d.h. der Ausgabe vondecision_function) in eine Klassenbezeichnung verwendet wird. Die Optimierung erfolgt durch Maximierung einer binären Metrik, möglicherweise eingeschränkt durch eine andere Metrik.Lesen Sie mehr im Benutzerhandbuch.
Hinzugefügt in Version 1.5.
- Parameter:
- estimatorSchätzer-Instanz
Der Klassifikator, trainiert oder nicht, für den wir den Entscheidungsschwellenwert optimieren möchten, der während
predictverwendet wird.- scoringstr oder aufrufbar, Standard="balanced_accuracy"
Die zu optimierende Zielmetrik. Kann sein
str: Zeichenkette, die einer Bewertungsfunktion für binäre Klassifizierung zugeordnet ist, siehe String-Namen für Bewerter für Optionen.
callable: Ein Scorer-Callable-Objekt (z. B. Funktion) mit der Signatur
scorer(estimator, X, y). Siehe Callable Scorer für Details.
- response_method{“auto”, “decision_function”, “predict_proba”}, Standard="auto"
Methoden des Klassifikators
estimator, die sich auf die Entscheidungsfunktion beziehen, für die wir einen Schwellenwert finden möchten. Es kann seinwenn
"auto", wird versucht, für jeden Klassifikator"predict_proba"oder"decision_function"in dieser Reihenfolge aufzurufen.andernfalls eine der Optionen
"predict_proba"oder"decision_function". Wenn die Methode nicht vom Klassifikator implementiert wird, wird ein Fehler ausgelöst.
- thresholdsint oder array-ähnlich, Standard=100
Die Anzahl der Entscheidungsschwellenwerte, die bei der Diskretisierung der Ausgabe der Klassifikator-Methode verwendet werden sollen. Übergeben Sie ein Array-ähnliches Objekt, um die zu verwendenden Schwellenwerte manuell anzugeben.
- cvint, float, Kreuzvalidierungsgenerator, iterierbar oder "prefit", Standard=None
Bestimmt die Kreuzvalidierungsstrategie zum Trainieren des Klassifikators. Mögliche Eingaben für cv sind
None, um die standardmäßige 5-fache stratifizierte K-Falt-Kreuzvalidierung zu verwenden;Eine ganze Zahl zur Angabe der Anzahl von Faltungen in einer stratifizierten k-Faltung;
Eine Gleitkommazahl zur Angabe einer einzelnen Shuffle-Split. Die Gleitkommazahl sollte im Bereich (0, 1) liegen und die Größe des Validierungsdatensatzes darstellen;
Ein Objekt, das als Kreuzvalidierungsgenerator verwendet wird;
Ein Iterator, der Train/Test-Splits liefert;
"prefit", um die Kreuzvalidierung zu umgehen.
Siehe Benutzerhandbuch für die verschiedenen Kreuzvalidierungsstrategien, die hier verwendet werden können.
Warnung
Die Verwendung von
cv="prefit"und die Übergabe desselben Datensatzes für das Training vonestimatorund die Optimierung des Cut-off-Punktes ist anfällig für unerwünschte Überanpassung. Siehe Überlegungen zur Modell-Nachbildung und Kreuzvalidierung für ein Beispiel.Diese Option sollte nur verwendet werden, wenn der Satz, der zum Trainieren von
estimatorverwendet wird, sich von demjenigen unterscheidet, der zur Optimierung des Cut-off-Punktes verwendet wird (durch Aufruf vonTunedThresholdClassifierCV.fit).- refitbool, Standard=True
Ob der Klassifikator nach dem Finden des Entscheidungsschwellenwerts auf dem gesamten Trainingsdatensatz neu trainiert werden soll oder nicht. Beachten Sie, dass das Erzwingen von
refit=Falsebei einer Kreuzvalidierung mit mehr als einem Split einen Fehler auslöst. Ebenso löstrefit=Truein Verbindung mitcv="prefit"einen Fehler aus.- n_jobsint, default=None
Die Anzahl der parallel auszuführenden Jobs. Wenn
cveine Kreuzvalidierungsstrategie darstellt, erfolgt das Trainieren und Bewerten auf jeder Datenaufteilung parallel.Nonebedeutet 1, sofern nicht in einemjoblib.parallel_backend-Kontext.-1bedeutet Verwendung aller Prozessoren. Siehe Glossar für weitere Details.- random_stateint, RandomState-Instanz oder None, default=None
Steuert die Zufälligkeit der Kreuzvalidierung, wenn
cvein Gleitkommazahl ist. Siehe Glossar.- store_cv_resultsbool, Standardwert=False
Ob alle während des Kreuzvalidierungsprozesses berechneten Bewertungen und Schwellenwerte gespeichert werden sollen.
- Attribute:
- estimator_Instanz des Estimators
Der trainierte Klassifikator, der bei der Vorhersage verwendet wird.
- best_threshold_float
Der neue Entscheidungsschwellenwert.
- best_score_float oder None
Die optimale Bewertung der Zielmetrik, ausgewertet an
best_threshold_.- cv_results_dict oder None
Ein Wörterbuch, das die während des Kreuzvalidierungsprozesses berechneten Bewertungen und Schwellenwerte enthält. Nur vorhanden, wenn
store_cv_results=True. Die Schlüssel sind"thresholds"und"scores".classes_ndarray der Form (n_classes,)Klassenbezeichnungen.
- n_features_in_int
Anzahl der während fit gesehenen Merkmale. Nur definiert, wenn der zugrunde liegende Schätzer ein solches Attribut nach dem Training bereitstellt.
- feature_names_in_ndarray mit Form (
n_features_in_,) Namen von Features, die während fit gesehen wurden. Nur definiert, wenn der zugrunde liegende Estimator ein solches Attribut nach dem Anpassen exponiert.
Siehe auch
sklearn.model_selection.FixedThresholdClassifierKlassifikator, der einen konstanten Schwellenwert verwendet.
sklearn.calibration.CalibratedClassifierCVEstimator, der Wahrscheinlichkeiten kalibriert.
Beispiele
>>> from sklearn.datasets import make_classification >>> from sklearn.ensemble import RandomForestClassifier >>> from sklearn.metrics import classification_report >>> from sklearn.model_selection import TunedThresholdClassifierCV, train_test_split >>> X, y = make_classification( ... n_samples=1_000, weights=[0.9, 0.1], class_sep=0.8, random_state=42 ... ) >>> X_train, X_test, y_train, y_test = train_test_split( ... X, y, stratify=y, random_state=42 ... ) >>> classifier = RandomForestClassifier(random_state=0).fit(X_train, y_train) >>> print(classification_report(y_test, classifier.predict(X_test))) precision recall f1-score support 0 0.94 0.99 0.96 224 1 0.80 0.46 0.59 26 accuracy 0.93 250 macro avg 0.87 0.72 0.77 250 weighted avg 0.93 0.93 0.92 250 >>> classifier_tuned = TunedThresholdClassifierCV( ... classifier, scoring="balanced_accuracy" ... ).fit(X_train, y_train) >>> print( ... f"Cut-off point found at {classifier_tuned.best_threshold_:.3f}" ... ) Cut-off point found at 0.342 >>> print(classification_report(y_test, classifier_tuned.predict(X_test))) precision recall f1-score support 0 0.96 0.95 0.96 224 1 0.61 0.65 0.63 26 accuracy 0.92 250 macro avg 0.78 0.80 0.79 250 weighted avg 0.92 0.92 0.92 250
- decision_function(X)[Quelle]#
Entscheidungsfunktion für Stichproben in
Xunter Verwendung des trainierten Estimators.- Parameter:
- X{array-like, sparse matrix} der Form (n_samples, n_features)
Trainingsvektoren, wobei
n_samplesdie Anzahl der Stichproben undn_featuresdie Anzahl der Merkmale ist.
- Gibt zurück:
- decisionsndarray der Form (n_samples,)
Die vom trainierten Estimator berechnete Entscheidungsfunktion.
- fit(X, y, **params)[Quelle]#
Trainiert den Klassifikator.
- Parameter:
- X{array-like, sparse matrix} der Form (n_samples, n_features)
Trainingsdaten.
- yarray-like von Form (n_samples,)
Zielwerte.
- **paramsdict
Parameter, die an die
fit-Methode des zugrunde liegenden Klassifikators übergeben werden.
- Gibt zurück:
- selfobject
Gibt eine Instanz von self zurück.
- get_metadata_routing()[Quelle]#
Holt das Metadaten-Routing dieses Objekts.
Bitte prüfen Sie im Benutzerhandbuch, wie der Routing-Mechanismus funktioniert.
- Gibt zurück:
- routingMetadataRouter
Ein
MetadataRouter, der die Routing-Informationen kapselt.
- get_params(deep=True)[Quelle]#
Holt Parameter für diesen Schätzer.
- Parameter:
- deepbool, default=True
Wenn True, werden die Parameter für diesen Schätzer und die enthaltenen Unterobjekte, die Schätzer sind, zurückgegeben.
- Gibt zurück:
- paramsdict
Parameternamen, zugeordnet ihren Werten.
- predict(X)[Quelle]#
Sagen Sie das Ziel neuer Stichproben voraus.
- Parameter:
- X{array-like, sparse matrix} der Form (n_samples, n_features)
Die Stichproben, wie sie von
estimator.predictakzeptiert werden.
- Gibt zurück:
- class_labelsndarray der Form (n_samples,)
Die vorhergesagte Klasse.
- predict_log_proba(X)[Quelle]#
Vorhersagt der logarithmischen Klassenwahrscheinlichkeiten für
Xunter Verwendung des trainierten Estimators.- Parameter:
- X{array-like, sparse matrix} der Form (n_samples, n_features)
Trainingsvektoren, wobei
n_samplesdie Anzahl der Stichproben undn_featuresdie Anzahl der Merkmale ist.
- Gibt zurück:
- log_probabilitiesndarray der Form (n_samples, n_classes)
Die logarithmischen Klassenwahrscheinlichkeiten der Eingabestichproben.
- predict_proba(X)[Quelle]#
Vorhersagt der Klassenwahrscheinlichkeiten für
Xunter Verwendung des trainierten Estimators.- Parameter:
- X{array-like, sparse matrix} der Form (n_samples, n_features)
Trainingsvektoren, wobei
n_samplesdie Anzahl der Stichproben undn_featuresdie Anzahl der Merkmale ist.
- Gibt zurück:
- probabilitiesndarray der Form (n_samples, n_classes)
Die Klassenwahrscheinlichkeiten der Eingabemuster.
- score(X, y, sample_weight=None)[Quelle]#
Gibt die Genauigkeit für die bereitgestellten Daten und Bezeichnungen zurück.
Bei der Multi-Label-Klassifizierung ist dies die Subset-Genauigkeit, eine strenge Metrik, da für jede Stichprobe verlangt wird, dass jede Label-Menge korrekt vorhergesagt wird.
- Parameter:
- Xarray-like der Form (n_samples, n_features)
Teststichproben.
- yarray-like der Form (n_samples,) oder (n_samples, n_outputs)
Wahre Bezeichnungen für
X.- sample_weightarray-like der Form (n_samples,), Standardwert=None
Stichprobengewichte.
- Gibt zurück:
- scorefloat
Mittlere Genauigkeit von
self.predict(X)in Bezug aufy.
- set_params(**params)[Quelle]#
Setzt die Parameter dieses Schätzers.
Die Methode funktioniert sowohl bei einfachen Schätzern als auch bei verschachtelten Objekten (wie
Pipeline). Letztere haben Parameter der Form<component>__<parameter>, so dass es möglich ist, jede Komponente eines verschachtelten Objekts zu aktualisieren.- Parameter:
- **paramsdict
Schätzer-Parameter.
- Gibt zurück:
- selfestimator instance
Schätzer-Instanz.
- set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') TunedThresholdClassifierCV[Quelle]#
Konfiguriert, ob Metadaten für die
score-Methode angefordert werden sollen.Beachten Sie, dass diese Methode nur relevant ist, wenn dieser Schätzer als Unter-Schätzer innerhalb eines Meta-Schätzers verwendet wird und Metadaten-Routing mit
enable_metadata_routing=Trueaktiviert ist (siehesklearn.set_config). Bitte lesen Sie das Benutzerhandbuch, um zu erfahren, wie der Routing-Mechanismus funktioniert.Die Optionen für jeden Parameter sind
True: Metadaten werden angefordert und, falls vorhanden, anscoreübergeben. Die Anforderung wird ignoriert, wenn keine Metadaten vorhanden sind.False: Metadaten werden nicht angefordert und der Meta-Schätzer übergibt sie nicht anscore.None: Metadaten werden nicht angefordert und der Meta-Schätzer löst einen Fehler aus, wenn der Benutzer sie bereitstellt.str: Metadaten sollten mit diesem Alias an den Meta-Schätzer übergeben werden und nicht mit dem ursprünglichen Namen.
Der Standardwert (
sklearn.utils.metadata_routing.UNCHANGED) behält die bestehende Anforderung bei. Dies ermöglicht es Ihnen, die Anforderung für einige Parameter zu ändern und für andere nicht.Hinzugefügt in Version 1.3.
- Parameter:
- sample_weightstr, True, False, oder None, Standardwert=sklearn.utils.metadata_routing.UNCHANGED
Metadaten-Routing für den Parameter
sample_weightinscore.
- Gibt zurück:
- selfobject
Das aktualisierte Objekt.
Galeriebeispiele#
Post-Hoc-Anpassung des Entscheidungsschwellenwerts für kostenempfindliches Lernen
Post-hoc-Anpassung des Cut-off-Punkts der Entscheidungskfunktion