Source code for torch_uncertainty.metrics.classification.mutual_information

from typing import Any, Literal

import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.utilities.data import dim_zero_cat


[docs] class MutualInformation(Metric): is_differentiable = False higher_is_better = None full_state_update = False values: list[Tensor] total: Tensor def __init__( self, reduction: Literal["mean", "sum", "none"] | None = "mean", **kwargs: Any, ) -> None: r"""Compute the Mutual Information Metric. The Mutual Information Metric estimates the epistemic uncertainty of an ensemble of estimators. Given per-estimator predicted probabilities :math:`\hat{\mathbf{p}}_n \in \Delta^{C-1}` for :math:`n = 1, \dots, N`, it decomposes the total predictive uncertainty into an epistemic term: .. math:: \text{MI} = H\!\left(\frac{1}{N}\sum_{n=1}^{N} \hat{\mathbf{p}}_n\right) - \frac{1}{N}\sum_{n=1}^{N} H(\hat{\mathbf{p}}_n) where :math:`H(\mathbf{p}) = -\sum_{c=1}^{C} p_c \log p_c` is the Shannon entropy. The first term is the entropy of the ensemble mean (total uncertainty) and the second is the mean entropy of individual estimators (aleatoric uncertainty), so their difference captures the epistemic uncertainty due to model disagreement. Args: reduction: Determines how to reduce over the :math:`B`/batch dimension: - ``'mean'`` [default]: Averages score across samples - ``'sum'``: Sum score across samples - ``'none'`` or ``None``: Returns score per sample kwargs: Additional keyword arguments, see `Advanced metric settings <https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metric-kwargs>`_. Inputs: - :attr:`probs`: :math:`(B, N, C)` where :math:`B` is the batch size, :math:`N` is the number of estimators, and :math:`C` is the number of classes. Note: A higher mutual information can be interpreted as a higher epistemic uncertainty. The Mutual Information is also computationally equivalent to the Generalized Jensen-Shannon Divergence (GJSD). The implementation of the mutual information clamps results to zero to avoid negative values that could appear due to numerical instabilities. Warning: Make sure that the probabilities in :attr:`probs` are normalized to sum to one. Raises: ValueError: If :attr:`reduction` is not one of ``'mean'``, ``'sum'``, ``'none'`` or ``None``. """ super().__init__(**kwargs) allowed_reduction = ("sum", "mean", "none", None) if reduction not in allowed_reduction: raise ValueError( "Expected argument `reduction` to be one of ", f"{allowed_reduction} but got {reduction}", ) self.reduction = reduction if self.reduction in ["mean", "sum"]: self.add_state("values", default=torch.tensor(0.0), dist_reduce_fx="sum") else: self.add_state("values", default=[], dist_reduce_fx="cat") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
[docs] def update(self, probs: torch.Tensor) -> None: # pyrefly: ignore[bad-override] r"""Update the current mutual information with a new tensor of probabilities. Args: probs: Likelihoods from the ensemble of shape :math:`(B, N, C)`, where :math:`B` is the batch size, :math:`N` is the number of estimators and :math:`C` is the number of classes. """ batch_size = probs.size(0) ens_probs = probs.mean(dim=1) entropy_mean = torch.special.entr(ens_probs).sum(dim=-1) mean_entropy = torch.special.entr(probs).sum(dim=-1).mean(dim=1) mutual_information = entropy_mean - mean_entropy if self.reduction is None or self.reduction == "none": self.values.append(mutual_information) else: self.values += mutual_information.sum() self.total += batch_size
[docs] def compute(self) -> torch.Tensor: r"""Compute mutual information based on inputs passed to ``update``.""" values = torch.clamp(dim_zero_cat(self.values), min=0) if self.reduction == "sum": return values.sum(dim=-1) if self.reduction == "mean": return values.sum(dim=-1) / self.total # reduction is None or "none" return values