ClasswiseCalibrationError#

class torch_uncertainty.metrics.classification.ClasswiseCalibrationError(num_classes, num_bins=15, norm='l1', reduction='mean', **kwargs)[source]#

Compute the Classwise Expected Calibration Error (ECE).

The Classwise ECE measures the expected calibration error for each class independently in a one-vs-all manner, and then reduces the scores. It is used to evaluate the calibration of individual classes, where a lower score indicates better calibration quality.

Parameters:
  • num_classes (int) – Number of classes.

  • num_bins (int, optional) – Number of calibration bins. Defaults to 15.

  • norm (Literal["l1", "l2", "max"]) – Norm used to compute the ECE (e.g., 'l1', 'l2', 'max'). Defaults to 'l1'.

  • reduction (Literal["mean", "sum", "none"] | None) –

    Determines how to reduce the score across the classes:

    • 'mean' [default]: Averages the ECE across classes.

    • 'sum': Sums the ECE across classes.

    • 'none' or None: Returns the ECE for each class.

  • kwargs – Additional keyword arguments, see Advanced metric settings.

Inputs:
  • probs: \((B, C)\)

    Predicted probabilities for each class.

  • target: \((B)\) or \((B, C)\)

    Ground truth class labels or one-hot encoded targets.

where:

\(B\) is the batch size, \(C\) is the number of classes.

Warning

Ensure that the probabilities in probs are normalized to sum to one before passing them to the metric.

Raises:
  • ValueError – If reduction is not one of 'mean', 'sum', 'none' or None.

  • ValueError – If norm is not one of 'l1', 'l2', or 'max'.

References

[1] Kull et al. Beyond temperature scaling: Obtaining well-calibrated multiclass probabilities with Dirichlet calibration. In NeurIPS, 2019.

Examples

>>> from torch_uncertainty.metrics.classification import ClasswiseECE
# Example: Multi-Class Classification
>>> probs = torch.tensor([[0.6, 0.3, 0.1], [0.2, 0.5, 0.3]])
>>> target = torch.tensor([0, 2])
>>> metric = ClasswiseECE(num_classes=3, reduction="mean")
>>> metric.update(probs, target)
>>> score = metric.compute()
>>> print(score)
tensor(...)
compute()[source]#

Compute the final Classwise ECE based on inputs passed to update.

Returns:

The final value(s) for the Classwise ECE.

Return type:

Tensor

update(probs, target)[source]#

Update the state with a new tensor of probabilities.

Parameters:
  • probs (Tensor) – A probability tensor of shape (batch, num_classes).

  • target (Tensor) – A tensor of ground truth labels of shape (batch, num_classes) or (batch).