MutualInformation#

class torch_uncertainty.metrics.classification.MutualInformation(reduction='mean', **kwargs)[source]#

Compute the Mutual Information Metric.

The Mutual Information Metric estimates the epistemic uncertainty of an ensemble of estimators.

Parameters:
  • reduction (str, optional) –

    Determines how to reduce over the \(B\)/batch dimension:

    • 'mean' [default]: Averages score across samples

    • 'sum': Sum score across samples

    • 'none' or None: Returns score per sample

  • kwargs – Additional keyword arguments, see Advanced metric settings.

Inputs:
  • probs: \((B, N, C)\)

    where \(B\) is the batch size, \(N\) is the number of estimators, and \(C\) is the number of classes.

Note

A higher mutual information can be interpreted as a higher epistemic uncertainty. The Mutual Information is also computationally equivalent to the Generalized Jensen-Shannon Divergence (GJSD).

The implementation of the mutual information clamps results to zero to avoid negative values that could appear due to numerical instabilities.

Warning

Make sure that the probabilities in probs are normalized to sum to one.

Raises:

ValueError – If reduction is not one of 'mean', 'sum', 'none' or None.

compute()[source]#

Computes Mutual Information based on inputs passed in to update previously.

update(probs)[source]#

Update the current mutual information with a new tensor of probabilities.

Parameters:

probs (torch.Tensor) – Likelihoods from the ensemble of shape \((B, N, C)\), where \(B\) is the batch size, \(N\) is the number of estimators and \(C\) is the number of classes.