元数据路由#

本文档展示了如何使用 metadata routing mechanism 在scikit-learn中,将元数据路由到使用它们的估计器、评分器和CV拆分器。

为了更好地理解以下文档,我们需要介绍两个概念:路由器和消费者。路由器是一个将一些给定数据和元数据转发到其他对象的对象。在大多数情况下,路由器是 meta-estimator ,即以另一个估计器为参数的估计器。功能如 sklearn.model_selection.cross_validate 它以估计器为参数并转发数据和元数据,也是一个路由器。

另一方面,消费者是接受和使用一些给定元数据的对象。例如,考虑到的估计器 sample_weight 在其 fit 方法是的消费者 sample_weight .

一个对象可能既是路由器又是消费者。例如,元估计器可能会考虑 sample_weight 在某些计算中,但它也可能将其路由到基础估计器。

首先,为脚本的其余部分导入一些数据和一些随机数据。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from pprint import pprint

import numpy as np

from sklearn import set_config
from sklearn.base import (
    BaseEstimator,
    ClassifierMixin,
    MetaEstimatorMixin,
    RegressorMixin,
    TransformerMixin,
    clone,
)
from sklearn.linear_model import LinearRegression
from sklearn.utils import metadata_routing
from sklearn.utils.metadata_routing import (
    MetadataRouter,
    MethodMapping,
    get_routing_for_object,
    process_routing,
)
from sklearn.utils.validation import check_is_fitted

n_samples, n_features = 100, 4
rng = np.random.RandomState(42)
X = rng.rand(n_samples, n_features)
y = rng.randint(0, 2, size=n_samples)
my_groups = rng.randint(0, 10, size=n_samples)
my_weights = rng.rand(n_samples)
my_other_weights = rng.rand(n_samples)

元数据路由仅在显式启用时才可用:

set_config(enable_metadata_routing=True)

此实用函数是一个哑函数,用于检查是否传递了元数据:

def check_metadata(obj, **kwargs):
    for key, value in kwargs.items():
        if value is not None:
            print(
                f"Received {key} of length = {len(value)} in {obj.__class__.__name__}."
            )
        else:
            print(f"{key} is None in {obj.__class__.__name__}.")

一个实用函数,可以很好地打印对象的路由信息:

def print_routing(obj):
    pprint(obj.get_metadata_routing()._serialize())

消耗估算#

在这里,我们演示了估计器如何公开所需的API,以支持作为消费者的元数据路由。想象一个简单的分类器接受 sample_weight 作为其元数据 fitgroups 在其 predict 方法:

class ExampleClassifier(ClassifierMixin, BaseEstimator):
    def fit(self, X, y, sample_weight=None):
        check_metadata(self, sample_weight=sample_weight)
        # all classifiers need to expose a classes_ attribute once they're fit.
        self.classes_ = np.array([0, 1])
        return self

    def predict(self, X, groups=None):
        check_metadata(self, groups=groups)
        # return a constant value of 1, not a very smart classifier!
        return np.ones(len(X))

上述估计器现在拥有了消费元数据所需的一切。这是通过一些魔法完成的 BaseEstimator .上面的课程现在公开了三种方法: set_fit_request , set_predict_request ,而且 get_metadata_routing .还有一个 set_score_requestsample_weight 它存在于 ClassifierMixin 实现了一个 score 接受方法 sample_weight .这同样适用于继承自 RegressorMixin .

默认情况下,不请求元数据,我们可以看到:

print_routing(ExampleClassifier())
{'fit': {'sample_weight': None},
 'predict': {'groups': None},
 'score': {'sample_weight': None}}

上述输出意味着 sample_weightgroups 未要求 ExampleClassifier ,如果为路由器提供了这些元数据,则应该会引发错误,因为用户没有明确设置是否需要这些元数据。也是如此 sample_weightscore 方法,它继承自 ClassifierMixin .为了显式地设置这些元数据的请求值,我们可以使用以下方法:

est = (
    ExampleClassifier()
    .set_fit_request(sample_weight=False)
    .set_predict_request(groups=True)
    .set_score_request(sample_weight=False)
)
print_routing(est)
{'fit': {'sample_weight': False},
 'predict': {'groups': True},
 'score': {'sample_weight': False}}

备注

请注意,只要元估计器中不使用上述估计器,用户就不需要设置对元数据的任何请求,并且设置的值将被忽略,因为消费者不会验证或路由给定的元数据。上述估计器的简单使用将按预期工作。

est = ExampleClassifier()
est.fit(X, y, sample_weight=my_weights)
est.predict(X[:3, :], groups=my_groups)
Received sample_weight of length = 100 in ExampleClassifier.
Received groups of length = 100 in ExampleClassifier.

array([1., 1., 1.])

路由元估计器#

现在,我们展示如何将元估计器设计为路由器。作为一个简化的示例,这是一个元估计器,除了路由元数据之外,它不会做太多事情。

class MetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
    def __init__(self, estimator):
        self.estimator = estimator

    def get_metadata_routing(self):
        # This method defines the routing for this meta-estimator.
        # In order to do so, a `MetadataRouter` instance is created, and the
        # routing is added to it. More explanations follow below.
        router = MetadataRouter(owner=self.__class__.__name__).add(
            estimator=self.estimator,
            method_mapping=MethodMapping()
            .add(caller="fit", callee="fit")
            .add(caller="predict", callee="predict")
            .add(caller="score", callee="score"),
        )
        return router

    def fit(self, X, y, **fit_params):
        # `get_routing_for_object` returns a copy of the `MetadataRouter`
        # constructed by the above `get_metadata_routing` method, that is
        # internally called.
        request_router = get_routing_for_object(self)
        # Meta-estimators are responsible for validating the given metadata.
        # `method` refers to the parent's method, i.e. `fit` in this example.
        request_router.validate_metadata(params=fit_params, method="fit")
        # `MetadataRouter.route_params` maps the given metadata to the metadata
        # required by the underlying estimator based on the routing information
        # defined by the MetadataRouter. The output of type `Bunch` has a key
        # for each consuming object and those hold keys for their consuming
        # methods, which then contain key for the metadata which should be
        # routed to them.
        routed_params = request_router.route_params(params=fit_params, caller="fit")

        # A sub-estimator is fitted and its classes are attributed to the
        # meta-estimator.
        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
        self.classes_ = self.estimator_.classes_
        return self

    def predict(self, X, **predict_params):
        check_is_fitted(self)
        # As in `fit`, we get a copy of the object's MetadataRouter,
        request_router = get_routing_for_object(self)
        # then we validate the given metadata,
        request_router.validate_metadata(params=predict_params, method="predict")
        # and then prepare the input to the underlying `predict` method.
        routed_params = request_router.route_params(
            params=predict_params, caller="predict"
        )
        return self.estimator_.predict(X, **routed_params.estimator.predict)

让我们分解上面代码的不同部分。

一是 get_routing_for_object 采用我们的元估计器 (self )并返回 MetadataRouter 或者在一种 MetadataRequest 如果对象是消费者,则根据估计器的输出 get_metadata_routing

Then in each method, we use the route_params method to construct a dictionary of the form {"object_name": {"method_name": {"metadata": value}}} to pass to the underlying estimator's method. The object_name (estimator in the above routed_params.estimator.fit example) is the same as the one added in the get_metadata_routing. validate_metadata makes sure all given metadata are requested to avoid silent bugs.

接下来,我们说明不同的行为,特别是所提出的错误类型。

meta_est = MetaClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight=True)
)
meta_est.fit(X, y, sample_weight=my_weights)
Received sample_weight of length = 100 in ExampleClassifier.
MetaClassifier(estimator=ExampleClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


请注意,上面的例子正在调用我们的实用函数 check_metadata() 经由 ExampleClassifier .它检查 sample_weight 正确传递给它。如果不是,就像下面的例子中那样,它会打印该内容 sample_weightNone :

meta_est.fit(X, y)
sample_weight is None in ExampleClassifier.
MetaClassifier(estimator=ExampleClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


如果我们传递未知的元数据,则会引发错误:

try:
    meta_est.fit(X, y, test=my_weights)
except TypeError as e:
    print(e)
MetaClassifier.fit got unexpected argument(s) {'test'}, which are not routed to any object.

如果我们传递未明确请求的元数据:

try:
    meta_est.fit(X, y, sample_weight=my_weights).predict(X, groups=my_groups)
except ValueError as e:
    print(e)
Received sample_weight of length = 100 in ExampleClassifier.
[groups] are passed but are not explicitly set as requested or not requested for ExampleClassifier.predict, which is used within MetaClassifier.predict. Call `ExampleClassifier.set_predict_request({metadata}=True/False)` for each metadata you want to request/ignore.

另外,如果我们显式地将其设置为未请求,但已提供:

meta_est = MetaClassifier(
    estimator=ExampleClassifier()
    .set_fit_request(sample_weight=True)
    .set_predict_request(groups=False)
)
try:
    meta_est.fit(X, y, sample_weight=my_weights).predict(X[:3, :], groups=my_groups)
except TypeError as e:
    print(e)
Received sample_weight of length = 100 in ExampleClassifier.
MetaClassifier.predict got unexpected argument(s) {'groups'}, which are not routed to any object.

另一个概念是 aliased metadata .这是当估计器请求变量名称与默认变量名称不同的元数据时。例如,在管道中有两个估计器的设置中,可以请求 sample_weight1 和其他 sample_weight2 .请注意,这不会改变估计器的期望,它只是告诉元估计器如何将提供的元数据映射到所需的内容。这是一个例子,我们通过 aliased_sample_weight 对于元估计器来说,但元估计器明白 aliased_sample_weight 的别名 sample_weight ,并通过它作为 sample_weight 到基本估计器:

meta_est = MetaClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
)
meta_est.fit(X, y, aliased_sample_weight=my_weights)
Received sample_weight of length = 100 in ExampleClassifier.
MetaClassifier(estimator=ExampleClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


通过 sample_weight 这里将失败,因为它是用别名请求的,并且 sample_weight 不需要使用该名称:

try:
    meta_est.fit(X, y, sample_weight=my_weights)
except TypeError as e:
    print(e)
MetaClassifier.fit got unexpected argument(s) {'sample_weight'}, which are not routed to any object.

这使我们得出 get_metadata_routing . scikit-learn中路由的工作方式是消费者请求他们需要的内容,然后路由器将其传递。此外,路由器暴露了它自己的要求,以便可以在另一个路由器内使用,例如网格搜索对象内的管道。的输出 get_metadata_routing 这是一个的字典表示 MetadataRouter ,包括所有嵌套对象请求的元数据的完整树及其相应的方法路由,即子估计器的哪种方法用于元估计器的哪种方法:

print_routing(meta_est)
{'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': 'aliased_sample_weight'},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}

如您所见,方法请求的唯一元数据 fit"sample_weight""aliased_sample_weight" 作为别名。的 ~utils.metadata_routing.MetadataRouter 类使我们能够轻松创建路由对象,该对象将创建我们所需的输出 get_metadata_routing .

为了了解别名如何在元估计器中工作,请想象我们的元估计器位于另一个估计器中:

meta_meta_est = MetaClassifier(estimator=meta_est).fit(
    X, y, aliased_sample_weight=my_weights
)
Received sample_weight of length = 100 in ExampleClassifier.

在上面的例子中,这就是 fit 方法 meta_meta_est 将调用其子估计器的 fit 方法::

# user feeds `my_weights` as `aliased_sample_weight` into `meta_meta_est`:
meta_meta_est.fit(X, y, aliased_sample_weight=my_weights):
    ...

    # the first sub-estimator (`meta_est`) expects `aliased_sample_weight`
    self.estimator_.fit(X, y, aliased_sample_weight=aliased_sample_weight):
        ...

        # the second sub-estimator (`est`) expects `sample_weight`
        self.estimator_.fit(X, y, sample_weight=aliased_sample_weight):
            ...

使用和路由元估计器#

对于稍微复杂一点的示例,考虑一个元估计器,它像以前一样将元数据路由到底层估计器,但它也在自己的方法中使用了一些元数据。这个元估计器同时是消费者和路由器。实现一个与我们以前的非常相似,但做了一些调整。

class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
    def __init__(self, estimator):
        self.estimator = estimator

    def get_metadata_routing(self):
        router = (
            MetadataRouter(owner=self.__class__.__name__)
            # defining metadata routing request values for usage in the meta-estimator
            .add_self_request(self)
            # defining metadata routing request values for usage in the sub-estimator
            .add(
                estimator=self.estimator,
                method_mapping=MethodMapping()
                .add(caller="fit", callee="fit")
                .add(caller="predict", callee="predict")
                .add(caller="score", callee="score"),
            )
        )
        return router

    # Since `sample_weight` is used and consumed here, it should be defined as
    # an explicit argument in the method's signature. All other metadata which
    # are only routed, will be passed as `**fit_params`:
    def fit(self, X, y, sample_weight, **fit_params):
        if self.estimator is None:
            raise ValueError("estimator cannot be None!")

        check_metadata(self, sample_weight=sample_weight)

        # We add `sample_weight` to the `fit_params` dictionary.
        if sample_weight is not None:
            fit_params["sample_weight"] = sample_weight

        request_router = get_routing_for_object(self)
        request_router.validate_metadata(params=fit_params, method="fit")
        routed_params = request_router.route_params(params=fit_params, caller="fit")
        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
        self.classes_ = self.estimator_.classes_
        return self

    def predict(self, X, **predict_params):
        check_is_fitted(self)
        # As in `fit`, we get a copy of the object's MetadataRouter,
        request_router = get_routing_for_object(self)
        # we validate the given metadata,
        request_router.validate_metadata(params=predict_params, method="predict")
        # and then prepare the input to the underlying ``predict`` method.
        routed_params = request_router.route_params(
            params=predict_params, caller="predict"
        )
        return self.estimator_.predict(X, **routed_params.estimator.predict)

上述元估计器与我们之前的元估计器不同的关键部分是接受 sample_weight 中明确 fit 并将其纳入 fit_params .以来 sample_weight 是一个明确的论点,我们可以肯定 set_fit_request(sample_weight=...) 是针对这种方法提出的。元估计器既是消费者,也是路由器。 sample_weight .

get_metadata_routing ,我们添加 self 使用路由 add_self_request 以表明此估计器正在消耗 sample_weight 以及作为路由器;这还增加了 $self_request 输入如下所示的路由信息。现在让我们看看一些例子:

  • 未请求元数据

meta_est = RouterConsumerClassifier(estimator=ExampleClassifier())
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': None},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}
  • sample_weight 由次估算者要求

meta_est = RouterConsumerClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight=True)
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': True},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}
  • sample_weight 由元估计器请求

meta_est = RouterConsumerClassifier(estimator=ExampleClassifier()).set_fit_request(
    sample_weight=True
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': True},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': None},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}

Note the difference in the requested metadata representations above.

  • 我们还可以别名元数据,以将不同的值传递给Meta估计器和次估计器的匹配方法:

meta_est = RouterConsumerClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight="clf_sample_weight"),
).set_fit_request(sample_weight="meta_clf_sample_weight")
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': 'meta_clf_sample_weight'},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': 'clf_sample_weight'},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}

然而, fit 元估计器的别名仅需要子估计器的别名,并将其自己的样本权重处理为 sample_weight ,因为它不会验证和路由自己所需的元数据:

meta_est.fit(X, y, sample_weight=my_weights, clf_sample_weight=my_other_weights)
Received sample_weight of length = 100 in RouterConsumerClassifier.
Received sample_weight of length = 100 in ExampleClassifier.
RouterConsumerClassifier(estimator=ExampleClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


  • 仅对次估计器进行验证:

当我们不希望元估计器使用元数据而次估计器应该使用元数据时,这很有用。

meta_est = RouterConsumerClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': 'aliased_sample_weight'},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}

元估计器无法使用 aliased_sample_weight ,因为它预计它通过了 sample_weight .即使如此,这也适用 set_fit_request(sample_weight=True) 已经设定在它上了。

管路简单#

一个稍微复杂的用例是类似于 Pipeline .这是一个元估计器,它接受Transformer和分类器。当呼叫其时 fit 方法,它应用Transformer的 fittransform 然后对转换后的数据运行分类器。后 predict ,它适用于Transformer的 transform 在用分类器的 predict 方法转换的新数据。

class SimplePipeline(ClassifierMixin, BaseEstimator):
    def __init__(self, transformer, classifier):
        self.transformer = transformer
        self.classifier = classifier

    def get_metadata_routing(self):
        router = (
            MetadataRouter(owner=self.__class__.__name__)
            # We add the routing for the transformer.
            .add(
                transformer=self.transformer,
                method_mapping=MethodMapping()
                # The metadata is routed such that it retraces how
                # `SimplePipeline` internally calls the transformer's `fit` and
                # `transform` methods in its own methods (`fit` and `predict`).
                .add(caller="fit", callee="fit")
                .add(caller="fit", callee="transform")
                .add(caller="predict", callee="transform"),
            )
            # We add the routing for the classifier.
            .add(
                classifier=self.classifier,
                method_mapping=MethodMapping()
                .add(caller="fit", callee="fit")
                .add(caller="predict", callee="predict"),
            )
        )
        return router

    def fit(self, X, y, **fit_params):
        routed_params = process_routing(self, "fit", **fit_params)

        self.transformer_ = clone(self.transformer).fit(
            X, y, **routed_params.transformer.fit
        )
        X_transformed = self.transformer_.transform(
            X, **routed_params.transformer.transform
        )

        self.classifier_ = clone(self.classifier).fit(
            X_transformed, y, **routed_params.classifier.fit
        )
        return self

    def predict(self, X, **predict_params):
        routed_params = process_routing(self, "predict", **predict_params)

        X_transformed = self.transformer_.transform(
            X, **routed_params.transformer.transform
        )
        return self.classifier_.predict(
            X_transformed, **routed_params.classifier.predict
        )

注意的使用 MethodMapping 声明子估计器(被调用者)的哪些方法用于Meta估计器(调用者)的哪些方法。如你所见, SimplePipeline 使用Transformer的 transformfit 方法 fit ,及其 transform 法在 predict ,这就是您在管道类的路由结构中看到的实现。

上例与前例的另一个不同之处在于 process_routing ,它处理输入参数,进行所需的验证,并返回 routed_params 它是我们在之前的示例中创建的。这减少了开发人员需要在每个元估计器的方法中编写的样板代码。强烈建议开发人员使用此功能,除非有充分的理由反对。

为了测试上述管道,让我们添加一个示例Transformer。

class ExampleTransformer(TransformerMixin, BaseEstimator):
    def fit(self, X, y, sample_weight=None):
        check_metadata(self, sample_weight=sample_weight)
        return self

    def transform(self, X, groups=None):
        check_metadata(self, groups=groups)
        return X

    def fit_transform(self, X, y, sample_weight=None, groups=None):
        return self.fit(X, y, sample_weight).transform(X, groups)

请注意,在上面的例子中,我们实现了 fit_transform 其中呼吁 fittransform 与适当的元数据。只有在以下情况下才需要此功能: transform 接受元数据,因为默认 fit_transform 实施 TransformerMixin 不会将元数据传递给 transform .

现在我们可以测试我们的管道,看看元数据是否正确传递。此示例使用我们的 SimplePipeline ,我们的 ExampleTransformer ,我们的 RouterConsumerClassifier 它使用我们的 ExampleClassifier .

pipe = SimplePipeline(
    transformer=ExampleTransformer()
    # we set transformer's fit to receive sample_weight
    .set_fit_request(sample_weight=True)
    # we set transformer's transform to receive groups
    .set_transform_request(groups=True),
    classifier=RouterConsumerClassifier(
        estimator=ExampleClassifier()
        # we want this sub-estimator to receive sample_weight in fit
        .set_fit_request(sample_weight=True)
        # but not groups in predict
        .set_predict_request(groups=False),
    )
    # and we want the meta-estimator to receive sample_weight as well
    .set_fit_request(sample_weight=True),
)
pipe.fit(X, y, sample_weight=my_weights, groups=my_groups).predict(
    X[:3], groups=my_groups
)
Received sample_weight of length = 100 in ExampleTransformer.
Received groups of length = 100 in ExampleTransformer.
Received sample_weight of length = 100 in RouterConsumerClassifier.
Received sample_weight of length = 100 in ExampleClassifier.
Received groups of length = 100 in ExampleTransformer.
groups is None in ExampleClassifier.

array([1., 1., 1.])

弃用/默认值更改#

在本节中,我们将展示如何处理路由器也成为消费者的情况,特别是当它消费与其子估计器相同的元数据时,或者消费者开始消费旧版本中没有的元数据时。在这种情况下,应该发出一段时间的警告,让用户知道该行为与以前的版本相比已更改。

class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
    def __init__(self, estimator):
        self.estimator = estimator

    def fit(self, X, y, **fit_params):
        routed_params = process_routing(self, "fit", **fit_params)
        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)

    def get_metadata_routing(self):
        router = MetadataRouter(owner=self.__class__.__name__).add(
            estimator=self.estimator,
            method_mapping=MethodMapping().add(caller="fit", callee="fit"),
        )
        return router

如上所述,如果 my_weights 不应该通过作为 sample_weightMetaRegressor :

reg = MetaRegressor(estimator=LinearRegression().set_fit_request(sample_weight=True))
reg.fit(X, y, sample_weight=my_weights)

现在想象我们进一步开发 MetaRegressor 而现在, consumes sample_weight :

class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
    # show warning to remind user to explicitly set the value with
    # `.set_{method}_request(sample_weight={boolean})`
    __metadata_request__fit = {"sample_weight": metadata_routing.WARN}

    def __init__(self, estimator):
        self.estimator = estimator

    def fit(self, X, y, sample_weight=None, **fit_params):
        routed_params = process_routing(
            self, "fit", sample_weight=sample_weight, **fit_params
        )
        check_metadata(self, sample_weight=sample_weight)
        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)

    def get_metadata_routing(self):
        router = (
            MetadataRouter(owner=self.__class__.__name__)
            .add_self_request(self)
            .add(
                estimator=self.estimator,
                method_mapping=MethodMapping().add(caller="fit", callee="fit"),
            )
        )
        return router

上述实现几乎与 MetaRegressor ,并且由于中定义的默认请求值 __metadata_request__fit 安装时会发出警告。

with warnings.catch_warnings(record=True) as record:
    WeightedMetaRegressor(
        estimator=LinearRegression().set_fit_request(sample_weight=False)
    ).fit(X, y, sample_weight=my_weights)
for w in record:
    print(w.message)
Received sample_weight of length = 100 in WeightedMetaRegressor.
Support for sample_weight has recently been added to this class. To maintain backward compatibility, it is ignored now. Using `set_fit_request(sample_weight={True, False})` on this method of the class, you can set the request value to False to silence this warning, or to True to consume and use the metadata.

当估计器消耗以前没有消耗的元数据时,可以使用以下模式来警告用户。

class ExampleRegressor(RegressorMixin, BaseEstimator):
    __metadata_request__fit = {"sample_weight": metadata_routing.WARN}

    def fit(self, X, y, sample_weight=None):
        check_metadata(self, sample_weight=sample_weight)
        return self

    def predict(self, X):
        return np.zeros(shape=(len(X)))


with warnings.catch_warnings(record=True) as record:
    MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights)
for w in record:
    print(w.message)
sample_weight is None in ExampleRegressor.
Support for sample_weight has recently been added to this class. To maintain backward compatibility, it is ignored now. Using `set_fit_request(sample_weight={True, False})` on this method of the class, you can set the request value to False to silence this warning, or to True to consume and use the metadata.

最后,我们禁用元数据路由的配置标志:

set_config(enable_metadata_routing=False)

第三方开发和scikit-learn依赖#

如上所示,信息在类之间使用 MetadataRequestMetadataRouter .强烈建议不这样做,但如果您严格希望拥有一个scikit-learn兼容的估计器,而不依赖scikit-learn包,则可以提供与元数据路由相关的工具。如果满足以下所有条件,您根本不需要修改代码:

  • 您的估计器继承自 BaseEstimator

  • 估计器方法消耗的参数,例如 fit ,在方法的签名中显式定义,而不是 *args*kwargs .

  • 您的估计器不会将任何元数据路由到底层对象,即它不是 router .

Total running time of the script: (0分0.049秒)

相关实例

归纳集群

Inductive Clustering

__sklearn_is_fitted__ 作为开发人员API

__sklearn_is_fitted__ as Developer API

离散数据结构上的高斯过程

Gaussian processes on discrete data structures

使用预先计算的格拉姆矩阵和加权样本来匹配弹性网络

Fitting an Elastic Net with a precomputed Gram Matrix and Weighted Samples

Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io> _