ConformalClsRAPS#

class torch_uncertainty.post_processing.ConformalClsRAPS(alpha, model=None, randomized=True, penalty=0.1, regularization_rank=1, ts_init_val=1.0, ts_lr=0.1, ts_max_iter=100, enable_ts=False, device=None)[source]#

Conformal classification with Regularised Adaptive Prediction Sets (RAPS; Angelopoulos, Bates, Jordan & Malik, 2021).

A regularised variant of ConformalClsAPS that penalises the inclusion of classes with a low predicted rank to produce smaller prediction sets without sacrificing coverage. The non-conformity score adds a rank-based regulariser to the APS score:

\[s(\mathbf{x}, y) = \underbrace{\sum_{i=1}^{k} \hat{p}_{(i)} - U \cdot \hat{p}_{(k)}}_{\text{APS}} + \lambda \cdot (k - k_\text{reg})_{+},\]

where \(k\) is the rank of class \(y\), \(\lambda\) is penalty, \(k_\text{reg}\) is regularization_rank, and \((\cdot)_+ = \max(\cdot, 0)\). Larger \(\lambda\) and smaller \(k_\text{reg}\) produce tighter sets at the cost of a coarser score.

Parameters:
  • alpha (float) – Target mis-coverage level \(\alpha \in (0, 1)\).

  • model (Module | None) – Trained classification model. Defaults to None.

  • randomized (bool) – Whether to use randomised tie-breaking. Defaults to True.

  • penalty (float) – Regularisation weight \(\lambda\). Defaults to 0.1.

  • regularization_rank (int) – Rank threshold \(k_\text{reg}\) above which the penalty is applied. Defaults to 1.

  • ts_init_val (float) – Initial value for the temperature. Defaults to 1.0.

  • ts_lr (float) – Learning rate for the temperature scaling optimizer. Defaults to 0.1.

  • ts_max_iter (int) – Maximum number of iterations for the temperature scaling optimizer. Defaults to 100.

  • enable_ts (bool) – Whether to apply temperature scaling before computing the conformal scores. Defaults to False.

  • device (Union[Literal['cpu', 'cuda'], device, None]) – Device to use. Defaults to None.

Warning

This implementation only works in the multiclass setting. Raise an issue if binary support is needed.

Reference:

Code inspired by TorchCP.

conformal(inputs)#

Compute the prediction set for each input.

Return type:

Tensor

fit(dataloader)#

Calibrate the APS threshold q_hat on a calibration set.

Return type:

None

model_forward(inputs)#

Apply the model and return the scores.

Return type:

Tensor