MCBatchNorm#
- class torch_uncertainty.post_processing.MCBatchNorm(model=None, num_estimators=16, convert=True, mc_batch_size=32, device=None)[source]#
Monte Carlo Batch Normalization (MCBN) wrapper for 2D inputs (Teye, Azizpour & Smith, ICML 2018).
Replaces the standard
BatchNorm2dlayers with stochasticMCBatchNorm2dlayers whose running statistics are re-sampled at every forward pass from random mini-batches of the training/calibration set. This turns the model into a stochastic predictor at test time and provides a cheap Bayesian-style uncertainty estimate equivalent to MC Dropout but for batch-normalised networks.- Parameters:
model (
Module|None) – Model to be converted.num_estimators (
int) – Number of MC estimators to draw at test time.convert (
bool) – Whether to convert the model’sBatchNorm2dlayers in place. Defaults toTrue.mc_batch_size (
int) – Monte-Carlo batch size. Smaller batches yield more variability in the predictions. Defaults to32.device (
Union[Literal['cpu','cuda'],device,None]) – Device to use. Defaults toNone.
Warning
The update of the batch statistics slightly differs from the description in the original paper but matches its reference implementation: instead of updating the training-based statistics with one new batch, we perform a direct replacement. See this issue/discussion.
Note
This wrapper is stochastic in eval mode only.
Note
Raise an issue if you need a wrapper for 1D or 3D inputs.
References
- fit(dataloader)[source]#
Fit the model on the dataset.
- Parameters:
dataloader (
DataLoader) – DataLoader with the post-processing dataset.- Return type:
None
Warning
The
batch_sizeof the DataLoader (i.e.mc_batch_size) should be carefully chosen as it will have an impact on diversity of the statistics of the MC BatchNorm layers and therefore of the predictions.Note
This method is used to populate the MC BatchNorm layers. Use the post-processing dataset.
- Raises:
ValueError – If there are less batches than the number of estimators.