MCBatchNorm¶
- class torch_uncertainty.post_processing.MCBatchNorm(model=None, num_estimators=16, convert=True, mc_batch_size=32, device=None)[source]¶
Monte Carlo Batch Normalization wrapper.
- Parameters:
model (nn.Module) – model to be converted.
num_estimators (int) – number of estimators.
convert (bool) – whether to convert the model. Defaults to
True
.mc_batch_size (int, optional) – Monte Carlo batch size. The smaller the more variability
32. (in the predictions. Defaults to) –
device (Literal["cpu", "cuda"] | torch.device | None, optional) – device. Defaults to
None
.
Note
This wrapper will be stochastic in eval mode only.
- Reference:
Teye M, Azizpour H, Smith K. Bayesian uncertainty estimation for batch normalized deep networks. In ICML 2018.
- fit(dataloader)[source]¶
Fit the model on the dataset.
- Parameters:
dataloader (DataLoader) – DataLoader with the post-processing dataset.
Note
This method is used to populate the MC BatchNorm layers. Use the post-processing dataset.
Warning
The
batch_size
of the DataLoader should be carefully chosen as it will have an impact on the statistics of the MC BatchNorm layers.- Raises:
ValueError – If there are less batches than the number of estimators.