Shortcuts

BatchLinear

class torch_uncertainty.layers.BatchLinear(in_features, out_features, num_estimators, bias=True, device=None, dtype=None)[source]

BatchEnsemble-style Linear layer.

Apply a linear transformation using BatchEnsemble method to the incoming data.

\[y=(x\circ \widehat{r_{group}})W^{T}\circ \widehat{s_{group}} + \widehat{b}\]
Parameters:
  • in_features (int) – Number of input features..

  • out_features (int) – Number of output features.

  • num_estimators (int) – number of estimators in the ensemble, referred as \(M\).

  • bias (bool, optional) – if True, adds a learnable bias to the output. Defaults to True.

  • device (Any, optional) – device to use for the parameters and buffers of this module. Defaults to None.

  • dtype (Any, optional) – data type to use for the parameters and buffers of this module. Defaults to None.

Reference:

Introduced by the paper BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning, we present here an implementation of a Linear BatchEnsemble layer in PyTorch heavily inspired by its official implementation in TensorFlow.

Variables:
  • weight – the learnable weights (\(W\)) of shape \((H_{out}, H_{in})\) shared between the estimators. The values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\), where \(k = \frac{1}{H_{in}}\).

  • r_group (torch.Tensor) – the learnable matrice of shape \((M, H_{in})\) where each row consist of the vector \(r_{i}\) corresponding to the \(i^{th}\) ensemble member. The values are initialized from \(\mathcal{N}(1.0, 0.5)\).

  • s_group (torch.Tensor) – the learnable matrice of shape \((M, H_{out})\) where each row consist of the vector \(s_{i}\) corresponding to the \(i^{th}\) ensemble member. The values are initialized from \(\mathcal{N}(1.0, 0.5)\).

  • bias (torch.Tensor | None) – the learnable bias (\(b\)) of shape \((M, H_{out})\) where each row corresponds to the bias of the \(i^{th}\) ensemble member. If bias is True, the values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{H_{in}}\).

Shape:
  • Input: \((N, H_{in})\) where \(N\) is the batch size and \(H_{in} = \text{in_features}\).

  • Output: \((N, H_{out})\) where \(H_{out} = \text{out_features}\).

Warning

Make sure that num_estimators divides out_features when calling forward().

Examples

>>> # With three estimators
>>> m = LinearBE(20, 30, 3)
>>> input = torch.randn(8, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([8, 30])