Source code for torch_uncertainty.ood_criteria
from abc import ABC, abstractmethod
from enum import Enum
import torch
from torch import Tensor, nn
from torch_uncertainty.metrics import MutualInformation, VariationRatio
class OODCriterionInputType(Enum):
"""Enum representing the type of input expected by the OOD (Out-of-Distribution) criteria.
Attributes:
LOGIT (int): The input of the OOD Criterion is in the form of logits (pre-softmax values).
PROB (int): The input is in the form of probabilities (post-softmax values), also called
likelihoods.
ESTIMATOR_PROB (int): The input is in the form of estimated probabilities from an ensemble
or another probabilistic model.
POST_PROCESSING (int): The input is the prediction score given by the post-processing
method.
"""
LOGIT = 1
PROB = 2
ESTIMATOR_PROB = 3
POST_PROCESSING = 4
[docs]
class TUOODCriterion(ABC, nn.Module):
input_type: OODCriterionInputType
single_only = False
ensemble_only = False
def __init__(self) -> None:
"""Abstract base class for Out-of-Distribution (OOD) criteria.
This class defines a common interface for implementing various OOD detection
criteria. Subclasses must implement the `forward` method.
Attributes:
input_type (OODCriterionInputType): Type of input expected by the criterion.
ensemble_only (bool): Whether the criterion requires ensemble outputs.
"""
super().__init__()
[docs]
@abstractmethod
def forward(self, inputs: Tensor) -> Tensor:
"""Forward pass for the OOD criterion.
Args:
inputs (Tensor): The input tensor representing model outputs.
Returns:
Tensor: OOD score computed according to the criterion.
"""
[docs]
class MaxLogitCriterion(TUOODCriterion):
single_only = True
input_type = OODCriterionInputType.LOGIT
def __init__(self) -> None:
"""OOD criterion based on the maximum logit value.
This criterion computes the negative of the highest logit value across
the output dimensions. Lower maximum logits indicate greater uncertainty.
Attributes:
input_type (OODCriterionInputType): Expected input type is logits.
"""
super().__init__()
[docs]
def forward(self, inputs: Tensor) -> Tensor:
"""Compute the negative of the maximum logit value.
Args:
inputs (Tensor): Tensor of logits with shape (batch_size, num_classes).
Returns:
Tensor: Negative of the maximum logit value for each sample.
"""
return -inputs.mean(dim=1).max(dim=-1).values
[docs]
class EnergyCriterion(TUOODCriterion):
single_only = True
input_type = OODCriterionInputType.LOGIT
def __init__(self) -> None:
r"""OOD criterion based on the energy function.
This criterion computes the negative log-sum-exp of the logits.
Higher energy values indicate greater uncertainty.
.. math::
E(\mathbf{z}) = -\log\left(\sum_{i=1}^{C} \exp(z_i)\right)
where :math:`\mathbf{z} = [z_1, z_2, \dots, z_C]` is the logit vector.
Attributes:
input_type (OODCriterionInputType): Expected input type is logits.
"""
super().__init__()
[docs]
def forward(self, inputs: Tensor) -> Tensor:
"""Compute the negative energy score.
Args:
inputs (Tensor): Tensor of logits with shape (batch_size, num_classes).
Returns:
Tensor: Negative energy score for each sample.
"""
return -inputs.mean(dim=1).logsumexp(dim=-1)
[docs]
class MaxSoftmaxCriterion(TUOODCriterion):
input_type = OODCriterionInputType.PROB
def __init__(self) -> None:
r"""OOD criterion based on maximum softmax probability.
This criterion computes the negative of the highest softmax probability.
Lower maximum probabilities indicate greater uncertainty. Probabilities are also called*
likelihoods in a more formal context.
.. math::
\text{score} = -\max_{i}(p_i)
where :math:`\mathbf{p} = [p_1, p_2, \dots, p_C]` is the probability vector.
Attributes:
input_type (OODCriterionInputType): Expected input type is probabilities.
"""
super().__init__()
[docs]
def forward(self, inputs: Tensor) -> Tensor:
"""Compute the negative of the maximum softmax probability.
Args:
inputs (Tensor): Tensor of probabilities with shape (batch_size, num_classes).
Returns:
Tensor: Negative of the highest softmax probability for each sample.
"""
return -inputs.max(-1)[0]
[docs]
class PostProcessingCriterion(MaxSoftmaxCriterion):
input_type = OODCriterionInputType.POST_PROCESSING
[docs]
class EntropyCriterion(TUOODCriterion):
input_type = OODCriterionInputType.ESTIMATOR_PROB
def __init__(self) -> None:
r"""OOD criterion based on entropy.
This criterion computes the mean entropy of the predicted probability distribution.
Higher entropy values indicate greater uncertainty.
.. math::
H(\mathbf{p}) = -\sum_{i=1}^{C} p_i \log(p_i)
where :math:`\mathbf{p} = [p_1, p_2, \dots, p_C]` is the probability vector.
Attributes:
input_type (OODCriterionInputType): Expected input type is estimated probabilities.
"""
super().__init__()
[docs]
def forward(self, inputs: Tensor) -> Tensor:
"""Compute the entropy of the predicted probability distribution.
Args:
inputs (Tensor): Tensor of estimated probabilities with shape (batch_size, num_classes).
Returns:
Tensor: Mean entropy value for each sample.
"""
return torch.special.entr(inputs).sum(dim=-1).mean(dim=1)
[docs]
class VariationRatioCriterion(TUOODCriterion):
ensemble_only = True
input_type = OODCriterionInputType.ESTIMATOR_PROB
def __init__(self) -> None:
r"""OOD criterion based on variation ratio.
This criterion computes the variation ratio from ensemble predictions.
Higher variation ratio values indicate greater uncertainty.
Given ensemble predictions where :math:`n_{\text{mode}}` is the count of the most frequently
predicted class among :math:`K` predictions, the variation ratio is computed as:
.. math::
\text{VR} = 1 - \frac{n_{\text{mode}}}{K}
Attributes:
ensemble_only (bool): Requires ensemble predictions.
input_type (OODCriterionInputType): Expected input type is estimated probabilities.
"""
super().__init__()
self.vr_metric = VariationRatio(reduction="none", probabilistic=False)
[docs]
def forward(self, inputs: Tensor) -> Tensor:
"""Compute variation ratio from ensemble predictions.
Args:
inputs (Tensor): Tensor of ensemble probabilities with shape
(ensemble_size, batch_size, num_classes).
Returns:
Tensor: Variation ratio for each sample.
"""
return self.vr_metric(inputs.transpose(0, 1))
def get_ood_criterion(ood_criterion: type[TUOODCriterion] | str) -> TUOODCriterion:
"""Get an OOD criterion instance based on a string identifier or class type.
Args:
ood_criterion (str or type): A string identifier for a predefined OOD criterion
or a subclass of `TUOODCriterion`.
Returns:
TUOODCriterion: An instance of the requested OOD criterion.
Raises:
ValueError: If the input string or class type is invalid.
"""
if isinstance(ood_criterion, str):
if ood_criterion == "logit":
return MaxLogitCriterion()
if ood_criterion == "energy":
return EnergyCriterion()
if ood_criterion == "msp":
return MaxSoftmaxCriterion()
if ood_criterion == "post_processing":
return PostProcessingCriterion()
if ood_criterion == "entropy":
return EntropyCriterion()
if ood_criterion == "mutual_information":
return MutualInformationCriterion()
if ood_criterion == "variation_ratio":
return VariationRatioCriterion()
raise ValueError(
"The OOD criterion must be one of 'msp', 'logit', 'energy', 'entropy',"
f" 'mutual_information' or 'variation_ratio'. Got {ood_criterion}."
)
if isinstance(ood_criterion, type):
return ood_criterion()
if isinstance(ood_criterion, TUOODCriterion):
return ood_criterion
raise ValueError(
f"The OOD criterion should be a string or a subclass of TUOODCriterion. Got {ood_criterion}."
)