CheckpointEnsemble¶
- class torch_uncertainty.models.CheckpointEnsemble(model, save_schedule=None, use_final_checkpoint=True)[source]¶
Ensemble of models at different points in the training trajectory.
- Parameters:
model (nn.Module) – The model to train and ensemble.
save_schedule (list[int]) – The epochs at which to save the model. If save schedule is None, save the model at every epoch. Defaults to None.
use_final_checkpoint (bool, optional) – Whether to use the final model as a checkpoint. Defaults to True.
- Reference:
Checkpoint Ensembles: Ensemble Methods from a Single Training Process. Hugh Chen, Scott Lundberg, Su-In Lee. In ArXiv 2018.
- eval_forward(x)[source]¶
Forward pass for evaluation.
If the model is in evaluation mode, this method will return the ensemble prediction. Otherwise, it will return the prediction of the current model.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The model or ensemble output.
- Return type:
torch.Tensor