GroupingLoss¶
- class torch_uncertainty.metrics.classification.GroupingLoss(**kwargs)[source]¶
Metric to estimate the Top-label Grouping Loss.
- Parameters:
kwargs – Additional keyword arguments, see Advanced metric settings.
- Inputs:
probs
: \((B, C)\) or \((B, N, C)\)target
: \((B)\) or \((B, C)\)features
: \((B, F)\) or \((B, N, F)\)
where \(B\) is the batch size, \(C\) is the number of classes and \(N\) is the number of estimators.
Warning
Make sure that the probabilities in
probs
are normalized to sum to one.- Raises:
ValueError – If
reduction
is not one of'mean'
,'sum'
,'none'
orNone
.
- Reference:
Perez-Lebel, Alexandre, Le Morvan, Marine and Varoquaux, Gaël. Beyond calibration: estimating the grouping loss of modern neural networks. In ICLR 2023.
- compute()[source]¶
Compute the final Brier score based on inputs passed to
update
.- Returns:
The final value(s) for the Brier score
- Return type:
torch.Tensor
- update(probs, target, features)[source]¶
Accumulate the tensors for the estimation of the Grouping Loss.
- Parameters:
probs (Tensor) – A probability tensor of shape (batch, num_classes), (batch, num_estimators, num_classes), or (batch) if binary classification
target (Tensor) – A tensor of ground truth labels of shape (batch, num_classes) or (batch)
features (Tensor) – A tensor of features of shape (batch, num_estimators, num_features) or (batch, num_features)