Source code for torch_uncertainty.methods.stochastic

import torch
from torch import Tensor, nn

from torch_uncertainty.layers.bayesian import bayesian_modules


[docs] class StochasticModel(nn.Module): def __init__( self, core_model: nn.Module, num_samples: int, probabilistic: bool = False, ) -> None: super().__init__() self.core_model = core_model self.num_samples = num_samples self.probabilistic = probabilistic def eval_forward(self, x: Tensor) -> Tensor | dict[str, Tensor]: out = [self.core_model(x) for _ in range(self.num_samples)] if self.probabilistic: key_set = {tuple(o.keys()) for o in out} return {k: torch.cat([o[k] for o in out], dim=0) for k in key_set.pop()} return torch.cat(out, dim=0) def forward(self, x: Tensor) -> Tensor | dict[str, Tensor]: if self.training: return self.core_model.forward(x) return self.eval_forward(x) def _inner_sample(self, module: nn.Module) -> tuple: # coverage: ignore """Check that the module has a sampling method, then sample it. Args: module: The module to sample. Raises: TypeError: Triggered when the module doesn't implement sample. TypeError: Triggered when the module sampling function does not return a tuple. Returns: tuple: The weight and biases. Should be both torch Tensors. """ sample_fn = getattr(module, "sample", None) if not callable(sample_fn): raise TypeError("Bayesian module must implement `sample()`.") weight_bias = sample_fn() if not isinstance(weight_bias, tuple) or len(weight_bias) != 2: raise TypeError("`sample()` must return (weight, bias).") return weight_bias
[docs] def sample(self, num_samples: int = 1) -> list[dict[str, Tensor]]: """Sample the wrapped model multiple times. Args: num_samples: Number of samples to generate. Defaults to ``1``. Returns: list[dict[str, Tensor]]: Sampled model states. """ sampled_models = [{}] * num_samples for module_name in self.core_model._modules: module = self.core_model._modules[module_name] if module is None: # coverage: ignore continue if isinstance(module, bayesian_modules): for model in sampled_models: weight, bias = self._inner_sample(module) model[module_name + ".weight"] = weight if bias is not None: # coverage: ignore model[module_name + ".bias"] = bias else: for model in sampled_models: state = module.state_dict() if not len(state): # no parameter break # TODO: fix this model |= { module_name + "." + key: val for key, val in module.state_dict().items() } return sampled_models
[docs] def freeze(self) -> None: """Freeze all Bayesian submodules in the wrapped model.""" for module in self.core_model.modules(): if isinstance(module, bayesian_modules): freeze_fn = getattr(module, "freeze", None) if callable(freeze_fn): freeze_fn()
[docs] def unfreeze(self) -> None: """Unfreeze all Bayesian submodules in the wrapped model.""" for module in self.core_model.modules(): if isinstance(module, bayesian_modules): unfreeze_fn = getattr(module, "unfreeze", None) if callable(unfreeze_fn): unfreeze_fn()