SmoothCalibrationError#
- class torch_uncertainty.metrics.classification.SmoothCalibrationError(kernel_type='logit', bandwidth='auto', eps=0.001, mesh_pts=200, refine_steps=10, **kwargs)[source]#
Smooth Expected Calibration Error (SmECE).
This metric implements the Kernel Density Estimation based ECE as proposed by Błasiok & Nakkiran (2023). It addresses the limitations of standard binned ECE, such as bin-edge effects and poor resolution for overconfident models, by using a continuous kernel and an adaptive bandwidth selection strategy. Computed on the top label.
- Parameters:
kernel_type (str, optional) –
The kernel to use. Choose between: -
'logit': Applies a Gaussian kernel in log-odds space. Thiseffectively uses an adaptive bandwidth that is narrower near 1.0, making it ideal for modern overconfident models. (Default)
'reflected': Applies a Gaussian kernel in probability spacewith reflections at 0 and 1 to prevent boundary bias.
Note that relplot’s original implementation has
'reflected'as default.bandwidth (Literal[auto] | float, optional) – The kernel bandwidth $h$. If set to
'auto', it uses a fixed-point binary search to find a bandwidth consistent with the error level. Defaults to'auto'.eps (float, optional) – The tolerance for the binary search when bandwidth is
'auto'. Defaults to0.001.mesh_pts (int, optional) – The base number of points for the grid discretization. The actual number may be higher depending on the bandwidth. Defaults to
200.refine_steps (int, optional) – Number of binary search iterations for the
'auto'bandwidth. Defaults to10.**kwargs – Additional arguments for the
torchmetrics.Metricbase.
Note
In the multiclass case, this metric evaluates the calibration of the maximum probability (top-label calibration). In the binary case, it evaluates the calibration of the predicted class (i.e., using $max(p, 1-p)$).
Note
This implementation has been tested on a use case and provided the same values (with 6 equal significant figures) as relplot’s original implementation.
References
Błasiok, J. & Nakkiran, P. Smooth ECE: Principled Reliability
Diagrams. ICLR 2024.
- compute()[source]#
Compute the Smooth ECE based on the accumulated state.
- Returns:
The scalar SmECE value.
- Return type:
Tensor
- update(preds, target)[source]#
Update the state with predictions and targets.
- Parameters:
preds (Tensor) – Predictions from the model. - Multiclass: Shape
(N, C)(logits or probabilities). - Binary: Shape(N,)or(N, 1)(logits or probabilities).target (Tensor) – Ground truth labels. - Multiclass: Shape
(N,)containing class indices. - Binary: Shape(N,)containing 0 or 1.