Shortcuts

Source code for torch_uncertainty.layers.batch_ensemble

import math

import torch
from torch import Tensor, nn
from torch.nn.common_types import _size_2_t
from torch.nn.modules.utils import _pair


[docs]class BatchLinear(nn.Module): __constants__ = ["in_features", "out_features", "num_estimators"] in_features: int out_features: int num_estimators: int r_group: Tensor s_group: Tensor bias: Tensor | None def __init__( self, in_features: int, out_features: int, num_estimators: int, bias: bool = True, device=None, dtype=None, ) -> None: r"""BatchEnsemble-style Linear layer. Apply a linear transformation using BatchEnsemble method to the incoming data. .. math:: y=(x\circ \widehat{r_{group}})W^{T}\circ \widehat{s_{group}} + \widehat{b} Args: in_features (int): Number of input features.. out_features (int): Number of output features. num_estimators (int): number of estimators in the ensemble, referred as :math:`M`. bias (bool, optional): if ``True``, adds a learnable bias to the output. Defaults to ``True``. device (Any, optional): device to use for the parameters and buffers of this module. Defaults to ``None``. dtype (Any, optional): data type to use for the parameters and buffers of this module. Defaults to ``None``. Reference: Introduced by the paper `BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning <https://arxiv.org/abs/2002.06715>`_, we present here an implementation of a Linear BatchEnsemble layer in `PyTorch <https://pytorch.org>`_ heavily inspired by its `official implementation <https://github.com/google/edward2>`_ in `TensorFlow <https://www.tensorflow.org>`_. Attributes: weight: the learnable weights (:math:`W`) of shape :math:`(H_{out}, H_{in})` shared between the estimators. The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k = \frac{1}{H_{in}}`. r_group: the learnable matrice of shape :math:`(M, H_{in})` where each row consist of the vector :math:`r_{i}` corresponding to the :math:`i^{th}` ensemble member. The values are initialized from :math:`\mathcal{N}(1.0, 0.5)`. s_group: the learnable matrice of shape :math:`(M, H_{out})` where each row consist of the vector :math:`s_{i}` corresponding to the :math:`i^{th}` ensemble member. The values are initialized from :math:`\mathcal{N}(1.0, 0.5)`. bias: the learnable bias (:math:`b`) of shape :math:`(M, H_{out})` where each row corresponds to the bias of the :math:`i^{th}` ensemble member. If :attr:`bias` is ``True``, the values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{H_{in}}`. Shape: - Input: :math:`(N, H_{in})` where :math:`N` is the batch size and :math:`H_{in} = \text{in_features}`. - Output: :math:`(N, H_{out})` where :math:`H_{out} = \text{out_features}`. Warning: Make sure that :attr:`num_estimators` divides :attr:`out_features` when calling :func:`forward()`. Examples: >>> # With three estimators >>> m = LinearBE(20, 30, 3) >>> input = torch.randn(8, 20) >>> output = m(input) >>> print(output.size()) torch.Size([8, 30]) """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.in_features = in_features self.out_features = out_features self.num_estimators = num_estimators self.linear = nn.Linear( in_features=in_features, out_features=out_features, bias=False, **factory_kwargs, ) self.r_group = nn.Parameter(torch.empty((num_estimators, in_features), **factory_kwargs)) self.s_group = nn.Parameter(torch.empty((num_estimators, out_features), **factory_kwargs)) if bias: self.bias = nn.Parameter(torch.empty((num_estimators, out_features), **factory_kwargs)) else: self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) if self.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.linear.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound) def forward(self, inputs: Tensor) -> Tensor: batch_size = inputs.size(0) examples_per_estimator = torch.tensor( batch_size // self.num_estimators, device=inputs.device ) extra = batch_size % self.num_estimators r_group = torch.repeat_interleave(self.r_group, examples_per_estimator, dim=0) r_group = torch.cat([r_group, r_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) s_group = torch.repeat_interleave(self.s_group, examples_per_estimator, dim=0) s_group = torch.cat([s_group, s_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) if self.bias is not None: bias = torch.repeat_interleave( self.bias, examples_per_estimator, dim=0, ) bias = torch.cat([bias, bias[:extra]], dim=0) else: bias = None return self.linear(inputs * r_group) * s_group + (bias if bias is not None else 0) def extra_repr(self) -> str: return ( f"in_features={ self.in_features}," f" out_features={self.out_features}," f" num_estimators={self.num_estimators}," f" bias={self.bias is not None}" )
[docs]class BatchConv2d(nn.Module): __constants__ = [ "stride", "padding", "dilation", "groups", "in_channels", "out_channels", "kernel_size", "num_estimators", ] in_channels: int out_channels: int kernel_size: tuple[int, ...] num_estimators: int stride: tuple[int, ...] padding: str | tuple[int, ...] dilation: tuple[int, ...] groups: int weight: Tensor r_group: Tensor s_group: Tensor bias: Tensor | None def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, num_estimators: int, stride: _size_2_t = 1, padding: str | _size_2_t = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, device=None, dtype=None, ) -> None: r"""BatchEnsemble-style Conv2d layer. Applies a 2d convolution over an input signal composed of several input planes using BatchEnsemble method to the incoming data. In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, H_{in}, W_{in})` and output :math:`(N, C_{out}, H_{out}, W_{out})` can be precisely described as: .. math:: \text{out}(N_i, C_{\text{out}_j})=\ &\widehat{b}(N_i,C_{\text{out}_j}) +\widehat{s_{group}}(N_{i},C_{\text{out}_j}) \\ &\times \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k)\star (\text{input}(N_i, k) \times \widehat{r_{group}}(N_i, k)) Reference: Introduced by the paper `BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning <https://arxiv.org/abs/2002.06715>`_, we present here an implementation of a Conv2d BatchEnsemble layer in `PyTorch <https://pytorch.org>`_ heavily inspired by its `official implementation <https://github.com/google/edward2>`_ in `TensorFlow <https://www.tensorflow.org>`_. Args: in_channels (int): number of channels in the input images. out_channels (int): number of channels produced by the convolution. kernel_size (int or tuple): size of the convolving kernel. num_estimators (int): number of estimators in the ensemble referred as :math:`M` here. stride (int or tuple, optional): stride of the convolution. Defaults to ``1``. padding (int, tuple or str, optional): padding added to all four sides of the input. Defaults to ``0``. dilation (int or tuple, optional): spacing between kernel elements. Defaults to ``1``. groups (int, optional): number of blocked connections from input channels to output channels. Defaults to ``1``. bias (bool, optional): if ``True``, adds a learnable bias to the output. Defaults to ``True``. device (Any, optional): device to use for the parameters and buffers of this module. Defaults to ``None``. dtype (Any, optional): data type to use for the parameters and buffers of this module. Defaults to ``None``. Attributes: weight: the learnable weights of the module of shape :math:`(\text{out_channels}, \frac{\text{in_channels}} {\text{groups}},`:math:`\text{kernel_size[0]}, \text{kernel_size[1]})` shared between the estimators. The values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{\text{groups}}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}`. r_group: the learnable matrice of shape :math:`(M, C_{in})` where each row consist of the vector :math:`r_{i}` corresponding to the :math:`i^{th}` ensemble member. The values are initialized from :math:`\mathcal{N}(1.0, 0.5)`. s_group: the learnable matrice of shape :math:`(M, C_{out})` where each row consist of the vector :math:`s_{i}` corresponding to the :math:`i^{th}` ensemble member. The values are initialized from :math:`\mathcal{N}(1.0, 0.5)`. bias: the learnable bias (:math:`b`) of shape :math:`(M, C_{out})` where each row corresponds to the bias of the :math:`i^{th}` ensemble member. If :attr:`bias` is ``True``, the values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k=\frac{\text{groups}}{C_\text{in}*\prod_{i=0}^{1} \text{kernel_size}[i]}`. Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})`. - Output: :math:`(N, C_{out}, H_{out}, W_{out})`. .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel_size}[0] - 1) - 1} {\text{stride}[0]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel_size}[1] - 1) - 1} {\text{stride}[1]} + 1\right\rfloor Warning: Make sure that :attr:`num_estimators` divides :attr:`out_channels` when calling :func:`forward()`. Examples: >>> # With square kernels, four estimators and equal stride >>> m = Conv2dBE(3, 32, 3, 4, stride=1) >>> input = torch.randn(8, 3, 16, 16) >>> output = m(input) >>> print(output.size()) torch.Size([8, 32, 14, 14]) """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.num_estimators = num_estimators self.stride = _pair(stride) self.padding = padding if isinstance(padding, str) else _pair(padding) self.dilation = _pair(dilation) self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False, **factory_kwargs, ) self.r_group = nn.Parameter(torch.empty((num_estimators, in_channels), **factory_kwargs)) self.s_group = nn.Parameter(torch.empty((num_estimators, out_channels), **factory_kwargs)) if bias: self.bias = nn.Parameter(torch.empty((num_estimators, out_channels), **factory_kwargs)) else: self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) if self.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.conv.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound) def forward(self, inputs: Tensor) -> Tensor: batch_size = inputs.size(0) examples_per_estimator = batch_size // self.num_estimators extra = batch_size % self.num_estimators r_group = ( torch.repeat_interleave( self.r_group, torch.full( [self.num_estimators], examples_per_estimator, device=self.r_group.device, ), dim=0, ) .unsqueeze(-1) .unsqueeze(-1) ) r_group = torch.cat([r_group, r_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) s_group = ( torch.repeat_interleave( self.s_group, torch.full( [self.num_estimators], examples_per_estimator, device=self.s_group.device, ), dim=0, ) .unsqueeze(-1) .unsqueeze(-1) ) s_group = torch.cat([s_group, s_group[:extra]], dim=0) # if self.bias is not None: bias = ( torch.repeat_interleave( self.bias, torch.full( [self.num_estimators], examples_per_estimator, device=self.bias.device, ), dim=0, ) .unsqueeze(-1) .unsqueeze(-1) ) bias = torch.cat([bias, bias[:extra]], dim=0) else: bias = None return self.conv(inputs * r_group) * s_group + (bias if bias is not None else 0) def extra_repr(self) -> str: return ( f"in_channels={self.in_channels}," f" out_channels={self.out_channels}," f" kernel_size={self.kernel_size}," f" num_estimators={self.num_estimators}," f" stride={self.stride}" )