MeanIntersectionOverUnion¶
- class torch_uncertainty.metrics.classification.MeanIntersectionOverUnion(num_classes, top_k=1, ignore_index=None, validate_args=True, **kwargs)[source]¶
Computes Mean Intersection over Union (IoU) score.
- Parameters:
num_classes (int) – Integer specifying the number of classes.
top_k (int, optional) – Number of highest probability or logit score predictions considered to find the correct label. Only works when
preds
contain probabilities/logits. Defaults to1
.ignore_index (int | None, optional) – Specifies a target value that is ignored and does not contribute to the metric calculation. Defaults to
None
.validate_args (bool, optional) – Bool indicating if input arguments and tensors should be validated for correctness. Set to
False
for faster computations. Defaults toTrue
.**kwargs – kwargs: Additional keyword arguments, see Advanced metric settings for more info.
- Shape:
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor of shape(B, ...)
or float tensor of shape(B, C, ..)
. If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(Tensor
): An int tensor of shape(B, ...)
.
As output to
forward
andcompute
the metric returns the following output:mean_iou
(Tensor
): The computed Mean Intersection over Union (IoU) score. A tensor containing a single float value.