library("mcboost")
library("mlr3")
library("mlr3proba")
library("mlr3pipelines")
library("mlr3learners")
library("tidyverse")
set.seed(27099)
To show the basic functionality of MCBoostSurv
, we provide a minimal example on the standard survival data set rats. After loading and pre-processing the data, we train a mlr3learner
on the training data. We instantiate a MCBoostSurv
instance with the default parameters. Then, we run the $multicalibrate()
method on our data to start multi-calibration in survival analysis. With $predict_probs()
, we can get multicalibrated predictions.
#prepare task
task = tsk("rats")
prep_pipe = po("encode", param_vals = list(method="one-hot"))
prep = prep_pipe$train(list(task))[[1]]
#split data
train = prep$clone()$filter(1:199)
val = prep$clone()$filter(200:250)
test = prep$clone()$filter(256:300)
# get trained survival model
baseline = lrn("surv.ranger")$train(train)
# initialize mcboost
mc_surv = MCBoostSurv$new(init_predictor = baseline)
# multicalibrate model
mc_surv$multicalibrate(data = val$data(cols = val$feature_names),
labels = val$data(cols = val$target_names))
# get new predictions
mc_surv$predict_probs(test$data(cols = test$feature_names))
Internally mcboostsurv runs the following procedure max_iter
times (similar ro mcboost
, just for distributions over time):
init_predictor
in the first iteration.res = y - y_hat
for all time pointsnum_buckets
according to y_hat
and time.auditor_fitter
) (here calledc(x)
) on the data in each bucket with target variable r
.misscal = mean(c(x) * res(x))
misscal > alpha
: For the bucket with highest misscal
, update the model using the prediction c(x)
. else: Stop the procedureBased on this, we can now show multicalibration on a data set with two sensitive attributes (age and gender). Again, we load and pre-process the data.
library(survival)
data_pbc = pbc %>%
mutate(status = if_else(status == 2, 1, 0)
) %>%
select(-id) %>%
drop_na()
task_pbc = TaskSurv$new("pbc", backend = as_data_backend(data_pbc),
time = "time", event = "status")
#Create data split
train_test = rsmp("holdout", ratio = 0.8)$instantiate(task_pbc)
train_g = train_test$train_set(1)
test_ids = train_test$test_set(1)
train_val = rsmp("holdout", ratio = 0.75)$instantiate(task_pbc$clone()$filter(train_g))
train_ids = train_val$train_set(1)
val_ids = train_val$test_set(1)
# Train distributional survival model
xgb_distr = as_learner(ppl("distrcompositor",
learner = as_learner(prep_pipe %>>% lrn("surv.xgboost"))))
xgb_distr$train(task_pbc$clone()$filter(train_ids))
# initialize mcboost
mcboost_learner = as_learner(
prep_pipe %>>% ppl_mcboostsurv(
learner = as_learner(prep_pipe %>>% xgb_distr),
param_vals = list(
alpha = 1e-6,
eta = 0.2,
time_buckets = 2,
num_buckets = 1 )
)
)
# multicalibrate model
mcboost_learner$train(task_pbc$clone()$filter(val_ids))
# get new predictions
test_task = task_pbc$clone()$filter(test_ids)
pred_pbc_mc = mcboost_learner$predict(task_pbc$clone()$filter(test_ids))
pred_pbc_xgb = xgb_distr$predict(task_pbc$clone()$filter(test_ids))