Source code for torch_uncertainty.models.wrappers.swa
import copy
import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader
[docs]class SWA(nn.Module):
num_avgd_models: Tensor
def __init__(
self,
model: nn.Module,
cycle_start: int,
cycle_length: int,
) -> None:
"""Stochastic Weight Averaging.
Update the SWA model every :attr:`cycle_length` epochs starting at
:attr:`cycle_start`. Uses the SWA model only at test time. Otherwise,
uses the base model for training.
Args:
model (nn.Module): PyTorch model to be trained.
cycle_start (int): Epoch to start SWA.
cycle_length (int): Number of epochs between SWA updates.
Reference:
Izmailov, P., Podoprikhin, D., Garipov, T., Vetrov, D., & Wilson, A. G.
(2018). Averaging Weights Leads to Wider Optima and Better Generalization.
In UAI 2018.
"""
super().__init__()
_swa_checks(cycle_start, cycle_length)
self.core_model = model
self.cycle_start = cycle_start
self.cycle_length = cycle_length
self.register_buffer("num_avgd_models", torch.tensor(0, device="cpu"))
self.swa_model = None
self.need_bn_update = False
@torch.no_grad()
def update_wrapper(self, epoch: int) -> None:
if (
epoch >= self.cycle_start
and (epoch - self.cycle_start) % self.cycle_length == 0
):
if self.swa_model is None:
self.swa_model = copy.deepcopy(self.core_model)
self.num_avgd_models = torch.tensor(1)
else:
for swa_param, param in zip(
self.swa_model.parameters(),
self.core_model.parameters(),
strict=False,
):
swa_param.data += (param.data - swa_param.data) / (
self.num_avgd_models + 1
)
self.num_avgd_models += 1
self.need_bn_update = True
def eval_forward(self, x: Tensor) -> Tensor:
if self.swa_model is None:
return self.core_model.forward(x)
return self.swa_model.forward(x)
def forward(self, x: Tensor) -> Tensor:
if self.training:
return self.core_model.forward(x)
return self.eval_forward(x)
def bn_update(self, loader: DataLoader, device) -> None:
if self.need_bn_update and self.swa_model is not None:
torch.optim.swa_utils.update_bn(
loader, self.swa_model, device=device
)
self.need_bn_update = False
def _swa_checks(cycle_start: int, cycle_length: int) -> None:
if cycle_start < 0:
raise ValueError(
f"`cycle_start` must be non-negative. Got {cycle_start}."
)
if cycle_length <= 0:
raise ValueError(
f"`cycle_length` must be strictly positive. Got {cycle_length}."
)