Source code for torch_uncertainty.models.wrappers.checkpoint_collector

import copy

import torch
from torch import Tensor, nn


[docs] class CheckpointCollector(nn.Module): def __init__( self, model: nn.Module, cycle_start: int | None = None, cycle_length: int | None = None, save_schedule: list[int] | None = None, use_final_model: bool = True, store_on_cpu: bool = False, ) -> None: """Ensemble of models at different points in the training trajectory. CheckpointCollector can be used to collect samples of the posterior distribution, either using classical stochastic gradient optimization methods, or SGLD and SGHMC as implemented in TorchUncertainty. Args: model (nn.Module): The model to train and ensemble. cycle_start (int): Epoch to start ensembling. Defaults to ``None``. cycle_length (int): Number of epochs between model collections. Defaults to ``None``. save_schedule (list[int] | None): The epochs at which to save the model. Defaults to ``None``. use_final_model (bool): Whether to use the final model as a checkpoint. Defaults to ``True``. store_on_cpu (bool): Whether to put the models on the CPU when unused. Defaults to ``False``. Note: The models are saved at the end of the specified epochs. Note: If :attr:`cycle_start`, :attr:`cycle_length` and :attr:`save_schedule` are ``None``, the wrapper will save the models at each epoch. Reference: Checkpoint Ensembles: Ensemble Methods from a Single Training Process. Hugh Chen, Scott Lundberg, Su-In Lee. In ArXiv 2018. """ super().__init__() self.mode = None if cycle_start is None and cycle_length is None and save_schedule is None: self.mode = "all" elif cycle_start is not None and cycle_length is not None and save_schedule is None: self.mode = "cycle" elif save_schedule is not None and cycle_start is None and cycle_length is None: self.mode = "schedule" else: raise ValueError( f"The combination of arguments: cycle_start: {cycle_start}, cycle_length: {cycle_length}, save_schedule: {save_schedule} is not known." ) self.core_model = model self.cycle_start = cycle_start self.cycle_length = cycle_length self.save_schedule = save_schedule self.use_final_model = use_final_model self.store_on_cpu = store_on_cpu self.register_buffer("num_estimators", torch.tensor(use_final_model, dtype=torch.long)) self.saved_models = nn.ModuleList() def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): self.saved_models = nn.ModuleList() for _ in range(state_dict["model.num_estimators"] - 1): self.saved_models.append(copy.deepcopy(self.core_model)) return super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): self.saved_models = nn.ModuleList() for _ in range(state_dict["model.num_estimators"] - 1): self.saved_models.append(copy.deepcopy(self.core_model)) return super().load_state_dict(state_dict, strict, assign)
[docs] @torch.no_grad() def update_wrapper(self, epoch: int) -> None: """Save the model at the end of the epoch, if included in the schedule. Args: epoch (int): The current epoch. """ match self.mode: case "schedule": if epoch not in self.save_schedule: return case "cycle": if epoch < self.cycle_start or (epoch - self.cycle_start) % self.cycle_length != 0: return self.saved_models.append( copy.deepcopy(self.core_model) if not self.store_on_cpu else copy.deepcopy(self.core_model).cpu() ) self.num_estimators += 1
[docs] def eval_forward(self, x: Tensor) -> Tensor: """Forward pass for evaluation. This method will return the ensemble prediction if models have already been collected. Args: x (Tensor): The input tensor. Returns: Tensor: The ensemble output. """ preds: list[Tensor] = [] if not len(self.saved_models): if self.store_on_cpu: preds = self.core_model.to(x.device).forward(x) self.core_model.cpu() return preds return self.core_model.forward(x) if self.store_on_cpu: for model in self.saved_models: preds.append(model.to(x.device).forward(x)) model.cpu() else: preds = [model.forward(x) for model in self.saved_models] if self.use_final_model: preds.append(self.core_model.to(x.device).forward(x)) if self.store_on_cpu: self.core_model.cpu() return torch.cat(preds, dim=0)
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass for training and evaluation mode. If the model is in evaluation mode, this method will return the ensemble prediction. Otherwise, it will return the prediction of the current model. Args: x (Tensor): The input tensor. Returns: Tensor: The model or ensemble output. """ if self.training: if self.store_on_cpu: preds = self.core_model.to(x.device).forward(x) self.core_model.cpu() return preds return self.core_model.forward(x) return self.eval_forward(x)
[docs] def to(self, *args, **kwargs): """Move the model and change its type. If :attr:`store_on_cpu` is set to True, we force device to "cpu" to avoid filling the VRAM. """ device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3] if self.store_on_cpu: device = torch.device("cpu") return super().to(device=device, dtype=dtype, non_blocking=non_blocking)