MCBatchNorm#

class torch_uncertainty.post_processing.MCBatchNorm(model=None, num_estimators=16, convert=True, mc_batch_size=32, device=None)[source]#

Monte Carlo Batch Normalization wrapper for 2d inputs.

Parameters:
  • model (Module | None) – model to be converted.

  • num_estimators (int) – number of estimators.

  • convert (bool) – whether to convert the model. Defaults to True.

  • mc_batch_size (int) – Monte-Carlo batch size. The smaller the more variability in the predictions. Defaults to 32.

  • device (Union[Literal['cpu', 'cuda'], device, None]) – device. Defaults to None.

Warning

The update of the batch statistics slightly differs from the method as worded in the original paper but sticks to its implementation. Instead of updating the training-based statistics with 1 new batch of data, we perform a direct replacement. See this issue/discussion.

Note

This wrapper will be stochastic in eval mode only.

Note

Raise an issue if you would like a wrapper for 1d and 3d inputs.

References

[1] Teye M, Azizpour H, Smith K. Bayesian uncertainty estimation for batch normalized deep networks. In ICML 2018.

fit(dataloader)[source]#

Fit the model on the dataset.

Parameters:

dataloader (DataLoader) – DataLoader with the post-processing dataset.

Return type:

None

Warning

The batch_size of the DataLoader (i.e. mc_batch_size) should be carefully chosen as it will have an impact on diversity of the statistics of the MC BatchNorm layers and therefore of the predictions.

Note

This method is used to populate the MC BatchNorm layers. Use the post-processing dataset.

Raises:

ValueError – If there are less batches than the number of estimators.

raise_counters()[source]#

Raise all counters by 1.

Return type:

None

replace_layers(model)[source]#

Replace all BatchNorm2d layers with MCBatchNorm2d layers.

Parameters:

model (Module) – model to be converted.

Return type:

None

reset_counters()[source]#

Reset all counters to 0.

Return type:

None

set_accumulate(accumulate)[source]#

Set the accumulate flag for all MCBatchNorm2d layers.

Parameters:

accumulate (bool) – accumulate flag.

Return type:

None