Hinweis
Zum Ende gehen, um den vollständigen Beispielcode herunterzuladen oder dieses Beispiel über JupyterLite oder Binder in Ihrem Browser auszuführen.
SVM-Anova: SVM mit univariater Merkmalsauswahl#
Dieses Beispiel zeigt, wie man eine univariate Merkmalsauswahl durchführt, bevor man einen SVC (Support Vector Classifier) ausführt, um die Klassifikationsergebnisse zu verbessern. Wir verwenden den Iris-Datensatz (4 Merkmale) und fügen 36 nicht-informative Merkmale hinzu. Wir können feststellen, dass unser Modell die beste Leistung erzielt, wenn wir etwa 10 % der Merkmale auswählen.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
Laden einiger Daten zum Spielen#
import numpy as np
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
# Add non-informative features
rng = np.random.RandomState(0)
X = np.hstack((X, 2 * rng.random((X.shape[0], 36))))
Erstellen der Pipeline#
from sklearn.feature_selection import SelectPercentile, f_classif
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
# Create a feature-selection transform, a scaler and an instance of SVM that we
# combine together to have a full-blown estimator
clf = Pipeline(
[
("anova", SelectPercentile(f_classif)),
("scaler", StandardScaler()),
("svc", SVC(gamma="auto")),
]
)
Plotten des Kreuzvalidierungs-Scores als Funktion des Perzentils der Merkmale#
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_score
score_means = list()
score_stds = list()
percentiles = (1, 3, 6, 10, 15, 20, 30, 40, 60, 80, 100)
for percentile in percentiles:
clf.set_params(anova__percentile=percentile)
this_scores = cross_val_score(clf, X, y)
score_means.append(this_scores.mean())
score_stds.append(this_scores.std())
plt.errorbar(percentiles, score_means, np.array(score_stds))
plt.title("Performance of the SVM-Anova varying the percentile of features selected")
plt.xticks(np.linspace(0, 100, 11, endpoint=True))
plt.xlabel("Percentile")
plt.ylabel("Accuracy Score")
plt.axis("tight")
plt.show()

Gesamtlaufzeit des Skripts: (0 Minuten 0,283 Sekunden)
Verwandte Beispiele
Verschiedene SVM-Klassifikatoren im Iris-Datensatz plotten
Verschiedene SVM-Klassifikatoren im Iris-Datensatz plotten