Source code for torch_uncertainty.metrics.segmentation.seg_fpr95

import torch
from torch import Tensor
from torchmetrics import Metric

from torch_uncertainty.metrics import FPR95


[docs] class SegmentationFPR95(Metric): is_differentiable = False higher_is_better = False full_state_update = False fpr95: Tensor total: Tensor def __init__(self, pos_label: int, **kwargs) -> None: r"""Image-averaged FPR@95 TPR for dense binary segmentation tasks. For each image, a per-pixel False Positive Rate at 95% True Positive Rate is computed (see :class:`~torch_uncertainty.metrics.classification.FPR95`) from the pixel scores and binary OOD labels. The metric is then averaged over the :math:`B` images of the test set: .. math:: \text{FPR95} = \frac{1}{B} \sum_{b=1}^{B} \text{FPR95}_b. Image-wise averaging is the convention used in the dense OOD-detection literature. Args: pos_label: The positive label in the segmentation OOD detection task (typically ``1`` for OOD pixels). kwargs: Additional keyword arguments for the underlying :class:`~torch_uncertainty.metrics.classification.FPR95` metric. """ super().__init__(**kwargs) self.fpr95_metric = FPR95(pos_label, **kwargs) self.add_state("fpr95", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # pyrefly: ignore[bad-override] batch_size = preds.size(0) self.fpr95 += self.fpr95_metric(preds, target) * batch_size self.total += batch_size def compute(self) -> Tensor: if self.total == 0: return torch.tensor(torch.nan, device=self.fpr95.device) return self.fpr95 / self.total