Logistic-Hazard fits a discrete neural network based on a cross-entropy loss and predictions of a discrete hazard function, also known as Nnet-Survival.

loghaz(
  formula = NULL,
  data = NULL,
  reverse = FALSE,
  time_variable = "time",
  status_variable = "status",
  x = NULL,
  y = NULL,
  frac = 0,
  cuts = 10,
  cutpoints = NULL,
  scheme = c("equidistant", "quantiles"),
  cut_min = 0,
  activation = "relu",
  custom_net = NULL,
  num_nodes = c(32L, 32L),
  batch_norm = TRUE,
  dropout = NULL,
  device = NULL,
  early_stopping = FALSE,
  best_weights = FALSE,
  min_delta = 0,
  patience = 10L,
  batch_size = 256L,
  epochs = 1L,
  verbose = FALSE,
  num_workers = 0L,
  shuffle = TRUE,
  ...
)

Arguments

formula

(formula(1))
Object specifying the model fit, left-hand-side of formula should describe a survival::Surv() object.

data

(data.frame(1))
Training data of data.frame like object, internally is coerced with stats::model.matrix().

reverse

(logical(1))
If TRUE fits estimator on censoring distribution, otherwise (default) survival distribution.

time_variable

(character(1))
Alternative method to call the function. Name of the 'time' variable, required if formula. or x and Y not given.

status_variable

(character(1))
Alternative method to call the function. Name of the 'status' variable, required if formula or x and Y not given.

x

(data.frame(1))
Alternative method to call the function. Required if formula, time_variable and status_variable not given. Data frame like object of features which is internally coerced with model.matrix.

y

([survival::Surv()])
Alternative method to call the function. Required if formula, time_variable and status_variable not given. Survival outcome of right-censored observations.

frac

(numeric(1))
Fraction of data to use for validation dataset, default is 0 and therefore no separate validation dataset.

cuts

(integer(1))
If discretise is TRUE then determines number of cut-points for discretisation.

cutpoints

(numeric())
Alternative to cuts if discretise is true, provide exact cutpoints for discretisation. cuts is ignored if cutpoints is non-NULL.

scheme

(character(1))
Method of discretisation, either "equidistant" (default) or "quantiles". See reticulate::py_help(pycox$models$LogisticHazard$label_transform) for more detail.

cut_min

(integer(1))
Starting duration for discretisation, see reticulate::py_help(pycox$models$LogisticHazard$label_transform) for more detail.

activation

(character(1))
See get_pycox_activation.

custom_net

(torch.nn.modules.module.Module(1))
Optional custom network built with build_pytorch_net, otherwise default architecture used. Note that if building a custom network the number of output channels depends on cuts or cutpoints.

num_nodes, batch_norm, dropout

(integer()/logical(1)/numeric(1))
See build_pytorch_net.

device

(integer(1)|character(1))
Passed to pycox.models.LogisticHazard, specifies device to compute models on.

early_stopping, best_weights, min_delta, patience

(logical(1)/logical(1)/numeric(1)/integer(1)
See get_pycox_callbacks.

batch_size

(integer(1))
Passed to pycox.models.LogisticHazard.fit, elements in each batch.

epochs

(integer(1))
Passed to pycox.models.LogisticHazard.fit, number of epochs.

verbose

(logical(1))
Passed to pycox.models.LogisticHazard.fit, should information be displayed during fitting.

num_workers

(integer(1))
Passed to pycox.models.LogisticHazard.fit, number of workers used in the dataloader.

shuffle

(logical(1))
Passed to pycox.models.LogisticHazard.fit, should order of dataset be shuffled?

...

ANY
Passed to get_pycox_optim.

Value

An object inheriting from class loghaz.

An object of class survivalmodel.

Details

Implemented from the pycox Python package via reticulate. Calls pycox.models.LogisticHazard.

References

Gensheimer, M. F., & Narasimhan, B. (2018). A Simple Discrete-Time Survival Model for Neural Networks, 1–17. https://doi.org/arXiv:1805.00917v3

Kvamme, H., & Borgan, Ø. (2019). Continuous and discrete-time survival prediction with neural networks. https://doi.org/arXiv:1910.06724.

Examples

# \donttest{
if (requireNamespaces("reticulate")) {
  # all defaults
  loghaz(data = simsurvdata(50))

  # common parameters
  loghaz(data = simsurvdata(50), frac = 0.3, activation = "relu",
    num_nodes = c(4L, 8L, 4L, 2L), dropout = 0.1, early_stopping = TRUE, epochs = 100L,
    batch_size = 32L)
}
#> Error in py_module_import(module, convert = convert): ModuleNotFoundError: No module named 'pycox'
#> Run `reticulate::py_last_error()` for details.
# }