SWA#
- class torch_uncertainty.models.SWA(model, cycle_start, cycle_length)[source]#
Stochastic Weight Averaging.
Update the SWA model every
cycle_length
epochs starting atcycle_start
. Uses the SWA model only at test time. Otherwise, uses the base model for training.- Parameters:
model (nn.Module) – PyTorch model to be trained.
cycle_start (int) – Epoch to start SWA.
cycle_length (int) – Number of epochs between SWA updates.
References
[1] Averaging Weights Leads to Wider Optima and Better Generalization.. In UAI 2018.