Source code for torch_uncertainty.losses.bayesian
import torch
from torch import Tensor, nn
from torch.distributions import Independent
from torch_uncertainty.layers.bayesian import bayesian_modules
from torch_uncertainty.utils.distributions import get_dist_class
[docs]
class KLDiv(nn.Module):
def __init__(self, model: nn.Module) -> None:
r"""KL divergence loss for Bayesian Neural Networks.
Aggregates the per-layer Kullback-Leibler divergences between the variational
posterior :math:`q_\phi(\mathbf{w})` and the prior :math:`p(\mathbf{w})`:
.. math::
\mathrm{KL}\!\left[ q_\phi(\mathbf{w}) \;\|\; p(\mathbf{w}) \right]
= \frac{1}{L} \sum_{\ell=1}^{L} \mathrm{KL}\!\left[ q_\phi^{(\ell)} \;\|\; p^{(\ell)} \right].
Each Bayesian layer caches a single-sample Monte Carlo estimate of its KL term
during the forward pass via two scalars:
- ``log_variational_posterior`` — :math:`\log q_\phi(\mathbf{w}^{(s)})`, the **log variational
posterior** evaluated at the sampled weight,
- ``log_prior`` — :math:`\log p(\mathbf{w}^{(s)})`, the **log prior** at the same
sample,
so that ``log_variational_posterior - log_prior`` is a one-sample estimate of the layer's KL.
``KLDiv`` simply averages these contributions over the :math:`L` Bayesian
layers of :attr:`model`. The result is intended to be added to the data-fit
term of an ELBO objective — see :class:`ELBOLoss`.
Args:
model: The Bayesian Neural Network whose layers expose ``log_variational_posterior`` and
``log_prior`` log-probabilities.
"""
super().__init__()
self.model = model
def forward(self) -> Tensor:
return self._kl_div()
def _kl_div(self) -> Tensor:
"""Gathers pre-computed KL-Divergences from :attr:`model`."""
kl_divergence = torch.zeros(1)
count = 0
for module in self.model.modules():
if isinstance(module, bayesian_modules):
kl_divergence = kl_divergence.to(device=module.log_variational_posterior.device)
kl_divergence += module.log_variational_posterior - module.log_prior
count += 1
return kl_divergence / count
[docs]
class ELBOLoss(nn.Module):
model: nn.Module
def __init__(
self,
model: nn.Module | None,
inner_loss: nn.Module,
kl_weight: float,
num_samples: int,
dist_family: str | None = None,
) -> None:
r"""The (negative) Evidence Lower Bound (ELBO) loss for Bayesian Neural Networks.
Combines an inner data-fit loss (e.g., cross-entropy or a distribution NLL)
with the Kullback-Leibler regulariser :math:`\mathrm{KL}[q_\phi(\mathbf{w}) \|
p(\mathbf{w})]`, estimated by Monte Carlo over :attr:`num_samples` weight
samples:
.. math::
\mathcal{L}_{\text{ELBO}} = \frac{1}{S} \sum_{s=1}^{S}
\mathcal{L}_\text{inner}\!\left(f_{\mathbf{w}^{(s)}}(\mathbf{x}), y\right)
+ \beta_\text{KL} \cdot \mathrm{KL}\!\left[ q_\phi(\mathbf{w}) \;\|\; p(\mathbf{w}) \right],
with :math:`\mathbf{w}^{(s)} \sim q_\phi`. The KL weight :math:`\beta_\text{KL}`
is typically set to the inverse of the number of training points (or a
manually annealed schedule).
Args:
model: The Bayesian Neural Network to compute the loss for.
inner_loss: The data-fit loss to use during training.
kl_weight: The weight :math:`\beta_\text{KL}` of the KL-divergence term.
num_samples: The number of weight samples :math:`S` used to estimate the
expectation.
dist_family: The distribution family to use for the output of the model.
``None`` means point-wise prediction. Defaults to ``None``.
Note:
Set :attr:`model` to ``None`` when using ``ELBOLoss`` inside a
``ClassificationRoutine`` — it will be filled in automatically.
"""
super().__init__()
_elbo_loss_checks(inner_loss, kl_weight, num_samples)
if model is not None:
self.set_model(model)
self.inner_loss = inner_loss
self.kl_weight = kl_weight
self.num_samples = num_samples
self.dist_family = dist_family
[docs]
def forward(self, inputs: Tensor, targets: Tensor) -> Tensor:
"""Gather the KL divergence from the Bayesian modules and aggregate
the ELBO loss for a given network.
Args:
inputs: The inputs of the Bayesian Neural Network
targets: The target values
Returns:
Tensor: The aggregated ELBO loss
"""
aggregated_elbo = torch.zeros(1, device=inputs.device)
dist_class = get_dist_class(self.dist_family) if self.dist_family is not None else None
for _ in range(self.num_samples):
out = self.model(inputs)
if dist_class is not None:
# Wrap the distribution in an Independent distribution for log_prob computation.
out = Independent(dist_class(**out), 1)
aggregated_elbo += self.inner_loss(out, targets)
# TODO: This shouldn't be necessary
aggregated_elbo += self.kl_weight * self._kl_div().to(inputs.device)
return aggregated_elbo / self.num_samples
def set_model(self, model: nn.Module) -> None:
self.model = model
if model is not None:
self._kl_div = KLDiv(model)
def _elbo_loss_checks(inner_loss: nn.Module, kl_weight: float, num_samples: int) -> None:
if isinstance(inner_loss, type):
raise TypeError(f"The inner_loss should be an instance of a class.Got {inner_loss}.")
if kl_weight < 0:
raise ValueError(f"The KL weight should be non-negative. Got {kl_weight}.")
if num_samples < 1:
raise ValueError(f"The number of samples should not be lower than 1. Got {num_samples}.")
if not isinstance(num_samples, int):
raise TypeError(f"The number of samples should be an integer. Got {type(num_samples)}.")