EMA#

class torch_uncertainty.models.EMA(model, momentum)[source]#

Exponential Moving Average (EMA).

The model given as argument is used to compute the gradient during the training. The EMA model is regularly updated with the inner-model and used at evaluation time.

The model given as argument is used to compute the gradient during the training. The EMA model is regularly updated with the inner-model and used at evaluation time.

Parameters:
  • model (nn.Module) – The model to train and ensemble.

  • momentum (float) – The momentum of the moving average.

update_wrapper(epoch=None)[source]#

Update the EMA model.

Parameters:

epoch (int) – The current epoch. For API consistency.