CheckpointCollector#

class torch_uncertainty.models.CheckpointCollector(model, cycle_start=None, cycle_length=None, save_schedule=None, use_final_model=True, store_on_cpu=False)[source]#

Ensemble of models at different points in the training trajectory.

CheckpointCollector can be used to collect samples of the posterior distribution, either using classical stochastic gradient optimization methods, or SGLD and SGHMC as implemented in TorchUncertainty.

Parameters:
  • model (nn.Module) – The model to train and ensemble.

  • cycle_start (int) – Epoch to start ensembling. Defaults to None.

  • cycle_length (int) – Number of epochs between model collections. Defaults to None.

  • save_schedule (list[int] | None) – The epochs at which to save the model. Defaults to None.

  • use_final_model (bool) – Whether to use the final model as a checkpoint. Defaults to True.

  • store_on_cpu (bool) – Whether to put the models on the CPU when unused. Defaults to False.

Note

The models are saved at the end of the specified epochs.

Note

If cycle_start, cycle_length and save_schedule are None, the wrapper will save the models at each epoch.

Reference:

Checkpoint Ensembles: Ensemble Methods from a Single Training Process. Hugh Chen, Scott Lundberg, Su-In Lee. In ArXiv 2018.

eval_forward(x)[source]#

Forward pass for evaluation.

This method will return the ensemble prediction if models have already been collected.

Parameters:

x (Tensor) – The input tensor.

Returns:

The ensemble output.

Return type:

Tensor

forward(x)[source]#

Forward pass for training and evaluation mode.

If the model is in evaluation mode, this method will return the ensemble prediction. Otherwise, it will return the prediction of the current model.

Parameters:

x (Tensor) – The input tensor.

Returns:

The model or ensemble output.

Return type:

Tensor

to(*args, **kwargs)[source]#

Move the model and change its type.

If store_on_cpu is set to True, we force device to “cpu” to avoid filling the VRAM.

update_wrapper(epoch)[source]#

Save the model at the end of the epoch, if included in the schedule.

Parameters:

epoch (int) – The current epoch.