连续减半迭代#

此示例说明了连续的减半搜索如何 (HalvingGridSearchCVHalvingRandomSearchCV )迭代地从多个候选项中选择最佳参数组合。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import randint

from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.experimental import enable_halving_search_cv  # noqa: F401
from sklearn.model_selection import HalvingRandomSearchCV

我们首先定义参数空间并训练 HalvingRandomSearchCV instance.

rng = np.random.RandomState(0)

X, y = datasets.make_classification(n_samples=400, n_features=12, random_state=rng)

clf = RandomForestClassifier(n_estimators=20, random_state=rng)

param_dist = {
    "max_depth": [3, None],
    "max_features": randint(1, 6),
    "min_samples_split": randint(2, 11),
    "bootstrap": [True, False],
    "criterion": ["gini", "entropy"],
}

rsh = HalvingRandomSearchCV(
    estimator=clf, param_distributions=param_dist, factor=2, random_state=rng
)
rsh.fit(X, y)
HalvingRandomSearchCV(estimator=RandomForestClassifier(n_estimators=20,
                                                       random_state=RandomState(MT19937) at 0x7FA296FC3B40),
                      factor=2,
                      param_distributions={'bootstrap': [True, False],
                                           'criterion': ['gini', 'entropy'],
                                           'max_depth': [3, None],
                                           'max_features': <scipy.stats._distn_infrastructure.rv_discrete_frozen object at 0x7fa26c840a90>,
                                           'min_samples_split': <scipy.stats._distn_infrastructure.rv_discrete_frozen object at 0x7fa26c843550>},
                      random_state=RandomState(MT19937) at 0x7FA296FC3B40)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


我们现在可以使用 cv_results_ 搜索估计器的属性来检查和绘制搜索的演变。

results = pd.DataFrame(rsh.cv_results_)
results["params_str"] = results.params.apply(str)
results.drop_duplicates(subset=("params_str", "iter"), inplace=True)
mean_scores = results.pivot(
    index="iter", columns="params_str", values="mean_test_score"
)
ax = mean_scores.plot(legend=False, alpha=0.6)

labels = [
    f"iter={i}\nn_samples={rsh.n_resources_[i]}\nn_candidates={rsh.n_candidates_[i]}"
    for i in range(rsh.n_iterations_)
]

ax.set_xticks(range(rsh.n_iterations_))
ax.set_xticklabels(labels, rotation=45, multialignment="left")
ax.set_title("Scores of candidates over iterations")
ax.set_ylabel("mean test score", fontsize=15)
ax.set_xlabel("iterations", fontsize=15)
plt.tight_layout()
plt.show()
Scores of candidates over iterations

每次迭代的候选人数和资源量#

在第一次迭代中,使用了少量资源。这里的资源是训练估计器的样本数量。所有候选者都将被评估。

在第二次迭代中,仅评估候选人中最好的一半。分配的资源数量增加了一倍:候选人的样本数量是两倍。

重复此过程,直到最后一次迭代,此时只剩下2个候选人。最佳候选人是在最后一次迭代中得分最高的候选人。

Total running time of the script: (0分3.912秒)

相关实例

网格搜索和连续减半的比较

Comparison between grid search and successive halving

比较随机搜索和网格搜索用于超参数估计

Comparing randomized search and grid search for hyperparameter estimation

scikit-learn 0.24发布亮点

Release Highlights for scikit-learn 0.24

具有交叉验证的网格搜索自定义改装策略

Custom refit strategy of a grid search with cross-validation

Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io> _