5 - Tuning models
Survival analysis with tidymodels
Tuning parameters
Some model or preprocessing parameters cannot be estimated directly from the data.
Some examples:
- Tree depth in decision trees
- Number of neighbors in a K-nearest neighbor model
Optimize tuning parameters
- Try different values and measure their performance.
- Find good values for these parameters.
- Once the value(s) of the parameter(s) are determined, a model can be finalized by fitting the model to the entire training set.
Optimize tuning parameters
The main two strategies for optimization are:
Specifying tuning parameters
Letβs take our previous random forest workflow and tag for tuning the minimum number of data points in each node:
rf_spec <- rand_forest(min_n = tune()) %>%
set_engine("aorsf") %>%
set_mode("censored regression")
rf_wflow <- workflow(event_time ~ ., rf_spec)
rf_wflow
#> ββ Workflow ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
#> Preprocessor: Formula
#> Model: rand_forest()
#>
#> ββ Preprocessor ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
#> event_time ~ .
#>
#> ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
#> Random Forest Model Specification (censored regression)
#>
#> Main Arguments:
#> min_n = tune()
#>
#> Computational engine: aorsf
Try out multiple values
tune_grid()
works similar to fit_resamples()
but covers multiple parameter values:
set.seed(22)
rf_res <- tune_grid(
rf_wflow,
cat_folds,
eval_time = c(90, 30, 60, 120),
grid = 5
)
Compare results
Inspecting results and selecting the best-performing hyperparameter(s):
show_best(rf_res)
#> # A tibble: 5 Γ 8
#> min_n .metric .estimator .eval_time mean n std_err .config
#> <int> <chr> <chr> <dbl> <dbl> <int> <dbl> <chr>
#> 1 33 brier_survival standard 90 0.197 10 0.00937 Preprocessor1_β¦
#> 2 31 brier_survival standard 90 0.198 10 0.00953 Preprocessor1_β¦
#> 3 13 brier_survival standard 90 0.198 10 0.00956 Preprocessor1_β¦
#> 4 21 brier_survival standard 90 0.199 10 0.00960 Preprocessor1_β¦
#> 5 6 brier_survival standard 90 0.199 10 0.00977 Preprocessor1_β¦
best_parameter <- select_best(rf_res)
best_parameter
#> # A tibble: 1 Γ 2
#> min_n .config
#> <int> <chr>
#> 1 33 Preprocessor1_Model1
collect_metrics()
and autoplot()
are also available.
The final fit
rf_wflow <- finalize_workflow(rf_wflow, best_parameter)
final_fit <- last_fit(rf_wflow, cat_split, eval_time = c(90, 30, 60, 120))
collect_metrics(final_fit)
#> # A tibble: 4 Γ 5
#> .metric .estimator .eval_time .estimate .config
#> <chr> <chr> <dbl> <dbl> <chr>
#> 1 brier_survival standard 30 0.221 Preprocessor1_Model1
#> 2 brier_survival standard 60 0.225 Preprocessor1_Model1
#> 3 brier_survival standard 90 0.160 Preprocessor1_Model1
#> 4 brier_survival standard 120 0.107 Preprocessor1_Model1
Your turn
Modify your model workflow to tune one or more parameters.
Use grid search to find the best parameter(s).