Conformal#
- class torch_uncertainty.post_processing.Conformal(alpha, model, ts_init_val, ts_lr, ts_max_iter, enable_ts, device)[source]#
Abstract base class for split-conformal classification predictors.
Builds prediction sets \(\mathcal{C}(\mathbf{x}) \subseteq \{1, \dots, C\}\) with the marginal coverage guarantee
\[\mathbb{P}\!\left[ Y \in \mathcal{C}(X) \right] \geq 1 - \alpha,\]provided that the calibration and test points are exchangeable. At fit time, non-conformity scores are computed on a held-out calibration set and the empirical \((1 - \alpha)\)-quantile \(\hat{q}\) is stored in
q_hat. At test time, the prediction set is built by including all classes whose conformal score is below \(\hat{q}\). SeeConformalClsTHR,ConformalClsAPS, andConformalClsRAPSfor concrete scores.- Parameters:
alpha (
float) – Target mis-coverage level \(\alpha \in (0, 1)\). A smaller \(\alpha\) yields larger prediction sets.model (
Module|None) – Underlying classifier.ts_init_val (
float) – Initial temperature used whenenable_tsisTrue.ts_lr (
float) – Learning rate for the temperature optimizer.ts_max_iter (
int) – Maximum number of iterations for the temperature optimizer.enable_ts (
bool) – IfTrue, wraps the model in aTemperatureScalerfit on the calibration set before computing the conformal scores.device (
Union[Literal['cpu','cuda'],device,None]) – Device to run the post-processing on.
Warning
This implementation only works in the multiclass setting. Raise an issue if binary support is needed.
- abstract conformal(inputs)[source]#
Apply the conformal prediction rule to the inputs.
- Return type:
Tensor
- abstract fit(dataloader)#
Fit the post-processing module on a calibration dataloader.
- Parameters:
dataloader (
DataLoader) – A dataloader yielding(inputs, targets)pairs from a held-out calibration set, disjoint from both the training and the test sets.- Return type:
None