Source code for torch_uncertainty.models.wrappers.zero

import torch
from einops import rearrange
from torch import Tensor, nn
from torch.special import entr


[docs] class Zero(nn.Module): def __init__( self, core_model: nn.Module, num_tta: int, filter_views: float = 0.1, eps: float = 1e-8 ) -> None: """Zero for test-time adaptation. Zero performs "0-temperature averaging" (i.e. majority voting) at evaluation. It starts by filtering the :attr:`filter_views` most confident predictions, and returns the majority vote as a prediction. If used during training, the predictions will be those of the inner-model passed as argument (:attr:`model`). Args: core_model (nn.Module): The inner model to train. num_tta (int): The number of views at evaluation time. filter_views (float): Filter out 1-:attr:`filter_views` of the predictions of the augmented views. Defaults to ``0.1``. eps (float): for computational stability. Defaults to ``1e-8``; """ super().__init__() _zero_checks(num_tta, filter_views, eps) self.core_model = core_model self.filter = filter_views self.kept_views = int(filter_views * num_tta) self.num_tta = num_tta self.eps = eps def eval_forward(self, x: Tensor) -> Tensor: # predict and separate the views from the batch all_predictions = rearrange(self.core_model(x), "(b v) c -> b v c", v=self.num_tta) batch_size, _, num_classes = all_predictions.shape entropies = entr(all_predictions).sum(2) # Get the index of the most confident predictions on the views conf_idx = torch.argsort(entropies, dim=-1) votes = all_predictions.argmax(-1) # Count the votes predictions = torch.zeros((batch_size, num_classes), device=all_predictions.device) for img_id, img_votes in enumerate(votes): predictions[img_id, :] += torch.bincount( img_votes[conf_idx[img_id, : self.kept_views]], minlength=all_predictions.shape[-1] ) maximum = predictions[img_id, :].max() i = 0 # If the maximum is shared among two predictions, look at an additional one while ( self.kept_views + i < self.num_tta and torch.sum(1 * (predictions[img_id, :] == maximum)) > 1 ): predictions[img_id, img_votes[conf_idx[img_id, self.kept_views + i]]] += 1 maximum = predictions[img_id, :].max() i += 1 predictions /= self.num_tta # We will apply the softmax in the routine, so let's apply its inverse here return (predictions + self.eps).log() def forward(self, x: Tensor) -> Tensor: if self.training: return self.core_model.forward(x) return self.eval_forward(x)
def _zero_checks(num_tta: int, filter_views: float, eps: float) -> None: if filter_views <= 0.0 or filter_views > 1.0: raise ValueError(f"`filter_views` must be in the range ]0, 1]. Got {filter_views}.") if num_tta < 1 / filter_views: raise ValueError( f"`num_tta` should be greater than 1/filter_views to use Zero. Got {num_tta} < {1 / filter_views}." ) if eps <= 0: raise ValueError(f"`eps` should be strictly positive. Got {eps}.")