SWA#

class torch_uncertainty.models.SWA(model, cycle_start, cycle_length)[source]#

Stochastic Weight Averaging.

Update the SWA model every cycle_length epochs starting at cycle_start. Uses the SWA model only at test time. Otherwise, uses the base model for training.

Parameters:
  • model (nn.Module) – PyTorch model to be trained.

  • cycle_start (int) – Epoch to start SWA.

  • cycle_length (int) – Number of epochs between SWA updates.

References

[1] Averaging Weights Leads to Wider Optima and Better Generalization.. In UAI 2018.