CalibrationError#
- class torch_uncertainty.metrics.classification.CalibrationError(task, adaptive=False, num_bins=10, norm='l1', num_classes=None, ignore_index=None, validate_args=True, **kwargs)[source]#
Computes the Calibration Error for classification tasks.
This metric evaluates how well a model’s predicted probabilities align with the actual ground truth probabilities. Calibration is crucial in assessing the reliability of probabilistic predictions, especially for downstream decision-making tasks.
Given top-class confidences \(\hat{p}_i\) and accuracies \(a_i = \mathbf{1}[\hat{y}_i = y_i]\), the \(N\) samples are assigned to \(M\) bins \(B_1, \dots, B_M\) uniformly spaced in \([0, 1]\). Three norms are available:
Expected Calibration Error (ECE):
\[\text{ECE} = \sum_{m=1}^{M} \frac{|B_m|}{N} \left| \operatorname{acc}(B_m) - \operatorname{conf}(B_m) \right|\]Maximum Calibration Error (MCE):
\[\text{MCE} = \max_{m} \left| \operatorname{acc}(B_m) - \operatorname{conf}(B_m) \right|\]Root Mean Square Calibration Error (RMSCE):
\[\text{RMSCE} = \sqrt{\sum_{m=1}^{M} \frac{|B_m|}{N} \left( \operatorname{acc}(B_m) - \operatorname{conf}(B_m) \right)^2}\]where \(\operatorname{acc}(B_m) = \tfrac{1}{|B_m|}\sum_{i \in B_m} a_i\) is the fraction of correct predictions in bin \(m\), \(\operatorname{conf}(B_m) = \tfrac{1}{|B_m|}\sum_{i \in B_m} \hat{p}_i\) is the mean predicted confidence in bin \(m\), and \(|B_m|/N\) is the fraction of total samples in bin \(m\).
Bins are constructed either uniformly in the range \([0, 1]\) or adaptively (if
adaptive=True).- Parameters:
task – Specifies the task type, either
"binary"or"multiclass".adaptive – Whether to use adaptive binning. Defaults to
False.num_bins – Number of bins to divide the probability space. Defaults to
10.norm – Specifies the type of norm to use:
"l1","l2", or"max". Defaults to"l1".num_classes – Number of classes for
"multiclass"tasks. Required when task is"multiclass". Defaults toNone.ignore_index – Index to ignore during calculations. Defaults to
None.validate_args – Whether to validate input arguments. Defaults to
True.**kwargs – Additional keyword arguments for the metric.
Example:
from torch_uncertainty.metrics.classification.calibration_error import ( CalibrationError, ) # Example for binary classification predicted_probs = torch.tensor([0.9, 0.8, 0.3, 0.2]) true_labels = torch.tensor([1, 1, 0, 0]) metric = CalibrationError( task="binary", num_bins=5, norm="l1", adaptive=False, ) calibration_error = metric(predicted_probs, true_labels) print(f"Calibration Error: {calibration_error}") # Output: Calibration Error: 0.199
Note
Bins are either uniformly distributed in \([0, 1]\) or adaptively sized (if
adaptive=True).Warning
If
task="multiclass",num_classesmust be an integer; otherwise, aTypeErroris raised.References
[1] Naeini et al. Obtaining well calibrated probabilities using Bayesian binning. In AAAI, 2015.
See also
See CalibrationError for details. This implementation wraps the original metric and provides improved plotting functionality.