12.1. 数组API支持(实验性)#
的 Array API specification defines a standard API for all array manipulation libraries with a NumPy-like API. Scikit-learn vendors pinned copies of array-api-compat 和 array-api-extra .
Scikit-learn对阵列API标准的支持需要环境变量 SCIPY_ARRAY_API
设置为 1
在导入之前 scipy
和 scikit-learn
:
export SCIPY_ARRAY_API=1
请注意,此环境变量仅供临时使用。欲了解更多详细信息,请参阅SciPy的 Array API documentation .
一些scikit-learning估计器主要依赖NumPy(而不是使用Cython)来实现其算法逻辑 fit
, predict
或 transform
方法可以配置为接受任何Array API兼容的输入数据结构,并自动将操作分配到底层命名空间,而不是依赖NumPy。
现阶段,这种支持是 considered experimental 并且必须按照下面的解释显式启用。
备注
目前只有 array-api-strict
, cupy
,而且 PyTorch
已知与scikit-learn的估计器一起工作。
以下视频概述了该标准的设计原则以及它如何促进阵列库之间的互操作性:
Scikit-learn on GPUs with Array API 通过 Thomas Fan PyData NYC 2023
12.1.1. 示例使用#
以下是演示如何使用的示例代码片段 CuPy 运行 LinearDiscriminantAnalysis
在GPU上::
>>> from sklearn.datasets import make_classification
>>> from sklearn import config_context
>>> from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
>>> import cupy
>>> X_np, y_np = make_classification(random_state=0)
>>> X_cu = cupy.asarray(X_np)
>>> y_cu = cupy.asarray(y_np)
>>> X_cu.device
<CUDA Device 0>
>>> with config_context(array_api_dispatch=True):
... lda = LinearDiscriminantAnalysis()
... X_trans = lda.fit_transform(X_cu, y_cu)
>>> X_trans.device
<CUDA Device 0>
模型训练后,数组的匹配属性也将来自与训练数据相同的数组API命名空间。例如,如果CuPy的Array API命名空间用于训练,则匹配的属性将位于图形处理器上。我们提供一个实验性的 _estimator_with_converted_arrays
将估计器属性从Array API传输到ndArray的实用程序::
>>> from sklearn.utils._array_api import _estimator_with_converted_arrays
>>> cupy_to_ndarray = lambda array : array.get()
>>> lda_np = _estimator_with_converted_arrays(lda, cupy_to_ndarray)
>>> X_trans = lda_np.transform(X_np)
>>> type(X_trans)
<class 'numpy.ndarray'>
12.1.1.1. PyTorch支持#
PyTorch张量由设置支持 array_api_dispatch=True
并直接传递张量::
>>> import torch
>>> X_torch = torch.asarray(X_np, device="cuda", dtype=torch.float32)
>>> y_torch = torch.asarray(y_np, device="cuda", dtype=torch.float32)
>>> with config_context(array_api_dispatch=True):
... lda = LinearDiscriminantAnalysis()
... X_trans = lda.fit_transform(X_torch, y_torch)
>>> type(X_trans)
<class 'torch.Tensor'>
>>> X_trans.device.type
'cuda'
12.1.2. Support for Array API
-compatible inputs#
scikit-learn中支持Array API兼容输入的估计器和其他工具。
12.1.2.1. 估计#
decomposition.PCA
(与svd_solver="full"
,svd_solver="randomized"
和power_iteration_normalizer="QR"
)linear_model.Ridge
(与solver="svd"
)discriminant_analysis.LinearDiscriminantAnalysis
(与solver="svd"
)
12.1.2.2. 元估计量#
元估计器接受数组API输入,其条件是基本估计器还可以:
12.1.2.3. Metrics#
sklearn.metrics.cluster.entropy
sklearn.metrics.mean_poisson_deviance
(需要 enabling array API support for SciPy )sklearn.metrics.pairwise.euclidean_distances
(见 关于设备支持的注释 float64 )
12.1.2.4. 工具#
预计覆盖范围将随着时间的推移而增长。请关注专门的 meta-issue on GitHub 来追踪进展
12.1.2.5. 返回值类型和匹配属性#
当调用具有Array API兼容输入的函数或方法时,惯例是返回与输入数据相同的数组容器类型和设备的数组值。
类似地,当估计器与Array API兼容的输入拟合时,拟合的属性将是来自与输入相同的库并存储在相同设备上的数组。的 predict
和 transform
方法随后期望来自与传递给 fit
法
不过请注意,返回纯量值的评分函数返回Python纯量(通常是 float
实例)而不是数组纯量值。
12.1.3. 常见估计器检查#
添加 array_api_support
标签添加到估计器的标签集,以指示其支持数组API。这将启用作为常见测试的一部分的专门检查,以验证在使用vanilla NumPy和Array API输入时估计器的结果是否相同。
要运行这些检查,您需要安装 array-api-strict 在您的测试环境中。这使您可以在没有图形处理器的情况下运行检查。要运行全套检查,您还需要安装 PyTorch , CuPy 并且有一个图形处理器。无法执行或缺少依赖项的检查将自动跳过。因此,使用运行测试非常重要 -v
标记以查看跳过哪些检查:
pip install array-api-strict # and other libraries as needed
pytest -k "array_api" -v
运行scikit-learn测试 array-api-strict
应该有助于揭示与通过使用模拟非中央处理设备处理多个设备输入相关的大多数代码问题。这允许快速迭代开发和调试数组API相关代码。
然而,为了确保完全处理分配在实际图形处理设备上的PyTorch或CuPy输入,有必要针对这些库和硬件运行测试。这可以通过使用 Google Colab 或利用我们的CI基础设施处理拉取请求(出于成本原因由维护人员手动触发)。
12.1.3.1. 有关MPS设备支持的注释#
在macOS上,PyTorch可以使用Metal Performance Shaders(MPS)来访问硬件加速器(例如M1或M2芯片的内部GPU组件)。然而,在撰写本文时,MPS设备对PyTorch的支持还不完整。请参阅以下github issue以了解更多详细信息:
要在PyTorch中启用MPS支持,请设置环境变量 PYTORCH_ENABLE_MPS_FALLBACK=1
在运行测试之前:
PYTORCH_ENABLE_MPS_FALLBACK=1 pytest -k "array_api" -v
在编写所有scikit-learn测试时应该通过,然而,计算速度不一定比使用中央处理设备更好。
12.1.3.2. 关于设备支持的注释 float64
#
scikit-learn中的某些操作将自动对浮点值执行操作 float64
防止溢出并确保正确性的精确性(例如, metrics.pairwise.euclidean_distances
).但是,阵列名称空间和设备的某些组合,例如 PyTorch on MPS
(见 有关MPS设备支持的注释 )不支持 float64
数据类型。在这些情况下,scikit-learn将恢复使用 float32
相反,数据类型。与不使用数组API调度或使用具有 float64
支持.