Shortcuts

Source code for torch_uncertainty.layers.masksembles

"""Modified from https://github.com/nikitadurasov/masksembles/."""

from typing import Any

import numpy as np
import torch
from torch import Tensor, nn
from torch.nn.common_types import _size_2_t


def _generate_masks(m: int, n: int, s: float) -> np.ndarray:
    """Generates set of binary masks with properties defined by n, m, s params.
    Results of this function are stochastic, that is, calls with the same sets
    of arguments might generate outputs of different shapes. Check
    generate_masks and generation_wrapper function for more deterministic
    behaviour.

    Args:
        m (int): Number of ones in each mask.
        n (int): Number of masks in the set.
        s (float): Scale param controls overlap of generated masks.

    Returns:
        np.ndarray: Matrix of binary vectors.
    """
    rng = np.random.default_rng()
    total_positions = int(m * s)
    masks = []

    for _ in range(n):
        new_vector = np.zeros([total_positions])
        idx = rng.choice(range(total_positions), m, replace=False)
        new_vector[idx] = 1
        masks.append(new_vector)

    masks = np.array(masks)
    # drop useless positions
    return masks[:, ~np.all(masks == 0, axis=0)]


def generate_masks(m: int, n: int, s: float) -> np.ndarray:
    """Generates set of binary masks with properties defined by n, m, s params
    Resulting masks are required to have fixed features size.
    Since process of masks generation is stochastic therefore function
    evaluates _generate_masks multiple times till expected size is acquired.

    Args:
        m (int): number of ones in each mask
        n (int): number of masks in the set
        s (float): scale param controls overlap of generated masks

    Returns:
        np.ndarray: matrix of binary vectors
    """
    masks = _generate_masks(m, n, s)
    # hardcoded formula for expected size, check reference
    expected_size = int(m * s * (1 - (1 - 1 / s) ** n))
    while masks.shape[1] != expected_size:
        masks = _generate_masks(m, n, s)
    return masks


def generation_wrapper(c: int, n: int, scale: float) -> np.ndarray:
    """Generates set of binary masks with properties defined by c, n, scale
    params. Allows to generate masks sets with predefined features number c.
    Particularly convenient to use in torch-like layers where one need to
    define shapes inputs tensors beforehand.

    Args:
        c (int): number of channels in generated masks.
        n (int): number of masks in the set.
        scale (float): scale param controls overlap of generated masks.

    Raises:
        ValueError: If :attr:`c` < 10.
        ValueError: If :attr:`s` > 0.6.

    Returns:
        np.ndarray: matrix of binary vectors
    """
    if c < 10:
        raise ValueError(
            "Masksembles approach couldn't be used in such setups where "
            "number of channels is less then 10. Current value is "
            f"(channels={c})."
            "Please increase number of features in your layer or remove this "
            "particular instance of Masksembles from your architecture."
        )

    if scale > 6.0:
        raise ValueError(
            "Masksembles approach couldn't be used in such setups where "
            "scale parameter is larger then 6. Current value is  "
            f"(scale={scale})."
        )

    # inverse formula for number of active features in masks
    active_features = int(int(c) / (scale * (1 - (1 - 1 / scale) ** n)))

    # Use binary search to find the correct value of the scale
    masks = generate_masks(active_features, n, scale)
    up = 4 * scale
    down = max(0.2 * scale, 1.0)
    s = (down + up) / 2
    im_s = -1
    while im_s != c:
        masks = generate_masks(active_features, n, s)
        im_s = masks.shape[-1]
        if im_s < c:
            down = s
            s = (down + up) / 2
        elif im_s > c:
            up = s
            s = (down + up) / 2

    return masks


class Mask1d(nn.Module):
    def __init__(self, channels: int, num_masks: int, scale: float, **factory_kwargs) -> None:
        super().__init__()
        self.num_masks = num_masks

        masks = generation_wrapper(channels, num_masks, scale)
        masks = torch.from_numpy(masks)
        self.masks = torch.nn.Parameter(masks, requires_grad=False).to(
            device=factory_kwargs["device"]
        )

    def forward(self, inputs: Tensor) -> Tensor:
        batch = inputs.shape[0]
        x = torch.split(inputs.unsqueeze(1), batch // self.num_masks, dim=0)
        x = torch.cat(x, dim=1).permute([1, 0, 2])
        x = x * self.masks.unsqueeze(1)
        x = torch.cat(torch.split(x, 1, dim=0), dim=1)
        return torch.as_tensor(x, dtype=inputs.dtype).squeeze(0)


class Mask2d(nn.Module):
    def __init__(self, channels: int, num_masks: int, scale: float, **factory_kwargs) -> None:
        super().__init__()
        self.num_masks = num_masks

        masks = generation_wrapper(channels, num_masks, scale)
        masks = torch.from_numpy(masks)
        self.masks = torch.nn.Parameter(masks, requires_grad=False).to(
            device=factory_kwargs["device"]
        )

    def forward(self, inputs: Tensor) -> Tensor:
        batch = inputs.shape[0]
        x = torch.split(inputs.unsqueeze(1), batch // self.num_masks, dim=0)
        x = torch.cat(x, dim=1).permute([1, 0, 2, 3, 4])
        x = x * self.masks.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
        x = torch.cat(torch.split(x, 1, dim=0), dim=1)
        return torch.as_tensor(x, dtype=inputs.dtype).squeeze(0)


[docs]class MaskedLinear(nn.Module): def __init__( self, in_features: int, out_features: int, num_estimators: int, scale: float, bias: bool = True, device: Any | None = None, dtype: Any | None = None, ) -> None: r"""Masksembles-style Linear layer. This layer computes fully-connected operation for a given number of estimators (:attr:`num_estimators`) with a given :attr:`scale`. Args: in_features (int): Number of input features of the linear layer. out_features (int): Number of channels produced by the linear layer. num_estimators (int): The number of estimators grouped in the layer. scale (float): The scale parameter for the masks. bias (bool, optional): It ``True``, adds a learnable bias to the output. Defaults to ``True``. groups (int, optional): Number of blocked connections from input channels to output channels. Defaults to ``1``. device (Any, optional): The desired device of returned tensor. Defaults to ``None``. dtype (Any, optional): The desired data type of returned tensor. Defaults to ``None``. Warning: Be sure to apply a repeat on the batch at the start of the training if you use `MaskedLinear`. Reference: `Masksembles for Uncertainty Estimation`, Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua. """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if scale is None: raise ValueError("You must specify the value of the arg. `scale`") if scale < 1: raise ValueError(f"Attribute `scale` should be >= 1, not {scale}.") self.mask = Mask1d(in_features, num_masks=num_estimators, scale=scale, **factory_kwargs) self.linear = nn.Linear( in_features=in_features, out_features=out_features, bias=bias, **factory_kwargs, ) def forward(self, inputs: Tensor) -> Tensor: return self.linear(self.mask(inputs))
[docs]class MaskedConv2d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, num_estimators: int, scale: float, stride: _size_2_t = 1, padding: str | _size_2_t = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, device: Any | None = None, dtype: Any | None = None, ) -> None: r"""Masksembles-style Conv2d layer. Args: in_channels (int): Number of channels in the input image. out_channels (int): Number of channels produced by the convolution. kernel_size (int or tuple): Size of the convolving kernel. num_estimators (int): Number of estimators in the ensemble. scale (float): The scale parameter for the masks. stride (int or tuple, optional): Stride of the convolution. Defaults to ``1``. padding (int, tuple or str, optional): Padding added to all four sides of the input. Defaults to ``0``. dilation (int or tuple, optional): Spacing between kernel elements. Defaults to ``1``. groups (int, optional): Number of blocked connexions from input channels to output channels for each estimator. Defaults to ``1``. bias (bool, optional): If ``True``, adds a learnable bias to the output. Defaults to ``True``. device (Any, optional): The desired device of returned tensor. Defaults to ``None``. dtype (Any, optional): The desired data type of returned tensor. Defaults to ``None``. Warning: Be sure to apply a repeat on the batch at the start of the training if you use `MaskedConv2d`. Reference: `Masksembles for Uncertainty Estimation`, Nikita Durasov, Timur Bagautdinov, Pierre Baque, Pascal Fua. """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if scale is None: raise ValueError("You must specify the value of the arg. `scale`") if scale < 1: raise ValueError(f"Attribute `scale` should be >= 1, not {scale}.") self.mask = Mask2d(in_channels, num_masks=num_estimators, scale=scale, **factory_kwargs) self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode="zeros", **factory_kwargs, ) def forward(self, inputs: Tensor) -> Tensor: return self.conv(self.mask(inputs))