SWAG#

class torch_uncertainty.models.SWAG(model, cycle_start, cycle_length, scale=1.0, diag_covariance=False, max_num_models=20, var_clamp=1e-06, num_estimators=16)[source]#

Stochastic Weight Averaging Gaussian (SWAG).

Update the SWAG posterior every cycle_length epochs starting at cycle_start. Samples num_estimators models from the SWAG posterior after each update. Uses the SWAG posterior estimation only at test time. Otherwise, uses the base model for training.

Call update_wrapper() at the end of each epoch. It will update the SWAG posterior if the current epoch number minus cycle_start is a multiple of cycle_length. Call bn_update() to update the batchnorm statistics of the current SWAG samples.

Parameters:
  • model (nn.Module) – PyTorch model to be trained.

  • cycle_start (int) – Begininning of the first SWAG averaging cycle.

  • cycle_length (int) – Number of epochs between SWAG updates. The first update occurs at cycle_start + cycle_length.

  • scale (float, optional) – Scale of the Gaussian. Defaults to 1.0.

  • diag_covariance (bool, optional) – Whether to use a diagonal covariance. Defaults to False.

  • max_num_models (int, optional) – Maximum number of models to store. Defaults to 0.

  • var_clamp (float, optional) – Minimum variance. Defaults to 1e-30.

  • num_estimators (int, optional) – Number of posterior estimates to use. Defaults to 16.

References

[1] A simple baseline for bayesian uncertainty in deep learning. In NeurIPS 2019.

Note

Modified from wjmaddox/swa_gaussian.

bn_update(loader, device)[source]#

Update the bachnorm statistics of the current SWAG samples.

Parameters:
  • loader (DataLoader) – DataLoader to update the batchnorm statistics.

  • device (torch.device) – Device to perform the update.

eval_forward(x)[source]#

Forward pass of the SWAG model when in eval mode.

initialize_stats()[source]#

Initialize the SWAG dictionary of statistics.

For each parameter, we create a mean, squared mean, and covariance square root. The covariance square root is only used when diag_covariance is False.

sample(scale, diag_covariance=None, block=False, seed=None)[source]#

Sample a model from the SWAG posterior.

Parameters:
  • scale (float) – Rescale coefficient of the Gaussian.

  • diag_covariance (bool, optional) – Whether to use a diagonal covariance. Defaults to None.

  • block (bool, optional) – Whether to sample a block diagonal covariance. Defaults to False.

  • seed (int, optional) – Random seed. Defaults to None.

Returns:

Sampled model.

Return type:

nn.Module

state_dict(*args, destination=None, prefix='', keep_vars=False)[source]#

Add the SWAG statistics to the state dict.

update_wrapper(epoch)[source]#

Update the SWAG posterior.

The update is performed if the epoch is greater than the cycle start and the difference between the epoch and the cycle start is a multiple of the cycle length.

Parameters:

epoch (int) – Current epoch.