ConformalClsRAPS#
- class torch_uncertainty.post_processing.ConformalClsRAPS(alpha, model=None, randomized=True, penalty=0.1, regularization_rank=1, ts_init_val=1.0, ts_lr=0.1, ts_max_iter=100, enable_ts=False, device=None)[source]#
Conformal classification with Regularised Adaptive Prediction Sets (RAPS; Angelopoulos, Bates, Jordan & Malik, 2021).
A regularised variant of
ConformalClsAPSthat penalises the inclusion of classes with a low predicted rank to produce smaller prediction sets without sacrificing coverage. The non-conformity score adds a rank-based regulariser to the APS score:\[s(\mathbf{x}, y) = \underbrace{\sum_{i=1}^{k} \hat{p}_{(i)} - U \cdot \hat{p}_{(k)}}_{\text{APS}} + \lambda \cdot (k - k_\text{reg})_{+},\]where \(k\) is the rank of class \(y\), \(\lambda\) is
penalty, \(k_\text{reg}\) isregularization_rank, and \((\cdot)_+ = \max(\cdot, 0)\). Larger \(\lambda\) and smaller \(k_\text{reg}\) produce tighter sets at the cost of a coarser score.- Parameters:
alpha (
float) – Target mis-coverage level \(\alpha \in (0, 1)\).model (
Module|None) – Trained classification model. Defaults toNone.randomized (
bool) – Whether to use randomised tie-breaking. Defaults toTrue.penalty (
float) – Regularisation weight \(\lambda\). Defaults to0.1.regularization_rank (
int) – Rank threshold \(k_\text{reg}\) above which the penalty is applied. Defaults to1.ts_init_val (
float) – Initial value for the temperature. Defaults to1.0.ts_lr (
float) – Learning rate for the temperature scaling optimizer. Defaults to0.1.ts_max_iter (
int) – Maximum number of iterations for the temperature scaling optimizer. Defaults to100.enable_ts (
bool) – Whether to apply temperature scaling before computing the conformal scores. Defaults toFalse.device (
Union[Literal['cpu','cuda'],device,None]) – Device to use. Defaults toNone.
Warning
This implementation only works in the multiclass setting. Raise an issue if binary support is needed.
Code inspired by TorchCP.
- conformal(inputs)#
Compute the prediction set for each input.
- Return type:
Tensor
- fit(dataloader)#
Calibrate the APS threshold q_hat on a calibration set.
- Return type:
None
- model_forward(inputs)#
Apply the model and return the scores.
- Return type:
Tensor