Conformal#

class torch_uncertainty.post_processing.Conformal(alpha, model, ts_init_val, ts_lr, ts_max_iter, enable_ts, device)[source]#

Abstract base class for split-conformal classification predictors.

Builds prediction sets \(\mathcal{C}(\mathbf{x}) \subseteq \{1, \dots, C\}\) with the marginal coverage guarantee

\[\mathbb{P}\!\left[ Y \in \mathcal{C}(X) \right] \geq 1 - \alpha,\]

provided that the calibration and test points are exchangeable. At fit time, non-conformity scores are computed on a held-out calibration set and the empirical \((1 - \alpha)\)-quantile \(\hat{q}\) is stored in q_hat. At test time, the prediction set is built by including all classes whose conformal score is below \(\hat{q}\). See ConformalClsTHR, ConformalClsAPS, and ConformalClsRAPS for concrete scores.

Parameters:
  • alpha (float) – Target mis-coverage level \(\alpha \in (0, 1)\). A smaller \(\alpha\) yields larger prediction sets.

  • model (Module | None) – Underlying classifier.

  • ts_init_val (float) – Initial temperature used when enable_ts is True.

  • ts_lr (float) – Learning rate for the temperature optimizer.

  • ts_max_iter (int) – Maximum number of iterations for the temperature optimizer.

  • enable_ts (bool) – If True, wraps the model in a TemperatureScaler fit on the calibration set before computing the conformal scores.

  • device (Union[Literal['cpu', 'cuda'], device, None]) – Device to run the post-processing on.

Warning

This implementation only works in the multiclass setting. Raise an issue if binary support is needed.

abstract conformal(inputs)[source]#

Apply the conformal prediction rule to the inputs.

Return type:

Tensor

abstract fit(dataloader)#

Fit the post-processing module on a calibration dataloader.

Parameters:

dataloader (DataLoader) – A dataloader yielding (inputs, targets) pairs from a held-out calibration set, disjoint from both the training and the test sets.

Return type:

None

model_forward(inputs)[source]#

Apply the model and return the scores.

Return type:

Tensor