Shortcuts

Source code for torch_uncertainty.models.wrappers.mc_dropout

import torch
from torch import Tensor, nn
from torch.nn.modules.dropout import _DropoutNd


[docs]class MCDropout(nn.Module): def __init__( self, model: nn.Module, num_estimators: int, last_layer: bool, on_batch: bool, ) -> None: """MC Dropout wrapper for a model containing nn.Dropout modules. Args: model (nn.Module): model to wrap num_estimators (int): number of estimators to use during the evaluation last_layer (bool): whether to apply dropout to the last layer only. on_batch (bool): Perform the MC-Dropout on the batch-size. Otherwise in a for loop. Useful when constrained in memory. Warning: This module will work only if you apply dropout through modules declared in the constructor (__init__). Warning: The `last-layer` option disables the lastly initialized dropout during evaluation: make sure that the last dropout is either functional or a module of its own. """ super().__init__() filtered_modules = list( filter( lambda m: isinstance(m, _DropoutNd), model.modules(), ) ) if last_layer: filtered_modules = filtered_modules[-1:] _dropout_checks(filtered_modules, num_estimators) self.last_layer = last_layer self.on_batch = on_batch self.core_model = model self.num_estimators = num_estimators self.filtered_modules = filtered_modules
[docs] def train(self, mode: bool = True) -> nn.Module: """Override the default train method to set the training mode of each submodule to be the same as the module itself except for the selected dropout modules. Args: mode (bool, optional): whether to set the module to training mode. Defaults to True. """ if not isinstance(mode, bool): raise TypeError("Training mode is expected to be boolean") self.training = mode for module in self.children(): module.train(mode) for module in self.filtered_modules: module.train() return self
[docs] def forward( self, x: Tensor, ) -> Tensor: """Forward pass of the model. During training, the forward pass is the same as of the core model. During evaluation, the forward pass is repeated `num_estimators` times either on the batch size or in a for loop depending on :attr:`last_layer`. Args: x (Tensor): input tensor of shape (B, ...) Returns: Tensor: output tensor of shape (:attr:`num_estimators` * B, ...) """ if self.training: return self.core_model(x) if self.on_batch: x = x.repeat(self.num_estimators, 1, 1, 1) return self.core_model(x) # Else, for loop return torch.cat([self.core_model(x) for _ in range(self.num_estimators)], dim=0)
[docs]def mc_dropout( model: nn.Module, num_estimators: int, last_layer: bool = False, on_batch: bool = True, ) -> MCDropout: """MC Dropout wrapper for a model. Args: model (nn.Module): model to wrap num_estimators (int): number of estimators to use last_layer (bool, optional): whether to apply dropout to the last layer only. Defaults to False. on_batch (bool): Increase the batch_size to perform MC-Dropout. Otherwise in a for loop to reduce memory footprint. Defaults to true. """ return MCDropout( model=model, num_estimators=num_estimators, last_layer=last_layer, on_batch=on_batch, )
def _dropout_checks(filtered_modules: list[nn.Module], num_estimators: int) -> None: if not filtered_modules: raise ValueError( "No dropout module found in the model. " "Please use `nn.Dropout`-like modules to apply dropout." ) # Check that at least one module has > 0.0 dropout rate if not any(mod.p > 0.0 for mod in filtered_modules): raise ValueError("At least one dropout module must have a dropout rate > 0.0.") if num_estimators <= 0: raise ValueError("`num_estimators` must be strictly positive to use MC Dropout.")