EMA#

class torch_uncertainty.methods.EMA(core_model, momentum)[source]#

Exponential Moving Average (EMA).

The core model given as argument is used to compute the gradient during training. The EMA model is regularly updated with the inner-model and used at evaluation time.

Parameters:
  • core_model (Module) – The model to train and ensemble.

  • momentum (float) – The momentum of the moving average. The larger the momentum, the more stable the model.

Note

The momentum value is often large, such as 0.9 or 0.95.

eval_forward(x)[source]#

Run the EMA model in evaluation mode.

Return type:

Tensor

property remainder#

Complement of the EMA momentum.

update_wrapper(epoch=None)[source]#

Update the EMA model.

Parameters:

epoch (int | None) – The current epoch. Present for API consistency. Defaults to None.

Return type:

None