Source code for torch_uncertainty.metrics.classification.mutual_information
from typing import Any, Literal
import torch
from torchmetrics import Metric
from torchmetrics.utilities.data import dim_zero_cat
[docs]class MutualInformation(Metric):
is_differentiable: bool = False
higher_is_better: bool | None = None
full_state_update: bool = False
def __init__(
self,
reduction: Literal["mean", "sum", "none", None] = "mean",
**kwargs: Any,
) -> None:
"""The Mutual Information Metric to estimate the epistemic uncertainty of
an ensemble of estimators.
Args:
reduction (str, optional): Determines how to reduce over the
:math:`B`/batch dimension:
- ``'mean'`` [default]: Averages score across samples
- ``'sum'``: Sum score across samples
- ``'none'`` or ``None``: Returns score per sample
kwargs: Additional keyword arguments, see `Advanced metric settings
<https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metric-kwargs>`_.
Inputs:
- :attr:`probs`: the likelihoods provided by the ensemble as a Tensor
of shape :math:`(B, N, C)`,
where :math:`B` is the batch size, :math:`N` is the number of
estimators, and :math:`C` is the number of classes.
Raises:
ValueError:
If :attr:`reduction` is not one of ``'mean'``, ``'sum'``,
``'none'`` or ``None``.
Note:
A higher mutual information can be interpreted as a higher epistemic
uncertainty. The Mutual Information is also computationally equivalent
to the Generalized Jensen-Shannon Divergence (GJSD).
The implementation of the mutual information clamps results to zero to
avoid negative values that could appear due to numerical instabilities
Warning:
Make sure that the probabilities in :attr:`probs` are normalized to sum
to one.
"""
super().__init__(**kwargs)
allowed_reduction = ("sum", "mean", "none", None)
if reduction not in allowed_reduction:
raise ValueError(
"Expected argument `reduction` to be one of ",
f"{allowed_reduction} but got {reduction}",
)
self.reduction = reduction
if self.reduction in ["mean", "sum"]:
self.add_state("values", default=torch.tensor(0.0), dist_reduce_fx="sum")
else:
self.add_state("values", default=[], dist_reduce_fx="cat")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
[docs] def update(self, probs: torch.Tensor) -> None:
"""Update the current mutual information with a new tensor of
probabilities.
Args:
probs (torch.Tensor): Likelihoods from the ensemble of shape
:math:`(B, N, C)`, where :math:`B` is the batch size,
:math:`N` is the number of estimators and :math:`C` is the
number of classes.
"""
batch_size = probs.size(0)
ens_probs = probs.mean(dim=1)
entropy_mean = torch.special.entr(ens_probs).sum(dim=-1)
mean_entropy = torch.special.entr(probs).sum(dim=-1).mean(dim=1)
mutual_information = entropy_mean - mean_entropy
if self.reduction is None or self.reduction == "none":
self.values.append(mutual_information)
else:
self.values += mutual_information.sum()
self.total += batch_size
[docs] def compute(self) -> torch.Tensor:
"""Computes Mutual Information based on inputs passed in to ``update``
previously.
"""
values = torch.clamp(dim_zero_cat(self.values), min=0)
if self.reduction == "sum":
return values.sum(dim=-1)
if self.reduction == "mean":
return values.sum(dim=-1) / self.total
# reduction is None or "none"
return values