通常情况下,在第一次评估模型性能前,我们无法确定那个模型会最终跟测试集一起使用,仅仅通过数据集的划分是不足以可靠的评估模型的性能的。
重采样方法
交叉验证方法
交叉验证方法是一种常用的方法,它将数据集划分为 \(V\) 组个互斥的子集,称为折叠(fold)或折(run),每次迭代使用其中 \(V-1\) 组用于拟合模型,剩余一组用于测试模型的性能。
在resamle包中,交叉验证使用vfold_cv()函数实现。
set.seed(1001)
ames_folds <- vfold_cv(ames_train, v = 10)
ames_folds
# 10-fold cross-validation
# A tibble: 10 × 2
splits id
<list> <chr>
1 <split [2109/235]> Fold01
2 <split [2109/235]> Fold02
3 <split [2109/235]> Fold03
4 <split [2109/235]> Fold04
5 <split [2110/234]> Fold05
6 <split [2110/234]> Fold06
7 <split [2110/234]> Fold07
8 <split [2110/234]> Fold08
9 <split [2110/234]> Fold09
10 <split [2110/234]> Fold10
tidymodels中提供了高级接口以提取重抽样后的分析集、评估集的数据。后续的内容会会有相关介绍。
交叉验证有多重变体,常用的变体包括以下几类。
重复交叉验证
vfold_cv(ames_train, v = 10, repeats = 5)
# 10-fold cross-validation repeated 5 times
# A tibble: 50 × 3
splits id id2
<list> <chr> <chr>
1 <split [2109/235]> Repeat1 Fold01
2 <split [2109/235]> Repeat1 Fold02
3 <split [2109/235]> Repeat1 Fold03
4 <split [2109/235]> Repeat1 Fold04
5 <split [2110/234]> Repeat1 Fold05
6 <split [2110/234]> Repeat1 Fold06
7 <split [2110/234]> Repeat1 Fold07
8 <split [2110/234]> Repeat1 Fold08
9 <split [2110/234]> Repeat1 Fold09
10 <split [2110/234]> Repeat1 Fold10
# ℹ 40 more rows
蒙特卡洛交叉验证(MCCV)
mc_cv(ames_train, prop = 0.9, times = 20)
# Monte Carlo cross-validation (0.9/0.1) with 20 resamples
# A tibble: 20 × 2
splits id
<list> <chr>
1 <split [2109/235]> Resample01
2 <split [2109/235]> Resample02
3 <split [2109/235]> Resample03
4 <split [2109/235]> Resample04
5 <split [2109/235]> Resample05
6 <split [2109/235]> Resample06
7 <split [2109/235]> Resample07
8 <split [2109/235]> Resample08
9 <split [2109/235]> Resample09
10 <split [2109/235]> Resample10
11 <split [2109/235]> Resample11
12 <split [2109/235]> Resample12
13 <split [2109/235]> Resample13
14 <split [2109/235]> Resample14
15 <split [2109/235]> Resample15
16 <split [2109/235]> Resample16
17 <split [2109/235]> Resample17
18 <split [2109/235]> Resample18
19 <split [2109/235]> Resample19
20 <split [2109/235]> Resample20
- 当数据量较大(>10,000样本)且需评估模型稳定性时,优先使用蒙特卡洛交叉验证。
- 当数据量较小或需要严格保证每个样本都被测试时(如医疗诊断模型),优先使用普通交叉验证。
自助法
- 自助法通过有放回的随机抽样得到自助样本集,并且与训练集的样本数相同。
- 自助法中的评估集包含了所有没有被选入分析集的数据。
-
tidymodels包中使用bootstraps()函数实现自助法。
bootstraps(ames_train, times = 5) # 5次迭代的自助法重抽样
# Bootstrap sampling
# A tibble: 5 × 2
splits id
<list> <chr>
1 <split [2344/832]> Bootstrap1
2 <split [2344/897]> Bootstrap2
3 <split [2344/881]> Bootstrap3
4 <split [2344/862]> Bootstrap4
5 <split [2344/854]> Bootstrap5
滚动预测原点(rolling forecast origin resampling)
- 当数据具有明显的时间特征时,可以使用该重抽样的方法。
-
tidymodels中使用rolling_origin()函数实现滚动预测原点。
评估模型性能
通过fit_resamples()函数可以对重抽样的结果进行拟合。
-
fit_resamples()函数的第一个参数是一个parsnip模型对象或workflow/workflowset对象。
-
fit_resamples()函数没有data参数,取而代之的第二个参数是resamples参数,该参数接受bootstraps()、vfold_cv()、mc_cv()等函数的返回值。
-
metrics参数定义一组性能指标。默认情况下,回归模型使用 \(RMSE\) 和 \(R^2\),分类模型默认使用准确率和AUC。
-
control参数是一个由control_resamples()函数生成的控制对象,用于指定模型拟合的细节,主要包括
-
verbose:是否打印日志。
-
save_pred:是否保存评估集的预测结果。
-
save_workflow:是否保存工作流。
-
extract:是否保存重抽样过程中创建的模型。
set.seed(1003)
rf_model <- rand_forest(trees = 1000) %>%
set_engine("ranger") %>%
set_mode("regression")
rf_wflow <- workflow() %>%
add_formula(
Sale_Price ~
Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type + Latitude + Longitude
) %>%
add_model(rf_model)
rf_fit <- rf_wflow %>%
fit(data = ames_train)
keep_pred <- control_resamples(save_pred = TRUE, save_workflow = TRUE)
rf_res <- rf_wflow %>%
fit_resamples(
resamples = ames_folds,
control = keep_pred
)
rf_res
# Resampling results
# 10-fold cross-validation
# A tibble: 10 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [2109/235]> Fold01 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
2 <split [2109/235]> Fold02 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
3 <split [2109/235]> Fold03 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
4 <split [2109/235]> Fold04 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
5 <split [2110/234]> Fold05 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
6 <split [2110/234]> Fold06 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
7 <split [2110/234]> Fold07 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
8 <split [2110/234]> Fold08 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
9 <split [2110/234]> Fold09 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
10 <split [2110/234]> Fold10 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
# 提取评估集的性能指标
collect_metrics(rf_res)
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 rmse standard 0.0710 10 0.00257 pre0_mod0_post0
2 rsq standard 0.837 10 0.00923 pre0_mod0_post0
# 提取评估集的预测结果
assess_res <- collect_predictions(rf_res)
assess_res
# A tibble: 2,344 × 5
.pred id Sale_Price .row .config
<dbl> <chr> <dbl> <int> <chr>
1 5.28 Fold01 5.13 10 pre0_mod0_post0
2 4.89 Fold01 4.84 27 pre0_mod0_post0
3 5.10 Fold01 5.11 47 pre0_mod0_post0
4 5.10 Fold01 5.13 52 pre0_mod0_post0
5 5.35 Fold01 5.40 59 pre0_mod0_post0
6 5.27 Fold01 5.27 63 pre0_mod0_post0
7 5.06 Fold01 5.07 65 pre0_mod0_post0
8 5.02 Fold01 5.03 66 pre0_mod0_post0
9 5.36 Fold01 5.44 67 pre0_mod0_post0
10 5.58 Fold01 5.67 68 pre0_mod0_post0
# ℹ 2,334 more rows
保存重抽样对象
TODO: 如何保存重抽样对象,以便后续使用?
总结
- 重抽样是一种评估模型性能的有效方法。
-
fit_resamples()函数可以对重抽样的结果进行拟合,并计算性能指标。此函数也会被用在模型调优等其他场景。
- 如果样本量较小,建议选择重复 10 折交叉验证;
- 如果样本量足够大,比如几万,几十万这种,随便选,都可以;
- 如果目的不是得到最好的模型表现,而是为了在不同模型间进行选择,建议使用
bootstrap。
在后续的章节中会用的代码如下:
library(tidymodels)
data(ames)
ames <- ames %>%
mutate(Sale_Price = log10(Sale_Price))
set.seed(502)
ames_split <- initial_split(ames, prop = 0.80, strata = Sale_Price)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)
ames_rec <-
recipe(
Sale_Price ~
Neighborhood +
Gr_Liv_Area +
Year_Built +
Bldg_Type +
Latitude +
Longitude,
data = ames_train
) %>%
step_log(Sale_Price, base = 10) %>%
step_other(Neighborhood, threshold = 0.01) %>%
step_dummy(all_nominal_predictors()) %>%
step_interact(~ Gr_Liv_Area:starts_with("Bldg_Type_")) %>%
step_ns(Latitude, Longitude, deg_free = 20)
lm_model <- linear_reg() %>%
set_engine("lm") %>%
set_mode("regression")
lm_wflow <- workflow() %>%
add_recipe(ames_rec) %>%
add_model(lm_model)
lm_fit <- lm_wflow %>%
fit(data = ames_train)
rf_model <- rand_forest(trees = 1000) %>%
set_engine("ranger") %>%
set_mode("regression")
rf_wflow <- workflow() %>%
add_formula(
Sale_Price ~
Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type + Latitude + Longitude
) %>%
add_model(rf_model)
set.seed(1001)
ames_folds <- vfold_cv(ames_train, v = 10)
keep_pred <- control_resamples(save_pred = TRUE, save_workflow = TRUE)
set.seed(1003)
rf_res <- rf_wflow %>%
fit_resamples(resamples = ames_folds, control = keep_pred)
rf_res
# Resampling results
# 10-fold cross-validation
# A tibble: 10 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [2107/235]> Fold01 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
2 <split [2107/235]> Fold02 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
3 <split [2108/234]> Fold03 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
4 <split [2108/234]> Fold04 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
5 <split [2108/234]> Fold05 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
6 <split [2108/234]> Fold06 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
7 <split [2108/234]> Fold07 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
8 <split [2108/234]> Fold08 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
9 <split [2108/234]> Fold09 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>
10 <split [2108/234]> Fold10 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>