CompoundCheckpoint#

class torch_uncertainty.callbacks.CompoundCheckpoint(compound_metric_dict, dirpath=None, verbose=False, save_last=False, save_top_k=1, save_weights_only=False, mode='min', every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None, enable_version_counter=True)[source]#

Save the checkpoints maximizing or minimizing a given linear form on the metric values.

Parameters:
  • compound_metric_dict (dict) –

    A dictionary mapping metric names (key) to their corresponding factors (value) in the linear form:

    \[\sum_{i} \text{metric}_i \times \text{value}_i\]

  • dirpath (str | Path | None, optional) – The directory to save the checkpoints in. Defaults to None.

  • verbose (bool, optional) – Whether to print verbose output. Defaults to False.

  • save_last (bool | Literal["link"], optional) – Whether to save the last checkpoint. Defaults to False.

  • save_top_k (int, optional) – The number of best checkpoints to save. Defaults to 1.

  • save_weights_only (bool, optional) – Whether to save only the weights. Defaults to False.

  • mode (str, optional) – The mode to optimize the compound metric. Defaults to "min".

  • every_n_train_steps (int | None, optional) – The number of training steps to wait between saving checkpoints. Defaults to None.

  • train_time_interval (timedelta | None, optional) – The time interval to wait between saving checkpoints. Defaults to None.

  • every_n_epochs (int | None, optional) – The number of epochs to wait between saving checkpoints. Defaults to None.

  • save_on_train_epoch_end (bool | None, optional) – Whether to save the checkpoint at the end of each training epoch. Defaults to None.

  • enable_version_counter (bool, optional) – Whether to enable the version counter for the saved checkpoints. Defaults to True.