BatchLinear¶
- class torch_uncertainty.layers.BatchLinear(in_features, out_features, num_estimators, bias=True, device=None, dtype=None)[source]¶
BatchEnsemble-style Linear layer.
Apply a linear transformation using BatchEnsemble method to the incoming data.
\[y=(x\circ \widehat{r_{group}})W^{T}\circ \widehat{s_{group}} + \widehat{b}\]- Parameters:
in_features (int) – Number of input features..
out_features (int) – Number of output features.
num_estimators (int) – number of estimators in the ensemble, referred as \(M\).
bias (bool, optional) – if
True
, adds a learnable bias to the output. Defaults toTrue
.device (Any, optional) – device to use for the parameters and buffers of this module. Defaults to
None
.dtype (Any, optional) – data type to use for the parameters and buffers of this module. Defaults to
None
.
- Reference:
Introduced by the paper BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning, we present here an implementation of a Linear BatchEnsemble layer in PyTorch heavily inspired by its official implementation in TensorFlow.
- Variables:
weight – the learnable weights (\(W\)) of shape \((H_{out}, H_{in})\) shared between the estimators. The values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\), where \(k = \frac{1}{H_{in}}\).
r_group (torch.Tensor) – the learnable matrice of shape \((M, H_{in})\) where each row consist of the vector \(r_{i}\) corresponding to the \(i^{th}\) ensemble member. The values are initialized from \(\mathcal{N}(1.0, 0.5)\).
s_group (torch.Tensor) – the learnable matrice of shape \((M, H_{out})\) where each row consist of the vector \(s_{i}\) corresponding to the \(i^{th}\) ensemble member. The values are initialized from \(\mathcal{N}(1.0, 0.5)\).
bias (torch.Tensor | None) – the learnable bias (\(b\)) of shape \((M, H_{out})\) where each row corresponds to the bias of the \(i^{th}\) ensemble member. If
bias
isTrue
, the values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{H_{in}}\).
- Shape:
Input: \((N, H_{in})\) where \(N\) is the batch size and \(H_{in} = \text{in_features}\).
Output: \((N, H_{out})\) where \(H_{out} = \text{out_features}\).
Warning
Make sure that
num_estimators
dividesout_features
when callingforward()
.Examples
>>> # With three estimators >>> m = LinearBE(20, 30, 3) >>> input = torch.randn(8, 20) >>> output = m(input) >>> print(output.size()) torch.Size([8, 30])