Shortcuts

PackedMultiheadAttention

class torch_uncertainty.layers.PackedMultiheadAttention(embed_dim, num_heads, alpha, num_estimators, gamma=1, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, first=False, last=False, device=None, dtype=None)[source]

Packed-Ensembles-style MultiheadAttention layer.

Parameters:
  • embed_dim (int) – Size of the embedding dimension.

  • num_heads (int) – Number of parallel attention heads.

  • alpha (float) – The width multiplier of the embedding dimension.

  • num_estimators (int) – The number of estimators packed in the layer.

  • gamma (int, optional) – Defaults to 1.

  • dropout (float, optional) – Dropout probability on attn_output_weights. Defaults to 0.0 (no dropout).

  • bias (bool, optional) – Ì specified, adds bias to input / output projection layers. Defaults to True.

  • add_bias_kv (bool, optional) – If specified, adds bias to the key and value sequences at dim=0. Defaults to False.

  • add_zero_attn (bool, optional) – If specified, adds a new batch of zeros to the key and value sequences at dim=1. Defaults to False.

  • kdim (int | None, optional) – Total number of features for keys. Defaults to None (uses kdim=embed_dim).

  • vdim (int | None, optional) – Total number of features for values. Defaults to None (uses vdim=embed_dim).

  • batch_first (bool, optional) – If True, then the input and output tensors are provided as (batch, seq, feature). Defaults to False (seq, batch, feature).

  • first (bool, optional) – Whether this is the first layer of the network. Defaults to False.

  • last (bool, optional) – Whether this is the last layer of the network. Defaults to False.

  • device (torch.device, optional) – The device to use for the layer’s parameters. Defaults to None.

  • dtype (torch.dtype, optional) – The dtype to use for the layer’s parameters. Defaults to None.

Reference:
forward(query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, average_attn_weights=True, is_causal=False)[source]

Computes attention outputs given query, key, and value tensors.

Parameters:
  • query (Tensor) – Query embeddings of shape \((L, E_q)\) for unbatched input, \((L, B, E_q)\) when batch_first=False or \((B, L, E_q)\) when batch_first=True, where \(L\) is the target sequence length, \(B\) is the batch size, and \(E_q\) is the query embedding dimension embed_dim.

  • key (Tensor) – Key embeddingd of shape \((S, E_k)\) for unbatched input, \((S, B, E_k)\) when batch_first=False or \((B, S, E_k)\) when batch_first=True, where \(S\) is the source sequence length, \(B\) is the batch size and \(E_k\) is the key embedding dimension kdim.

  • value (Tensor) – Value embeddings of shape \((S, E_v)\) for unbatched input, \((S, B, E_v)\) when batch_first=False or \((B, S, E_v)\) when batch_first=True, where \(S\) is the source sequence length, \(B\) is the batch size and \(E_v\) is the value embedding dimension vdim.

  • key_padding_mask (Tensor | None, optional) – If specified, a mask of shape \((B, S)\) indicating which elements within key to ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be \((S)\). Binary and float masks are supported. For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding key value. Defaults to None.

  • need_weights (bool, optional) – If specified, returns attn_output_weights in addition to attn_outputs. Set need_weights=False to use the optimized scale_dot_product_attention and achieve the best performance for MHA. Defaults to False.

  • attn_mask (Tensor | None, optional) – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape \((L,S)\) or \((B \times \text{num_heads}, L, S)\), where \(B\) is the batch size, \(L\) is the target sequence length, and \(S\) is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary and float masks are supported. For a binary mask, a True value indicates that the corresponding position is not allowed to attend to. For a float mask, the mask values will be added to the attention weight. If both attn_mask and key_padding_mask are provided, their types should match. Defaults to None.

  • average_attn_weights (bool, optional) – If True, indicates that the returned attn_weights should be averaged across heads. Otherwise, attn_weights are provided separately per head. Note that this flag only has an effect when need_weights=True. Defaults to True.

  • is_causal (bool, optional) – _description_. Defaults to False.

Warning

need_weights=True and therefore average_attn_weights are not supported yet thus have no effect.

Returns:

  • attn_output (Tensor): The output tensor of shape \((L, E_q)\), \((L, B, E_q)\) or \((B, L, E_q)\) where \(L\) is the target sequence length, \(B\) is the batch size, and \(E_q\) is the embedding dimension embed_dim.

  • attn_output_weights (None): Always None has we do not support need_weights=True yet.

Return type:

tuple[Tensor, None]