9  评估模型性能-重采样

通常情况下,在第一次评估模型性能前,我们无法确定那个模型会最终跟测试集一起使用,仅仅通过数据集的划分是不足以可靠的评估模型的性能的。

9.1 重采样方法

9.1.1 交叉验证方法

交叉验证方法是一种常用的方法,它将数据集划分为 \(V\) 组个互斥的子集,称为折叠(fold)或折(run),每次迭代使用其中 \(V-1\) 组用于拟合模型,剩余一组用于测试模型的性能。

resamle包中,交叉验证使用vfold_cv()函数实现。

set.seed(1001)
ames_folds <- vfold_cv(ames_train, v = 10)
ames_folds
#  10-fold cross-validation 
# A tibble: 10 × 2
   splits             id    
   <list>             <chr> 
 1 <split [2109/235]> Fold01
 2 <split [2109/235]> Fold02
 3 <split [2109/235]> Fold03
 4 <split [2109/235]> Fold04
 5 <split [2110/234]> Fold05
 6 <split [2110/234]> Fold06
 7 <split [2110/234]> Fold07
 8 <split [2110/234]> Fold08
 9 <split [2110/234]> Fold09
10 <split [2110/234]> Fold10

tidymodels中提供了高级接口以提取重抽样后的分析集、评估集的数据。后续的内容会会有相关介绍。

交叉验证有多重变体,常用的变体包括以下几类。

9.1.1.1 重复交叉验证

  • 简单来说就是在V折交叉验证的基础上,在重复R次,即产生V折的过程被重复了R次。此时,最终重抽样性能的估计是基于 \(V \times R\) 个结果的来的。

  • resamle包中,重复交叉验证仅需要在vfold_cv()函数中增加repeats参数即可。

vfold_cv(ames_train, v = 10, repeats = 5)
#  10-fold cross-validation repeated 5 times 
# A tibble: 50 × 3
   splits             id      id2   
   <list>             <chr>   <chr> 
 1 <split [2109/235]> Repeat1 Fold01
 2 <split [2109/235]> Repeat1 Fold02
 3 <split [2109/235]> Repeat1 Fold03
 4 <split [2109/235]> Repeat1 Fold04
 5 <split [2110/234]> Repeat1 Fold05
 6 <split [2110/234]> Repeat1 Fold06
 7 <split [2110/234]> Repeat1 Fold07
 8 <split [2110/234]> Repeat1 Fold08
 9 <split [2110/234]> Repeat1 Fold09
10 <split [2110/234]> Repeat1 Fold10
# ℹ 40 more rows

9.1.1.2 蒙特卡洛交叉验证(MCCV)

  • 与普通交叉验证的区别在于,MCCV在分配分析集和评估集数据比例时可以自由指定,这有可能导致评估集之间不是互斥的(即同一个样本可能出现在不同的评估集中)。

  • resamle包中,MCCV使用mc_cv()函数实现。

mc_cv(ames_train, prop = 0.9, times = 20)
# Monte Carlo cross-validation (0.9/0.1) with 20 resamples 
# A tibble: 20 × 2
   splits             id        
   <list>             <chr>     
 1 <split [2109/235]> Resample01
 2 <split [2109/235]> Resample02
 3 <split [2109/235]> Resample03
 4 <split [2109/235]> Resample04
 5 <split [2109/235]> Resample05
 6 <split [2109/235]> Resample06
 7 <split [2109/235]> Resample07
 8 <split [2109/235]> Resample08
 9 <split [2109/235]> Resample09
10 <split [2109/235]> Resample10
11 <split [2109/235]> Resample11
12 <split [2109/235]> Resample12
13 <split [2109/235]> Resample13
14 <split [2109/235]> Resample14
15 <split [2109/235]> Resample15
16 <split [2109/235]> Resample16
17 <split [2109/235]> Resample17
18 <split [2109/235]> Resample18
19 <split [2109/235]> Resample19
20 <split [2109/235]> Resample20
Note
  • 当数据量较大(>10,000样本)且需评估模型稳定性时,优先使用蒙特卡洛交叉验证。
  • 当数据量较小或需要严格保证每个样本都被测试时(如医疗诊断模型),优先使用普通交叉验证。

9.1.2 自助法

  • 自助法通过有放回的随机抽样得到自助样本集,并且与训练集的样本数相同。
  • 自助法中的评估集包含了所有没有被选入分析集的数据。
  • tidymodels包中使用bootstraps()函数实现自助法。
bootstraps(ames_train, times = 5) # 5次迭代的自助法重抽样
# Bootstrap sampling 
# A tibble: 5 × 2
  splits             id        
  <list>             <chr>     
1 <split [2344/832]> Bootstrap1
2 <split [2344/897]> Bootstrap2
3 <split [2344/881]> Bootstrap3
4 <split [2344/862]> Bootstrap4
5 <split [2344/854]> Bootstrap5

9.1.3 滚动预测原点(rolling forecast origin resampling)

  1. 当数据具有明显的时间特征时,可以使用该重抽样的方法。
  2. tidymodels中使用rolling_origin()函数实现滚动预测原点。

9.2 评估模型性能

通过fit_resamples()函数可以对重抽样的结果进行拟合。

  • fit_resamples()函数的第一个参数是一个parsnip模型对象或workflow/workflowset对象。
  • fit_resamples()函数没有data参数,取而代之的第二个参数是resamples参数,该参数接受bootstraps()vfold_cv()mc_cv()等函数的返回值。
  • metrics参数定义一组性能指标。默认情况下,回归模型使用 \(RMSE\)\(R^2\),分类模型默认使用准确率AUC
  • control参数是一个由control_resamples()函数生成的控制对象,用于指定模型拟合的细节,主要包括
    • verbose:是否打印日志。
    • save_pred:是否保存评估集的预测结果。
    • save_workflow:是否保存工作流。
    • extract:是否保存重抽样过程中创建的模型。
set.seed(1003)
rf_model <- rand_forest(trees = 1000) %>%
  set_engine("ranger") %>%
  set_mode("regression")

rf_wflow <- workflow() %>%
  add_formula(
    Sale_Price ~
      Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type + Latitude + Longitude
  ) %>%
  add_model(rf_model)

rf_fit <- rf_wflow %>%
  fit(data = ames_train)

keep_pred <- control_resamples(save_pred = TRUE, save_workflow = TRUE)
rf_res <- rf_wflow %>%
  fit_resamples(
    resamples = ames_folds,
    control = keep_pred
  )
rf_res
# Resampling results
# 10-fold cross-validation 
# A tibble: 10 × 5
   splits             id     .metrics         .notes           .predictions
   <list>             <chr>  <list>           <list>           <list>      
 1 <split [2109/235]> Fold01 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 2 <split [2109/235]> Fold02 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 3 <split [2109/235]> Fold03 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 4 <split [2109/235]> Fold04 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 5 <split [2110/234]> Fold05 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 6 <split [2110/234]> Fold06 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 7 <split [2110/234]> Fold07 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 8 <split [2110/234]> Fold08 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 9 <split [2110/234]> Fold09 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
10 <split [2110/234]> Fold10 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
# 提取评估集的性能指标
collect_metrics(rf_res)
# A tibble: 2 × 6
  .metric .estimator   mean     n std_err .config        
  <chr>   <chr>       <dbl> <int>   <dbl> <chr>          
1 rmse    standard   0.0710    10 0.00257 pre0_mod0_post0
2 rsq     standard   0.837     10 0.00923 pre0_mod0_post0
# 提取评估集的预测结果
assess_res <- collect_predictions(rf_res)
assess_res
# A tibble: 2,344 × 5
   .pred id     Sale_Price  .row .config        
   <dbl> <chr>       <dbl> <int> <chr>          
 1  5.28 Fold01       5.13    10 pre0_mod0_post0
 2  4.89 Fold01       4.84    27 pre0_mod0_post0
 3  5.10 Fold01       5.11    47 pre0_mod0_post0
 4  5.10 Fold01       5.13    52 pre0_mod0_post0
 5  5.35 Fold01       5.40    59 pre0_mod0_post0
 6  5.27 Fold01       5.27    63 pre0_mod0_post0
 7  5.06 Fold01       5.07    65 pre0_mod0_post0
 8  5.02 Fold01       5.03    66 pre0_mod0_post0
 9  5.36 Fold01       5.44    67 pre0_mod0_post0
10  5.58 Fold01       5.67    68 pre0_mod0_post0
# ℹ 2,334 more rows

9.3 保存重抽样对象

TODO: 如何保存重抽样对象,以便后续使用?

9.4 总结

  • 重抽样是一种评估模型性能的有效方法。
  • fit_resamples()函数可以对重抽样的结果进行拟合,并计算性能指标。此函数也会被用在模型调优等其他场景。
Note
  1. 如果样本量较小,建议选择重复 10 折交叉验证;
  2. 如果样本量足够大,比如几万,几十万这种,随便选,都可以;
  3. 如果目的不是得到最好的模型表现,而是为了在不同模型间进行选择,建议使用 bootstrap

在后续的章节中会用的代码如下:

library(tidymodels)
data(ames)
ames <- ames %>%
  mutate(Sale_Price = log10(Sale_Price))

set.seed(502)
ames_split <- initial_split(ames, prop = 0.80, strata = Sale_Price)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)

ames_rec <-
  recipe(
    Sale_Price ~
      Neighborhood +
        Gr_Liv_Area +
        Year_Built +
        Bldg_Type +
        Latitude +
        Longitude,
    data = ames_train
  ) %>%
  step_log(Sale_Price, base = 10) %>%
  step_other(Neighborhood, threshold = 0.01) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_interact(~ Gr_Liv_Area:starts_with("Bldg_Type_")) %>%
  step_ns(Latitude, Longitude, deg_free = 20)

lm_model <- linear_reg() %>%
  set_engine("lm") %>%
  set_mode("regression")

lm_wflow <- workflow() %>%
  add_recipe(ames_rec) %>%
  add_model(lm_model)

lm_fit <- lm_wflow %>%
  fit(data = ames_train)

rf_model <- rand_forest(trees = 1000) %>%
  set_engine("ranger") %>%
  set_mode("regression")

rf_wflow <- workflow() %>%
  add_formula(
    Sale_Price ~
      Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type + Latitude + Longitude
  ) %>%
  add_model(rf_model)

set.seed(1001)
ames_folds <- vfold_cv(ames_train, v = 10)

keep_pred <- control_resamples(save_pred = TRUE, save_workflow = TRUE)

set.seed(1003)
rf_res <- rf_wflow %>%
  fit_resamples(resamples = ames_folds, control = keep_pred)

rf_res
# Resampling results
# 10-fold cross-validation 
# A tibble: 10 × 5
   splits             id     .metrics         .notes           .predictions
   <list>             <chr>  <list>           <list>           <list>      
 1 <split [2107/235]> Fold01 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 2 <split [2107/235]> Fold02 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 3 <split [2108/234]> Fold03 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 4 <split [2108/234]> Fold04 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 5 <split [2108/234]> Fold05 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 6 <split [2108/234]> Fold06 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 7 <split [2108/234]> Fold07 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 8 <split [2108/234]> Fold08 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
 9 <split [2108/234]> Fold09 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>    
10 <split [2108/234]> Fold10 <tibble [2 × 4]> <tibble [0 × 4]> <tibble>