PackedMultiheadAttention#
- class torch_uncertainty.layers.PackedMultiheadAttention(embed_dim, num_heads, alpha, num_estimators, gamma=1, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, first=False, last=False, device=None, dtype=None)[source]#
Packed-Ensembles-style MultiheadAttention layer.
- Parameters:
embed_dim (int) – Size of the embedding dimension.
num_heads (int) – Number of parallel attention heads.
alpha (float) – The width multiplier of the embedding dimension.
num_estimators (int) – The number of estimators packed in the layer.
gamma (int, optional) – Defaults to
1.dropout (float, optional) – Dropout probability on
attn_output_weights. Defaults to0.0(no dropout).bias (bool, optional) – Ì specified, adds bias to input / output projection layers. Defaults to
True.add_bias_kv (bool, optional) – If specified, adds bias to the key and value sequences at
dim=0. Defaults toFalse.add_zero_attn (bool, optional) – If specified, adds a new batch of zeros to the key and value sequences at
dim=1. Defaults toFalse.kdim (int | None, optional) – Total number of features for keys. Defaults to
None(useskdim=embed_dim).vdim (int | None, optional) – Total number of features for values. Defaults to
None(usesvdim=embed_dim).batch_first (bool, optional) – If
True, then the input and output tensors are provided as (batch, seq, feature). Defaults toFalse(seq, batch, feature).first (bool, optional) – Whether this is the first layer of the network. Defaults to
False.last (bool, optional) – Whether this is the last layer of the network. Defaults to
False.device (torch.device, optional) – The device to use for the layer’s parameters. Defaults to
None.dtype (torch.dtype, optional) – The dtype to use for the layer’s parameters. Defaults to
None.
- Reference:
Attention Is All You Need: Original Multihead Attention formulation.
Hierarchical Light Tranformer Ensembles for Multimodal Trajectory Forecasting : Packed-Ensembles-style Multihead Attention formulation.
- forward(query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, average_attn_weights=True, is_causal=False)[source]#
Computes attention outputs given query, key, and value tensors.
- Parameters:
query (Tensor) – Query embeddings of shape \((L, E_q)\) for unbatched input, \((L, B, E_q)\) when
batch_first=Falseor \((B, L, E_q)\) whenbatch_first=True, where \(L\) is the target sequence length, \(B\) is the batch size, and \(E_q\) is the query embedding dimensionembed_dim.key (Tensor) – Key embeddingd of shape \((S, E_k)\) for unbatched input, \((S, B, E_k)\) when
batch_first=Falseor \((B, S, E_k)\) whenbatch_first=True, where \(S\) is the source sequence length, \(B\) is the batch size and \(E_k\) is the key embedding dimensionkdim.value (Tensor) – Value embeddings of shape \((S, E_v)\) for unbatched input, \((S, B, E_v)\) when
batch_first=Falseor \((B, S, E_v)\) whenbatch_first=True, where \(S\) is the source sequence length, \(B\) is the batch size and \(E_v\) is the value embedding dimensionvdim.key_padding_mask (Tensor | None, optional) – If specified, a mask of shape \((B, S)\) indicating which elements within
keyto ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be \((S)\). Binary and float masks are supported. For a binary mask, aTruevalue indicates that the correspondingkeyvalue will be ignored for the purpose of attention. For a float mask, it will be directly added to the correspondingkeyvalue. Defaults toNone.need_weights (bool, optional) – If specified, returns
attn_output_weightsin addition toattn_outputs. Setneed_weights=Falseto use the optimizedscale_dot_product_attentionand achieve the best performance for MHA. Defaults toFalse.attn_mask (Tensor | None, optional) – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape \((L,S)\) or \((B \times \text{num_heads}, L, S)\), where \(B\) is the batch size, \(L\) is the target sequence length, and \(S\) is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary and float masks are supported. For a binary mask, a
Truevalue indicates that the corresponding position is not allowed to attend to. For a float mask, the mask values will be added to the attention weight. If bothattn_maskandkey_padding_maskare provided, their types should match. Defaults toNone.average_attn_weights (bool, optional) – If
True, indicates that the returnedattn_weightsshould be averaged across heads. Otherwise,attn_weightsare provided separately per head. Note that this flag only has an effect whenneed_weights=True. Defaults toTrue.is_causal (bool, optional) – _description_. Defaults to
False.
Warning
need_weights=Trueand thereforeaverage_attn_weightsare not supported yet thus have no effect.- Returns:
attn_output (Tensor): The output tensor of shape \((L, E_q)\), \((L, B, E_q)\) or \((B, L, E_q)\) where \(L\) is the target sequence length, \(B\) is the batch size, and \(E_q\) is the embedding dimension
embed_dim.attn_output_weights (None): Always
Nonehas we do not supportneed_weights=Trueyet.
- Return type:
tuple[Tensor, None]