Shortcuts

Source code for torch_uncertainty.models.wrappers.swag

import copy
from collections.abc import Mapping

import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader

from .swa import SWA


[docs]class SWAG(SWA): swag_stats: dict[str, Tensor] prfx = "model.swag_stats." def __init__( self, model: nn.Module, cycle_start: int, cycle_length: int, scale: float = 1.0, diag_covariance: bool = False, max_num_models: int = 20, var_clamp: float = 1e-6, num_estimators: int = 16, ) -> None: """Stochastic Weight Averaging Gaussian (SWAG). Update the SWAG posterior every `cycle_length` epochs starting at `cycle_start`. Samples :attr:`num_estimators` models from the SWAG posterior after each update. Uses the SWAG posterior estimation only at test time. Otherwise, uses the base model for training. Call :meth:`update_wrapper` at the end of each epoch. It will update the SWAG posterior if the current epoch number minus :attr:`cycle_start` is a multiple of :attr:`cycle_length`. Call :meth:`bn_update` to update the batchnorm statistics of the current SWAG samples. Args: model (nn.Module): PyTorch model to be trained. cycle_start (int): Begininning of the first SWAG averaging cycle. cycle_length (int): Number of epochs between SWAG updates. The first update occurs at :attr:`cycle_start` + :attr:`cycle_length`. scale (float, optional): Scale of the Gaussian. Defaults to 1.0. diag_covariance (bool, optional): Whether to use a diagonal covariance. Defaults to False. max_num_models (int, optional): Maximum number of models to store. Defaults to 0. var_clamp (float, optional): Minimum variance. Defaults to 1e-30. num_estimators (int, optional): Number of posterior estimates to use. Defaults to 16. Reference: Maddox, W. J. et al. A simple baseline for bayesian uncertainty in deep learning. In NeurIPS 2019. Note: Originates from https://github.com/wjmaddox/swa_gaussian. """ super().__init__(model, cycle_start, cycle_length) _swag_checks(scale, max_num_models, var_clamp) self.num_estimators = num_estimators self.scale = scale self.diag_covariance = diag_covariance self.max_num_models = max_num_models self.var_clamp = var_clamp self.initialize_stats() self.fit = False self.samples = []
[docs] def eval_forward(self, x: Tensor) -> Tensor: """Forward pass of the SWAG model when in eval mode.""" if not self.fit: return self.core_model.forward(x) return torch.cat([mod.to(device=x.device)(x) for mod in self.samples])
[docs] def initialize_stats(self) -> None: """Initialize the SWAG dictionary of statistics. For each parameter, we create a mean, squared mean, and covariance square root. The covariance square root is only used when `diag_covariance` is False. """ self.swag_stats = {} for name_p, param in self.core_model.named_parameters(): mean, squared_mean = ( torch.zeros_like(param, device="cpu"), torch.zeros_like(param, device="cpu"), ) self.swag_stats[self.prfx + name_p + "_mean"] = mean self.swag_stats[self.prfx + name_p + "_sq_mean"] = squared_mean if not self.diag_covariance: covariance_sqrt = torch.zeros((0, param.numel()), device="cpu") self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] = covariance_sqrt
[docs] @torch.no_grad() def update_wrapper(self, epoch: int) -> None: """Update the SWAG posterior. The update is performed if the epoch is greater than the cycle start and the difference between the epoch and the cycle start is a multiple of the cycle length. Args: epoch (int): Current epoch. """ if not (epoch > self.cycle_start and (epoch - self.cycle_start) % self.cycle_length == 0): return for name_p, param in self.core_model.named_parameters(): mean = self.swag_stats[self.prfx + name_p + "_mean"] squared_mean = self.swag_stats[self.prfx + name_p + "_sq_mean"] new_param = param.data.detach().cpu() mean = mean * self.num_avgd_models / (self.num_avgd_models + 1) + new_param / ( self.num_avgd_models + 1 ) squared_mean = squared_mean * self.num_avgd_models / ( self.num_avgd_models + 1 ) + new_param**2 / (self.num_avgd_models + 1) self.swag_stats[self.prfx + name_p + "_mean"] = mean self.swag_stats[self.prfx + name_p + "_sq_mean"] = squared_mean if not self.diag_covariance: covariance_sqrt = self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] dev = (new_param - mean).view(-1, 1).t() covariance_sqrt = torch.cat((covariance_sqrt, dev), dim=0) if self.num_avgd_models + 1 > self.max_num_models: covariance_sqrt = covariance_sqrt[1:, :] self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] = covariance_sqrt self.num_avgd_models += 1 self.samples = [ self.sample(self.scale, self.diag_covariance) for _ in range(self.num_estimators) ] self.need_bn_update = True self.fit = True
[docs] def bn_update(self, loader: DataLoader, device: torch.device) -> None: """Update the bachnorm statistics of the current SWAG samples. Args: loader (DataLoader): DataLoader to update the batchnorm statistics. device (torch.device): Device to perform the update. """ if self.need_bn_update: for mod in self.samples: torch.optim.swa_utils.update_bn(loader, mod, device=device) self.need_bn_update = False
[docs] def sample( self, scale: float, diag_covariance: bool | None = None, block: bool = False, seed: int | None = None, ) -> nn.Module: """Sample a model from the SWAG posterior. Args: scale (float): Rescale coefficient of the Gaussian. diag_covariance (bool, optional): Whether to use a diagonal covariance. Defaults to None. block (bool, optional): Whether to sample a block diagonal covariance. Defaults to False. seed (int, optional): Random seed. Defaults to None. Returns: nn.Module: Sampled model. """ if seed is not None: torch.manual_seed(seed) if diag_covariance is None: diag_covariance = self.diag_covariance if not diag_covariance and self.diag_covariance: raise ValueError("Cannot sample full rank from diagonal covariance matrix.") if not block: return self._fullrank_sample(scale, diag_covariance) raise NotImplementedError("Raise an issue if you need this feature.")
def _fullrank_sample(self, scale: float, diagonal_covariance: bool) -> nn.Module: new_sample = copy.deepcopy(self.core_model) for name_p, param in new_sample.named_parameters(): mean = self.swag_stats[self.prfx + name_p + "_mean"] sq_mean = self.swag_stats[self.prfx + name_p + "_sq_mean"] if not diagonal_covariance: cov_mat_sqrt = self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] var = torch.clamp(sq_mean - mean**2, self.var_clamp) var_sample = var.sqrt() * torch.randn_like(var, requires_grad=False) if not diagonal_covariance: cov_sample = cov_mat_sqrt.t() @ torch.randn((cov_mat_sqrt.size(0),)) cov_sample /= (self.max_num_models - 1) ** 0.5 var_sample += cov_sample.view_as(var_sample) sample = mean + scale**0.5 * var_sample param.data = sample.to(device="cpu", dtype=param.dtype) return new_sample def _save_to_state_dict(self, destination, prefix: str, keep_vars: bool): """Add the SWAG statistics to the destination dict.""" super()._save_to_state_dict(destination, prefix, keep_vars) destination |= self.swag_stats
[docs] def state_dict(self, *args, destination=None, prefix="", keep_vars=False) -> Mapping: """Add the SWAG statistics to the state dict.""" return self.swag_stats | super().state_dict( *args, destination=destination, prefix=prefix, keep_vars=keep_vars )
def _load_swag_stats(self, state_dict: Mapping): """Load the SWAG statistics from the state dict.""" self.swag_stats = {k: v for k, v in state_dict.items() if k in self.swag_stats} for k in self.swag_stats: del state_dict[k] self.samples = [ self.sample(self.scale, self.diag_covariance) for _ in range(self.num_estimators) ] self.need_bn_update = True self.fit = True def load_state_dict(self, state_dict: Mapping, strict: bool = True, assign: bool = False): self._load_swag_stats(state_dict) return super().load_state_dict(state_dict, strict, assign) def compute_logdet(self, block=False): raise NotImplementedError("Raise an issue if you need this feature.") def compute_logprob(self, vec=None, block=False, diag=False): raise NotImplementedError("Raise an issue if you need this feature.")
def _swag_checks(scale: float, max_num_models: int, var_clamp: float) -> None: if scale < 0: raise ValueError(f"`scale` must be non-negative. Got {scale}.") if max_num_models < 0: raise ValueError(f"`max_num_models` must be non-negative. Got {max_num_models}.") if var_clamp < 0: raise ValueError(f"`var_clamp` must be non-negative. Got {var_clamp}.")