Shortcuts

Source code for torch_uncertainty.losses.bayesian

import torch
from torch import Tensor, nn

from torch_uncertainty.layers.bayesian import bayesian_modules


[docs]class KLDiv(nn.Module): def __init__(self, model: nn.Module) -> None: """KL divergence loss for Bayesian Neural Networks. Gathers the KL from the modules computed in the forward passes. Args: model (nn.Module): Bayesian Neural Network """ super().__init__() self.model = model def forward(self) -> Tensor: return self._kl_div() def _kl_div(self) -> Tensor: """Gathers pre-computed KL-Divergences from :attr:`model`.""" kl_divergence = torch.zeros(1) count = 0 for module in self.model.modules(): if isinstance(module, bayesian_modules): kl_divergence = kl_divergence.to(device=module.lvposterior.device) kl_divergence += module.lvposterior - module.lprior count += 1 return kl_divergence / count
[docs]class ELBOLoss(nn.Module): def __init__( self, model: nn.Module | None, inner_loss: nn.Module, kl_weight: float, num_samples: int, ) -> None: """The Evidence Lower Bound (ELBO) loss for Bayesian Neural Networks. ELBO loss for Bayesian Neural Networks. Use this loss function with the objective that you seek to minimize as :attr:`inner_loss`. Args: model (nn.Module): The Bayesian Neural Network to compute the loss for inner_loss (nn.Module): The loss function to use during training kl_weight (float): The weight of the KL divergence term num_samples (int): The number of samples to use for the ELBO loss Note: Set the model to None if you use the ELBOLoss within the ClassificationRoutine. It will get filled automatically. """ super().__init__() _elbo_loss_checks(inner_loss, kl_weight, num_samples) self.set_model(model) self.inner_loss = inner_loss self.kl_weight = kl_weight self.num_samples = num_samples
[docs] def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: """Gather the KL divergence from the Bayesian modules and aggregate the ELBO loss for a given network. Args: inputs (Tensor): The inputs of the Bayesian Neural Network targets (Tensor): The target values Returns: Tensor: The aggregated ELBO loss """ aggregated_elbo = torch.zeros(1, device=inputs.device) for _ in range(self.num_samples): logits = self.model(inputs) aggregated_elbo += self.inner_loss(logits, targets) # TODO: This shouldn't be necessary aggregated_elbo += self.kl_weight * self._kl_div().to(inputs.device) return aggregated_elbo / self.num_samples
def set_model(self, model: nn.Module | None) -> None: self.model = model if model is not None: self._kl_div = KLDiv(model)
def _elbo_loss_checks(inner_loss: nn.Module, kl_weight: float, num_samples: int) -> None: if isinstance(inner_loss, type): raise TypeError("The inner_loss should be an instance of a class." f"Got {inner_loss}.") if kl_weight < 0: raise ValueError(f"The KL weight should be non-negative. Got {kl_weight}.") if num_samples < 1: raise ValueError("The number of samples should not be lower than 1." f"Got {num_samples}.") if not isinstance(num_samples, int): raise TypeError("The number of samples should be an integer. " f"Got {type(num_samples)}.")