BatchConv2d¶
- class torch_uncertainty.layers.BatchConv2d(in_channels, out_channels, kernel_size, num_estimators, stride=1, padding=0, dilation=1, groups=1, bias=True, device=None, dtype=None)[source]¶
BatchEnsemble-style Conv2d layer.
Applies a 2d convolution over an input signal composed of several input planes using BatchEnsemble method to the incoming data.
In the simplest case, the output value of the layer with input size \((N, C_{in}, H_{in}, W_{in})\) and output \((N, C_{out}, H_{out}, W_{out})\) can be precisely described as:
\[\begin{split}\text{out}(N_i, C_{\text{out}_j})=\ &\widehat{b}(N_i,C_{\text{out}_j}) +\widehat{s_{group}}(N_{i},C_{\text{out}_j}) \\ &\times \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k)\star (\text{input}(N_i, k) \times \widehat{r_{group}}(N_i, k))\end{split}\]- Reference:
Introduced by the paper BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning, we present here an implementation of a Conv2d BatchEnsemble layer in PyTorch heavily inspired by its official implementation in TensorFlow.
- Parameters:
in_channels (int) – number of channels in the input images.
out_channels (int) – number of channels produced by the convolution.
kernel_size (int or tuple) – size of the convolving kernel.
num_estimators (int) – number of estimators in the ensemble referred as \(M\) here.
stride (int or tuple, optional) – stride of the convolution. Defaults to
1
.padding (int, tuple or str, optional) – padding added to all four sides of the input. Defaults to
0
.dilation (int or tuple, optional) – spacing between kernel elements. Defaults to
1
.groups (int, optional) – number of blocked connections from input channels to output channels. Defaults to
1
.bias (bool, optional) – if
True
, adds a learnable bias to the output. Defaults toTrue
.device (Any, optional) – device to use for the parameters and buffers of this module. Defaults to
None
.dtype (Any, optional) – data type to use for the parameters and buffers of this module. Defaults to
None
.
- Variables:
weight (torch.Tensor) – the learnable weights of the module of shape \((\text{out_channels}, \frac{\text{in_channels}} {\text{groups}},\)\(\text{kernel_size[0]}, \text{kernel_size[1]})\) shared between the estimators. The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}\).
r_group (torch.Tensor) – the learnable matrice of shape \((M, C_{in})\) where each row consist of the vector \(r_{i}\) corresponding to the \(i^{th}\) ensemble member. The values are initialized from \(\mathcal{N}(1.0, 0.5)\).
s_group (torch.Tensor) – the learnable matrice of shape \((M, C_{out})\) where each row consist of the vector \(s_{i}\) corresponding to the \(i^{th}\) ensemble member. The values are initialized from \(\mathcal{N}(1.0, 0.5)\).
bias (torch.Tensor | None) – the learnable bias (\(b\)) of shape \((M, C_{out})\) where each row corresponds to the bias of the \(i^{th}\) ensemble member. If
bias
isTrue
, the values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k=\frac{\text{groups}}{C_\text{in}*\prod_{i=0}^{1} \text{kernel_size}[i]}\).
- Shape:
Input: \((N, C_{in}, H_{in}, W_{in})\).
Output: \((N, C_{out}, H_{out}, W_{out})\).
\[H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel_size}[0] - 1) - 1} {\text{stride}[0]} + 1\right\rfloor\]\[W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel_size}[1] - 1) - 1} {\text{stride}[1]} + 1\right\rfloor\]
Warning
Make sure that
num_estimators
dividesout_channels
when callingforward()
.Examples
>>> # With square kernels, four estimators and equal stride >>> m = Conv2dBE(3, 32, 3, 4, stride=1) >>> input = torch.randn(8, 3, 16, 16) >>> output = m(input) >>> print(output.size()) torch.Size([8, 32, 14, 14])