使用预先计算的格拉姆矩阵和加权样本来匹配弹性网络#

下面的示例展示了如何在使用加权样本时预计算gram矩阵 ElasticNet .

如果使用加权样本,则在计算gram矩阵之前,必须先将设计矩阵居中,然后根据权重矩阵的平方根重新缩放。

备注

sample_weight 也会重新缩放向总和 n_samples ,请参阅

文档 sample_weight 参数以 fit .

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

让我们首先加载数据集并创建一些样本权重。

import numpy as np

from sklearn.datasets import make_regression

rng = np.random.RandomState(0)

n_samples = int(1e5)
X, y = make_regression(n_samples=n_samples, noise=0.5, random_state=rng)

sample_weight = rng.lognormal(size=n_samples)
# normalize the sample weights
normalized_weights = sample_weight * (n_samples / (sample_weight.sum()))

要安装弹性网,请使用 precompute 选项与样本权重一起使用时,我们必须首先将设计矩阵居中,并在计算gram矩阵之前通过归一化权重对其进行重新缩放。

X_offset = np.average(X, axis=0, weights=normalized_weights)
X_centered = X - np.average(X, axis=0, weights=normalized_weights)
X_scaled = X_centered * np.sqrt(normalized_weights)[:, np.newaxis]
gram = np.dot(X_scaled.T, X_scaled)

我们现在可以继续进行试穿。我们必须通过中心设计矩阵, fit 否则,弹性净估计器将检测到它是非中心的,并丢弃我们传递的gram矩阵。然而,如果我们通过缩放的设计矩阵,预处理代码将第二次错误地重新缩放它。

from sklearn.linear_model import ElasticNet

lm = ElasticNet(alpha=0.01, precompute=gram)
lm.fit(X_centered, y, sample_weight=normalized_weights)
ElasticNet(alpha=0.01,
           precompute=array([[ 9.98809919e+04, -4.48938813e+02, -1.03237920e+03, ...,
        -2.25349312e+02, -3.53959628e+02, -1.67451144e+02],
       [-4.48938813e+02,  1.00768662e+05,  1.19112072e+02, ...,
        -1.07963978e+03,  7.47987268e+01, -5.76195467e+02],
       [-1.03237920e+03,  1.19112072e+02,  1.00393284e+05, ...,
        -3.07582983e+02,  6.66670169e+02,  2.65799352e+02],
       ...,
       [-2.25349312e+02, -1.07963978e+03, -3.07582983e+02, ...,
         9.99891212e+04, -4.58195950e+02, -1.58667835e+02],
       [-3.53959628e+02,  7.47987268e+01,  6.66670169e+02, ...,
        -4.58195950e+02,  9.98350372e+04,  5.60836363e+02],
       [-1.67451144e+02, -5.76195467e+02,  2.65799352e+02, ...,
        -1.58667835e+02,  5.60836363e+02,  1.00911944e+05]],
      shape=(100, 100)))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Total running time of the script: (0分0.646秒)

相关实例

支持者:加权样本

SVM: Weighted samples

新元:加权样本

SGD: Weighted samples

元数据路由

Metadata Routing

最近邻回归

Nearest Neighbors regression

Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io> _