SmoothCalibrationError#
- class torch_uncertainty.metrics.classification.SmoothCalibrationError(kernel_type='logit', bandwidth='auto', eps=0.001, mesh_pts=200, refine_steps=10, **kwargs)[source]#
Smooth Expected Calibration Error (SmECE).
This metric implements the Kernel Density Estimation based ECE proposed by Błasiok & Nakkiran (2023). It addresses the limitations of the standard binned ECE — bin-edge effects, poor resolution for overconfident models — by replacing the histogram with a kernel density estimate of the residuals.
Given top-class confidences \(\hat{p}_i \in [0, 1]\) and accuracies \(a_i \in \{0, 1\}\), define the kernel-smoothed conditional gap
\[r_h(t) = \frac{\sum_{i=1}^{N} K_h(t, \hat{p}_i) (\hat{p}_i - a_i)} {\sum_{i=1}^{N} K_h(t, \hat{p}_i)},\]where \(K_h\) is a kernel of bandwidth \(h\). The Smooth ECE is then
\[\text{SmECE} = \int_0^1 |r_h(t)| \, \hat{f}(t) \, \mathrm{d}t,\]with \(\hat{f}(t) = \tfrac{1}{N}\sum_i K_h(t, \hat{p}_i)\) the kernel density of confidences. The bandwidth \(h\) can be fixed or selected adaptively via a fixed-point binary search.
- Parameters:
kernel_type (
Literal['logit','reflected']) – The kernel to use. Choose between:'logit' (-) – Applies a Gaussian kernel in log-odds space. This effectively uses an adaptive bandwidth that is narrower near 1.0, making it ideal for modern overconfident models. (Default)
'reflected' (-) –
Applies a Gaussian kernel in probability space with reflections at 0 and 1 to prevent boundary bias.
Note that relplot’s original implementation uses
'reflected'by default.bandwidth (
Union[float,Literal['auto']]) – The kernel bandwidth \(h\). If set to'auto', it uses a fixed-point binary search to find a bandwidth consistent with the error level. Defaults to'auto'.eps (
float) – The tolerance for the binary search when bandwidth is'auto'. Defaults to0.001.mesh_pts (
int) – The base number of points for the grid discretization. The actual number may be higher depending on the bandwidth. Defaults to200.refine_steps (
int) – Number of binary search iterations for the'auto'bandwidth. Defaults to10.**kwargs – Additional arguments for the
torchmetrics.Metricbase.
Note
In the multiclass case, this metric evaluates the calibration of the maximum probability (top-label calibration). In the binary case, it evaluates the calibration of the predicted class (i.e., using \(\max(p, 1-p)\)).
Note
This implementation has been tested on a use case and provides the same values as relplot’s original implementation (to 6 significant figures).
References
Błasiok, J. & Nakkiran, P. Smooth ECE: Principled Reliability Diagrams via Kernel Smoothing. ICLR 2024.
- compute()[source]#
Compute the Smooth ECE based on the accumulated state.
- Returns:
The scalar SmECE value.
- Return type:
Tensor
- update(preds, target)[source]#
Update the state with predictions and targets.
- Parameters:
preds (
Tensor) – Predictions from the model. - Multiclass: Shape(N, C)(logits or probabilities). - Binary: Shape(N,)or(N, 1)(logits or probabilities).target (
Tensor) – Ground truth labels. - Multiclass: Shape(N,)containing class indices. - Binary: Shape(N,)containing 0 or 1.
- Return type:
None