Rekursive Merkmalseliminierung mit Kreuzvalidierung#

Ein Beispiel für rekursive Merkmalseliminierung (RFE) mit automatischer Abstimmung der Anzahl der mit Kreuzvalidierung ausgewählten Merkmale.

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

Datengenerierung#

Wir erstellen eine Klassifizierungsaufgabe mit 3 informativen Merkmalen. Die Einführung von 2 zusätzlichen redundanten (d. h. korrelierten) Merkmalen führt dazu, dass die ausgewählten Merkmale je nach Kreuzvalidierungs-Fold variieren. Die verbleibenden Merkmale sind nicht informativ, da sie zufällig gezogen werden.

from sklearn.datasets import make_classification

n_features = 15
feat_names = [f"feature_{i}" for i in range(15)]

X, y = make_classification(
    n_samples=500,
    n_features=n_features,
    n_informative=3,
    n_redundant=2,
    n_repeated=0,
    n_classes=8,
    n_clusters_per_class=1,
    class_sep=0.8,
    random_state=0,
)

Modelltraining und -auswahl#

Wir erstellen das RFE-Objekt und berechnen die kreuzvalidierten Scores. Die Scoring-Strategie „Accuracy“ optimiert den Anteil korrekt klassifizierter Stichproben.

from sklearn.feature_selection import RFECV
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold

min_features_to_select = 1  # Minimum number of features to consider
clf = LogisticRegression()
cv = StratifiedKFold(5)

rfecv = RFECV(
    estimator=clf,
    step=1,
    cv=cv,
    scoring="accuracy",
    min_features_to_select=min_features_to_select,
    n_jobs=2,
)
rfecv.fit(X, y)

print(f"Optimal number of features: {rfecv.n_features_}")
Optimal number of features: 3

In diesem Fall erweist sich das Modell mit 3 Merkmalen (das dem wahren generativen Modell entspricht) als am optimalsten.

Diagramm Anzahl der Merkmale VS. Kreuzvalidierungs-Scores#

import matplotlib.pyplot as plt
import pandas as pd

data = {
    key: value
    for key, value in rfecv.cv_results_.items()
    if key in ["n_features", "mean_test_score", "std_test_score"]
}
cv_results = pd.DataFrame(data)
plt.figure()
plt.xlabel("Number of features selected")
plt.ylabel("Mean test accuracy")
plt.errorbar(
    x=cv_results["n_features"],
    y=cv_results["mean_test_score"],
    yerr=cv_results["std_test_score"],
)
plt.title("Recursive Feature Elimination \nwith correlated features")
plt.show()
Recursive Feature Elimination  with correlated features

Aus der obigen Grafik kann man weiter ein Plateau gleichwertiger Scores (ähnlicher Mittelwert und überlappende Fehlerbalken) für 3 bis 5 ausgewählte Merkmale erkennen. Dies ist das Ergebnis der Einführung korrelierter Merkmale. Tatsächlich kann das von RFE ausgewählte optimale Modell je nach Kreuzvalidierungstechnik innerhalb dieses Bereichs liegen. Die Testgenauigkeit nimmt über 5 ausgewählten Merkmalen hinaus ab. Das bedeutet, dass das Beibehalten nicht informativer Merkmale zu Overfitting führt und daher die statistische Leistung der Modelle beeinträchtigt.

import numpy as np

for i in range(cv.n_splits):
    mask = rfecv.cv_results_[f"split{i}_support"][
        rfecv.n_features_ - 1
    ]  # mask of features selected by the RFE
    features_selected = np.ma.compressed(np.ma.masked_array(feat_names, mask=1 - mask))
    print(f"Features selected in fold {i}: {features_selected}")
Features selected in fold 0: ['feature_3' 'feature_4' 'feature_8']
Features selected in fold 1: ['feature_3' 'feature_4' 'feature_8']
Features selected in fold 2: ['feature_3' 'feature_4' 'feature_8']
Features selected in fold 3: ['feature_3' 'feature_4' 'feature_8']
Features selected in fold 4: ['feature_3' 'feature_4' 'feature_8']

In den fünf Folds sind die ausgewählten Merkmale konsistent. Das sind gute Nachrichten, es bedeutet, dass die Auswahl über die Folds hinweg stabil ist und bestätigt, dass diese Merkmale die informativsten sind.

Gesamtlaufzeit des Skripts: (0 Minuten 0,661 Sekunden)

Verwandte Beispiele

Modellkomplexität und kreuzvalidierter Score ausbalancieren

Modellkomplexität und kreuzvalidierter Score ausbalancieren

Pipeline ANOVA SVM

Pipeline ANOVA SVM

Benutzerdefinierte Refit-Strategie einer Gitter-Suche mit Kreuzvalidierung

Benutzerdefinierte Refit-Strategie einer Gitter-Suche mit Kreuzvalidierung

Post-hoc-Anpassung des Cut-off-Punkts der Entscheidungskfunktion

Post-hoc-Anpassung des Cut-off-Punkts der Entscheidungskfunktion

Galerie generiert von Sphinx-Gallery