attention_utils¶
-
class
Linear3D
(hidden_size, num_attention_heads, size_per_head, weight_attr=None, bias_attr=None)[source]¶
-
class
Attention
(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[source]¶ -
forward
(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[source]¶ Defines the computation performed at every call. Should be overridden by all subclasses.
- Parameters
*inputs (tuple) – unpacked tuple arguments
**kwargs (dict) – unpacked dict arguments
-
-
class
DefaultAttention
(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[source]¶ -
forward
(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[source]¶ Defines the computation performed at every call. Should be overridden by all subclasses.
- Parameters
*inputs (tuple) – unpacked tuple arguments
**kwargs (dict) – unpacked dict arguments
-
-
class
BigBirdSparseAttention
(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[source]¶ -
forward
(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[source]¶ query_matrix: [B, H, T, D] key_matrix: [B, H, T, D] value_matrix: [B, H, T, D] query_mask: [B, 1, T, 1] bool mask key_mask: [B, 1, 1, T] bool mask rand_mask_idx: [H, T//bs, bs] Global Attention Random Attention Window Attention
-
-
class
MultiHeadAttention
(embed_dim, num_heads, dropout=0.0, kdim=None, vdim=None, weight_attr=None, bias_attr=None, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None, attention_type='bigbird')[source]¶ -
-
forward
(query, key, value, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, cache=None)[source]¶ Defines the computation performed at every call. Should be overridden by all subclasses.
- Parameters
*inputs (tuple) – unpacked tuple arguments
**kwargs (dict) – unpacked dict arguments
-