Source code for torch_uncertainty.metrics.segmentation.mean_iou

from torch import Tensor
from torchmetrics.classification.stat_scores import MulticlassStatScores
from torchmetrics.utilities.compute import _safe_divide


[docs] class MeanIntersectionOverUnion(MulticlassStatScores): is_differentiable = False higher_is_better = True full_state_update = False def __init__( self, num_classes: int, top_k: int = 1, ignore_index: int | None = None, validate_args: bool = True, **kwargs, ) -> None: r"""Computes the Mean Intersection over Union (mIoU) score. For a multi-class segmentation task with :math:`C` classes, the per-class Intersection over Union is .. math:: \text{IoU}_c = \frac{\text{TP}_c}{\text{TP}_c + \text{FP}_c + \text{FN}_c}, where :math:`\text{TP}_c`, :math:`\text{FP}_c`, :math:`\text{FN}_c` are the numbers of true-positive, false-positive and false-negative pixels for class :math:`c` (aggregated over all images). The mean IoU is the unweighted average .. math:: \text{mIoU} = \frac{1}{C} \sum_{c=1}^{C} \text{IoU}_c. Classes that never appear in the targets are excluded from the average (``nanmean``). Args: num_classes: Integer specifying the number of classes. top_k: Number of highest probability or logit score predictions considered to find the correct label. Only works when ``preds`` contain probabilities/logits. Defaults to ``1``. ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation. Defaults to ``None``. validate_args: Bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. Defaults to ``True``. **kwargs: kwargs: Additional keyword arguments, see `Advanced metric settings <https://lightning.ai/docs/torchmetrics/stable/pages/overview.html#metric-kwargs>`_ for more info. Shape: As input to ``forward`` and ``update`` the metric accepts the following input: - **preds** (:class:`~torch.Tensor`): An int tensor of shape ``(B, *)`` or float tensor of shape ``(B, C, *)``. If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into an int tensor. - **target** (:class:`~torch.Tensor`): An int tensor of shape ``(B, *)``. As output to ``forward`` and ``compute`` the metric returns the following output: - **mean_iou** (:class:`~torch.Tensor`): The computed Mean Intersection over Union (IoU) score. A tensor containing a single float value. """ super().__init__( num_classes, top_k, "macro", "global", ignore_index, validate_args, **kwargs, )
[docs] def compute(self) -> Tensor: """Compute the Mean Intersection over Union (mIoU) based on the accumulated state.""" tp, fp, _, fn = self._final_state() return _safe_divide(tp, tp + fp + fn, zero_division=float("nan")).nanmean()