validation_curve#
- sklearn.model_selection.validation_curve(estimator, X, y, *, param_name, param_range, groups=None, cv=None, scoring=None, n_jobs=None, pre_dispatch='all', verbose=0, error_score=nan, params=None)[source]#
Validierungskurve.
Trainings- und Test-Scores für variierende Parameterwerte bestimmen.
Berechnet Scores für einen Estimator mit verschiedenen Werten eines spezifizierten Parameters. Dies ist ähnlich wie eine Grid-Suche mit einem Parameter. Es werden jedoch auch Trainings-Scores berechnet und es handelt sich lediglich um ein Hilfsmittel zur grafischen Darstellung der Ergebnisse.
Lesen Sie mehr im Benutzerhandbuch.
- Parameter:
- estimatorObjekttyp, der die „fit“-Methode implementiert
Ein Objekt dieses Typs, das für jede Validierung geklont wird. Es muss auch „predict“ implementieren, es sei denn,
scoringist ein Callable, das nicht auf „predict“ angewiesen ist, um einen Score zu berechnen.- X{array-like, sparse matrix} der Form (n_samples, n_features)
Trainingsvektor, wobei
n_samplesdie Anzahl der Stichproben undn_featuresdie Anzahl der Merkmale ist.- yarray-ähnlich der Form (n_samples,) oder (n_samples, n_outputs) oder None
Ziel relativ zu X für Klassifikation oder Regression; None für unüberwachtes Lernen.
- param_namestr
Name des Parameters, der variiert werden soll.
- param_rangearray-ähnlich von Form (n_values,)
Die zu bewertenden Werte des Parameters.
- groupsarray-like of shape (n_samples,), default=None
Gruppenbezeichnungen für die Stichproben, die bei der Aufteilung des Datensatzes in Trainings-/Testdatensätze verwendet werden. Nur in Verbindung mit einer `Group`-Instanz von cv (z. B.
GroupKFold) verwendet.Geändert in Version 1.6:
groupskann nur übergeben werden, wenn Metadata Routing nicht übersklearn.set_config(enable_metadata_routing=True)aktiviert ist. Wenn Routing aktiviert ist, übergeben Siegroupszusammen mit anderen Metadaten über das Argumentparams. Z. B.:validation_curve(..., params={'groups': groups}).- cvint, Kreuzvalidierungsgenerator oder iterierbar, Standardwert=None
Bestimmt die Strategie der Kreuzvalidierungsaufteilung. Mögliche Eingaben für cv sind
None, um die Standard-5-fache Kreuzvalidierung zu verwenden,
int, um die Anzahl der Folds in einem
(Stratified)KFoldanzugeben,Eine iterierbare Liste, die (Trainings-, Test-) Splits als Indizes-Arrays liefert.
Für ganzzahlige/None-Eingaben wird, wenn der Schätzer ein Klassifikator ist und `y` entweder binär oder mehrklassig ist,
StratifiedKFoldverwendet. In allen anderen Fällen wirdKFoldverwendet. Diese Splitter werden mit `shuffle=False` instanziiert, sodass die Splits über Aufrufe hinweg gleich sind.Siehe Benutzerhandbuch für die verschiedenen Kreuzvalidierungsstrategien, die hier verwendet werden können.
Geändert in Version 0.22: Der Standardwert von
cv, wenn None, hat sich von 3-Fold auf 5-Fold geändert.- scoringstr oder callable, Standardwert=None
Bewertungsmethode zur Auswertung der Trainings- und Testdatensätze.
str: siehe Zeichenkettennamen für Bewerter für Optionen.
callable: Ein Scorer-Callable-Objekt (z. B. Funktion) mit der Signatur
scorer(estimator, X, y). Siehe Callable Scorer für Details.None: das Standard-Bewertungskriterium desestimatorwird verwendet.
- n_jobsint, default=None
Anzahl der parallel auszuführenden Jobs. Das Training des Estimators und die Berechnung des Scores werden über die Kombinationen jedes Parameterwerts und jeder Kreuzvalidierungsaufteilung parallelisiert.
Nonebedeutet 1, es sei denn, Sie befinden sich in einemjoblib.parallel_backend-Kontext.-1bedeutet die Verwendung aller Prozessoren. Siehe Glossar für weitere Details.- pre_dispatchint oder str, Standard=’all’
Anzahl der vorab verteilten Jobs für die parallele Ausführung (Standard ist 'all'). Die Option kann den zugeordneten Speicher reduzieren. Die Zeichenkette kann ein Ausdruck wie '2*n_jobs' sein.
- verboseint, default=0
Steuert die Ausführlichkeit: je höher, desto mehr Meldungen.
- error_score‘raise’ oder numerisch, Standard=np.nan
Wert, der der Punktzahl zugewiesen wird, wenn beim Anpassen des Schätzers ein Fehler auftritt. Wenn auf 'raise' gesetzt, wird der Fehler ausgelöst. Wenn ein numerischer Wert angegeben wird, wird FitFailedWarning ausgelöst.
Hinzugefügt in Version 0.20.
- paramsdict, Standardwert=None
An den Estimator, den Scorrer und das Kreuzvalidierungsobjekt zu übergebende Parameter.
Wenn
enable_metadata_routing=False(Standard): Parameter, die direkt an diefit-Methode des Estimators übergeben werden.Wenn
enable_metadata_routing=True: Sicher an diefit-Methode des Estimators, an den Scorrer und an das Kreuzvalidierungsobjekt weitergeleitete Parameter. Siehe Benutzerhandbuch für Metadaten-Routing für weitere Details.
Hinzugefügt in Version 1.6.
- Gibt zurück:
- train_scoresArray der Form (n_ticks, n_cv_folds)
Ergebnisse auf Trainingsdatensätzen.
- test_scoresArray der Form (n_ticks, n_cv_folds)
Ergebnisse auf Testdatensätzen.
Siehe auch
ValidationCurveDisplay.from_estimatorZeichnet die Validierungskurve unter Berücksichtigung eines Estimators, der Daten und des zu variierenden Parameters.
Anmerkungen
Siehe Effekt der Modellregularisierung auf Trainings- und Testfehler
Beispiele
>>> import numpy as np >>> from sklearn.datasets import make_classification >>> from sklearn.model_selection import validation_curve >>> from sklearn.linear_model import LogisticRegression >>> X, y = make_classification(n_samples=1_000, random_state=0) >>> logistic_regression = LogisticRegression() >>> param_name, param_range = "C", np.logspace(-8, 3, 10) >>> train_scores, test_scores = validation_curve( ... logistic_regression, X, y, param_name=param_name, param_range=param_range ... ) >>> print(f"The average train accuracy is {train_scores.mean():.2f}") The average train accuracy is 0.81 >>> print(f"The average test accuracy is {test_scores.mean():.2f}") The average test accuracy is 0.81
Galeriebeispiele#
Skalierung des Regularisierungsparameters für SVCs