Gemeinsame Merkmalsauswahl mit Multi-Task Lasso#

Das Multi-Task Lasso ermöglicht die gemeinsame Anpassung mehrerer Regressionsprobleme, wobei die ausgewählten Merkmale über die Aufgaben hinweg gleich bleiben müssen. Dieses Beispiel simuliert sequentielle Messungen, wobei jede Aufgabe ein Zeitpunkt ist und die relevanten Merkmale ihre Amplitude über die Zeit ändern, aber gleich bleiben. Das Multi-Task Lasso erzwingt, dass Merkmale, die zu einem Zeitpunkt ausgewählt werden, für alle Zeitpunkte ausgewählt werden. Dies macht die Merkmalsauswahl durch das Lasso stabiler.

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

Daten generieren#

import numpy as np

rng = np.random.RandomState(42)

# Generate some 2D coefficients with sine waves with random frequency and phase
n_samples, n_features, n_tasks = 100, 30, 40
n_relevant_features = 5
coef = np.zeros((n_tasks, n_features))
times = np.linspace(0, 2 * np.pi, n_tasks)
for k in range(n_relevant_features):
    coef[:, k] = np.sin((1.0 + rng.randn(1)) * times + 3 * rng.randn(1))

X = rng.randn(n_samples, n_features)
Y = np.dot(X, coef.T) + rng.randn(n_samples, n_tasks)

Modelle anpassen#

from sklearn.linear_model import Lasso, MultiTaskLasso

coef_lasso_ = np.array([Lasso(alpha=0.5).fit(X, y).coef_ for y in Y.T])
coef_multi_task_lasso_ = MultiTaskLasso(alpha=1.0).fit(X, Y).coef_

Unterstützung und Zeitreihen plotten#

import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 5))
plt.subplot(1, 2, 1)
plt.spy(coef_lasso_)
plt.xlabel("Feature")
plt.ylabel("Time (or Task)")
plt.text(10, 5, "Lasso")
plt.subplot(1, 2, 2)
plt.spy(coef_multi_task_lasso_)
plt.xlabel("Feature")
plt.ylabel("Time (or Task)")
plt.text(10, 5, "MultiTaskLasso")
fig.suptitle("Coefficient non-zero location")

feature_to_plot = 0
plt.figure()
lw = 2
plt.plot(coef[:, feature_to_plot], color="seagreen", linewidth=lw, label="Ground truth")
plt.plot(
    coef_lasso_[:, feature_to_plot], color="cornflowerblue", linewidth=lw, label="Lasso"
)
plt.plot(
    coef_multi_task_lasso_[:, feature_to_plot],
    color="gold",
    linewidth=lw,
    label="MultiTaskLasso",
)
plt.legend(loc="upper center")
plt.axis("tight")
plt.ylim([-1.1, 1.1])
plt.show()
  • Coefficient non-zero location
  • plot multi task lasso support

Gesamtlaufzeit des Skripts: (0 Minuten 0,180 Sekunden)

Verwandte Beispiele

L1-basierte Modelle für sparse Signale

L1-basierte Modelle für sparse Signale

Lasso auf dichten und spärlichen Daten

Lasso auf dichten und spärlichen Daten

Lasso, Lasso-LARS und Elastic Net Pfade

Lasso, Lasso-LARS und Elastic Net Pfade

Lasso-Modellauswahl: AIC-BIC / Kreuzvalidierung

Lasso-Modellauswahl: AIC-BIC / Kreuzvalidierung

Galerie generiert von Sphinx-Gallery