LearningCurveDisplay#
- class sklearn.model_selection.LearningCurveDisplay(*, train_sizes, train_scores, test_scores, score_name=None)[Quelle]#
Visualisierung der Lernkurve.
Es wird empfohlen,
from_estimatorzu verwenden, um eineLearningCurveDisplay-Instanz zu erstellen. Alle Parameter werden als Attribute gespeichert.Lesen Sie mehr im Benutzerhandbuch für allgemeine Informationen zur Visualisierungs-API und detaillierte Dokumentation zur Lernkurven-Visualisierung.
Hinzugefügt in Version 1.2.
- Parameter:
- train_sizesndarray der Form (n_unique_ticks,)
Anzahl der Trainingsbeispiele, die zur Erzeugung der Lernkurve verwendet wurden.
- train_scoresndarray der Form (n_ticks, n_cv_folds)
Ergebnisse auf Trainingsdatensätzen.
- test_scoresndarray der Form (n_ticks, n_cv_folds)
Ergebnisse auf Testdatensätzen.
- score_namestr, Standard=None
Der Name der in
learning_curveverwendeten Punktzahl. Er überschreibt den aus demscoring-Parameter abgeleiteten Namen. WennscoreNoneist, verwenden wir"Score", wennnegate_scoreFalseist, und"Negative score"andernfalls. Wennscoringeine Zeichenkette oder aufrufbar ist, leiten wir den Namen ab. Wir ersetzen_durch Leerzeichen und großschreiben den ersten Buchstaben. Wir entfernenneg_und ersetzen es durch"Negative", wennnegate_scoreFalseist, oder entfernen es andernfalls einfach.
- Attribute:
- ax_matplotlib Axes
Achsen mit der Lernkurve.
- figure_matplotlib Figure
Abbildung, die die Lernkurve enthält.
- errorbar_Liste von matplotlib Artist oder None
Wenn
std_display_style"errorbar"ist, handelt es sich um eine Liste vonmatplotlib.container.ErrorbarContainer-Objekten. Wenn ein anderer Stil verwendet wird, isterrorbar_None.- lines_Liste von matplotlib Artist oder None
Wenn
std_display_style"fill_between"ist, handelt es sich um eine Liste vonmatplotlib.lines.Line2D-Objekten, die den mittleren Trainings- und Testergebnissen entsprechen. Wenn ein anderer Stil verwendet wird, istline_None.- fill_between_Liste von matplotlib Artist oder None
Wenn
std_display_style"fill_between"ist, handelt es sich um eine Liste vonmatplotlib.collections.PolyCollection-Objekten. Wenn ein anderer Stil verwendet wird, istfill_between_None.
Siehe auch
sklearn.model_selection.learning_curveBerechnet die Lernkurve.
Beispiele
>>> import matplotlib.pyplot as plt >>> from sklearn.datasets import load_iris >>> from sklearn.model_selection import LearningCurveDisplay, learning_curve >>> from sklearn.tree import DecisionTreeClassifier >>> X, y = load_iris(return_X_y=True) >>> tree = DecisionTreeClassifier(random_state=0) >>> train_sizes, train_scores, test_scores = learning_curve( ... tree, X, y) >>> display = LearningCurveDisplay(train_sizes=train_sizes, ... train_scores=train_scores, test_scores=test_scores, score_name="Score") >>> display.plot() <...> >>> plt.show()
- classmethod from_estimator(estimator, X, y, *, groups=None, train_sizes=array([0.1, 0.33, 0.55, 0.78, 1.]), cv=None, scoring=None, exploit_incremental_learning=False, n_jobs=None, pre_dispatch='all', verbose=0, shuffle=False, random_state=None, error_score=nan, fit_params=None, ax=None, negate_score=False, score_name=None, score_type='both', std_display_style='fill_between', line_kw=None, fill_between_kw=None, errorbar_kw=None)[Quelle]#
Erstellt eine Lernkurvenanzeige aus einem Schätzer.
Lesen Sie mehr im Benutzerhandbuch für allgemeine Informationen zur Visualisierungs-API und detaillierte Dokumentation zur Lernkurven-Visualisierung.
- Parameter:
- estimatorObjekttyp, der die Methoden "fit" und "predict" implementiert
Ein Objekt dieses Typs, das für jede Validierung geklont wird.
- Xarray-like der Form (n_samples, n_features)
Trainingsdaten, wobei
n_samplesdie Anzahl der Stichproben undn_featuresdie Anzahl der Merkmale ist.- yarray-ähnlich der Form (n_samples,) oder (n_samples, n_outputs) oder None
Ziel relativ zu X für Klassifikation oder Regression; None für unüberwachtes Lernen.
- groupsarray-like of shape (n_samples,), default=None
Gruppenbezeichnungen für die Stichproben, die bei der Aufteilung des Datensatzes in Trainings-/Testdatensätze verwendet werden. Nur in Verbindung mit einer `Group`-Instanz von cv (z. B.
GroupKFold) verwendet.- train_sizesarray-ähnlich der Form (n_ticks,), Standard=np.linspace(0.1, 1.0, 5)
Relative oder absolute Anzahl von Trainingsbeispielen, die zur Erzeugung der Lernkurve verwendet werden. Wenn der dtype float ist, wird er als Bruchteil der maximalen Größe des Trainingsdatensatzes betrachtet (die durch die ausgewählte Validierungsmethode bestimmt wird), d.h. er muss zwischen (0, 1] liegen. Andernfalls wird er als absolute Größe der Trainingsdatensätze interpretiert. Beachten Sie, dass für die Klassifikation die Anzahl der Stichproben in der Regel groß genug sein muss, um mindestens eine Stichprobe aus jeder Klasse zu enthalten.
- cvint, Kreuzvalidierungsgenerator oder iterierbar, Standardwert=None
Bestimmt die Strategie der Kreuzvalidierungsaufteilung. Mögliche Eingaben für cv sind
None, um die Standard-5-fache Kreuzvalidierung zu verwenden,
int, um die Anzahl der Folds in einem
(Stratified)KFoldanzugeben,Eine iterierbare Liste, die (Trainings-, Test-) Splits als Indizes-Arrays liefert.
Für ganzzahlige/None-Eingaben wird, wenn der Schätzer ein Klassifikator ist und `y` entweder binär oder mehrklassig ist,
StratifiedKFoldverwendet. In allen anderen Fällen wirdKFoldverwendet. Diese Splitter werden mit `shuffle=False` instanziiert, sodass die Splits über Aufrufe hinweg gleich sind.Siehe Benutzerhandbuch für die verschiedenen Kreuzvalidierungsstrategien, die hier verwendet werden können.
- scoringstr oder callable, Standardwert=None
Die verwendete Bewertungsfunktion zur Berechnung der Lernkurve. Optionen
str: siehe Zeichenkettennamen für Bewerter für Optionen.
callable: Ein Scorer-Callable-Objekt (z. B. Funktion) mit der Signatur
scorer(estimator, X, y). Siehe Callable Scorer für Details.None: das Standard-Bewertungskriterium desestimatorwird verwendet.
- exploit_incremental_learningbool, Standard=False
Wenn der Schätzer inkrementelles Lernen unterstützt, wird dies verwendet, um das Anpassen für verschiedene Trainingsdatensatzgrößen zu beschleunigen.
- n_jobsint, default=None
Anzahl der parallel auszuführenden Jobs. Das Anpassen des Schätzers und die Berechnung der Bewertung werden für die verschiedenen Trainings- und Testdatensätze parallelisiert.
Nonebedeutet 1, es sei denn, Sie befinden sich in einemjoblib.parallel_backend-Kontext.-1bedeutet die Verwendung aller Prozessoren. Siehe Glossar für weitere Details.- pre_dispatchint oder str, Standard=’all’
Anzahl der vorab verteilten Jobs für die parallele Ausführung (Standard ist 'all'). Die Option kann den zugeordneten Speicher reduzieren. Die Zeichenkette kann ein Ausdruck wie '2*n_jobs' sein.
- verboseint, default=0
Steuert die Ausführlichkeit: je höher, desto mehr Meldungen.
- shufflebool, default=False
Gibt an, ob die Trainingsdaten vor der Entnahme von Präfixen basierend auf `train_sizes` gemischt werden sollen.
- random_stateint, RandomState-Instanz oder None, default=None
Wird verwendet, wenn
shuffleTrue ist. Übergeben Sie eine Ganzzahl für reproduzierbare Ergebnisse über mehrere Funktionsaufrufe hinweg. Siehe Glossar.- error_score‘raise’ oder numerisch, Standard=np.nan
Wert, der der Punktzahl zugewiesen wird, wenn beim Anpassen des Schätzers ein Fehler auftritt. Wenn auf 'raise' gesetzt, wird der Fehler ausgelöst. Wenn ein numerischer Wert angegeben wird, wird FitFailedWarning ausgelöst.
- fit_paramsdict, Standard=None
Parameter, die an die fit-Methode des Schätzers übergeben werden.
- axmatplotlib Axes, Standard=None
Axes-Objekt, auf dem geplottet werden soll. Wenn
None, wird eine neue Figur und Achse erstellt.- negate_scorebool, Standard=False
Ob die über
learning_curveerhaltenen Punktzahlen negiert werden sollen oder nicht. Dies ist besonders nützlich, wenn der Fehler mitneg_*inscikit-learnverwendet wird.- score_namestr, Standard=None
Der Name der Punktzahl, die zur Beschriftung der y-Achse des Plots verwendet wird. Er überschreibt den aus dem
scoring-Parameter abgeleiteten Namen. WennscoreNoneist, verwenden wir"Score", wennnegate_scoreFalseist, und"Negative score"andernfalls. Wennscoringeine Zeichenkette oder aufrufbar ist, leiten wir den Namen ab. Wir ersetzen_durch Leerzeichen und großschreiben den ersten Buchstaben. Wir entfernenneg_und ersetzen es durch"Negative", wennnegate_scoreFalseist, oder entfernen es andernfalls einfach.- score_type{“test”, “train”, “both”}, Standard=”both”
Der zu plottende Punktzahltyp. Kann einer von
"test","train"oder"both"sein.- std_display_style{“errorbar”, “fill_between”} oder None, Standard=”fill_between”
Der Stil, der zur Anzeige der Standardabweichung der Punktzahl um die mittlere Punktzahl verwendet wird. Wenn
None, wird keine Darstellung der Standardabweichung angezeigt.- line_kwdict, Standard=None
Zusätzliche Schlüsselwortargumente, die an
plt.plotübergeben werden, um die mittlere Punktzahl zu zeichnen.- fill_between_kwdict, Standard=None
Zusätzliche Schlüsselwortargumente, die an
plt.fill_betweenübergeben werden, um die Standardabweichung der Punktzahl zu zeichnen.- errorbar_kwdict, Standard=None
Zusätzliche Schlüsselwortargumente, die an
plt.errorbarübergeben werden, um die mittlere Punktzahl und die Standardabweichung der Punktzahl zu zeichnen.
- Gibt zurück:
- display
LearningCurveDisplay Objekt, das berechnete Werte speichert.
- display
Beispiele
>>> import matplotlib.pyplot as plt >>> from sklearn.datasets import load_iris >>> from sklearn.model_selection import LearningCurveDisplay >>> from sklearn.tree import DecisionTreeClassifier >>> X, y = load_iris(return_X_y=True) >>> tree = DecisionTreeClassifier(random_state=0) >>> LearningCurveDisplay.from_estimator(tree, X, y) <...> >>> plt.show()
- plot(ax=None, *, negate_score=False, score_name=None, score_type='both', std_display_style='fill_between', line_kw=None, fill_between_kw=None, errorbar_kw=None)[Quelle]#
Visualisierung plotten.
- Parameter:
- axmatplotlib Axes, Standard=None
Axes-Objekt, auf dem geplottet werden soll. Wenn
None, wird eine neue Figur und Achse erstellt.- negate_scorebool, Standard=False
Ob die über
learning_curveerhaltenen Punktzahlen negiert werden sollen oder nicht. Dies ist besonders nützlich, wenn der Fehler mitneg_*inscikit-learnverwendet wird.- score_namestr, Standard=None
Der Name der Punktzahl, die zur Beschriftung der y-Achse des Plots verwendet wird. Er überschreibt den aus dem
scoring-Parameter abgeleiteten Namen. WennscoreNoneist, verwenden wir"Score", wennnegate_scoreFalseist, und"Negative score"andernfalls. Wennscoringeine Zeichenkette oder aufrufbar ist, leiten wir den Namen ab. Wir ersetzen_durch Leerzeichen und großschreiben den ersten Buchstaben. Wir entfernenneg_und ersetzen es durch"Negative", wennnegate_scoreFalseist, oder entfernen es andernfalls einfach.- score_type{“test”, “train”, “both”}, Standard=”both”
Der zu plottende Punktzahltyp. Kann einer von
"test","train"oder"both"sein.- std_display_style{“errorbar”, “fill_between”} oder None, Standard=”fill_between”
Der Stil, der zur Anzeige der Standardabweichung der Punktzahl um die mittlere Punktzahl verwendet wird. Wenn None, wird keine Darstellung der Standardabweichung angezeigt.
- line_kwdict, Standard=None
Zusätzliche Schlüsselwortargumente, die an
plt.plotübergeben werden, um die mittlere Punktzahl zu zeichnen.- fill_between_kwdict, Standard=None
Zusätzliche Schlüsselwortargumente, die an
plt.fill_betweenübergeben werden, um die Standardabweichung der Punktzahl zu zeichnen.- errorbar_kwdict, Standard=None
Zusätzliche Schlüsselwortargumente, die an
plt.errorbarübergeben werden, um die mittlere Punktzahl und die Standardabweichung der Punktzahl zu zeichnen.
- Gibt zurück:
- display
LearningCurveDisplay Objekt, das berechnete Werte speichert.
- display
Galeriebeispiele#
Lernkurven plotten und die Skalierbarkeit von Modellen prüfen