Shortcuts

BatchConv2d

class torch_uncertainty.layers.BatchConv2d(in_channels, out_channels, kernel_size, num_estimators, stride=1, padding=0, dilation=1, groups=1, bias=True, device=None, dtype=None)[source]

BatchEnsemble-style Conv2d layer.

Applies a 2d convolution over an input signal composed of several input planes using BatchEnsemble method to the incoming data.

In the simplest case, the output value of the layer with input size \((N, C_{in}, H_{in}, W_{in})\) and output \((N, C_{out}, H_{out}, W_{out})\) can be precisely described as:

\[\begin{split}\text{out}(N_i, C_{\text{out}_j})=\ &\widehat{b}(N_i,C_{\text{out}_j}) +\widehat{s_{group}}(N_{i},C_{\text{out}_j}) \\ &\times \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k)\star (\text{input}(N_i, k) \times \widehat{r_{group}}(N_i, k))\end{split}\]
Reference:

Introduced by the paper BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning, we present here an implementation of a Conv2d BatchEnsemble layer in PyTorch heavily inspired by its official implementation in TensorFlow.

Parameters:
  • in_channels (int) – number of channels in the input images.

  • out_channels (int) – number of channels produced by the convolution.

  • kernel_size (int or tuple) – size of the convolving kernel.

  • num_estimators (int) – number of estimators in the ensemble referred as \(M\) here.

  • stride (int or tuple, optional) – stride of the convolution. Defaults to 1.

  • padding (int, tuple or str, optional) – padding added to all four sides of the input. Defaults to 0.

  • dilation (int or tuple, optional) – spacing between kernel elements. Defaults to 1.

  • groups (int, optional) – number of blocked connections from input channels to output channels. Defaults to 1.

  • bias (bool, optional) – if True, adds a learnable bias to the output. Defaults to True.

  • device (Any, optional) – device to use for the parameters and buffers of this module. Defaults to None.

  • dtype (Any, optional) – data type to use for the parameters and buffers of this module. Defaults to None.

Variables:
  • weight (torch.Tensor) – the learnable weights of the module of shape \((\text{out_channels}, \frac{\text{in_channels}} {\text{groups}},\)\(\text{kernel_size[0]}, \text{kernel_size[1]})\) shared between the estimators. The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}\).

  • r_group (torch.Tensor) – the learnable matrice of shape \((M, C_{in})\) where each row consist of the vector \(r_{i}\) corresponding to the \(i^{th}\) ensemble member. The values are initialized from \(\mathcal{N}(1.0, 0.5)\).

  • s_group (torch.Tensor) – the learnable matrice of shape \((M, C_{out})\) where each row consist of the vector \(s_{i}\) corresponding to the \(i^{th}\) ensemble member. The values are initialized from \(\mathcal{N}(1.0, 0.5)\).

  • bias (torch.Tensor | None) – the learnable bias (\(b\)) of shape \((M, C_{out})\) where each row corresponds to the bias of the \(i^{th}\) ensemble member. If bias is True, the values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k=\frac{\text{groups}}{C_\text{in}*\prod_{i=0}^{1} \text{kernel_size}[i]}\).

Shape:
  • Input: \((N, C_{in}, H_{in}, W_{in})\).

  • Output: \((N, C_{out}, H_{out}, W_{out})\).

\[H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel_size}[0] - 1) - 1} {\text{stride}[0]} + 1\right\rfloor\]
\[W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel_size}[1] - 1) - 1} {\text{stride}[1]} + 1\right\rfloor\]

Warning

Ensure that batch_size is divisible by num_estimators when calling forward(). In a BatchEnsemble architecture, the input batch is typically repeated num_estimators times along the first axis. Incorrect batch size may lead to unexpected results.

To simplify batch handling, wrap your model with BatchEnsembleWrapper, which automatically repeats the batch before passing it through the network. See BatchEnsembleWrapper for details.

Examples

>>> # With square kernels, four estimators and equal stride
>>> m = BatchConv2d(3, 32, 3, 4, stride=1)
>>> input = torch.randn(8, 3, 16, 16)
>>> output = m(input)
>>> print(output.size())
torch.Size([8, 32, 14, 14])
classmethod from_conv2d(conv2d, num_estimators)[source]

Create a BatchEnsemble-style Conv2d layer from an existing Conv2d layer.

Parameters:
  • conv2d (nn.Conv2d) – The Conv2d layer to convert.

  • num_estimators (int) – Number of ensemble members.

Returns:

The converted BatchEnsemble-style Conv2d layer.

Return type:

BatchConv2d

Warning

All parameters of the original Conv2d layer will be discarded.

Example

>>> conv2d = nn.Conv2d(3, 32, kernel_size=3)
>>> be_conv2d = BatchConv2d.from_conv2d(conv2d, num_estimators=3)