attention_utils#

class Linear3D(hidden_size, num_attention_heads, size_per_head, weight_attr=None, bias_attr=None)[source]#
forward(input)[source]#

Defines the computation performed at every call. Should be overridden by all subclasses.

Parameters:
  • *inputs (tuple) – unpacked tuple arguments

  • **kwargs (dict) – unpacked dict arguments

class Attention(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[source]#
forward(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[source]#

Defines the computation performed at every call. Should be overridden by all subclasses.

Parameters:
  • *inputs (tuple) – unpacked tuple arguments

  • **kwargs (dict) – unpacked dict arguments

class DefaultAttention(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[source]#
forward(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[source]#

Defines the computation performed at every call. Should be overridden by all subclasses.

Parameters:
  • *inputs (tuple) – unpacked tuple arguments

  • **kwargs (dict) – unpacked dict arguments

class BigBirdSparseAttention(num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None)[source]#
forward(query_matrix, key_matrix, value_matrix, d_head, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, dropout=None)[source]#

query_matrix: [B, H, T, D] key_matrix: [B, H, T, D] value_matrix: [B, H, T, D] query_mask: [B, 1, T, 1] bool mask key_mask: [B, 1, 1, T] bool mask rand_mask_idx: [H, T//bs, bs] Global Attention Random Attention Window Attention

class MultiHeadAttention(embed_dim, num_heads, dropout=0.0, kdim=None, vdim=None, weight_attr=None, bias_attr=None, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None, attention_type='bigbird')[source]#
class Cache(k, v)#
k#

Alias for field number 0

v#

Alias for field number 1

class StaticCache(k, v)#
k#

Alias for field number 0

v#

Alias for field number 1

forward(query, key, value, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, cache=None)[source]#

Defines the computation performed at every call. Should be overridden by all subclasses.

Parameters:
  • *inputs (tuple) – unpacked tuple arguments

  • **kwargs (dict) – unpacked dict arguments