MCBatchNorm#
- class torch_uncertainty.post_processing.MCBatchNorm(model=None, num_estimators=16, convert=True, mc_batch_size=32, device=None)[source]#
Monte Carlo Batch Normalization wrapper for 2d inputs.
- Parameters:
model (nn.Module) – model to be converted.
num_estimators (int) – number of estimators.
convert (bool) – whether to convert the model. Defaults to
True.mc_batch_size (int, optional) – Monte Carlo batch size. The smaller the more variability in the predictions. Defaults to
32.device (Literal["cpu", "cuda"] | torch.device | None, optional) – device. Defaults to
None.
Warning
The update of the batch statistics slightly differs from the method as worded in the original paper but sticks to its implementation. Instead of updating the training-based statistics with 1 new batch of data, we perform a direct replacement. See this issue/discussion.
Note
This wrapper will be stochastic in eval mode only.
Note
Raise an issue if you would like a wrapper for 1d and 3d inputs.
References
- fit(dataloader)[source]#
Fit the model on the dataset.
- Parameters:
dataloader (DataLoader) – DataLoader with the post-processing dataset.
Warning
The
batch_sizeof the DataLoader (i.e.mc_batch_size) should be carefully chosen as it will have an impact on diversity of the statistics of the MC BatchNorm layers and therefore of the predictions.Note
This method is used to populate the MC BatchNorm layers. Use the post-processing dataset.
- Raises:
ValueError – If there are less batches than the number of estimators.