SWA¶
- class torch_uncertainty.models.SWA(model, cycle_start, cycle_length)[source]¶
Stochastic Weight Averaging.
Update the SWA model every
cycle_length
epochs starting atcycle_start
. Uses the SWA model only at test time. Otherwise, uses the base model for training.- Parameters:
model (nn.Module) – PyTorch model to be trained.
cycle_start (int) – Epoch to start SWA.
cycle_length (int) – Number of epochs between SWA updates.
- Reference:
Izmailov, P., Podoprikhin, D., Garipov, T., Vetrov, D., & Wilson, A. G. (2018). Averaging Weights Leads to Wider Optima and Better Generalization. In UAI 2018.