SWAG#
- class torch_uncertainty.methods.SWAG(core_model, cycle_start, cycle_length, scale=1.0, diag_covariance=False, max_num_models=20, var_clamp=1e-06, num_estimators=16)[source]#
Stochastic Weight Averaging Gaussian (SWAG).
Update the SWAG posterior every cycle_length epochs starting at cycle_start. Samples
num_estimatorsmodels from the SWAG posterior after each update. Uses the SWAG posterior estimation only at test time. Otherwise, uses the base model for training.Call
update_wrapper()at the end of each epoch. It will update the SWAG posterior if the current epoch number minuscycle_startis a multiple ofcycle_length. Callbn_update()to update the batchnorm statistics of the current SWAG samples.- Parameters:
core_model (
Module) – PyTorch model to be trained.cycle_start (
int) – Beginning of the first SWAG averaging cycle.cycle_length (
int) – Number of epochs between SWAG updates. The first update occurs atcycle_start+cycle_length.scale (
float) – Scale of the Gaussian. Defaults to1.0.diag_covariance (
bool) – Whether to use a diagonal covariance. Defaults toFalse.max_num_models (
int) – Maximum number of models to store. Defaults to0.var_clamp (
float) – Minimum variance. Defaults to1e-30.num_estimators (
int) – Number of posterior estimates to use. Defaults to16.
References
[1] A simple baseline for bayesian uncertainty in deep learning. In NeurIPS 2019.
Note
Modified from wjmaddox/swa_gaussian.
- bn_update(loader, device)[source]#
Update the batchnorm statistics of the current SWAG samples.
- Parameters:
loader (
DataLoader) – DataLoader to update the batchnorm statistics.device (
device|str|int|None) – Device to perform the update.
- Return type:
None
- initialize_stats()[source]#
Initialize the SWAG dictionary of statistics.
For each parameter, we create a mean, squared mean, and covariance square root. The covariance square root is only used when diag_covariance is False.
- Return type:
None
- sample(scale, diag_covariance=None, block=False, seed=None)[source]#
Sample a model from the SWAG posterior.
- Parameters:
scale (
float) – Rescale coefficient of the Gaussian.diag_covariance (
bool|None) – Whether to use a diagonal covariance. Defaults toNone.block (
bool) – Whether to sample a block diagonal covariance. Defaults toFalse.seed (
int|None) – Random seed. Defaults toNone.
- Returns:
Sampled model.
- Return type:
nn.Module