BatchLinear#

class torch_uncertainty.layers.BatchLinear(in_features, out_features, num_estimators, bias=True, device=None, dtype=None)[source]#

BatchEnsemble-style Linear layer.

Apply a linear transformation using the BatchEnsemble method to the incoming data.

\[y=(x\circ \widehat{r_{group}})W^{T}\circ \widehat{s_{group}} + \widehat{b}\]
Parameters:
  • in_features (int) – Number of input features.

  • out_features (int) – Number of output features.

  • num_estimators (int) – Number of estimators in the ensemble (M).

  • bias (bool) – If True, adds a learnable bias to the output. Defaults to True.

  • device – Device to use for parameters and buffers. Defaults to None.

  • dtype – Data type to use for parameters and buffers. Defaults to None.

Reference:

Introduced by the paper BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning, we present here an implementation of a Linear BatchEnsemble layer in PyTorch heavily inspired by its official implementation in TensorFlow.

Variables:
  • weight – The learnable weights (\(W\)) of shape \((H_{out}, H_{in})\) shared between the estimators. The values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\), where \(k = \frac{1}{H_{in}}\).

  • r_group – Learnable matrix of shape \((M, H_{in})\) where each row is the vector \(r_{i}\) corresponding to the \(i^{th}\) ensemble member. Initialized from \(\mathcal{N}(1.0, 0.5)\).

  • s_group – Learnable matrix of shape \((M, H_{out})\) where each row is the vector \(s_{i}\) corresponding to the \(i^{th}\) ensemble member. Initialized from \(\mathcal{N}(1.0, 0.5)\).

  • bias – The learnable bias (\(b\)) of shape \((M, H_{out})\) where each row corresponds to the bias of the \(i^{th}\) ensemble member. If bias is True, the values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{H_{in}}\).

Shape:
  • Input: \((N, H_{in})\) where \(N\) is the batch size and \(H_{in} = \text{in_features}\).

  • Output: \((N, H_{out})\) where \(H_{out} = \text{out_features}\).

Warning

It is advised to ensure that batch_size is divisible by num_estimators when calling forward(), so each estimator receives the same number of examples. In a BatchEnsemble architecture, the input is typically repeated num_estimators times along the batch dimension. Incorrect batch size may lead to unexpected results.

To simplify batch handling, wrap your model with torch_uncertainty.wrappers.BatchEnsemble, which automatically repeats the batch before passing it through the network.

Examples

>>> # With three estimators
>>> m = BatchLinear(20, 30, 3)
>>> input = torch.randn(8, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([8, 30])
classmethod from_linear(linear, num_estimators)[source]#

Create a BatchEnsemble-style Linear layer from an existing Linear layer.

Parameters:
  • linear (Linear) – The Linear layer to convert.

  • num_estimators (int) – Number of ensemble members.

Returns:

The converted BatchEnsemble-style Linear layer.

Return type:

BatchLinear

Example

>>> linear = nn.Linear(20, 30)
>>> be_linear = BatchLinear.from_linear(linear, num_estimators=3)