备注
Go to the end 下载完整的示例代码。或者通过浏览器中的MysterLite或Binder运行此示例
元数据路由#
本文档展示了如何使用 metadata routing mechanism 在scikit-learn中,将元数据路由到使用它们的估计器、评分器和CV拆分器。
为了更好地理解以下文档,我们需要介绍两个概念:路由器和消费者。路由器是一个将一些给定数据和元数据转发到其他对象的对象。在大多数情况下,路由器是 meta-estimator ,即以另一个估计器为参数的估计器。功能如 sklearn.model_selection.cross_validate
它以估计器为参数并转发数据和元数据,也是一个路由器。
另一方面,消费者是接受和使用一些给定元数据的对象。例如,考虑到的估计器 sample_weight
在其 fit 方法是的消费者 sample_weight
.
一个对象可能既是路由器又是消费者。例如,元估计器可能会考虑 sample_weight
在某些计算中,但它也可能将其路由到基础估计器。
首先,为脚本的其余部分导入一些数据和一些随机数据。
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from pprint import pprint
import numpy as np
from sklearn import set_config
from sklearn.base import (
BaseEstimator,
ClassifierMixin,
MetaEstimatorMixin,
RegressorMixin,
TransformerMixin,
clone,
)
from sklearn.linear_model import LinearRegression
from sklearn.utils import metadata_routing
from sklearn.utils.metadata_routing import (
MetadataRouter,
MethodMapping,
get_routing_for_object,
process_routing,
)
from sklearn.utils.validation import check_is_fitted
n_samples, n_features = 100, 4
rng = np.random.RandomState(42)
X = rng.rand(n_samples, n_features)
y = rng.randint(0, 2, size=n_samples)
my_groups = rng.randint(0, 10, size=n_samples)
my_weights = rng.rand(n_samples)
my_other_weights = rng.rand(n_samples)
元数据路由仅在显式启用时才可用:
set_config(enable_metadata_routing=True)
此实用函数是一个哑函数,用于检查是否传递了元数据:
def check_metadata(obj, **kwargs):
for key, value in kwargs.items():
if value is not None:
print(
f"Received {key} of length = {len(value)} in {obj.__class__.__name__}."
)
else:
print(f"{key} is None in {obj.__class__.__name__}.")
一个实用函数,可以很好地打印对象的路由信息:
def print_routing(obj):
pprint(obj.get_metadata_routing()._serialize())
消耗估算#
在这里,我们演示了估计器如何公开所需的API,以支持作为消费者的元数据路由。想象一个简单的分类器接受 sample_weight
作为其元数据 fit
和 groups
在其 predict
方法:
class ExampleClassifier(ClassifierMixin, BaseEstimator):
def fit(self, X, y, sample_weight=None):
check_metadata(self, sample_weight=sample_weight)
# all classifiers need to expose a classes_ attribute once they're fit.
self.classes_ = np.array([0, 1])
return self
def predict(self, X, groups=None):
check_metadata(self, groups=groups)
# return a constant value of 1, not a very smart classifier!
return np.ones(len(X))
上述估计器现在拥有了消费元数据所需的一切。这是通过一些魔法完成的 BaseEstimator
.上面的课程现在公开了三种方法: set_fit_request
, set_predict_request
,而且 get_metadata_routing
.还有一个 set_score_request
为 sample_weight
它存在于 ClassifierMixin
实现了一个 score
接受方法 sample_weight
.这同样适用于继承自 RegressorMixin
.
默认情况下,不请求元数据,我们可以看到:
print_routing(ExampleClassifier())
{'fit': {'sample_weight': None},
'predict': {'groups': None},
'score': {'sample_weight': None}}
上述输出意味着 sample_weight
和 groups
未要求 ExampleClassifier
,如果为路由器提供了这些元数据,则应该会引发错误,因为用户没有明确设置是否需要这些元数据。也是如此 sample_weight
在 score
方法,它继承自 ClassifierMixin
.为了显式地设置这些元数据的请求值,我们可以使用以下方法:
est = (
ExampleClassifier()
.set_fit_request(sample_weight=False)
.set_predict_request(groups=True)
.set_score_request(sample_weight=False)
)
print_routing(est)
{'fit': {'sample_weight': False},
'predict': {'groups': True},
'score': {'sample_weight': False}}
备注
请注意,只要元估计器中不使用上述估计器,用户就不需要设置对元数据的任何请求,并且设置的值将被忽略,因为消费者不会验证或路由给定的元数据。上述估计器的简单使用将按预期工作。
est = ExampleClassifier()
est.fit(X, y, sample_weight=my_weights)
est.predict(X[:3, :], groups=my_groups)
Received sample_weight of length = 100 in ExampleClassifier.
Received groups of length = 100 in ExampleClassifier.
array([1., 1., 1.])
路由元估计器#
现在,我们展示如何将元估计器设计为路由器。作为一个简化的示例,这是一个元估计器,除了路由元数据之外,它不会做太多事情。
class MetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
def __init__(self, estimator):
self.estimator = estimator
def get_metadata_routing(self):
# This method defines the routing for this meta-estimator.
# In order to do so, a `MetadataRouter` instance is created, and the
# routing is added to it. More explanations follow below.
router = MetadataRouter(owner=self.__class__.__name__).add(
estimator=self.estimator,
method_mapping=MethodMapping()
.add(caller="fit", callee="fit")
.add(caller="predict", callee="predict")
.add(caller="score", callee="score"),
)
return router
def fit(self, X, y, **fit_params):
# `get_routing_for_object` returns a copy of the `MetadataRouter`
# constructed by the above `get_metadata_routing` method, that is
# internally called.
request_router = get_routing_for_object(self)
# Meta-estimators are responsible for validating the given metadata.
# `method` refers to the parent's method, i.e. `fit` in this example.
request_router.validate_metadata(params=fit_params, method="fit")
# `MetadataRouter.route_params` maps the given metadata to the metadata
# required by the underlying estimator based on the routing information
# defined by the MetadataRouter. The output of type `Bunch` has a key
# for each consuming object and those hold keys for their consuming
# methods, which then contain key for the metadata which should be
# routed to them.
routed_params = request_router.route_params(params=fit_params, caller="fit")
# A sub-estimator is fitted and its classes are attributed to the
# meta-estimator.
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
self.classes_ = self.estimator_.classes_
return self
def predict(self, X, **predict_params):
check_is_fitted(self)
# As in `fit`, we get a copy of the object's MetadataRouter,
request_router = get_routing_for_object(self)
# then we validate the given metadata,
request_router.validate_metadata(params=predict_params, method="predict")
# and then prepare the input to the underlying `predict` method.
routed_params = request_router.route_params(
params=predict_params, caller="predict"
)
return self.estimator_.predict(X, **routed_params.estimator.predict)
让我们分解上面代码的不同部分。
一是 get_routing_for_object
采用我们的元估计器 (self
)并返回 MetadataRouter
或者在一种 MetadataRequest
如果对象是消费者,则根据估计器的输出 get_metadata_routing
法
Then in each method, we use the route_params
method to construct a
dictionary of the form {"object_name": {"method_name": {"metadata":
value}}}
to pass to the underlying estimator's method. The object_name
(estimator
in the above routed_params.estimator.fit
example) is the
same as the one added in the get_metadata_routing
. validate_metadata
makes sure all given metadata are requested to avoid silent bugs.
接下来,我们说明不同的行为,特别是所提出的错误类型。
meta_est = MetaClassifier(
estimator=ExampleClassifier().set_fit_request(sample_weight=True)
)
meta_est.fit(X, y, sample_weight=my_weights)
Received sample_weight of length = 100 in ExampleClassifier.
请注意,上面的例子正在调用我们的实用函数 check_metadata()
经由 ExampleClassifier
.它检查 sample_weight
正确传递给它。如果不是,就像下面的例子中那样,它会打印该内容 sample_weight
是 None
:
meta_est.fit(X, y)
sample_weight is None in ExampleClassifier.
如果我们传递未知的元数据,则会引发错误:
try:
meta_est.fit(X, y, test=my_weights)
except TypeError as e:
print(e)
MetaClassifier.fit got unexpected argument(s) {'test'}, which are not routed to any object.
如果我们传递未明确请求的元数据:
try:
meta_est.fit(X, y, sample_weight=my_weights).predict(X, groups=my_groups)
except ValueError as e:
print(e)
Received sample_weight of length = 100 in ExampleClassifier.
[groups] are passed but are not explicitly set as requested or not requested for ExampleClassifier.predict, which is used within MetaClassifier.predict. Call `ExampleClassifier.set_predict_request({metadata}=True/False)` for each metadata you want to request/ignore.
另外,如果我们显式地将其设置为未请求,但已提供:
meta_est = MetaClassifier(
estimator=ExampleClassifier()
.set_fit_request(sample_weight=True)
.set_predict_request(groups=False)
)
try:
meta_est.fit(X, y, sample_weight=my_weights).predict(X[:3, :], groups=my_groups)
except TypeError as e:
print(e)
Received sample_weight of length = 100 in ExampleClassifier.
MetaClassifier.predict got unexpected argument(s) {'groups'}, which are not routed to any object.
另一个概念是 aliased metadata .这是当估计器请求变量名称与默认变量名称不同的元数据时。例如,在管道中有两个估计器的设置中,可以请求 sample_weight1
和其他 sample_weight2
.请注意,这不会改变估计器的期望,它只是告诉元估计器如何将提供的元数据映射到所需的内容。这是一个例子,我们通过 aliased_sample_weight
对于元估计器来说,但元估计器明白 aliased_sample_weight
的别名 sample_weight
,并通过它作为 sample_weight
到基本估计器:
meta_est = MetaClassifier(
estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
)
meta_est.fit(X, y, aliased_sample_weight=my_weights)
Received sample_weight of length = 100 in ExampleClassifier.
通过 sample_weight
这里将失败,因为它是用别名请求的,并且 sample_weight
不需要使用该名称:
try:
meta_est.fit(X, y, sample_weight=my_weights)
except TypeError as e:
print(e)
MetaClassifier.fit got unexpected argument(s) {'sample_weight'}, which are not routed to any object.
这使我们得出 get_metadata_routing
. scikit-learn中路由的工作方式是消费者请求他们需要的内容,然后路由器将其传递。此外,路由器暴露了它自己的要求,以便可以在另一个路由器内使用,例如网格搜索对象内的管道。的输出 get_metadata_routing
这是一个的字典表示 MetadataRouter
,包括所有嵌套对象请求的元数据的完整树及其相应的方法路由,即子估计器的哪种方法用于元估计器的哪种方法:
print_routing(meta_est)
{'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
{'callee': 'predict', 'caller': 'predict'},
{'callee': 'score', 'caller': 'score'}],
'router': {'fit': {'sample_weight': 'aliased_sample_weight'},
'predict': {'groups': None},
'score': {'sample_weight': None}}}}
如您所见,方法请求的唯一元数据 fit
是 "sample_weight"
与 "aliased_sample_weight"
作为别名。的 ~utils.metadata_routing.MetadataRouter
类使我们能够轻松创建路由对象,该对象将创建我们所需的输出 get_metadata_routing
.
为了了解别名如何在元估计器中工作,请想象我们的元估计器位于另一个估计器中:
meta_meta_est = MetaClassifier(estimator=meta_est).fit(
X, y, aliased_sample_weight=my_weights
)
Received sample_weight of length = 100 in ExampleClassifier.
在上面的例子中,这就是 fit
方法 meta_meta_est
将调用其子估计器的 fit
方法::
# user feeds `my_weights` as `aliased_sample_weight` into `meta_meta_est`:
meta_meta_est.fit(X, y, aliased_sample_weight=my_weights):
...
# the first sub-estimator (`meta_est`) expects `aliased_sample_weight`
self.estimator_.fit(X, y, aliased_sample_weight=aliased_sample_weight):
...
# the second sub-estimator (`est`) expects `sample_weight`
self.estimator_.fit(X, y, sample_weight=aliased_sample_weight):
...
使用和路由元估计器#
对于稍微复杂一点的示例,考虑一个元估计器,它像以前一样将元数据路由到底层估计器,但它也在自己的方法中使用了一些元数据。这个元估计器同时是消费者和路由器。实现一个与我们以前的非常相似,但做了一些调整。
class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
def __init__(self, estimator):
self.estimator = estimator
def get_metadata_routing(self):
router = (
MetadataRouter(owner=self.__class__.__name__)
# defining metadata routing request values for usage in the meta-estimator
.add_self_request(self)
# defining metadata routing request values for usage in the sub-estimator
.add(
estimator=self.estimator,
method_mapping=MethodMapping()
.add(caller="fit", callee="fit")
.add(caller="predict", callee="predict")
.add(caller="score", callee="score"),
)
)
return router
# Since `sample_weight` is used and consumed here, it should be defined as
# an explicit argument in the method's signature. All other metadata which
# are only routed, will be passed as `**fit_params`:
def fit(self, X, y, sample_weight, **fit_params):
if self.estimator is None:
raise ValueError("estimator cannot be None!")
check_metadata(self, sample_weight=sample_weight)
# We add `sample_weight` to the `fit_params` dictionary.
if sample_weight is not None:
fit_params["sample_weight"] = sample_weight
request_router = get_routing_for_object(self)
request_router.validate_metadata(params=fit_params, method="fit")
routed_params = request_router.route_params(params=fit_params, caller="fit")
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
self.classes_ = self.estimator_.classes_
return self
def predict(self, X, **predict_params):
check_is_fitted(self)
# As in `fit`, we get a copy of the object's MetadataRouter,
request_router = get_routing_for_object(self)
# we validate the given metadata,
request_router.validate_metadata(params=predict_params, method="predict")
# and then prepare the input to the underlying ``predict`` method.
routed_params = request_router.route_params(
params=predict_params, caller="predict"
)
return self.estimator_.predict(X, **routed_params.estimator.predict)
上述元估计器与我们之前的元估计器不同的关键部分是接受 sample_weight
中明确 fit
并将其纳入 fit_params
.以来 sample_weight
是一个明确的论点,我们可以肯定 set_fit_request(sample_weight=...)
是针对这种方法提出的。元估计器既是消费者,也是路由器。 sample_weight
.
在 get_metadata_routing
,我们添加 self
使用路由 add_self_request
以表明此估计器正在消耗 sample_weight
以及作为路由器;这还增加了 $self_request
输入如下所示的路由信息。现在让我们看看一些例子:
未请求元数据
meta_est = RouterConsumerClassifier(estimator=ExampleClassifier())
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
'score': {'sample_weight': None}},
'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
{'callee': 'predict', 'caller': 'predict'},
{'callee': 'score', 'caller': 'score'}],
'router': {'fit': {'sample_weight': None},
'predict': {'groups': None},
'score': {'sample_weight': None}}}}
sample_weight
由次估算者要求
meta_est = RouterConsumerClassifier(
estimator=ExampleClassifier().set_fit_request(sample_weight=True)
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
'score': {'sample_weight': None}},
'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
{'callee': 'predict', 'caller': 'predict'},
{'callee': 'score', 'caller': 'score'}],
'router': {'fit': {'sample_weight': True},
'predict': {'groups': None},
'score': {'sample_weight': None}}}}
sample_weight
由元估计器请求
meta_est = RouterConsumerClassifier(estimator=ExampleClassifier()).set_fit_request(
sample_weight=True
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': True},
'score': {'sample_weight': None}},
'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
{'callee': 'predict', 'caller': 'predict'},
{'callee': 'score', 'caller': 'score'}],
'router': {'fit': {'sample_weight': None},
'predict': {'groups': None},
'score': {'sample_weight': None}}}}
Note the difference in the requested metadata representations above.
我们还可以别名元数据,以将不同的值传递给Meta估计器和次估计器的匹配方法:
meta_est = RouterConsumerClassifier(
estimator=ExampleClassifier().set_fit_request(sample_weight="clf_sample_weight"),
).set_fit_request(sample_weight="meta_clf_sample_weight")
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': 'meta_clf_sample_weight'},
'score': {'sample_weight': None}},
'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
{'callee': 'predict', 'caller': 'predict'},
{'callee': 'score', 'caller': 'score'}],
'router': {'fit': {'sample_weight': 'clf_sample_weight'},
'predict': {'groups': None},
'score': {'sample_weight': None}}}}
然而, fit
元估计器的别名仅需要子估计器的别名,并将其自己的样本权重处理为 sample_weight
,因为它不会验证和路由自己所需的元数据:
meta_est.fit(X, y, sample_weight=my_weights, clf_sample_weight=my_other_weights)
Received sample_weight of length = 100 in RouterConsumerClassifier.
Received sample_weight of length = 100 in ExampleClassifier.
仅对次估计器进行验证:
当我们不希望元估计器使用元数据而次估计器应该使用元数据时,这很有用。
meta_est = RouterConsumerClassifier(
estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
'score': {'sample_weight': None}},
'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
{'callee': 'predict', 'caller': 'predict'},
{'callee': 'score', 'caller': 'score'}],
'router': {'fit': {'sample_weight': 'aliased_sample_weight'},
'predict': {'groups': None},
'score': {'sample_weight': None}}}}
元估计器无法使用 aliased_sample_weight
,因为它预计它通过了 sample_weight
.即使如此,这也适用 set_fit_request(sample_weight=True)
已经设定在它上了。
管路简单#
一个稍微复杂的用例是类似于 Pipeline
.这是一个元估计器,它接受Transformer和分类器。当呼叫其时 fit
方法,它应用Transformer的 fit
和 transform
然后对转换后的数据运行分类器。后 predict
,它适用于Transformer的 transform
在用分类器的 predict
方法转换的新数据。
class SimplePipeline(ClassifierMixin, BaseEstimator):
def __init__(self, transformer, classifier):
self.transformer = transformer
self.classifier = classifier
def get_metadata_routing(self):
router = (
MetadataRouter(owner=self.__class__.__name__)
# We add the routing for the transformer.
.add(
transformer=self.transformer,
method_mapping=MethodMapping()
# The metadata is routed such that it retraces how
# `SimplePipeline` internally calls the transformer's `fit` and
# `transform` methods in its own methods (`fit` and `predict`).
.add(caller="fit", callee="fit")
.add(caller="fit", callee="transform")
.add(caller="predict", callee="transform"),
)
# We add the routing for the classifier.
.add(
classifier=self.classifier,
method_mapping=MethodMapping()
.add(caller="fit", callee="fit")
.add(caller="predict", callee="predict"),
)
)
return router
def fit(self, X, y, **fit_params):
routed_params = process_routing(self, "fit", **fit_params)
self.transformer_ = clone(self.transformer).fit(
X, y, **routed_params.transformer.fit
)
X_transformed = self.transformer_.transform(
X, **routed_params.transformer.transform
)
self.classifier_ = clone(self.classifier).fit(
X_transformed, y, **routed_params.classifier.fit
)
return self
def predict(self, X, **predict_params):
routed_params = process_routing(self, "predict", **predict_params)
X_transformed = self.transformer_.transform(
X, **routed_params.transformer.transform
)
return self.classifier_.predict(
X_transformed, **routed_params.classifier.predict
)
注意的使用 MethodMapping
声明子估计器(被调用者)的哪些方法用于Meta估计器(调用者)的哪些方法。如你所见, SimplePipeline
使用Transformer的 transform
和 fit
方法 fit
,及其 transform
法在 predict
,这就是您在管道类的路由结构中看到的实现。
上例与前例的另一个不同之处在于 process_routing
,它处理输入参数,进行所需的验证,并返回 routed_params
它是我们在之前的示例中创建的。这减少了开发人员需要在每个元估计器的方法中编写的样板代码。强烈建议开发人员使用此功能,除非有充分的理由反对。
为了测试上述管道,让我们添加一个示例Transformer。
class ExampleTransformer(TransformerMixin, BaseEstimator):
def fit(self, X, y, sample_weight=None):
check_metadata(self, sample_weight=sample_weight)
return self
def transform(self, X, groups=None):
check_metadata(self, groups=groups)
return X
def fit_transform(self, X, y, sample_weight=None, groups=None):
return self.fit(X, y, sample_weight).transform(X, groups)
请注意,在上面的例子中,我们实现了 fit_transform
其中呼吁 fit
和 transform
与适当的元数据。只有在以下情况下才需要此功能: transform
接受元数据,因为默认 fit_transform
实施 TransformerMixin
不会将元数据传递给 transform
.
现在我们可以测试我们的管道,看看元数据是否正确传递。此示例使用我们的 SimplePipeline
,我们的 ExampleTransformer
,我们的 RouterConsumerClassifier
它使用我们的 ExampleClassifier
.
pipe = SimplePipeline(
transformer=ExampleTransformer()
# we set transformer's fit to receive sample_weight
.set_fit_request(sample_weight=True)
# we set transformer's transform to receive groups
.set_transform_request(groups=True),
classifier=RouterConsumerClassifier(
estimator=ExampleClassifier()
# we want this sub-estimator to receive sample_weight in fit
.set_fit_request(sample_weight=True)
# but not groups in predict
.set_predict_request(groups=False),
)
# and we want the meta-estimator to receive sample_weight as well
.set_fit_request(sample_weight=True),
)
pipe.fit(X, y, sample_weight=my_weights, groups=my_groups).predict(
X[:3], groups=my_groups
)
Received sample_weight of length = 100 in ExampleTransformer.
Received groups of length = 100 in ExampleTransformer.
Received sample_weight of length = 100 in RouterConsumerClassifier.
Received sample_weight of length = 100 in ExampleClassifier.
Received groups of length = 100 in ExampleTransformer.
groups is None in ExampleClassifier.
array([1., 1., 1.])
弃用/默认值更改#
在本节中,我们将展示如何处理路由器也成为消费者的情况,特别是当它消费与其子估计器相同的元数据时,或者消费者开始消费旧版本中没有的元数据时。在这种情况下,应该发出一段时间的警告,让用户知道该行为与以前的版本相比已更改。
class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
def __init__(self, estimator):
self.estimator = estimator
def fit(self, X, y, **fit_params):
routed_params = process_routing(self, "fit", **fit_params)
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
def get_metadata_routing(self):
router = MetadataRouter(owner=self.__class__.__name__).add(
estimator=self.estimator,
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
)
return router
如上所述,如果 my_weights
不应该通过作为 sample_weight
到 MetaRegressor
:
reg = MetaRegressor(estimator=LinearRegression().set_fit_request(sample_weight=True))
reg.fit(X, y, sample_weight=my_weights)
现在想象我们进一步开发 MetaRegressor
而现在, consumes sample_weight
:
class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
# show warning to remind user to explicitly set the value with
# `.set_{method}_request(sample_weight={boolean})`
__metadata_request__fit = {"sample_weight": metadata_routing.WARN}
def __init__(self, estimator):
self.estimator = estimator
def fit(self, X, y, sample_weight=None, **fit_params):
routed_params = process_routing(
self, "fit", sample_weight=sample_weight, **fit_params
)
check_metadata(self, sample_weight=sample_weight)
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
def get_metadata_routing(self):
router = (
MetadataRouter(owner=self.__class__.__name__)
.add_self_request(self)
.add(
estimator=self.estimator,
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
)
)
return router
上述实现几乎与 MetaRegressor
,并且由于中定义的默认请求值 __metadata_request__fit
安装时会发出警告。
with warnings.catch_warnings(record=True) as record:
WeightedMetaRegressor(
estimator=LinearRegression().set_fit_request(sample_weight=False)
).fit(X, y, sample_weight=my_weights)
for w in record:
print(w.message)
Received sample_weight of length = 100 in WeightedMetaRegressor.
Support for sample_weight has recently been added to this class. To maintain backward compatibility, it is ignored now. Using `set_fit_request(sample_weight={True, False})` on this method of the class, you can set the request value to False to silence this warning, or to True to consume and use the metadata.
当估计器消耗以前没有消耗的元数据时,可以使用以下模式来警告用户。
class ExampleRegressor(RegressorMixin, BaseEstimator):
__metadata_request__fit = {"sample_weight": metadata_routing.WARN}
def fit(self, X, y, sample_weight=None):
check_metadata(self, sample_weight=sample_weight)
return self
def predict(self, X):
return np.zeros(shape=(len(X)))
with warnings.catch_warnings(record=True) as record:
MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights)
for w in record:
print(w.message)
sample_weight is None in ExampleRegressor.
Support for sample_weight has recently been added to this class. To maintain backward compatibility, it is ignored now. Using `set_fit_request(sample_weight={True, False})` on this method of the class, you can set the request value to False to silence this warning, or to True to consume and use the metadata.
最后,我们禁用元数据路由的配置标志:
set_config(enable_metadata_routing=False)
第三方开发和scikit-learn依赖#
如上所示,信息在类之间使用 MetadataRequest
和 MetadataRouter
.强烈建议不这样做,但如果您严格希望拥有一个scikit-learn兼容的估计器,而不依赖scikit-learn包,则可以提供与元数据路由相关的工具。如果满足以下所有条件,您根本不需要修改代码:
您的估计器继承自
BaseEstimator
估计器方法消耗的参数,例如
fit
,在方法的签名中显式定义,而不是*args
或*kwargs
.您的估计器不会将任何元数据路由到底层对象,即它不是 router .
Total running time of the script: (0分0.049秒)
相关实例
Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>
_