Source code for torch_uncertainty.metrics.segmentation.seg_fpr95

import torch
from torch import Tensor
from torchmetrics import Metric

from torch_uncertainty.metrics import FPR95


[docs] class SegmentationFPR95(Metric): is_differentiable = False higher_is_better = False full_state_update = False def __init__(self, pos_label: int, **kwargs) -> None: """FPR95 metric for segmentation tasks. Compute the mean FPR95 per batch across all batches. Args: pos_label (int): The positive label in the segmentation OOD detection task. **kwargs: Additional keyword arguments for the FPR95 metric. """ super().__init__(**kwargs) self.fpr95_metric = FPR95(pos_label, **kwargs) self.add_state("fpr95", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: batch_size = preds.size(0) fpr95 = self.fpr95_metric(preds, target) self.fpr95 += fpr95 * batch_size self.total += batch_size def compute(self) -> Tensor: if self.total == 0: return torch.tensor(0.0, device=self.fpr95.device) return self.fpr95 / self.total