MutualInformationCriterion#

class torch_uncertainty.ood_criteria.MutualInformationCriterion[source]#

OOD criterion based on mutual information.

This criterion computes the mutual information between ensemble predictions. Higher mutual information values indicate lower uncertainty.

Given ensemble predictions \(\{\mathbf{p}^{(k)}\}_{k=1}^{K}\), the mutual information is computed as:

\[I(y, \theta) = H\Big(\frac{1}{K}\sum_{k=1}^{K} \mathbf{p}^{(k)}\Big) - \frac{1}{K}\sum_{k=1}^{K} H(\mathbf{p}^{(k)})\]
Variables:
  • ensemble_only (bool) – Requires ensemble predictions.

  • input_type (OODCriterionInputType) – Expected input type is estimated probabilities.

forward(inputs)[source]#

Compute mutual information from ensemble predictions.

Parameters:

inputs (Tensor) – Tensor of ensemble probabilities with shape (ensemble_size, batch_size, num_classes).

Returns:

Mutual information for each sample.

Return type:

Tensor