attention_utils
- class Linear3D(hidden_size, num_attention_heads, size_per_head, weight_attr=None, bias_attr=None)[源代码]
- class Attention(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[源代码]
- forward(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[源代码]
Defines the computation performed at every call. Should be overridden by all subclasses.
- 参数:
*inputs (tuple) -- unpacked tuple arguments
**kwargs (dict) -- unpacked dict arguments
- class DefaultAttention(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[源代码]
- forward(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[源代码]
Defines the computation performed at every call. Should be overridden by all subclasses.
- 参数:
*inputs (tuple) -- unpacked tuple arguments
**kwargs (dict) -- unpacked dict arguments
- class BigBirdSparseAttention(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[源代码]
- forward(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[源代码]
query_matrix: [B, H, T, D] key_matrix: [B, H, T, D] value_matrix: [B, H, T, D] query_mask: [B, 1, T, 1] bool mask key_mask: [B, 1, 1, T] bool mask rand_mask_idx: [H, T//bs, bs] Global Attention Random Attention Window Attention