Shortcuts

Source code for torch_uncertainty.metrics.regression.nll

from torch import Tensor, distributions
from torchmetrics.utilities.data import dim_zero_cat

from torch_uncertainty.metrics import CategoricalNLL


[docs]class DistributionNLL(CategoricalNLL):
[docs] def update( self, dist: distributions.Distribution, target: Tensor, padding_mask: Tensor | None = None, ) -> None: """Update state with the predicted distributions and the targets. Args: dist (torch.distributions.Distribution): Predicted distributions. target (Tensor): Ground truth labels. padding_mask (Tensor, optional): The padding mask. Defaults to None. Sets the loss to 0 for padded values. """ nlog_prob = -dist.log_prob(target) if padding_mask is not None: nlog_prob = nlog_prob.masked_fill(padding_mask, float("nan")) if self.reduction is None or self.reduction == "none": self.values.append(nlog_prob) else: self.values += nlog_prob.nansum() self.total += padding_mask.sum() if padding_mask is not None else target.numel()
[docs] def compute(self) -> Tensor: """Computes NLL based on inputs passed in to ``update`` previously.""" values = dim_zero_cat(self.values) if self.reduction == "sum": return values.nansum() if self.reduction == "mean": return values.nansum() / self.total # reduction is None or "none" return values