Helper function to return a character string with a populated pytorch weight initializer method from torch.nn.init. Used in build_pytorch_net to define a weighting function.

get_pycox_init(
  init = "uniform",
  a = 0,
  b = 1,
  mean = 0,
  std = 1,
  val,
  gain = 1,
  mode = c("fan_in", "fan_out"),
  non_linearity = c("leaky_relu", "relu")
)

Arguments

init

(character(1))
Initialization method, see details for list of implemented methods.

a

(numeric(1))
Passed to uniform, kaiming_uniform, and kaiming_normal.

b

(numeric(1))
Passed to uniform.

mean, std

(numeric(1))
Passed to normal.

val

(numeric(1))
Passed to constant.

gain

(numeric(1))
Passed to xavier_uniform, xavier_normal, and orthogonal.

mode

(character(1))
Passed to kaiming_uniform and kaiming_normal, one of fan_in (default) and fan_out.

non_linearity

(character(1))
Passed to kaiming_uniform and kaiming_normal, one of leaky_relu (default) and relu.

Details

Implemented methods (with help pages) are

  • "uniform"
    reticulate::py_help(torch$nn$init$uniform_)

  • "normal"
    reticulate::py_help(torch$nn$init$normal_)

  • "constant"
    reticulate::py_help(torch$nn$init$constant_)

  • "xavier_uniform"
    reticulate::py_help(torch$nn$init$xavier_uniform_)

  • "xavier_normal"
    reticulate::py_help(torch$nn$init$xavier_normal_)

  • "kaiming_uniform"
    reticulate::py_help(torch$nn$init$kaiming_uniform_)

  • "kaiming_normal"
    reticulate::py_help(torch$nn$init$kaiming_normal_)

  • "orthogonal"
    reticulate::py_help(torch$nn$init$orthogonal_)

Examples

# \donttest{
if (requireNamespaces("reticulate")) {
  get_pycox_init(init = "uniform")

  get_pycox_init(init = "kaiming_uniform", a = 0, mode = "fan_out")
}
#> [1] "torch.nn.init.kaiming_uniform_(m.weight, 0, 'fan_out', 'leaky_relu')"
# }