SWAG¶
- class torch_uncertainty.models.SWAG(model, cycle_start, cycle_length, scale=1.0, diag_covariance=False, max_num_models=20, var_clamp=1e-06, num_estimators=16)[source]¶
Stochastic Weight Averaging Gaussian (SWAG).
Update the SWAG posterior every cycle_length epochs starting at cycle_start. Samples
num_estimators
models from the SWAG posterior after each update. Uses the SWAG posterior estimation only at test time. Otherwise, uses the base model for training.Call
update_wrapper()
at the end of each epoch. It will update the SWAG posterior if the current epoch number minuscycle_start
is a multiple ofcycle_length
. Callbn_update()
to update the batchnorm statistics of the current SWAG samples.- Parameters:
model (nn.Module) – PyTorch model to be trained.
cycle_start (int) – Begininning of the first SWAG averaging cycle.
cycle_length (int) – Number of epochs between SWAG updates. The first update occurs at
cycle_start`+:attr:`cycle_length
.scale (float, optional) – Scale of the Gaussian. Defaults to 1.0.
diag_covariance (bool, optional) – Whether to use a diagonal covariance. Defaults to False.
max_num_models (int, optional) – Maximum number of models to store. Defaults to 0.
var_clamp (float, optional) – Minimum variance. Defaults to 1e-30.
num_estimators (int, optional) – Number of posterior estimates to use. Defaults to 16.
- Reference:
Maddox, W. J. et al. A simple baseline for bayesian uncertainty in deep learning. In NeurIPS 2019.
Note
Originates from https://github.com/wjmaddox/swa_gaussian.
- bn_update(loader, device)[source]¶
Update the bachnorm statistics of the current SWAG samples.
- Parameters:
loader (DataLoader) – DataLoader to update the batchnorm statistics.
device (torch.device) – Device to perform the update.
- initialize_stats()[source]¶
Initialize the SWAG dictionary of statistics.
For each parameter, we create a mean, squared mean, and covariance square root. The covariance square root is only used when diag_covariance is False.
- sample(scale, diag_covariance=None, block=False, seed=None)[source]¶
Sample a model from the SWAG posterior.
- Parameters:
scale (float) – Rescale coefficient of the Gaussian.
diag_covariance (bool, optional) – Whether to use a diagonal covariance. Defaults to None.
block (bool, optional) – Whether to sample a block diagonal covariance. Defaults to False.
seed (int, optional) – Random seed. Defaults to None.
- Returns:
Sampled model.
- Return type:
nn.Module