Source code for torch_uncertainty.metrics.segmentation.seg_binary_average_precision
from typing import Any
import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.classification import BinaryAveragePrecision
[docs]
class SegmentationBinaryAveragePrecision(Metric):
is_differentiable = False
higher_is_better = True
full_state_update = False
def __init__(
self,
thresholds: int | list[float] | Tensor | None = None,
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
"""SegmentationBinaryAveragePrecision computes the Average Precision (AP) for binary segmentation tasks.
It aggregates the mean AP across batches and computes the average AP over all batches processed.
"""
super().__init__(**kwargs)
self.aupr_metric = BinaryAveragePrecision(
thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args, **kwargs
)
self.add_state("binary_aupr", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, preds: Tensor, target: Tensor) -> None:
batch_size = preds.size(0)
aupr = self.aupr_metric(preds, target)
self.binary_aupr += aupr * batch_size
self.total += batch_size
def compute(self) -> Tensor:
if self.total == 0:
return torch.tensor(0.0, device=self.binary_aupr.device)
return self.binary_aupr / self.total