Source code for torch_uncertainty.metrics.regression.nll
from torch import Tensor, distributions
from torchmetrics.utilities.data import dim_zero_cat
from torch_uncertainty.metrics import CategoricalNLL
[docs]class DistributionNLL(CategoricalNLL):
[docs] def update(
self,
dist: distributions.Distribution,
target: Tensor,
padding_mask: Tensor | None = None,
) -> None:
"""Update state with the predicted distributions and the targets.
Args:
dist (torch.distributions.Distribution): Predicted distributions.
target (Tensor): Ground truth labels.
padding_mask (Tensor, optional): The padding mask. Defaults to None.
Sets the loss to 0 for padded values.
"""
nlog_prob = -dist.log_prob(target)
if padding_mask is not None:
nlog_prob = nlog_prob.masked_fill(padding_mask, float("nan"))
if self.reduction is None or self.reduction == "none":
self.values.append(nlog_prob)
else:
self.values += nlog_prob.nansum()
self.total += padding_mask.sum() if padding_mask is not None else target.numel()
[docs] def compute(self) -> Tensor:
"""Computes NLL based on inputs passed in to ``update`` previously."""
values = dim_zero_cat(self.values)
if self.reduction == "sum":
return values.nansum()
if self.reduction == "mean":
return values.nansum() / self.total
# reduction is None or "none"
return values