Source code for torch_uncertainty.metrics.regression.nll
from torch import Tensor, distributions
from torchmetrics.utilities.data import dim_zero_cat
from torch_uncertainty.metrics import CategoricalNLL
[docs]
class DistributionNLL(CategoricalNLL):
r"""Negative Log-Likelihood under a predictive distribution.
Evaluates a probabilistic regression model by computing the negative log-likelihood
of the targets under the model's predictive distribution
:math:`p_\theta(y \mid x)`:
.. math::
\text{NLL} = -\frac{1}{N} \sum_{i=1}^{N} \log p_\theta(y_i \mid x_i).
For multi-variate targets, the underlying distribution is typically wrapped in a
:class:`torch.distributions.Independent` so that ``log_prob`` correctly sums the
log-density over the event dimensions.
Args:
reduction: How to reduce the per-sample losses (``"mean"``, ``"sum"``,
``"none"`` or ``None``). Defaults to ``"mean"``.
kwargs: Additional keyword arguments, see `Advanced metric settings
<https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metric-kwargs>`_.
Inputs:
- :attr:`dist`: a :class:`torch.distributions.Distribution` over the targets.
- :attr:`target`: ground-truth targets of compatible shape.
- :attr:`padding_mask`: optional boolean mask of positions to ignore (``True`` for padding).
"""
[docs]
def update( # pyrefly: ignore[bad-override]
self,
dist: distributions.Distribution,
target: Tensor,
padding_mask: Tensor | None = None,
) -> None:
"""Update state with the predicted distributions and the targets.
Args:
dist: Predicted distributions.
target: Ground truth labels.
padding_mask: Optional padding mask. Sets the loss to 0 for padded values. Defaults to
``None``.
"""
nlog_prob = -dist.log_prob(target)
if padding_mask is not None:
nlog_prob = nlog_prob.masked_fill(padding_mask, float("nan"))
if self.reduction is None or self.reduction == "none":
self.values.append(nlog_prob)
else:
self.values += nlog_prob.nansum()
self.total += padding_mask.sum() if padding_mask is not None else target.numel()
[docs]
def compute(self) -> Tensor:
"""Compute NLL based on inputs passed to ``update``."""
values = dim_zero_cat(self.values)
if self.reduction == "sum":
return values.nansum()
if self.reduction == "mean":
return values.nansum() / self.total
# reduction is None or "none"
return values