CheckpointCollector#
- class torch_uncertainty.models.CheckpointCollector(model, cycle_start=None, cycle_length=None, save_schedule=None, use_final_model=True, store_on_cpu=False)[source]#
Ensemble of models at different points in the training trajectory.
CheckpointCollector can be used to collect samples of the posterior distribution, either using classical stochastic gradient optimization methods, or SGLD and SGHMC as implemented in TorchUncertainty.
- Parameters:
model (nn.Module) – The model to train and ensemble.
cycle_start (int) – Epoch to start ensembling. Defaults to
None
.cycle_length (int) – Number of epochs between model collections. Defaults to
None
.save_schedule (list[int] | None) – The epochs at which to save the model. Defaults to
None
.use_final_model (bool) – Whether to use the final model as a checkpoint. Defaults to
True
.store_on_cpu (bool) – Whether to put the models on the CPU when unused. Defaults to
False
.
Note
The models are saved at the end of the specified epochs.
Note
If
cycle_start
,cycle_length
andsave_schedule
areNone
, the wrapper will save the models at each epoch.- Reference:
Checkpoint Ensembles: Ensemble Methods from a Single Training Process. Hugh Chen, Scott Lundberg, Su-In Lee. In ArXiv 2018.
- eval_forward(x)[source]#
Forward pass for evaluation.
This method will return the ensemble prediction if models have already been collected.
- Parameters:
x (Tensor) – The input tensor.
- Returns:
The ensemble output.
- Return type:
Tensor
- forward(x)[source]#
Forward pass for training and evaluation mode.
If the model is in evaluation mode, this method will return the ensemble prediction. Otherwise, it will return the prediction of the current model.
- Parameters:
x (Tensor) – The input tensor.
- Returns:
The model or ensemble output.
- Return type:
Tensor