4. 元数据路由#
备注
元数据路由API是实验性的,尚未针对所有估计器实现。请参阅 list of supported and unsupported models for more information.它可能会在没有通常的弃用周期的情况下发生变化。默认情况下,此功能未启用。您可以通过设置 enable_metadata_routing
标志以 True
>>> import sklearn
>>> sklearn.set_config(enable_metadata_routing=True)
请注意,本文档中介绍的方法和要求仅在您希望通过 metadata (例如 sample_weight
)到一种方法。如果你只是路过 X
和 y
并且没有其他参数/元数据到方法,例如 fit , transform 等,那么你就不需要设置任何东西了。
本指南演示了如何 metadata 可以在scikit-learn中的对象之间路由和传递。如果您正在开发scikit-learn兼容的估计器或元估计器,您可以查看我们的相关开发人员指南: 元数据路由 .
元数据是估计器、评分器或CV拆分器在用户显式将其作为参数传递时考虑的数据。例如, KMeans
接受 sample_weight
在其 fit()
方法,并考虑它来计算其质心。 classes
被一些分类器消耗, groups
在某些拆分器中使用,但除了X和y之外,传递到对象方法中的任何数据都可以被视为元数据。在scikit-learn版本1.3之前,没有单一的API可以像这些对象与其他对象结合使用时那样传递元数据,例如评分器接受 sample_weight
内 GridSearchCV
.
使用元数据路由API,我们可以使用以下方式将元数据传输到估计器、评分器和CV拆分器 meta-estimators (such作为 Pipeline
或 GridSearchCV
)或功能,例如 cross_validate
其将数据路由到其他对象。为了将元数据传递给类似 fit
或 score
使用元数据的对象必须 request 它。这是通过 set_{method}_request()
方法,其中 {method}
由请求元数据的方法的名称取代。例如,使用元数据的估计器 fit()
方法将使用 set_fit_request()
,得分者会使用 set_score_request()
.例如,这些方法允许我们指定要请求哪些元数据 set_fit_request(sample_weight=True)
.
对于分组拆分器,例如 GroupKFold
、a groups
默认情况下请求参数。下面的例子最好地证明了这一点。
4.1. 用法示例#
在这里,我们提供了一些示例来展示一些常见的用例。我们的目标是通过 sample_weight
和 groups
通过 cross_validate
,它将元数据路由到 LogisticRegressionCV
以及定制得分手 make_scorer
,这两者 can 在他们的方法中使用元数据。在这些示例中,我们希望单独设置是否在不同的 consumers .
本节中的示例需要以下导入和数据::
>>> import numpy as np
>>> from sklearn.metrics import make_scorer, accuracy_score
>>> from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
>>> from sklearn.model_selection import cross_validate, GridSearchCV, GroupKFold
>>> from sklearn.feature_selection import SelectKBest
>>> from sklearn.pipeline import make_pipeline
>>> n_samples, n_features = 100, 4
>>> rng = np.random.RandomState(42)
>>> X = rng.rand(n_samples, n_features)
>>> y = rng.randint(0, 2, size=n_samples)
>>> my_groups = rng.randint(0, 10, size=n_samples)
>>> my_weights = rng.rand(n_samples)
>>> my_other_weights = rng.rand(n_samples)
4.1.1. 加权评分和匹配#
The splitter used internally in LogisticRegressionCV
,
GroupKFold
, requests groups
by default. However, we need
to explicitly request sample_weight
for it and for our custom scorer by specifying
sample_weight=True
in LogisticRegressionCV
's set_fit_request()
method and in make_scorer
's set_score_request()
method. Both
consumers know how to use sample_weight
in their fit()
or
score()
methods. We can then pass the metadata in
cross_validate
which will route it to any active consumers:
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(sample_weight=True)
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(),
... scoring=weighted_acc
... ).set_fit_request(sample_weight=True)
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... params={"sample_weight": my_weights, "groups": my_groups},
... cv=GroupKFold(),
... scoring=weighted_acc,
... )
请注意,在这个例子中, cross_validate
航线 my_weights
对得分手和 LogisticRegressionCV
.
If we would pass sample_weight
in the params of
cross_validate
, but not set any object to request it,
UnsetMetadataPassedError
would be raised, hinting to us that we need to explicitly set
where to route it. The same applies if params={"sample_weights": my_weights, ...}
were passed (note the typo, i.e. weights
instead of weight
), since
sample_weights
was not requested by any of its underlying objects.
4.1.2. 加权评分和非加权拟合#
When passing metadata such as sample_weight
into a router
(meta-estimators or routing function), all sample_weight
consumers require weights to be either explicitly requested or explicitly not
requested (i.e. True
or False
). Thus, to perform an unweighted fit, we need to
configure LogisticRegressionCV
to not request sample weights, so
that cross_validate
does not pass the weights along:
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(sample_weight=True)
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).set_fit_request(sample_weight=False)
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... cv=GroupKFold(),
... params={"sample_weight": my_weights, "groups": my_groups},
... scoring=weighted_acc,
... )
如果 linear_model.LogisticRegressionCV.set_fit_request
还没有被叫到, cross_validate
会引发错误,因为 sample_weight
已通过, LogisticRegressionCV
不会被显式配置为识别权重。
4.1.3. 未加权特征选择#
只有当对象的方法知道如何使用元数据时,才可能路由元数据,这在大多数情况下意味着他们将其作为显式参数。只有这样,我们才能使用设置元数据的请求值 set_fit_request(sample_weight=True)
例如。这使对象成为 consumer .
Unlike LogisticRegressionCV
,
SelectKBest
can't consume weights and therefore no request
value for sample_weight
on its instance is set and sample_weight
is not routed
to it:
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(sample_weight=True)
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).set_fit_request(sample_weight=True)
>>> sel = SelectKBest(k=2)
>>> pipe = make_pipeline(sel, lr)
>>> cv_results = cross_validate(
... pipe,
... X,
... y,
... cv=GroupKFold(),
... params={"sample_weight": my_weights, "groups": my_groups},
... scoring=weighted_acc,
... )
4.1.4. 不同的评分和适合权重#
尽管 make_scorer
和 LogisticRegressionCV
都期待钥匙 sample_weight
,我们可以使用别名将不同的权重传递给不同的消费者。在这个例子中,我们通过 scoring_weight
对于得分手,并且 fitting_weight
到 LogisticRegressionCV
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
... sample_weight="scoring_weight"
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).set_fit_request(sample_weight="fitting_weight")
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... cv=GroupKFold(),
... params={
... "scoring_weight": my_weights,
... "fitting_weight": my_other_weights,
... "groups": my_groups,
... },
... scoring=weighted_acc,
... )
4.2. API接口#
A consumer 是一个接受并使用一些的对象(估计器、元估计器、记分器、拆分器) metadata 在其至少一种方法中(例如 fit
, predict
, inverse_transform
, transform
, score
, split
).只将元数据转发给其他对象(子估计器、评分器或拆分器)而不使用元数据本身的元估计器不是消费者。(Meta-)将元数据路由到其他对象的估计器是 routers .(n)(Meta)估计量可以是 consumer 和 router 同时还研究与讨论(Meta-)估计器和拆分器暴露了 set_{method}_request
每个接受至少一个元数据的方法的方法。例如,如果估计者支持 sample_weight
在 fit
和 score
,它暴露了 estimator.set_fit_request(sample_weight=value)
和 estimator.set_score_request(sample_weight=value)
.这里 value
可以是:
True
:方法请求一个sample_weight
.这意味着如果提供了元数据,就会使用它,否则不会引发错误。False
:方法不请求sample_weight
.None
:如果出现以下情况,路由器将引发错误sample_weight
已通过。几乎在所有情况下,这都是实例化对象时的默认值,并确保用户在传递元数据时显式设置元数据请求。唯一例外是Group*Fold
分配器。"param_name"
:别名sample_weight
如果我们想将不同的权重传递给不同的消费者。如果使用混叠,元估计器不应转发"param_name"
对消费者来说,但是sample_weight
相反,因为消费者会期待一个名为sample_weight
.这意味着对象所需的元数据之间的映射,例如sample_weight
以及用户提供的变量名,例如my_weights
是在路由器级别完成的,而不是由消费对象本身完成的。
对于评分者来说,以相同的方式请求元数据 set_score_request
.
如果元数据,例如 sample_weight
,由用户传递针对可能消费的所有对象的元数据请求 sample_weight
应由用户设置,否则路由器对象会引发错误。例如,以下代码引发错误,因为尚未显式指定是否 sample_weight
是否应传递给估算者的评分者::
>>> param_grid = {"C": [0.1, 1]}
>>> lr = LogisticRegression().set_fit_request(sample_weight=True)
>>> try:
... GridSearchCV(
... estimator=lr, param_grid=param_grid
... ).fit(X, y, sample_weight=my_weights)
... except ValueError as e:
... print(e)
[sample_weight] are passed but are not explicitly set as requested or not
requested for LogisticRegression.score, which is used within GridSearchCV.fit.
Call `LogisticRegression.set_score_request({metadata}=True/False)` for each metadata
you want to request/ignore. See the Metadata Routing User guide
<https://scikit-learn.org/stable/metadata_routing.html> for more information.
可以通过显式设置请求值来修复此问题:
>>> lr = LogisticRegression().set_fit_request(
... sample_weight=True
... ).set_score_request(sample_weight=False)
结束时 Usage Examples 部分,我们禁用元数据路由的配置标志::
>>> sklearn.set_config(enable_metadata_routing=False)
4.3. 元数据路由支持状态#
所有消费者(即仅消费元数据而不路由元数据的简单估计器)支持元数据路由,这意味着它们可以在支持元数据路由的元估计器内部使用。然而,元估计器的元数据路由支持的开发正在进行中,以下是支持和尚未支持元数据路由的元估计器和工具的列表。
元估计器和支持元数据路由的功能:
元估计器和工具还不支持元数据路由: