MCBatchNorm#

class torch_uncertainty.post_processing.MCBatchNorm(model=None, num_estimators=16, convert=True, mc_batch_size=32, device=None)[source]#

Monte Carlo Batch Normalization (MCBN) wrapper for 2D inputs (Teye, Azizpour & Smith, ICML 2018).

Replaces the standard BatchNorm2d layers with stochastic MCBatchNorm2d layers whose running statistics are re-sampled at every forward pass from random mini-batches of the training/calibration set. This turns the model into a stochastic predictor at test time and provides a cheap Bayesian-style uncertainty estimate equivalent to MC Dropout but for batch-normalised networks.

Parameters:
  • model (Module | None) – Model to be converted.

  • num_estimators (int) – Number of MC estimators to draw at test time.

  • convert (bool) – Whether to convert the model’s BatchNorm2d layers in place. Defaults to True.

  • mc_batch_size (int) – Monte-Carlo batch size. Smaller batches yield more variability in the predictions. Defaults to 32.

  • device (Union[Literal['cpu', 'cuda'], device, None]) – Device to use. Defaults to None.

Warning

The update of the batch statistics slightly differs from the description in the original paper but matches its reference implementation: instead of updating the training-based statistics with one new batch, we perform a direct replacement. See this issue/discussion.

Note

This wrapper is stochastic in eval mode only.

Note

Raise an issue if you need a wrapper for 1D or 3D inputs.

References

[1] Teye, M., Azizpour, H., & Smith, K. (2018). Bayesian uncertainty estimation for batch normalized deep networks. ICML 2018.

fit(dataloader)[source]#

Fit the model on the dataset.

Parameters:

dataloader (DataLoader) – DataLoader with the post-processing dataset.

Return type:

None

Warning

The batch_size of the DataLoader (i.e. mc_batch_size) should be carefully chosen as it will have an impact on diversity of the statistics of the MC BatchNorm layers and therefore of the predictions.

Note

This method is used to populate the MC BatchNorm layers. Use the post-processing dataset.

Raises:

ValueError – If there are less batches than the number of estimators.

raise_counters()[source]#

Raise all counters by 1.

Return type:

None

replace_layers(model)[source]#

Replace all BatchNorm2d layers with MCBatchNorm2d layers.

Parameters:

model (Module) – model to be converted.

Return type:

None

reset_counters()[source]#

Reset all counters to 0.

Return type:

None

set_accumulate(accumulate)[source]#

Set the accumulate flag for all MCBatchNorm2d layers.

Parameters:

accumulate (bool) – accumulate flag.

Return type:

None