开发scikit-learn估计器#
无论您是提议将估计器包含在scikit-learn中,还是开发与scikit-learn兼容的单独包,还是为自己的项目实现自定义组件,本章详细介绍了如何开发与scikit-learn管道和模型选择工具安全交互的对象。
本节详细介绍了您应该为scikit-learn兼容估计器使用和实现的公共API。在scikit-learn内部,我们实验并使用一些私人工具,我们的目标始终是在它们足够稳定后将其公开,以便您也可以在自己的项目中使用它们。
scikit-learn对象的API#
有两种主要类型的估计器。您可以将第一组视为简单的估计器,它由大多数估计器组成,例如 LogisticRegression
或 RandomForestClassifier
.第二类是元估计量,它们是包裹其他估计量的估计量。 Pipeline
和 GridSearchCV
是元估计量的两个示例。
在这里,我们从一些词汇表术语开始,然后说明如何实现自己的估计器。
scikit-learn API的元素在 常用术语和API元素词汇表 .
不同对象#
scikit-learn中的主要对象是(一个类可以实现多个接口):
- 估计器:
基对象实现
fit
从数据中学习的方法,要么::estimator = estimator.fit(data, targets)
或者::
estimator = estimator.fit(data)
- 预测器:
对于监督学习或一些无监督问题,实现::
prediction = predictor.predict(data)
分类算法通常还提供一种量化预测确定性的方法,可以使用
decision_function
或predict_proba
probability = predictor.predict_proba(data)
- Transformer:
用于以监督或无监督的方式修改数据(例如,通过添加、更改或删除列,但不是通过添加或删除行)。实施::
new_data = transformer.transform(data)
当配合和变形一起执行比单独执行更有效时,请实现::
new_data = transformer.fit_transform(data)
- 模型:
一个可以提供 goodness of fit 测量或可能性看不见的数据,实现(越高越好)::
score = model.score(data)
估计#
API有一个主要对象:估计器。估计器是一个适合基于一些训练数据的模型的对象,并且能够推断新数据的一些属性。例如,它可以是分类器或回归器。所有估计器都实施匹配方法::
estimator.fit(X, y)
在估计器实现的所有方法中, fit
通常是您想要自己实现的。其它方法如 set_params
, get_params
等实施于 BaseEstimator
你应该继承的。您可能需要从更多的mixin中继承,我们将在后面解释。
实例化#
这涉及对象的创建。对象的 __init__
方法可能会接受常数作为确定估计器行为的参数(例如 alpha
常数 SGDClassifier
).然而,它不应该将实际的训练数据作为论据,因为这是留给 fit()
方法:
clf2 = SGDClassifier(alpha=2.3)
clf3 = SGDClassifier([[1, 2], [2, 3]], [-1, 1]) # WRONG!
理想情况下,接受的论点 __init__
都应该是具有默认值的关键字参数。换句话说,用户应该能够实例化估计器,而无需向其传递任何参数。在某些情况下,如果参数没有合理的默认值,则可以将其保留为没有默认值。在scikit-learn本身中,只有在某些元估计器中,次估计器参数是必需的参数。
大多数参数对应于描述模型或估计器试图解决的优化问题的超参数。其他参数可能定义估计器的行为方式,例如定义存储一些数据的缓存位置。这些初始参数(或参数)总是被估计器记住。另请注意,它们不应记录在“属性”部分下,而应记录在该估计器的“参数”部分下。
In addition, every keyword argument accepted by __init__
should
correspond to an attribute on the instance. Scikit-learn relies on this to
find the relevant attributes to set on an estimator when doing model selection.
总而言之, __init__
应该看起来像::
def __init__(self, param1=1, param2=2):
self.param1 = param1
self.param2 = param2
应该没有逻辑,甚至没有输入验证,参数不应该被更改;这也意味着理想情况下它们不应该是可变对象,如列表或字典。如果它们是可变的,则应在修改之前复制它们。相应的逻辑应该放在使用参数的地方,通常在 fit
.以下是错误的::
def __init__(self, param1=1, param2=2, param3=3):
# WRONG: parameters should not be modified
if param1 > 1:
param2 += 1
self.param1 = param1
# WRONG: the object's attributes should have exactly the name of
# the argument in the constructor
self.param3 = param2
推迟验证的原因是,如果 __init__
包括输入验证,则必须在 set_params
,用于以下算法 GridSearchCV
.
Also it is expected that parameters with trailing _
are not to be set
inside the __init__
method. More details on attributes that are not init
arguments come shortly.
拟合#
接下来你可能要做的是估计模型中的一些参数。这是在 fit()
方法,这就是训练发生的地方。例如,这是您学习或估计线性模型系数的计算。
的 fit()
该方法将训练数据作为参数,在无监督学习的情况下可以是一个数组,在监督学习的情况下可以是两个数组。训练数据附带的其他元数据,例如 sample_weight
,也可以传递给 fit
作为关键字参数。
请注意,该模型是用以下方式进行的 X
和 y
,但该对象没有引用 X
和 y
.然而,也有一些例外,例如在预计算内核的情况下,必须存储这些数据以供预测方法使用。
参数 |
|
---|---|
X |
形状类似阵列(n_samples,n_features) |
y |
形状类似阵列(n_samples,) |
kwargs |
可选数据相关参数 |
样本数量,即 X.shape[0]
宜同 y.shape[0]
.如果不满足此要求,则出现类型的异常 ValueError
应该提出。
y
在无监督学习的情况下可能会被忽视。然而,为了能够将估计器用作可以混合监督和无监督变换器的管道的一部分,即使是无监督估计器也需要接受 y=None
第二个位置的关键字参数被估计器忽略。出于同样的原因, fit_predict
, fit_transform
, score
和 partial_fit
方法需要接受 y
如果它们得到实施,第二个论点。
该方法应该返回对象 (self
).此模式对于能够在IPython会话中实现快速一行代码非常有用,例如::
y_predicted = SGDClassifier(alpha=10).fit(X_train, y_train).predict(X_test)
根据算法的性质, fit
有时也可以接受额外的关键字参数。但是,任何可以在访问数据之前赋值的参数都应该是 __init__
关键字参数。理想情况下, fit parameters should be restricted to directly data dependent variables .例如,根据数据矩阵预先计算的Gram矩阵或亲和矩阵 X
取决于数据。容忍停止标准 tol
不直接取决于数据(尽管根据某个评分函数的最佳值可能是)。
当 fit
被称为,之前的任何呼叫 fit
应该被忽视。一般来说,打电话 estimator.fit(X1)
然后 estimator.fit(X2)
应该和只打电话一样 estimator.fit(X2)
.然而,在实践中,当 fit
取决于一些随机过程,请参阅 random_state .此规则的另一个例外是当超参数 warm_start
设置为 True
对于支持它的估计者来说。 warm_start=True
意味着重新使用估计器的可训练参数的先前状态,而不是使用默认的初始化策略。
估计属性#
根据scikit-learn惯例,您想要作为公共属性向用户公开并且已从数据中估计或学习的属性必须始终具有以尾随强调结尾的名称,例如,某些回归估计器的系数将存储在 coef_
属性后 fit
已被呼叫。同样,您在此过程中学习并且希望存储但不向用户公开的属性应该有一个前面的强调线,例如 _intermediate_coefs
.您需要将第一组(后面有一个星号)记录为“属性”,而无需将第二组(后面有一个星号)记录为“属性”。
当您调用时,预计将覆盖估计的属性 fit
第二次。
通用属性#
期望表格输入的估计器应该设置 n_features_in_
attribute at fit
time to indicate the number of features that the estimator expects for subsequent calls to predict or transform. See SLEP010 有关详细信息
类似地,如果估计者被赋予大熊猫或极地等参数,他们应该设置 feature_names_in_
attribute to indicate the features names of the input data, detailed in SLEP007 .使用 validate_data
会自动为您设置这些属性。
滚动您自己的估计器#
如果您想实现与scikit-learn兼容的新估计器,除了上面概述的scikit-learn API之外,您还应该注意scikit-learn的几个内部功能。您可以通过运行以下命令来检查您的估计器是否遵守scikit-learn界面和标准 check_estimator
在一个例子上。的 parametrize_with_checks
还可以使用pytest装饰器(请参阅其文档字符串了解详细信息和可能的交互 pytest
):
>>> from sklearn.utils.estimator_checks import check_estimator
>>> from sklearn.tree import DecisionTreeClassifier
>>> check_estimator(DecisionTreeClassifier()) # passes
[...]
制作与scikit-learn估计器界面兼容的类的主要动机可能是您想将其与模型评估和选择工具(例如 GridSearchCV
和 Pipeline
.
在详细说明下面所需的界面之前,我们描述了两种更轻松实现正确界面的方法。
您可以检查上述估计器是否通过了所有常见检查::
>>> from sklearn.utils.estimator_checks import check_estimator
>>> check_estimator(TemplateClassifier()) # passes
get_params和set_params#
所有scikit-learn估计器都具有 get_params
和 set_params
功能协调发展的
的 get_params
函数不接受任何参数,并返回 __init__
估计器的参数及其值。
它需要一个关键词参数, deep
,它接收布尔值,该值确定该方法是否应该返回子估计量的参数(仅与元估计量相关)。的默认值 deep
是 True
.例如,考虑以下估计器::
>>> from sklearn.base import BaseEstimator
>>> from sklearn.linear_model import LogisticRegression
>>> class MyEstimator(BaseEstimator):
... def __init__(self, subestimator=None, my_extra_param="random"):
... self.subestimator = subestimator
... self.my_extra_param = my_extra_param
参数 deep
控制是否 subestimator
应该报告。因此当 deep=True
,输出将是::
>>> my_estimator = MyEstimator(subestimator=LogisticRegression())
>>> for param, value in my_estimator.get_params(deep=True).items():
... print(f"{param} -> {value}")
my_extra_param -> random
subestimator__C -> 1.0
subestimator__class_weight -> None
subestimator__dual -> False
subestimator__fit_intercept -> True
subestimator__intercept_scaling -> 1
subestimator__l1_ratio -> None
subestimator__max_iter -> 100
subestimator__multi_class -> deprecated
subestimator__n_jobs -> None
subestimator__penalty -> l2
subestimator__random_state -> None
subestimator__solver -> lbfgs
subestimator__tol -> 0.0001
subestimator__verbose -> 0
subestimator__warm_start -> False
subestimator -> LogisticRegression()
如果元估计器采用多个子估计器,那么这些子估计器通常都有名称(例如,在 Pipeline
对象),在这种情况下,密钥应该成为 <name>__C
, <name>__class_weight
等。
当 deep=False
,输出将是::
>>> for param, value in my_estimator.get_params(deep=False).items():
... print(f"{param} -> {value}")
my_extra_param -> random
subestimator -> LogisticRegression()
另一方面, set_params
采用的参数 __init__
作为关键字参数,将它们解压为形式的dict 'parameter': value
并使用此指令设置估计器的参数。它返回估计器本身。
的 set_params
例如,函数用于在网格搜索期间设置参数。
克隆#
As already mentioned that when constructor arguments are mutable, they should be
copied before modifying them. This also applies to constructor arguments which are
estimators. That's why meta-estimators such as GridSearchCV
create a copy of the given estimator before modifying it.
然而,在scikit-learn中,当我们复制估计器时,我们会得到一个不适合的估计器,其中仅复制构造器参数(有一些例外,例如与某些内部机器相关的属性,例如元数据路由)。
负责此行为的功能是 clone
.
估计者可以自定义的行为 base.clone
通过重写 base.BaseEstimator.__sklearn_clone__
法 __sklearn_clone__
must return an instance of the estimator. _ _sklearn_clone__'当估计器需要保持某个状态时, :func:`base.clone
被称为估计器。例如, FrozenEstimator
利用这个。
估计类型#
在简单估计器(与元估计器相反)中,最常见的类型是转换器、分类器、回归器和集群算法。
Transformers 继承自 TransformerMixin
,并实施 transform
法这些是估计器,它们接受输入并以某种方式对其进行转换。请注意,它们永远不应更改输入样本的数量以及 transform
应该以相同的给定顺序与其输入样本相对应。
Regressors 继承自 RegressorMixin
,并实施 predict
法他们应该接受数字 y
在他们 fit
法回归者使用 r2_score
默认情况下, score
法
Classifiers 继承自 ClassifierMixin
.如果适用,分类器可以实现 decision_function
返回原始决策值,基于该值 predict
可以做出决定。如果支持计算概率,分类器还可以实现 predict_proba
和 predict_log_proba
.
Classifiers should accept y
(target) arguments to fit
that are sequences (lists,
arrays) of either strings or integers. They should not assume that the class labels are
a contiguous range of integers; instead, they should store a list of classes in a
classes_
attribute or property. The order of class labels in this attribute should
match the order in which predict_proba
, predict_log_proba
and
decision_function
return their values. The easiest way to achieve this is to put:
self.classes_, y = np.unique(y, return_inverse=True)
在 fit
. 这将返回一个新的 y
它包含范围[0, n_classes
).
分类器的 predict
方法应返回包含来自的类标签的数组 classes_
.在实现 decision_function
,这可以通过::
def predict(self, X):
D = self.decision_function(X)
return self.classes_[np.argmax(D, axis=1)]
的 multiclass
模块包含用于处理多类和多标签问题的有用函数。
Clustering algorithms inherit from ClusterMixin
. Ideally, they should
accept a y
parameter in their fit
method, but it should be ignored. Clustering
algorithms should set a labels_
attribute, storing the labels assigned to each
sample. If applicable, they can also implement a predict
method, returning the
labels assigned to newly given samples.
如果需要检查给定估计器的类型,例如在元估计器中,则可以检查给定对象是否实现 transform
转换器的方法,否则使用助手函数,例如 is_classifier
或 is_regressor
.
开发人员API set_output
#
与 SLEP018 ,scikit-learn介绍了 set_output
用于配置transformer以输出pandas DataFrames的API。的 set_output
如果Transformer定义了API,则自动定义 get_feature_names_out 和子类 base.TransformerMixin
. get_feature_names_out 用于获取pandas输出的列名。
base.OneToOneFeatureMixin
和 base.ClassNamePrefixFeaturesOutMixin
对于定义有用的混合 get_feature_names_out . base.OneToOneFeatureMixin
当Transformer在输入要素和输出要素之间具有一一对应关系时, StandardScaler
. base.ClassNamePrefixFeaturesOutMixin
当Transformer需要生成自己的要素名称时(例如) PCA
.
您可以选择退出 set_output
API通过设置 auto_wrap_output_keys=None
定义自定义子类别时::
class MyTransformer(TransformerMixin, BaseEstimator, auto_wrap_output_keys=None):
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
return X
def get_feature_names_out(self, input_features=None):
...
的默认值 auto_wrap_output_keys
is ("transform",)
, which automatically wraps fit_transform
and transform
. The TransformerMixin
uses the _ _init_subclass__'消费机制 `auto_wrap_output_keys
and pass all other keyword arguments to its super class. Super classes' _ _init_subclass__'应该 **not** 取决于 `auto_wrap_output_keys
.
对于返回多个阵列的变压器 transform
,自动换行将只换行第一个数组,而不会改变其他数组。
看到 介绍 set_output API 了解如何使用API的示例。
开发人员API check_is_fitted
#
默认情况下 check_is_fitted
检查实例中是否有任何带有尾部强调的属性,例如 coef_
. An estimator can change the behavior by implementing a _ _sklearn_is_fitted__'方法不接受输入并返回布尔值。如果存在这种方法, :func:`~sklearn.utils.validation.check_is_fitted
简单地返回其输出。
看到 __sklearn_is_fitted__ 作为开发人员API 了解如何使用API的示例。
用于HTML表示的开发人员API#
警告
HTML表示API是实验性的,API可能会发生变化。
继承自 BaseEstimator
在交互式编程环境(例如Deliveryter笔记本)中显示其自身的HTML表示。例如,我们可以显示这个HTML图表::
from sklearn.base import BaseEstimator
BaseEstimator()
原始HTML表示是通过调用该函数获得的 estimator_html_repr
在估计器实例上。
要自定义链接到评估人员文档的URL(例如,当单击“?“图标),重写 _doc_link_module
and _ Doc_link_templatform '属性。此外,您还可以提供 `_doc_link_url_param_generator
method. Set _ doc_link_module`指向包含估算器的(顶级)模块的名称。如果该值与顶级模块名称不匹配,则HTML表示将不包含指向文档的链接。对于scikit-learn估计器,这被设置为 `"sklearn"
.
的 _doc_link_template
is used to construct the final URL. By default, it can contain two variables: estimator_module
(the full name of the module containing the estimator) and estimator_name
(the class name of the estimator). If you need more variables you should implement the _ Doc_link_url_param_generator '方法,该方法应该返回变量及其值的字典。此字典将用于呈现 `_doc_link_template
.
编码指南#
以下是有关如何编写新代码以包含在scikit-learn中的一些指南,这些指南可能适合在外部项目中采用。当然,也有特殊情况,这些规则也会有例外。然而,在提交新代码时遵循这些规则可以使审查更容易,以便可以在更短的时间内集成新代码。
格式统一的代码使共享代码所有权变得更容易。scikit-learn项目试图严格遵循中详细介绍的官方Python指南 PEP8 详细介绍了代码应该如何格式化和缩进。请阅读并遵循它。
此外,我们还添加了以下准则:
使用强调线分隔非Class名称中的单词:
n_samples
而不是nsamples
.Avoid multiple statements on one line. Prefer a line return after a control flow statement (
if
/for
).对于scikit-learn内部的引用使用相对导入。
单元测试是前一条规则的例外;它们应该使用绝对导入,就像客户端代码一样。必然结果是,如果
sklearn.foo
输出在中实现的类或函数sklearn.foo.bar.baz
,测试应该从sklearn.foo
.Please don't use
import *
in any case. It is considered harmful by the official Python recommendations. It makes the code harder to read as the origin of symbols is no longer explicitly referenced, but most important, it prevents using a static analysis tool like pyflakes to automatically find bugs in scikit-learn.使用 numpy docstring standard 在你所有的文档中。
我们喜欢的代码的一个很好的例子可以在 here .
输入验证#
模块 sklearn.utils
包含用于进行输入验证和转换的各种功能。有时, np.asarray
足以验证;做 not 使用 np.asanyarray
或 np.atleast_2d
,因为那些让NumPy的 np.matrix
通过,它具有不同的API(例如, *
意味着点积 np.matrix
,但阿达玛产品在 np.ndarray
).
其他情况请务必致电 check_array
任何传递给scikit-learn API函数的类似数组的参数。要使用的确切参数主要取决于 scipy.sparse
必须接受矩阵。
有关更多信息,请参阅 开发人员的实用程序 页.
随机数#
如果您的代码依赖于随机数生成器,请不要使用 numpy.random.random()
或类似的例行公事。 为了确保错误检查的可重复性,例程应该接受关键字 random_state
并利用它来构建一个 numpy.random.RandomState
object.看到 sklearn.utils.check_random_state
在 开发人员的实用程序 .
以下是使用上述一些准则的简单代码示例::
from sklearn.utils import check_array, check_random_state
def choose_random_sample(X, random_state=0):
"""Choose a random point from X.
Parameters
----------
X : array-like of shape (n_samples, n_features)
An array representing the data.
random_state : int or RandomState instance, default=0
The seed of the pseudo random number generator that selects a
random sample. Pass an int for reproducible output across multiple
function calls.
See :term:`Glossary <random_state>`.
Returns
-------
x : ndarray of shape (n_features,)
A random point selected from X.
"""
X = check_array(X)
random_state = check_random_state(random_state)
i = random_state.randint(X.shape[0])
return X[i]
如果在估计器中使用随机性而不是独立函数,则适用一些额外的指导方针。
首先,估计者应该采取 random_state
其论点 __init__
默认值为 None
.它应该存储该参数的值, unmodified ,在属性中 random_state
. fit
可以叫 check_random_state
以获得实际的随机数生成器。如果出于某种原因,需要随机性 fit
,RNG应存储在属性中 random_state_
.下面的例子应该可以清楚地说明这一点:
class GaussianNoise(BaseEstimator, TransformerMixin):
"""This estimator ignores its input and returns random Gaussian noise.
It also does not adhere to all scikit-learn conventions,
but showcases how to handle randomness.
"""
def __init__(self, n_components=100, random_state=None):
self.random_state = random_state
self.n_components = n_components
# the arguments are ignored anyway, so we make them optional
def fit(self, X=None, y=None):
self.random_state_ = check_random_state(self.random_state)
def transform(self, X):
n_samples = X.shape[0]
return self.random_state_.randn(n_samples, self.n_components)
这种设置的原因是可重复性:当估计器是 fit
两次对相同的数据进行验证,两次都应该产生相同的模型,因此在 fit
,而不是 __init__
.
测试中的数字断言#
当断言连续值数组的准相等时,请使用 sklearn.utils._testing.assert_allclose
.
相对容差是从提供的数组dtypes中自动推断出来的(特别是对于float 32和float 64 dtypes),但您可以通过 rtol
.
When comparing arrays of zero-elements, please do provide a non-zero value for
the absolute tolerance via atol
.
有关详细信息,请参阅 sklearn.utils._testing.assert_allclose
.