MutualInformationCriterion#

class torch_uncertainty.ood_criteria.MutualInformationCriterion[source]#

OOD criterion based on mutual information (BALD).

This criterion computes the mutual information between the prediction and the model parameters across the ensemble’s predictions — a classical estimator of epistemic uncertainty. Higher mutual information values indicate greater epistemic uncertainty and thus a higher likelihood of being out-of-distribution.

Given ensemble predictions \(\{\mathbf{p}^{(k)}\}_{k=1}^{K}\), the mutual information is

\[I(y, \theta) = H\!\left(\frac{1}{K}\sum_{k=1}^{K} \mathbf{p}^{(k)}\right) - \frac{1}{K}\sum_{k=1}^{K} H(\mathbf{p}^{(k)}),\]

i.e. the total predictive entropy minus the average per-estimator entropy.

Variables:
  • ensemble_only – Requires ensemble predictions.

  • input_type – Expected input type is estimated probabilities.

forward(inputs)[source]#

Compute mutual information from ensemble predictions.

Parameters:

inputs (Tensor) – Tensor of ensemble probabilities with shape (ensemble_size, batch_size, num_classes).

Returns:

Mutual information for each sample.

Return type:

Tensor