# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn import Layer
class Registry(object):
def __init__(self):
self.cls_dict = {}
def register(self, name):
def add_item(name, cls):
self.cls_dict[name] = cls
return cls
return lambda cls: add_item(name, cls)
AttentionRegistry = Registry()
def create_bigbird_rand_mask_idx(
num_layers, query_length, key_length, num_heads, block_size, window_size, num_global_blocks, num_rand_blocks, seed
):
# TODO(zsj): need to simplify
num_key_blocks = key_length // block_size
num_query_blocks = query_length // block_size
num_window_blocks = window_size // 2
all_key_blocks_idx = np.arange(0, num_key_blocks, dtype=np.int32)
rand_mask_idx = [[] for i in range(num_heads)]
for query_block_idx in range(num_query_blocks):
left_key_block_idx = max(0, query_block_idx - num_window_blocks)
right_key_block_idx = min(query_block_idx + num_window_blocks, num_key_blocks - 1)
illegal_blocks_idx = [i for i in range(left_key_block_idx, right_key_block_idx + 1)]
illegal_blocks_idx.extend([i for i in range(num_global_blocks)])
left_key_block_idx = query_block_idx - num_window_blocks
right_key_block_idx = query_block_idx + num_window_blocks
if num_global_blocks > left_key_block_idx:
num_fill_blocks = num_global_blocks - left_key_block_idx
illegal_blocks_idx.extend([i for i in range(num_key_blocks - num_fill_blocks, num_key_blocks)])
if right_key_block_idx >= num_key_blocks:
num_fill_blocks = right_key_block_idx - num_key_blocks + 1
illegal_blocks_idx.extend([i for i in range(num_global_blocks, num_global_blocks + num_fill_blocks)])
illegal_blocks_idx = set(illegal_blocks_idx)
for i in range(num_heads):
legal_blocks_idx = []
perm_block = np.random.permutation(all_key_blocks_idx)
for j in perm_block:
if j not in illegal_blocks_idx:
legal_blocks_idx.append(j)
if len(legal_blocks_idx) == num_rand_blocks:
break
rand_mask_idx[i].append(legal_blocks_idx)
rand_mask_idx = np.stack(rand_mask_idx, axis=0)
rand_mask_idx = rand_mask_idx[:, num_global_blocks:] - num_global_blocks // 2
# transform rand_mask_idx
H = rand_mask_idx.shape[0]
L = rand_mask_idx.shape[1]
R = rand_mask_idx.shape[2]
rand_mask_idx = rand_mask_idx.reshape([-1, 1])
head_idx = np.arange(H).reshape([-1, 1])
head_idx = np.pad(head_idx, ([0, 0], [0, L * R - 1]), mode="edge").reshape([-1, 1])
rand_mask_idx_list = np.concatenate([head_idx, rand_mask_idx], axis=1)
return rand_mask_idx_list
def create_bigbird_rand_mask_idx_list(
num_layers, query_length, key_length, num_heads, block_size, window_size, num_global_blocks, num_rand_blocks, seed
):
rand_mask_idx_list = [
create_bigbird_rand_mask_idx(
num_layers,
query_length,
key_length,
num_heads,
block_size,
window_size,
num_global_blocks,
num_rand_blocks,
seed,
)
for i in range(num_layers)
]
rand_mask_idx_list = np.stack(rand_mask_idx_list)
return rand_mask_idx_list
def _convert_param_attr_to_list(param_attr, n):
if isinstance(param_attr, (list, tuple)):
assert len(param_attr) == n, "length of param_attr should be %d when it is a list/tuple" % n
param_attrs = []
for attr in param_attr:
if isinstance(attr, bool):
if attr:
param_attrs.append(ParamAttr._to_attr(None))
else:
param_attrs.append(False)
else:
param_attrs.append(ParamAttr._to_attr(attr))
elif isinstance(param_attr, bool):
param_attrs = []
if param_attr:
param_attrs = [ParamAttr._to_attr(None) for i in range(n)]
else:
param_attrs = [False] * n
else:
param_attrs = []
attr = ParamAttr._to_attr(param_attr)
for i in range(n):
attr_i = copy.deepcopy(attr)
if attr.name:
attr_i.name = attr_i.name + "_" + str(i)
param_attrs.append(attr_i)
return param_attrs
[docs]class Linear3D(Layer):
def __init__(self, hidden_size, num_attention_heads, size_per_head, weight_attr=None, bias_attr=None):
super(Linear3D, self).__init__()
self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self.weight = self.create_parameter(
shape=[hidden_size, hidden_size], attr=self._weight_attr, dtype=self._dtype, is_bias=False
)
self.bias = self.create_parameter(shape=[hidden_size], attr=self._bias_attr, dtype=self._dtype, is_bias=True)
self.size_per_head = size_per_head
self.num_attention_heads = num_attention_heads
self.hidden_size = hidden_size
[docs] def forward(self, input):
# abc,cde->adbe
B, T, D = input.shape
H = self.num_attention_heads
result = paddle.matmul(input, self.weight)
reshape_b = paddle.reshape(self.bias, [1, 1, D])
result += reshape_b
result = paddle.reshape(result, [B, T, H, -1])
result = paddle.transpose(result, [0, 2, 1, 3])
return result
[docs]class Attention(Layer):
def __init__(self, num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None):
super().__init__()
[docs] def forward(
self,
query_matrix,
key_matrix,
value_matrix,
d_head,
attn_mask=None,
rand_mask_idx=None,
query_mask=None,
key_mask=None,
dropout=None,
):
raise NotImplementedError
[docs]@AttentionRegistry.register("default_attention")
class DefaultAttention(Attention):
[docs] def forward(
self,
query_matrix,
key_matrix,
value_matrix,
d_head,
attn_mask=None,
rand_mask_idx=None,
query_mask=None,
key_mask=None,
dropout=None,
):
# scale dot product attention
product = paddle.matmul(x=query_matrix, y=key_matrix, transpose_y=True)
product = product * (d_head**-0.5)
product += (1 - paddle.matmul(query_mask, key_mask)) * -1e6
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
if dropout:
weights = F.dropout(weights, dropout, training=self.training, mode="upscale_in_train")
out = paddle.matmul(weights, value_matrix)
return out
[docs]@AttentionRegistry.register("bigbird")
class BigBirdSparseAttention(Attention):
def __init__(self, num_heads=1, block_size=1, window_size=3, num_global_blocks=1, num_rand_blocks=1, seed=None):
super(BigBirdSparseAttention, self).__init__(
num_heads, block_size, window_size, num_global_blocks, num_rand_blocks, seed
)
for k, v in locals().items():
if k != "self":
setattr(self, k, v)
self.num_global_blocks_back = num_global_blocks // 2
self.num_global_blocks_front = (
num_global_blocks // 2 if num_global_blocks % 2 == 0 else num_global_blocks // 2 + 1
)
def _get_band_mask(self, blocked_query_mask, blocked_key_mask, batch_size, sequence_length):
"""
Return second mask: [B, 1, L-G, bs, G+W]
"""
GB = self.num_global_blocks_back
GF = self.num_global_blocks_front
G = self.num_global_blocks
W = self.window_size
bs = self.block_size
T = sequence_length
L = T // bs # blocked length
B = batch_size
H = self.num_heads
# G+W+R
# query_mask: [B, L, bs]
# key_mask: [B, L, bs]
# [B, L-G, bs, 1] * [B, L-G, 1, G*bs] -> [B, L-G, bs, G*bs]
temp_query_mask = paddle.reshape(blocked_query_mask[:, GF:-GB], [B, L - G, bs, 1])
temp_key_mask_front = paddle.reshape(blocked_key_mask[:, :GF], [B, 1, 1, GF * bs])
global_block_mask_front = paddle.einsum("blqd,bmdk->blqk", temp_query_mask, temp_key_mask_front)
temp_key_mask_back = paddle.reshape(blocked_key_mask[:, -GB:], [B, 1, 1, GB * bs])
global_block_mask_back = paddle.einsum("blqd,bmdk->blqk", temp_query_mask, temp_key_mask_back)
# create window block mask
key_mask_list = []
for query_block_id in range(GF, GF + W // 2):
left_block_id = query_block_id - W // 2
right_block_id = query_block_id + W // 2
zero_key_mask = paddle.zeros_like(blocked_key_mask[:, -(W - (right_block_id + 1 - G)) : -GB])
temp_key_mask = paddle.concat([blocked_key_mask[:, GF : (right_block_id + 1)], zero_key_mask], axis=1)
temp_key_mask = paddle.unsqueeze(temp_key_mask, 1)
key_mask_list.append(temp_key_mask)
roll_key_mask1 = paddle.concat(key_mask_list, axis=1)
roll_key_mask1 = paddle.reshape(roll_key_mask1, [0, 0, W * bs])
key_mask_list = []
band_length = L - G - W // 2 * 2
for query_block_id in range(GF + W // 2, GF + W // 2 + W):
left_block_id = query_block_id - W // 2
right_block_id = query_block_id + W // 2
key_mask_list.append(blocked_key_mask[:, left_block_id : left_block_id + band_length])
window_key_mask = paddle.concat(key_mask_list, axis=2)
window_key_mask = paddle.reshape(window_key_mask, [0, 0, W * bs])
key_mask_list = []
for query_block_id in range((L - GB) - W // 2, L - GB):
left_block_id = query_block_id - W // 2
right_block_id = query_block_id + W // 2
zero_key_mask = paddle.zeros_like(blocked_key_mask[:, GF : GF + W - (L - left_block_id - GB)])
temp_key_mask = paddle.concat([zero_key_mask, blocked_key_mask[:, left_block_id:-GB]], axis=1)
temp_key_mask = paddle.unsqueeze(temp_key_mask, 1)
key_mask_list.append(temp_key_mask)
roll_key_mask2 = paddle.concat(key_mask_list, axis=1)
roll_key_mask2 = paddle.reshape(roll_key_mask2, [0, 0, W * bs])
window_key_mask = paddle.concat([roll_key_mask1, window_key_mask, roll_key_mask2], axis=1)
window_key_mask = paddle.unsqueeze(window_key_mask, axis=2)
# [B, L-G, bs, 1] * [B, L-G, 1, W*bs] -> [B, L-G, bs, W*bs]
window_block_mask = paddle.einsum("blkd,bldq->blkq", temp_query_mask, window_key_mask)
band_mask = paddle.concat([global_block_mask_front, window_block_mask, global_block_mask_back], axis=3)
band_mask = paddle.unsqueeze(band_mask, 1) # for head
band_mask = paddle.expand(band_mask, [B, H, L - G, bs, -1])
return band_mask
def _get_band_matrix(self, blocked_matrix, B, T):
"""
return global and window matrix: [B, H, L-G, (G+W) * bs, -1]
"""
# blocked_matrix: [B, H, L, bs, -1]
GB = self.num_global_blocks_back
GF = self.num_global_blocks_front
G = self.num_global_blocks
W = self.window_size
bs = self.block_size
L = T // bs # blocked length
H = self.num_heads
# get roll matrix
blocked_list = []
for query_block_id in range(GF, GF + W // 2):
left_block_id = query_block_id - W // 2
right_block_id = query_block_id + W // 2
temp_blocked_matrix_list = [
blocked_matrix[:, :, 0 : (right_block_id + 1)],
blocked_matrix[:, :, -(G + W - right_block_id - 1) :],
]
temp_blocked_matrix = paddle.concat(temp_blocked_matrix_list, axis=2)
temp_blocked_matrix = paddle.unsqueeze(temp_blocked_matrix, axis=2)
blocked_list.append(temp_blocked_matrix)
# get window matrix
band_length = L - G - W // 2 * 2
band_matrix_list = []
for query_block_id in range(GF + W // 2, GF + W // 2 + W):
left_block_id = query_block_id - W // 2
right_block_id = query_block_id + W // 2
band_matrix_list.append(
paddle.unsqueeze(blocked_matrix[:, :, left_block_id : left_block_id + band_length], axis=3)
)
band_matrix = paddle.concat(band_matrix_list, axis=3)
global_blocked_front_matrix = paddle.unsqueeze(blocked_matrix[:, :, :GF], axis=2)
global_blocked_front_matrix = paddle.expand(global_blocked_front_matrix, [B, H, band_length, GF, bs, -1])
global_blocked_back_matrix = paddle.unsqueeze(blocked_matrix[:, :, -GB:], axis=2)
global_blocked_back_matrix = paddle.expand(global_blocked_back_matrix, [B, H, band_length, GB, bs, -1])
band_matrix = paddle.concat([global_blocked_front_matrix, band_matrix, global_blocked_back_matrix], axis=3)
blocked_list.append(band_matrix)
for query_block_id in range(L - GB - W // 2, L - GB):
left_block_id = query_block_id - W // 2
right_block_id = query_block_id + W // 2
temp_blocked_matrix_list = [
blocked_matrix[:, :, 0 : G + W - (L - left_block_id)],
blocked_matrix[:, :, left_block_id:],
]
temp_blocked_matrix = paddle.concat(temp_blocked_matrix_list, axis=2)
temp_blocked_matrix = paddle.unsqueeze(temp_blocked_matrix, axis=2)
blocked_list.append(temp_blocked_matrix)
band_matrix = paddle.concat(blocked_list, axis=2)
band_matrix = paddle.reshape(band_matrix, [B, H, L - G, (G + W) * bs, -1])
return band_matrix
def _get_rand_mask(self, blocked_query_mask, blocked_key_mask, rand_mask_idx, batch_size, sequence_length):
"""
return random mask: [B, H, L-G, bs, R * bs]
"""
# rand_mask_idx: [H, T]
# blocked_query_mask: [B, L, bs]
# blocked_key_mask: [B, L, bs]
bs = self.block_size
B = batch_size
L = sequence_length // bs
H = self.num_heads
GB = self.num_global_blocks_back
GF = self.num_global_blocks_front
R = self.num_rand_blocks
temp_block_key_mask = paddle.unsqueeze(blocked_key_mask, 1)
temp_block_key_mask = paddle.expand(temp_block_key_mask, [B, H, L, -1])
temp_block_key_mask_list = [paddle.gather_nd(temp_block_key_mask[b], rand_mask_idx) for b in range(B)]
temp_block_key_mask = paddle.concat(temp_block_key_mask_list, 0)
temp_block_key_mask = paddle.reshape(
temp_block_key_mask, [B, temp_block_key_mask.shape[0] // B // (L - GF - GB) // R, L - GF - GB, -1]
)
rand_mask = paddle.einsum("blq,bhlk->bhlqk", blocked_query_mask[:, GF:-GB], temp_block_key_mask)
return rand_mask
def _gather_random_key_value(self, blocked_matrix, rand_mask_idx, B, T):
"""
return random key matrix: [B, H, L-G, R * bs, -1]
"""
# blocked_matrix: [B, H, L, bs, -1]
# rand_mask_idx: [H, T]
G = self.num_global_blocks
H = self.num_heads
bs = self.block_size
L = T // bs
R = self.num_rand_blocks
gathered_matrix = paddle.concat(
[paddle.gather_nd(blocked_matrix[b, :], rand_mask_idx) for b in range(B)], axis=0
)
gathered_matrix = paddle.reshape(gathered_matrix, [B, H, L - G, R * bs, -1])
return gathered_matrix
def _get_global_out(self, query_matrix, key_matrix, value_matrix, key_mask, d_head, dropout, is_front=True):
GB = self.num_global_blocks_back
GF = self.num_global_blocks_front
if is_front:
global_query_matrix = query_matrix[:, :, 0 : GF * self.block_size]
else:
global_query_matrix = query_matrix[:, :, -GB * self.block_size :]
global_product = paddle.matmul(global_query_matrix, key_matrix, transpose_y=True)
global_product = global_product * (d_head**-0.5)
global_product += (1 - key_mask) * -1e6
global_weights = F.softmax(global_product)
# [B, H, GF*bs, T] * [B, H, T, D] -> [B, H, GF*bs, D]
global_product = paddle.matmul(global_weights, value_matrix)
return global_product
def _get_splited_matrix(self, matrix):
W = self.window_size // 2
return matrix[:, :, 0:W], matrix[:, :, W:-W], matrix[:, :, -W:]
[docs] def forward(
self,
query_matrix,
key_matrix,
value_matrix,
d_head,
attn_mask=None,
rand_mask_idx=None,
query_mask=None,
key_mask=None,
dropout=None,
):
"""
query_matrix: [B, H, T, D]
key_matrix: [B, H, T, D]
value_matrix: [B, H, T, D]
query_mask: [B, 1, T, 1] bool mask
key_mask: [B, 1, 1, T] bool mask
rand_mask_idx: [H, T//bs, bs]
Global Attention
Random Attention
Window Attention
"""
B = query_matrix.shape[0] # batch_size
H = self.num_heads
T = query_matrix.shape[2] # sequence_length
G = self.num_global_blocks
GB = self.num_global_blocks_back
GF = self.num_global_blocks_front
R = self.num_rand_blocks
bs = self.block_size
L = T // bs # blocked length
blocked_query_matrix = paddle.reshape(query_matrix, [B, H, L, bs, -1])
blocked_key_matrix = paddle.reshape(key_matrix, [B, H, L, bs, -1])
blocked_value_matrix = paddle.reshape(value_matrix, [B, H, L, bs, -1])
blocked_query_mask = paddle.reshape(query_mask, [B, L, bs])
blocked_key_mask = paddle.reshape(key_mask, [B, L, bs])
# 1. global_front_product
global_front_out = self._get_global_out(query_matrix, key_matrix, value_matrix, key_mask, d_head, dropout)
# 2. global_back_product
global_back_out = self._get_global_out(
query_matrix, key_matrix, value_matrix, key_mask, d_head, dropout, False
)
# 3. second_product
# create second matrix
# [B, 1, L-G, bs, (G+W)*bs]
band_mask = self._get_band_mask(blocked_query_mask, blocked_key_mask, B, T)
# [B, H, L-G, bs, R*bs]
rand_mask = self._get_rand_mask(blocked_query_mask, blocked_key_mask, rand_mask_idx, B, T)
# [B, H, L-G, bs, (G+W+R)*bs]
second_mask = paddle.concat([band_mask, rand_mask], axis=4)
# [B, H, L-G, R * bs, -1]
random_keys = self._gather_random_key_value(blocked_key_matrix, rand_mask_idx, B, T)
random_values = self._gather_random_key_value(blocked_value_matrix, rand_mask_idx, B, T)
band_keys_matrix = self._get_band_matrix(blocked_key_matrix, B, T)
band_value_matrix = self._get_band_matrix(blocked_value_matrix, B, T)
# [B, H, L - G, bs, -1]
second_query_matrix = blocked_query_matrix[:, :, GF:-GB]
# [B, H, L - G, (G+W+R)*bs, -1]
second_key_matrix = paddle.concat([band_keys_matrix, random_keys], axis=3)
# [B, H, L - G, (G+W+R)*bs, -1]
second_value_matrix = paddle.concat([band_value_matrix, random_values], axis=3)
second_top_value_matrix, second_middle_value_matrix, second_bottom_value_matrix = self._get_splited_matrix(
second_value_matrix
)
second_product = paddle.einsum("bhlqd,bhlkd->bhlqk", second_query_matrix, second_key_matrix)
second_product = second_product * (d_head**-0.5)
second_product += (1 - second_mask) * -1e6
second_weights = F.softmax(second_product)
second_top_weights, second_middle_weights, second_bottom_weights = self._get_splited_matrix(second_weights)
second_top_out = paddle.einsum("bhlqk,bhlkd->bhlqd", second_top_weights, second_top_value_matrix)
second_middle_out = paddle.einsum(
"bhlqk,bhlkd->bhlqd",
second_middle_weights[:, :, :, :, GF * bs : -(GB + R) * bs],
second_middle_value_matrix[:, :, :, GF * bs : -(GB + R) * bs],
)
# add global block attention
second_middle_out += paddle.einsum(
"bhlqk,bhkd->bhlqd", second_middle_weights[:, :, :, :, : GF * bs], blocked_value_matrix[:, :, 0]
)
second_middle_out += paddle.einsum(
"bhlqk,bhkd->bhlqd",
second_middle_weights[:, :, :, :, -(GB + R) * bs : -R * bs],
blocked_value_matrix[:, :, -GB],
)
# add random block attention
second_middle_out += paddle.einsum(
"...qk,...kd->...qd", second_middle_weights[:, :, :, :, -R * bs :], random_values[:, :, GF:-GB]
)
second_bottom_out = paddle.einsum("bhlqk,bhlkd->bhlqd", second_bottom_weights, second_bottom_value_matrix)
second_out = paddle.concat([second_top_out, second_middle_out, second_bottom_out], axis=2)
second_out = paddle.reshape(second_out, [B, H, (L - G) * bs, -1])
# [B, H, T, D]
out = paddle.concat([global_front_out, second_out, global_back_out], axis=2)
out = out * query_mask
return out
[docs]class MultiHeadAttention(Layer):
Cache = collections.namedtuple("Cache", ["k", "v"])
StaticCache = collections.namedtuple("StaticCache", ["k", "v"])
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
kdim=None,
vdim=None,
weight_attr=None,
bias_attr=None,
block_size=1,
window_size=3,
num_global_blocks=1,
num_rand_blocks=1,
seed=None,
attention_type="bigbird",
):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.q_proj = Linear3D(embed_dim, num_heads, self.head_dim, weight_attr, bias_attr=bias_attr)
self.k_proj = Linear3D(embed_dim, num_heads, self.head_dim, weight_attr, bias_attr=bias_attr)
self.v_proj = Linear3D(embed_dim, num_heads, self.head_dim, weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
self.attn_impl = AttentionRegistry.cls_dict[attention_type](
num_heads, block_size, window_size, num_global_blocks, num_rand_blocks, seed
)
def _prepare_qkv(self, query, key, value, cache=None):
q = self.q_proj(query)
if isinstance(cache, self.StaticCache):
# for encoder-decoder attention in inference and has cached
k, v = cache.k, cache.v
else:
k, v = self.compute_kv(key, value)
if isinstance(cache, self.Cache):
# for decoder self-attention in inference
k = paddle.concat([cache.k, k], axis=2)
v = paddle.concat([cache.v, v], axis=2)
cache = self.Cache(k, v)
return (q, k, v) if cache is None else (q, k, v, cache)
def compute_kv(self, key, value):
k = self.k_proj(key)
v = self.v_proj(value)
return k, v
def gen_cache(self, key, value=None, type=Cache):
if type == MultiHeadAttention.StaticCache: # static_kv
k, v = self.compute_kv(key, value)
return self.StaticCache(k, v)
elif value is None: # incremental_state
k = paddle.full(shape=[-1, self.num_heads, 0, self.head_dim], fill_value=0, dtype=key.dtype)
v = paddle.full(shape=[-1, self.num_heads, 0, self.head_dim], fill_value=0, dtype=key.dtype)
return self.Cache(k, v)
else:
# incremental_state with initial value, mainly for usage like UniLM
return self.Cache(key, value)
[docs] def forward(
self, query, key, value, attn_mask=None, rand_mask_idx=None, query_mask=None, key_mask=None, cache=None
):
key = query if key is None else key
value = query if value is None else value
# compute q ,k ,v
if cache is None:
q, k, v = self._prepare_qkv(query, key, value, cache)
else:
q, k, v, cache = self._prepare_qkv(query, key, value, cache)
out = self.attn_impl(q, k, v, self.head_dim, attn_mask, rand_mask_idx, query_mask, key_mask, self.dropout)
# combine heads
out = paddle.transpose(out, perm=[0, 2, 1, 3])
out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
outs = [out]
if cache is not None:
outs.append(cache)
return out if len(outs) == 1 else tuple(outs)