attention_utils
- class Linear3D(hidden_size, num_attention_heads, size_per_head, weight_attr=None, bias_attr=None)[source]
- class Attention(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[source]
- forward(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[source]
Defines the computation performed at every call. Should be overridden by all subclasses.
- Parameters:
*inputs (tuple) – unpacked tuple arguments
**kwargs (dict) – unpacked dict arguments
- class DefaultAttention(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[source]
- forward(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[source]
Defines the computation performed at every call. Should be overridden by all subclasses.
- Parameters:
*inputs (tuple) – unpacked tuple arguments
**kwargs (dict) – unpacked dict arguments
- class BigBirdSparseAttention(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[source]
- forward(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[source]
query_matrix: [B, H, T, D] key_matrix: [B, H, T, D] value_matrix: [B, H, T, D] query_mask: [B, 1, T, 1] bool mask key_mask: [B, 1, 1, T] bool mask rand_mask_idx: [H, T//bs, bs] Global Attention Random Attention Window Attention
- class MultiHeadAttention(embed_dim, num_heads, dropout=0.0, kdim=None, vdim=None, weight_attr=None, bias_attr=None, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None, attention_type='bigbird')[source]
-
- forward(query, key, value, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, cache=None)[source]
Defines the computation performed at every call. Should be overridden by all subclasses.
- Parameters:
*inputs (tuple) – unpacked tuple arguments
**kwargs (dict) – unpacked dict arguments