Predicted values from a fitted pycox ANN.

# S3 method for pycox
predict(
  object,
  newdata,
  batch_size = 256L,
  num_workers = 0L,
  interpolate = FALSE,
  inter_scheme = c("const_hazard", "const_pdf"),
  sub = 10L,
  type = c("survival", "risk", "all"),
  distr6 = FALSE,
  ...
)

Arguments

object

(pycox(1))
Object of class inheriting from "pycox".

newdata

(data.frame(1))
Testing data of data.frame like object, internally is coerced with stats::model.matrix(). If missing then training data from fitted object is used.

batch_size

(integer(1))
Passed to pycox.models.X.fit, elements in each batch.

num_workers

(integer(1))
Passed to pycox.models.X.fit, number of workers used in the dataloader.

interpolate

(logical(1))
For models deephit and loghaz, should predictions be linearly interpolated? Ignored for other models.

inter_scheme

(character(1))
If interpolate is TRUE then the scheme for interpolation, see reticulate::py_help(py_help(pycox$models$DeepHitSingle$interpolate)) for further details.

sub

(integer(1))
If interpolate is TRUE or model is loghaz, number of sub-divisions for interpolation. See reticulate::py_help(py_help(pycox$models$DeepHitSingle$interpolate))` for further details.

type

(character(1))
Type of predicted value. Choices are survival probabilities over all time-points in training data ("survival") or a relative risk ranking ("risk"), which is the negative mean survival time so higher rank implies higher risk of event, or both ("all").

distr6

(logical(1))
If FALSE (default) and type is "survival" or "all" returns matrix of survival probabilities, otherwise returns a distr6::Matdist().

...

ANY
Currently ignored.

Value

A numeric if type = "risk", a distr6::Matdist() (if distr6 = TRUE) and type = "survival"; a matrix if (distr6 = FALSE) and type = "survival" where entries are survival probabilities with rows of observations and columns are time-points; or a list combining above if type = "all".

Examples

if (FALSE) {
if (requireNamespaces("reticulate")) {
  fit <- coxtime(data = simsurvdata(50))

  # predict survival matrix and relative risks
  predict(fit, simsurvdata(10), type = "all")

  # return as distribution
  if (requireNamespaces("distr6")) {
    predict(fit, simsurvdata(10), distr6 = TRUE)
  }
}
}