PackedLinear¶
- class torch_uncertainty.layers.PackedLinear(in_features, out_features, alpha, num_estimators, gamma=1, bias=True, first=False, last=False, implementation='legacy', rearrange=True, device=None, dtype=None)[source]¶
Packed-Ensembles-style Linear layer.
This layer computes fully-connected operation for a given number of estimators (
num_estimators
).- Parameters:
in_features (int) – Number of input features of the linear layer.
out_features (int) – Number of channels produced by the linear layer.
alpha (float) – The width multiplier of the linear layer.
num_estimators (int) – The number of estimators grouped in the layer.
gamma (int, optional) – Defaults to
1
.bias (bool, optional) – It
True
, adds a learnable bias to the output. Defaults toTrue
.first (bool, optional) – Whether this is the first layer of the network. Defaults to
False
.last (bool, optional) – Whether this is the last layer of the network. Defaults to
False
.implementation (str, optional) – The implementation to use. Defaults to
"legacy"
.rearrange (bool, optional) – Rearrange the input and outputs for compatibility with previous and later layers. Defaults to
True
.device (torch.device, optional) – The device to use for the layer’s parameters. Defaults to
None
.dtype (torch.dtype, optional) – The dtype to use for the layer’s parameters. Defaults to
None
.
- Explanation Note:
Increasing
alpha
will increase the number of channels of the ensemble, increasing its representation capacity. Increasinggamma
will increase the number of groups in the network and therefore reduce the number of parameters.
Note
Each ensemble member will only see \(\frac{\text{in_features}}{\text{num_estimators}}\) features, so when using
gamma
you should make sure thatin_features
andout_features
are both divisible byn_estimators
\(\times\)gamma
. However, the number of input and output features will be changed to comply with this constraint.Note
The input should be of shape (batch_size,
in_features
, 1, 1). The (often) necessary rearrange operation is executed by default.