MCBatchNorm#

class torch_uncertainty.post_processing.MCBatchNorm(model=None, num_estimators=16, convert=True, mc_batch_size=32, device=None)[source]#

Monte Carlo Batch Normalization wrapper for 2d inputs.

Parameters:
  • model (nn.Module) – model to be converted.

  • num_estimators (int) – number of estimators.

  • convert (bool) – whether to convert the model. Defaults to True.

  • mc_batch_size (int, optional) – Monte Carlo batch size. The smaller the more variability in the predictions. Defaults to 32.

  • device (Literal["cpu", "cuda"] | torch.device | None, optional) – device. Defaults to None.

Warning

The update of the batch statistics slightly differs from the method as worded in the original paper but sticks to its implementation. Instead of updating the training-based statistics with 1 new batch of data, we perform a direct replacement. See this issue/discussion.

Note

This wrapper will be stochastic in eval mode only.

Note

Raise an issue if you would like a wrapper for 1d and 3d inputs.

References

[1] Teye M, Azizpour H, Smith K. Bayesian uncertainty estimation for batch normalized deep networks. In ICML 2018.

fit(dataloader)[source]#

Fit the model on the dataset.

Parameters:

dataloader (DataLoader) – DataLoader with the post-processing dataset.

Warning

The batch_size of the DataLoader (i.e. mc_batch_size) should be carefully chosen as it will have an impact on diversity of the statistics of the MC BatchNorm layers and therefore of the predictions.

Note

This method is used to populate the MC BatchNorm layers. Use the post-processing dataset.

Raises:

ValueError – If there are less batches than the number of estimators.

raise_counters()[source]#

Raise all counters by 1.

replace_layers(model)[source]#

Replace all BatchNorm2d layers with MCBatchNorm2d layers.

Parameters:

model (nn.Module) – model to be converted.

reset_counters()[source]#

Reset all counters to 0.

set_accumulate(accumulate)[source]#

Set the accumulate flag for all MCBatchNorm2d layers.

Parameters:

accumulate (bool) – accumulate flag.