备注
Go to the end 下载完整的示例代码。或者通过浏览器中的MysterLite或Binder运行此示例
Gradient Boosting中的提前停止#
Gradient Boosting是一种集成技术,它结合了多个弱学习器(通常是决策树)来创建稳健且强大的预测模型。它以迭代的方式做到这一点,其中每个新阶段(树)都会纠正之前阶段的错误。
早期停止是Gradient Boosting中的一种技术,它使我们能够找到构建模型所需的最佳迭代次数,该模型可以很好地推广到未见数据并避免过度逼近。概念很简单:我们留出一部分数据集作为验证集(使用指定 validation_fraction
)来评估模型在训练期间的表现。由于模型是用额外的阶段(树)迭代构建的,因此它在验证集上的性能是作为步骤数的函数来监控的。
当模型在验证集上的性能达到稳定或稳定(在 tol
)在一定数量的连续阶段(指定为 n_iter_no_change
).这表明该模型已经达到了进一步迭代可能导致过度适应的地步,是时候停止训练了。
当应用提前停止时,可以使用 n_estimators_
属性总体而言,提前停止是在模型性能和梯度提升效率之间取得平衡的宝贵工具。
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
数据准备#
首先,我们加载并准备加州住房价格数据集进行训练和评估。它对数据集进行子集化,将其拆分为训练集和验证集。
import time
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_california_housing
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
data = fetch_california_housing()
X, y = data.data[:600], data.target[:600]
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
模型训练与比较#
两 GradientBoostingRegressor
模型接受训练:一个有提前停止,另一个没有提前停止。目的是比较他们的表现。它还计算训练时间和 n_estimators_
被两个模型使用。
params = dict(n_estimators=1000, max_depth=5, learning_rate=0.1, random_state=42)
gbm_full = GradientBoostingRegressor(**params)
gbm_early_stopping = GradientBoostingRegressor(
**params,
validation_fraction=0.1,
n_iter_no_change=10,
)
start_time = time.time()
gbm_full.fit(X_train, y_train)
training_time_full = time.time() - start_time
n_estimators_full = gbm_full.n_estimators_
start_time = time.time()
gbm_early_stopping.fit(X_train, y_train)
training_time_early_stopping = time.time() - start_time
estimators_early_stopping = gbm_early_stopping.n_estimators_
误差计算#
代码计算 mean_squared_error
适用于上一节中训练的模型的训练和验证数据集。它计算每次助推迭代的误差。目的是评估模型的性能和收敛性。
train_errors_without = []
val_errors_without = []
train_errors_with = []
val_errors_with = []
for i, (train_pred, val_pred) in enumerate(
zip(
gbm_full.staged_predict(X_train),
gbm_full.staged_predict(X_val),
)
):
train_errors_without.append(mean_squared_error(y_train, train_pred))
val_errors_without.append(mean_squared_error(y_val, val_pred))
for i, (train_pred, val_pred) in enumerate(
zip(
gbm_early_stopping.staged_predict(X_train),
gbm_early_stopping.staged_predict(X_val),
)
):
train_errors_with.append(mean_squared_error(y_train, train_pred))
val_errors_with.append(mean_squared_error(y_val, val_pred))
可视化比较#
它包括三个子情节:
绘制两个模型在助推迭代上的训练错误。
绘制两个模型在助推迭代中的验证错误。
创建一个条形图来比较提前停止和不提前停止的模型的训练时间和估计量。
fig, axes = plt.subplots(ncols=3, figsize=(12, 4))
axes[0].plot(train_errors_without, label="gbm_full")
axes[0].plot(train_errors_with, label="gbm_early_stopping")
axes[0].set_xlabel("Boosting Iterations")
axes[0].set_ylabel("MSE (Training)")
axes[0].set_yscale("log")
axes[0].legend()
axes[0].set_title("Training Error")
axes[1].plot(val_errors_without, label="gbm_full")
axes[1].plot(val_errors_with, label="gbm_early_stopping")
axes[1].set_xlabel("Boosting Iterations")
axes[1].set_ylabel("MSE (Validation)")
axes[1].set_yscale("log")
axes[1].legend()
axes[1].set_title("Validation Error")
training_times = [training_time_full, training_time_early_stopping]
labels = ["gbm_full", "gbm_early_stopping"]
bars = axes[2].bar(labels, training_times)
axes[2].set_ylabel("Training Time (s)")
for bar, n_estimators in zip(bars, [n_estimators_full, estimators_early_stopping]):
height = bar.get_height()
axes[2].text(
bar.get_x() + bar.get_width() / 2,
height + 0.001,
f"Estimators: {n_estimators}",
ha="center",
va="bottom",
)
plt.tight_layout()
plt.show()

The difference in training error between the gbm_full
and the
gbm_early_stopping
stems from the fact that gbm_early_stopping
sets
aside validation_fraction
of the training data as internal validation set.
Early stopping is decided based on this internal validation score.
总结#
在我们的例子中, GradientBoostingRegressor
加州住房价格数据集的模型,我们已经证明了提前停止的实际好处:
Preventing Overfitting: 我们展示了验证误差如何在某个点后稳定或开始增加,这表明该模型可以更好地推广到未见数据。这是通过在过度匹配发生之前停止训练过程来实现的。
Improving Training Efficiency: 我们比较了提前停止和不提前停止的模型之间的训练时间。提前停止的模型实现了相当的准确性,同时需要的估计量明显减少,从而导致训练速度更快。
Total running time of the script: (0分2.884秒)
相关实例
Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>
_