attention_utils¶
-
class
Linear3D
(hidden_size, num_attention_heads, size_per_head, weight_attr=None, bias_attr=None)[源代码]¶
-
class
Attention
(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[源代码]¶ -
forward
(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[源代码]¶ Defines the computation performed at every call. Should be overridden by all subclasses.
- 参数
*inputs (tuple) -- unpacked tuple arguments
**kwargs (dict) -- unpacked dict arguments
-
-
class
DefaultAttention
(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[源代码]¶ -
forward
(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[源代码]¶ Defines the computation performed at every call. Should be overridden by all subclasses.
- 参数
*inputs (tuple) -- unpacked tuple arguments
**kwargs (dict) -- unpacked dict arguments
-
-
class
BigBirdSparseAttention
(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[源代码]¶ -
forward
(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[源代码]¶ query_matrix: [B, H, T, D] key_matrix: [B, H, T, D] value_matrix: [B, H, T, D] query_mask: [B, 1, T, 1] bool mask key_mask: [B, 1, 1, T] bool mask rand_mask_idx: [H, T//bs, bs] Global Attention Random Attention Window Attention
-