DeepHit fits a neural network based on the PMF of a discrete Cox model. This is the single (non-competing) event implementation.
deephit(
formula = NULL,
data = NULL,
reverse = FALSE,
time_variable = "time",
status_variable = "status",
x = NULL,
y = NULL,
frac = 0,
cuts = 10,
cutpoints = NULL,
scheme = c("equidistant", "quantiles"),
cut_min = 0,
activation = "relu",
custom_net = NULL,
num_nodes = c(32L, 32L),
batch_norm = TRUE,
dropout = NULL,
device = NULL,
mod_alpha = 0.2,
sigma = 0.1,
early_stopping = FALSE,
best_weights = FALSE,
min_delta = 0,
patience = 10L,
batch_size = 256L,
epochs = 1L,
verbose = FALSE,
num_workers = 0L,
shuffle = TRUE,
...
)
(formula(1))
Object specifying the model fit, left-hand-side of formula should describe a survival::Surv()
object.
(data.frame(1))
Training data of data.frame
like object, internally is coerced with stats::model.matrix()
.
(logical(1))
If TRUE
fits estimator on censoring distribution, otherwise (default) survival distribution.
(character(1))
Alternative method to call the function. Name of the 'time' variable, required if formula
.
or x
and Y
not given.
(character(1))
Alternative method to call the function. Name of the 'status' variable, required if formula
or x
and Y
not given.
(data.frame(1))
Alternative method to call the function. Required if formula, time_variable
and
status_variable
not given. Data frame like object of features which is internally
coerced with model.matrix
.
([survival::Surv()])
Alternative method to call the function. Required if formula, time_variable
and
status_variable
not given. Survival outcome of right-censored observations.
(numeric(1))
Fraction of data to use for validation dataset, default is 0
and therefore no separate
validation dataset.
(integer(1))
If discretise
is TRUE
then determines number of cut-points for discretisation.
(numeric())
Alternative to cuts
if discretise
is true, provide exact cutpoints for discretisation.
cuts
is ignored if cutpoints
is non-NULL.
(character(1))
Method of discretisation, either "equidistant"
(default) or "quantiles"
.
See reticulate::py_help(pycox$models$LogisticHazard$label_transform)
for more detail.
(integer(1))
Starting duration for discretisation, see
reticulate::py_help(pycox$models$LogisticHazard$label_transform)
for more detail.
(character(1))
See get_pycox_activation.
(torch.nn.modules.module.Module(1))
Optional custom network built with build_pytorch_net, otherwise default architecture used.
Note that if building a custom network the number of output channels depends on cuts
or
cutpoints
.
(integer()/logical(1)/numeric(1))
See build_pytorch_net.
(integer(1)|character(1))
Passed to pycox.models.DeepHitSingle
, specifies device to compute models on.
(numeric(1))
Weighting in (0,1) for combining likelihood (L1) and rank loss (L2). See reference and
py_help(pycox$models$DeepHitSingle)
for more detail.
(numeric(1))
From eta in rank loss (L2) of ref. See reference and
py_help(pycox$models$DeepHitSingle)
for more detail.
(logical(1)/logical(1)/numeric(1)/integer(1)
See get_pycox_callbacks.
(integer(1))
Passed to pycox.models.DeepHitSingle.fit
, elements in each batch.
(integer(1))
Passed to pycox.models.DeepHitSingle.fit
, number of epochs.
(logical(1))
Passed to pycox.models.DeepHitSingle.fit
, should information be displayed during
fitting.
(integer(1))
Passed to pycox.models.DeepHitSingle.fit
, number of workers used in the
dataloader.
(logical(1))
Passed to pycox.models.DeepHitSingle.fit
, should order of dataset be shuffled?
ANY
Passed to get_pycox_optim.
An object inheriting from class deephit
.
An object of class survivalmodel
.
Implemented from the pycox
Python package via reticulate.
Calls pycox.models.DeepHitSingle
.
Changhee Lee, William R Zame, Jinsung Yoon, and Mihaela van der Schaar. Deephit: A deep learning approach to survival analysis with competing risks. In Thirty-Second AAAI Conference on Artificial Intelligence, 2018. http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit
# \donttest{
if (requireNamespaces("reticulate")) {
# all defaults
deephit(data = simsurvdata(50))
# common parameters
deephit(data = simsurvdata(50), frac = 0.3, activation = "relu",
num_nodes = c(4L, 8L, 4L, 2L), dropout = 0.1, early_stopping = TRUE, epochs = 100L,
batch_size = 32L)
}
#> Error in py_module_import(module, convert = convert): ModuleNotFoundError: No module named 'pycox'
#> Run `reticulate::py_last_error()` for details.
# }