Shortcuts

SWA

class torch_uncertainty.models.SWA(model, cycle_start, cycle_length)[source]

Stochastic Weight Averaging.

Update the SWA model every cycle_length epochs starting at cycle_start. Uses the SWA model only at test time. Otherwise, uses the base model for training.

Parameters:
  • model (nn.Module) – PyTorch model to be trained.

  • cycle_start (int) – Epoch to start SWA.

  • cycle_length (int) – Number of epochs between SWA updates.

Reference:

Izmailov, P., Podoprikhin, D., Garipov, T., Vetrov, D., & Wilson, A. G. (2018). Averaging Weights Leads to Wider Optima and Better Generalization. In UAI 2018.