5 - Tuning models

Survival analysis with tidymodels

Tuning parameters

Some model or preprocessing parameters cannot be estimated directly from the data.

Some examples:

  • Tree depth in decision trees
  • Number of neighbors in a K-nearest neighbor model

Optimize tuning parameters

  • Try different values and measure their performance.
  • Find good values for these parameters.
  • Once the value(s) of the parameter(s) are determined, a model can be finalized by fitting the model to the entire training set.

Optimize tuning parameters

The main two strategies for optimization are:

  • Grid search πŸ’  which tests a pre-defined set of candidate values

  • Iterative search πŸŒ€ which suggests/estimates new values of candidate parameters to evaluate

Specifying tuning parameters

Let’s take our previous random forest workflow and tag for tuning the minimum number of data points in each node:

rf_spec <- rand_forest(min_n = tune()) %>% 
  set_engine("aorsf") %>% 
  set_mode("censored regression")

rf_wflow <- workflow(event_time ~ ., rf_spec)
rf_wflow
#> ══ Workflow ══════════════════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: rand_forest()
#> 
#> ── Preprocessor ──────────────────────────────────────────────────────
#> event_time ~ .
#> 
#> ── Model ─────────────────────────────────────────────────────────────
#> Random Forest Model Specification (censored regression)
#> 
#> Main Arguments:
#>   min_n = tune()
#> 
#> Computational engine: aorsf

Try out multiple values

tune_grid() works similar to fit_resamples() but covers multiple parameter values:

set.seed(22)
rf_res <- tune_grid(
  rf_wflow,
  cat_folds,
  eval_time = c(90, 30, 60, 120),
  grid = 5
)

Compare results

Inspecting results and selecting the best-performing hyperparameter(s):

show_best(rf_res)
#> # A tibble: 5 Γ— 8
#>   min_n .metric        .estimator .eval_time  mean     n std_err .config        
#>   <int> <chr>          <chr>           <dbl> <dbl> <int>   <dbl> <chr>          
#> 1    33 brier_survival standard           90 0.197    10 0.00937 Preprocessor1_…
#> 2    31 brier_survival standard           90 0.198    10 0.00953 Preprocessor1_…
#> 3    13 brier_survival standard           90 0.198    10 0.00956 Preprocessor1_…
#> 4    21 brier_survival standard           90 0.199    10 0.00960 Preprocessor1_…
#> 5     6 brier_survival standard           90 0.199    10 0.00977 Preprocessor1_…

best_parameter <- select_best(rf_res)
best_parameter
#> # A tibble: 1 Γ— 2
#>   min_n .config             
#>   <int> <chr>               
#> 1    33 Preprocessor1_Model1

collect_metrics() and autoplot() are also available.

The final fit

rf_wflow <- finalize_workflow(rf_wflow, best_parameter)

final_fit <- last_fit(rf_wflow, cat_split, eval_time = c(90, 30, 60, 120)) 

collect_metrics(final_fit)
#> # A tibble: 4 Γ— 5
#>   .metric        .estimator .eval_time .estimate .config             
#>   <chr>          <chr>           <dbl>     <dbl> <chr>               
#> 1 brier_survival standard           30     0.221 Preprocessor1_Model1
#> 2 brier_survival standard           60     0.225 Preprocessor1_Model1
#> 3 brier_survival standard           90     0.160 Preprocessor1_Model1
#> 4 brier_survival standard          120     0.107 Preprocessor1_Model1

Your turn

Modify your model workflow to tune one or more parameters.

Use grid search to find the best parameter(s).

05:00