EnergyCriterion#

class torch_uncertainty.ood_criteria.EnergyCriterion[source]#

OOD criterion based on the energy function.

This criterion computes the negative log-sum-exp of the logits. Higher energy values indicate greater uncertainty.

\[E(\mathbf{z}) = -\log\left(\sum_{i=1}^{C} \exp(z_i)\right)\]

where \(\mathbf{z} = [z_1, z_2, \dots, z_C]\) is the logit vector.

Variables:

input_type (OODCriterionInputType) – Expected input type is logits.

forward(inputs)[source]#

Compute the negative energy score.

Parameters:

inputs (Tensor) – Tensor of logits with shape (batch_size, num_classes).

Returns:

Negative energy score for each sample.

Return type:

Tensor