Shortcuts

SWAG

class torch_uncertainty.models.SWAG(model, cycle_start, cycle_length, scale=1.0, diag_covariance=False, max_num_models=20, var_clamp=1e-06, num_estimators=16)[source]

Stochastic Weight Averaging Gaussian (SWAG).

Update the SWAG posterior every cycle_length epochs starting at cycle_start. Samples num_estimators models from the SWAG posterior after each update. Uses the SWAG posterior estimation only at test time. Otherwise, uses the base model for training.

Call update_wrapper() at the end of each epoch. It will update the SWAG posterior if the current epoch number minus cycle_start is a multiple of cycle_length. Call bn_update() to update the batchnorm statistics of the current SWAG samples.

Parameters:
  • model (nn.Module) – PyTorch model to be trained.

  • cycle_start (int) – Begininning of the first SWAG averaging cycle.

  • cycle_length (int) – Number of epochs between SWAG updates. The first update occurs at cycle_start`+:attr:`cycle_length.

  • scale (float, optional) – Scale of the Gaussian. Defaults to 1.0.

  • diag_covariance (bool, optional) – Whether to use a diagonal covariance. Defaults to False.

  • max_num_models (int, optional) – Maximum number of models to store. Defaults to 0.

  • var_clamp (float, optional) – Minimum variance. Defaults to 1e-30.

  • num_estimators (int, optional) – Number of posterior estimates to use. Defaults to 16.

Reference:

Maddox, W. J. et al. A simple baseline for bayesian uncertainty in deep learning. In NeurIPS 2019.

bn_update(loader, device)[source]

Update the bachnorm statistics of the current SWAG samples.

Parameters:
  • loader (DataLoader) – DataLoader to update the batchnorm statistics.

  • device (torch.device) – Device to perform the update.

eval_forward(x)[source]

Forward pass of the SWAG model when in eval mode.

initialize_stats()[source]

Initialize the SWAG dictionary of statistics.

For each parameter, we create a mean, squared mean, and covariance square root. The covariance square root is only used when diag_covariance is False.

sample(scale, diag_covariance=None, block=False, seed=None)[source]

Sample a model from the SWAG posterior.

Parameters:
  • scale (float) – Rescale coefficient of the Gaussian.

  • diag_covariance (bool, optional) – Whether to use a diagonal covariance. Defaults to None.

  • block (bool, optional) – Whether to sample a block diagonal covariance. Defaults to False.

  • seed (int, optional) – Random seed. Defaults to None.

Returns:

Sampled model.

Return type:

nn.Module

state_dict(*args, destination=None, prefix='', keep_vars=False)[source]

Add the SWAG statistics to the state dict.

update_wrapper(epoch)[source]

Update the SWAG posterior.

The update is performed if the epoch is greater than the cycle start and the difference between the epoch and the cycle start is a multiple of the cycle length.

Parameters:

epoch (int) – Current epoch.