PackedMultiheadAttention¶
- class torch_uncertainty.layers.PackedMultiheadAttention(embed_dim, num_heads, alpha, num_estimators, gamma=1, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, first=False, last=False, device=None, dtype=None)[source]¶
Packed-Ensembles-style MultiheadAttention layer.
- Parameters:
embed_dim (int) – Size of the embedding dimension.
num_heads (int) – Number of parallel attention heads.
alpha (float) – The width multiplier of the embedding dimension.
num_estimators (int) – The number of estimators packed in the layer.
gamma (int, optional) – Defaults to
1
.dropout (float, optional) – Dropout probability on
attn_output_weights
. Defaults to0.0
(no dropout).bias (bool, optional) – Ì specified, adds bias to input / output projection layers. Defaults to
True
.add_bias_kv (bool, optional) – If specified, adds bias to the key and value sequences at
dim=0
. Defaults toFalse
.add_zero_attn (bool, optional) – If specified, adds a new batch of zeros to the key and value sequences at
dim=1
. Defaults toFalse
.kdim (int | None, optional) – Total number of features for keys. Defaults to
None
(useskdim=embed_dim
).vdim (int | None, optional) – Total number of features for values. Defaults to
None
(usesvdim=embed_dim
).batch_first (bool, optional) – If
True
, then the input and output tensors are provided as (batch, seq, feature). Defaults toFalse
(seq, batch, feature).first (bool, optional) – Whether this is the first layer of the network. Defaults to
False
.last (bool, optional) – Whether this is the last layer of the network. Defaults to
False
.device (torch.device, optional) – The device to use for the layer’s parameters. Defaults to
None
.dtype (torch.dtype, optional) – The dtype to use for the layer’s parameters. Defaults to
None
.
- Reference:
Attention Is All You Need: Original Multihead Attention formulation.
Hierarchical Light Tranformer Ensembles for Multimodal Trajectory Forecasting : Packed-Ensembles-style Multihead Attention formulation.
- forward(query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, average_attn_weights=True, is_causal=False)[source]¶
Computes attention outputs given query, key, and value tensors.
- Parameters:
query (Tensor) – Query embeddings of shape \((L, E_q)\) for unbatched input, \((L, B, E_q)\) when
batch_first=False
or \((B, L, E_q)\) whenbatch_first=True
, where \(L\) is the target sequence length, \(B\) is the batch size, and \(E_q\) is the query embedding dimensionembed_dim
.key (Tensor) – Key embeddingd of shape \((S, E_k)\) for unbatched input, \((S, B, E_k)\) when
batch_first=False
or \((B, S, E_k)\) whenbatch_first=True
, where \(S\) is the source sequence length, \(B\) is the batch size and \(E_k\) is the key embedding dimensionkdim
.value (Tensor) – Value embeddings of shape \((S, E_v)\) for unbatched input, \((S, B, E_v)\) when
batch_first=False
or \((B, S, E_v)\) whenbatch_first=True
, where \(S\) is the source sequence length, \(B\) is the batch size and \(E_v\) is the value embedding dimensionvdim
.key_padding_mask (Tensor | None, optional) – If specified, a mask of shape \((B, S)\) indicating which elements within
key
to ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be \((S)\). Binary and float masks are supported. For a binary mask, aTrue
value indicates that the correspondingkey
value will be ignored for the purpose of attention. For a float mask, it will be directly added to the correspondingkey
value. Defaults toNone
.need_weights (bool, optional) – If specified, returns
attn_output_weights
in addition toattn_outputs
. Setneed_weights=False
to use the optimizedscale_dot_product_attention
and achieve the best performance for MHA. Defaults toFalse
.attn_mask (Tensor | None, optional) – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape \((L,S)\) or \((B \times \text{num_heads}, L, S)\), where \(B\) is the batch size, \(L\) is the target sequence length, and \(S\) is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary and float masks are supported. For a binary mask, a
True
value indicates that the corresponding position is not allowed to attend to. For a float mask, the mask values will be added to the attention weight. If bothattn_mask
andkey_padding_mask
are provided, their types should match. Defaults toNone
.average_attn_weights (bool, optional) – If
True
, indicates that the returnedattn_weights
should be averaged across heads. Otherwise,attn_weights
are provided separately per head. Note that this flag only has an effect whenneed_weights=True
. Defaults toTrue
.is_causal (bool, optional) – _description_. Defaults to
False
.
Warning
need_weights=True
and thereforeaverage_attn_weights
are not supported yet thus have no effect.- Returns:
attn_output (Tensor): The output tensor of shape \((L, E_q)\), \((L, B, E_q)\) or \((B, L, E_q)\) where \(L\) is the target sequence length, \(B\) is the batch size, and \(E_q\) is the embedding dimension
embed_dim
.attn_output_weights (None): Always
None
has we do not supportneed_weights=True
yet.
- Return type:
tuple[Tensor, None]