SWAG#
- class torch_uncertainty.models.SWAG(model, cycle_start, cycle_length, scale=1.0, diag_covariance=False, max_num_models=20, var_clamp=1e-06, num_estimators=16)[source]#
Stochastic Weight Averaging Gaussian (SWAG).
Update the SWAG posterior every cycle_length epochs starting at cycle_start. Samples
num_estimatorsmodels from the SWAG posterior after each update. Uses the SWAG posterior estimation only at test time. Otherwise, uses the base model for training.Call
update_wrapper()at the end of each epoch. It will update the SWAG posterior if the current epoch number minuscycle_startis a multiple ofcycle_length. Callbn_update()to update the batchnorm statistics of the current SWAG samples.- Parameters:
model (nn.Module) – PyTorch model to be trained.
cycle_start (int) – Begininning of the first SWAG averaging cycle.
cycle_length (int) – Number of epochs between SWAG updates. The first update occurs at
cycle_start+cycle_length.scale (float, optional) – Scale of the Gaussian. Defaults to
1.0.diag_covariance (bool, optional) – Whether to use a diagonal covariance. Defaults to
False.max_num_models (int, optional) – Maximum number of models to store. Defaults to
0.var_clamp (float, optional) – Minimum variance. Defaults to
1e-30.num_estimators (int, optional) – Number of posterior estimates to use. Defaults to
16.
References
[1] A simple baseline for bayesian uncertainty in deep learning. In NeurIPS 2019.
Note
Modified from wjmaddox/swa_gaussian.
- bn_update(loader, device)[source]#
Update the bachnorm statistics of the current SWAG samples.
- Parameters:
loader (DataLoader) – DataLoader to update the batchnorm statistics.
device (torch.device) – Device to perform the update.
- initialize_stats()[source]#
Initialize the SWAG dictionary of statistics.
For each parameter, we create a mean, squared mean, and covariance square root. The covariance square root is only used when diag_covariance is False.
- sample(scale, diag_covariance=None, block=False, seed=None)[source]#
Sample a model from the SWAG posterior.
- Parameters:
scale (float) – Rescale coefficient of the Gaussian.
diag_covariance (bool, optional) – Whether to use a diagonal covariance. Defaults to None.
block (bool, optional) – Whether to sample a block diagonal covariance. Defaults to False.
seed (int, optional) – Random seed. Defaults to None.
- Returns:
Sampled model.
- Return type:
nn.Module