censored

A tidymodels package for survival models - Hannah Frick

For censored data,

i.e., data with two aspects

  • “how long until?”
  • “has it happened yet?”

use survival analysis

to take both aspects into account!




We’re extending support for

survival analysis in tidymodels




Consistency in

  • how you specify and fit models
  • how you predict and what you get back




Consistency in

  • how you specify and fit survival models
  • how you predict and what you get back

Specify and fit!

parsnip model specification

rand_forest() %>% 
  set_engine("ranger") %>% 
  set_mode("regression")
#> Random Forest Model Specification (regression)
#> 
#> Computational engine: ranger

Additions for survival models

  • New model type proportional_hazards()
  • New mode "censored regression"
  • New engines for
    • Parametric models
    • Semi-parametric models
    • Tree-based models
  • Formula interface for all models, including stratification

Cetaceans data

Adapted from Tidy Tuesday 2018-38:

cetaceans
#> # A tibble: 1,313 × 6
#>      age event species    sex   transfers born_in_captivity
#>    <dbl> <dbl> <fct>      <fct>     <int> <lgl>            
#>  1    28     0 Bottlenose F             0 TRUE             
#>  2    44     0 Bottlenose F             0 TRUE             
#>  3    39     0 Bottlenose M            13 TRUE             
#>  4    38     0 Bottlenose F             1 TRUE             
#>  5    38     0 Bottlenose M             2 TRUE             
#>  6    37     0 Bottlenose F             2 TRUE             
#>  7    36     0 Bottlenose M             2 TRUE             
#>  8    36     0 Bottlenose F             2 TRUE             
#>  9    35     0 Bottlenose M             3 TRUE             
#> 10    34     0 Bottlenose F             4 TRUE             
#> # … with 1,303 more rows

Proportional hazards model

library(censored)

mod_survival_censored <- 
  proportional_hazards()

Proportional hazards model

library(censored)

mod_survival_censored <- 
  proportional_hazards() %>%
  set_engine("survival") 

Proportional hazards model

library(censored)

mod_survival_censored <- 
  proportional_hazards() %>%
  set_engine("survival") %>% 
  set_mode("censored regression")

Proportional hazards model

library(censored)

mod_survival_censored <- 
  proportional_hazards() %>%
  set_engine("survival") %>% 
  set_mode("censored regression") %>%
  fit(
    Surv(age, event) ~ species + sex + transfers,
    data = cetaceans
  )

Proportional hazards model

library(censored)

mod_survival_censored <- 
  proportional_hazards() %>%
  set_engine("survival") %>%
  set_mode("censored regression") %>%
  fit(
    Surv(age, event) ~ species + sex + transfers + 
      strata(born_in_captivity),
    data = cetaceans
  )

Proportional hazards model

library(survival)

mod_survival <- coxph(
  Surv(age, event) ~ species + sex + transfers + strata(born_in_captivity),
  data = cetaceans
)

Switching to penalized model

library(survival)

mod_survival <- coxph(
  Surv(age, event) ~ species + sex + 
    transfers + strata(born_in_captivity),
  data = cetaceans
)

Switching to penalized model

library(survival)

mod_survival <- coxph(
  Surv(age, event) ~ species + sex + 
    transfers + strata(born_in_captivity),
  data = cetaceans
)
library(glmnet)

x <- model.matrix(~ species + sex + 
                    transfers, 
                  data = cetaceans)[, -1]
y <- Surv(cetaceans$age, cetaceans$event)
y_strata <- glmnet::stratifySurv(
  y, strata = cetaceans$born_in_captivity
)

mod_glmnet <- glmnet(x, y_strata, 
                     family = "cox", 
                     lambda = 0.1)

Switching via censored

mod_survival_censored <- 
  proportional_hazards() %>%
  set_engine("survival") %>%
  set_mode("censored regression") %>%
  fit(
    Surv(age, event) ~ species + sex + 
      transfers + 
      strata(born_in_captivity),
    data = cetaceans
  )
mod_glmnet_censored <- 
  proportional_hazards(penalty = 0.1) %>%
  set_engine("glmnet") %>%
  set_mode("censored regression") %>%
  fit(
    Surv(age, event) ~ species + sex + 
      transfers + 
      strata(born_in_captivity),
    data = cetaceans
  )

Switching via censored

mod_survival_censored <- 
  proportional_hazards() %>%
  set_engine("survival") %>%
  set_mode("censored regression") %>%
  fit(
    Surv(age, event) ~ species + sex + 
      transfers + 
      strata(born_in_captivity),
    data = cetaceans
  )
mod_glmnet_censored <- 
  proportional_hazards(penalty = 0.1) %>%
  set_engine("glmnet") %>%
  set_mode("censored regression") %>%
  fit(
    Surv(age, event) ~ species + sex + 
      transfers + 
      strata(born_in_captivity),
    data = cetaceans
  )

Switching via censored

mod_survival_censored <- 
  proportional_hazards() %>%
  set_engine("survival") %>%
  set_mode("censored regression") %>%
  fit(
    Surv(age, event) ~ species + sex + 
      transfers + 
      strata(born_in_captivity),
    data = cetaceans
  )
mod_glmnet_censored <- 
  proportional_hazards(penalty = 0.1) %>%
  set_engine("glmnet") %>%
  set_mode("censored regression") %>%
  fit(
    Surv(age, event) ~ species + sex + 
      transfers + 
      strata(born_in_captivity),
    data = cetaceans
  )

More models, same syntax

mod_mboost_censored <- 
  boost_tree() %>%
  set_engine("mboost") %>%
  set_mode("censored regression") %>%
  fit(
    Surv(age, event) ~ species + sex + transfers,
    data = cetaceans
  )

Available in censored

All for the mode "censored regression".

model engine
bag_tree() rpart
boost_tree() mboost
decision_tree() rpart
decision_tree() partykit
proportional_hazards() survival
proportional_hazards() glmnet
rand_forest() partykit
survival_reg() survival
survival_reg() flexsurv

Predict!

tidymodels prediction guarantee

  • The predictions are always inside a tibble.
  • The column names and types are unsurprising and predictable.
  • The number of rows in new_data and the output are the same.

Survival time

cetaceans_pred <- cetaceans[1:3,]
cetaceans_pred$species[1] <- NA

predict(
  mod_survival_censored,
  new_data = cetaceans_pred,
  type = "time"
)
#> # A tibble: 3 × 1
#>   .pred_time
#>        <dbl>
#> 1       NA  
#> 2       32.2
#> 3       59.3

Survival probability

pred_survival_censored <- predict(
  mod_survival_censored,
  new_data = cetaceans_pred,
  type = "survival",
  time = c(10, 15, 20, 40)
)

pred_survival_censored
#> # A tibble: 3 × 1
#>   .pred           
#>   <list>          
#> 1 <tibble [4 × 2]>
#> 2 <tibble [4 × 2]>
#> 3 <tibble [4 × 2]>

Survival probability

pred_survival_censored$.pred[[1]]
#> # A tibble: 4 × 2
#>   .time .pred_survival
#>   <dbl>          <dbl>
#> 1    10             NA
#> 2    15             NA
#> 3    20             NA
#> 4    40             NA

Survival probability

pred_survival_censored$.pred[[2]]
#> # A tibble: 4 × 2
#>   .time .pred_survival
#>   <dbl>          <dbl>
#> 1    10          0.748
#> 2    15          0.660
#> 3    20          0.584
#> 4    40          0.385

Approximation the survival curve

predict(
  mod_mboost_censored, 
  new_data = cetaceans[2:3,], 
  type = "survival",
  time = 1:80
) %>% 
  mutate(id = factor(2:3)) %>% 
  tidyr::unnest(cols = .pred)
#> # A tibble: 160 × 3
#>    .time .pred_survival id   
#>    <int>          <dbl> <fct>
#>  1     1          1     2    
#>  2     2          0.981 2    
#>  3     3          0.957 2    
#>  4     4          0.932 2    
#>  5     5          0.901 2    
#>  6     6          0.869 2    
#>  7     7          0.843 2    
#>  8     8          0.802 2    
#>  9     9          0.776 2    
#> 10    10          0.752 2    
#> # … with 150 more rows

Approximation the survival curve

predict(
  mod_mboost_censored, 
  new_data = cetaceans[2:3,], 
  type = "survival",
  time = 1:80
) %>% 
  mutate(id = factor(2:3)) %>% 
  tidyr::unnest(cols = .pred) %>% 
  ggplot(
    aes(x = .time, y = .pred_survival,
        col = id)
  ) +
  geom_step() +
  theme_bw() +
  theme(legend.position = "top") 

Prediction types in censored

  • For all models: "time" and "survival"
  • Depending on the engine: "hazard", "quantile", "linear_pred"




A consistent interface to

survival models