Shortcuts

CheckpointEnsemble

class torch_uncertainty.models.CheckpointEnsemble(model, save_schedule=None, use_final_checkpoint=True)[source]

Ensemble of models at different points in the training trajectory.

Parameters:
  • model (nn.Module) – The model to train and ensemble.

  • save_schedule (list[int]) – The epochs at which to save the model. If save schedule is None, save the model at every epoch. Defaults to None.

  • use_final_checkpoint (bool, optional) – Whether to use the final model as a checkpoint. Defaults to True.

Reference:

Checkpoint Ensembles: Ensemble Methods from a Single Training Process. Hugh Chen, Scott Lundberg, Su-In Lee. In ArXiv 2018.

eval_forward(x)[source]

Forward pass for evaluation.

If the model is in evaluation mode, this method will return the ensemble prediction. Otherwise, it will return the prediction of the current model.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The model or ensemble output.

Return type:

torch.Tensor

update_wrapper(epoch)[source]

Save the model at the given epoch if included in the schedule.

Parameters:

epoch (int) – The current epoch.