Source code for paddlenlp.ops.fast_transformer.transformer.decoder
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlenlp.ops import transfer_param
from paddlenlp.ops.ext_utils import LOADED_EXT, load
from paddlenlp.transformers import (
PositionalEmbedding,
WordEmbedding,
position_encoding_init,
)
from paddlenlp.utils.log import logger
from .decoding import run_custom
def infer_transformer_decoder(
from_tensor,
memory_tensor,
mem_seq_len,
self_ln_weight,
self_ln_bias,
self_q_weight,
self_q_bias,
self_k_weight,
self_k_bias,
self_v_weight,
self_v_bias,
self_out_weight,
self_out_bias,
cross_ln_weight,
cross_ln_bias,
cross_q_weight,
cross_q_bias,
cross_k_weight,
cross_k_bias,
cross_v_weight,
cross_v_bias,
cross_out_weight,
cross_out_bias,
ffn_ln_weight,
ffn_ln_bias,
ffn_inter_weight,
ffn_inter_bias,
ffn_out_weight,
ffn_out_bias,
old_self_cache_key,
old_self_cache_value,
old_mem_cache,
step,
n_head,
size_per_head,
memory_hidden_dim,
is_fuse_qkv=False,
):
inputs_names = [
"FromTensor",
"MemoryTensor",
"MemSeqLen",
"SelfLayernormWeight",
"SelfLayernormBias",
"SelfQueryWeight",
"SelfQueryBias",
"SelfKeyWeight",
"SelfKeyBias",
"SelfValueWeight",
"SelfValueBias",
"SelfOutWeight",
"SelfOutBias",
"CrossLayernormWeight",
"CrossLayernormBias",
"CrossQueryWeight",
"CrossQueryBias",
"CrossKeyWeight",
"CrossKeyBias",
"CrossValueWeight",
"CrossValueBias",
"CrossOutWeight",
"CrossOutBias",
"FFNLayernormWeight",
"FFNLayernormBias",
"FFNInterWeight",
"FFNInterBias",
"FFNOutWeight",
"FFNOutBias",
"OldSelfCacheKey",
"OldSelfCacheValue",
]
inputs_var = [
from_tensor,
memory_tensor,
mem_seq_len,
self_ln_weight,
self_ln_bias,
self_q_weight,
self_q_bias,
self_k_weight,
self_k_bias,
self_v_weight,
self_v_bias,
self_out_weight,
self_out_bias,
cross_ln_weight,
cross_ln_bias,
cross_q_weight,
cross_q_bias,
cross_k_weight,
cross_k_bias,
cross_v_weight,
cross_v_bias,
cross_out_weight,
cross_out_bias,
ffn_ln_weight,
ffn_ln_bias,
ffn_inter_weight,
ffn_inter_bias,
ffn_out_weight,
ffn_out_bias,
old_self_cache_key,
old_self_cache_value,
old_mem_cache,
]
attrs_names = ["step", "n_head", "size_per_head", "memory_hidden_dim", "is_fuse_qkv"]
attrs_val = [step, n_head, size_per_head, memory_hidden_dim, is_fuse_qkv]
outputs_names = ["DecoderOutput", "NewSelfCacheKey", "NewSelfCacheValue", "NewMemCache"]
outputs_dtype = [memory_tensor.dtype] * len(outputs_names)
return run_custom("fusion_decoder", inputs_names, inputs_var, attrs_names, attrs_val, outputs_names, outputs_dtype)
def get_op_cache_config(use_batch_major_op_cache, size_per_head, is_fp16):
x = 8 if is_fp16 else 4
use_batch_major_op_cache = True if use_batch_major_op_cache is True and size_per_head % x == 0 else False
x = x if use_batch_major_op_cache else 1
return use_batch_major_op_cache, x
[docs]
class InferTransformerDecoder(nn.Layer):
"""
FasterTransformer decoder block.
Args:
decoder (`TransformerDecoder`):
Transformer decoder block.
n_head (`int`):
The number of head used in multi-head attention.
size_per_head (`int`):
The size of per head used in multi-head attention.
decoder_lib (`str`):
The path to decoder_lib. Default to None.
use_fp16_decoder (`bool`):
Whether to use fp16 for decoder. Default to False.
"""
def __init__(
self, decoder, n_head, size_per_head, decoder_lib=None, use_fp16_decoder=False, use_batch_major_op_cache=False
):
if decoder_lib is not None and os.path.isfile(decoder_lib):
# Maybe it has been loadad by `ext_utils.load`
if "FastGeneration" not in LOADED_EXT.keys():
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(decoder_lib)
LOADED_EXT["FastGeneration"] = ops
else:
if decoder_lib is not None:
logger.warning("The specified decoder_lib does not exist, and it will be built automatically.")
load("FastGeneration", verbose=True)
super(InferTransformerDecoder, self).__init__()
self.n_head = n_head
self.size_per_head = size_per_head
self.use_batch_major_op_cache = use_batch_major_op_cache
if use_fp16_decoder:
for idx, mod in enumerate(decoder.layers):
mod.norm1.weight = transfer_param(mod.norm1.weight)
mod.norm1.bias = transfer_param(mod.norm1.bias, is_bias=True)
mod.self_attn.q_proj.weight = transfer_param(mod.self_attn.q_proj.weight)
mod.self_attn.q_proj.bias = transfer_param(mod.self_attn.q_proj.bias, is_bias=True)
mod.self_attn.k_proj.weight = transfer_param(mod.self_attn.k_proj.weight)
mod.self_attn.k_proj.bias = transfer_param(mod.self_attn.k_proj.bias, is_bias=True)
mod.self_attn.v_proj.weight = transfer_param(mod.self_attn.v_proj.weight)
mod.self_attn.v_proj.bias = transfer_param(mod.self_attn.v_proj.bias, is_bias=True)
mod.self_attn.out_proj.weight = transfer_param(mod.self_attn.out_proj.weight)
mod.self_attn.out_proj.bias = transfer_param(mod.self_attn.out_proj.bias, is_bias=True)
mod.norm2.weight = transfer_param(mod.norm2.weight)
mod.norm2.bias = transfer_param(mod.norm2.bias, is_bias=True)
mod.cross_attn.q_proj.weight = transfer_param(mod.cross_attn.q_proj.weight)
mod.cross_attn.q_proj.bias = transfer_param(mod.cross_attn.q_proj.bias, is_bias=True)
mod.cross_attn.k_proj.weight = transfer_param(mod.cross_attn.k_proj.weight)
mod.cross_attn.k_proj.bias = transfer_param(mod.cross_attn.k_proj.bias, is_bias=True)
mod.cross_attn.v_proj.weight = transfer_param(mod.cross_attn.v_proj.weight)
mod.cross_attn.v_proj.bias = transfer_param(mod.cross_attn.v_proj.bias, is_bias=True)
mod.cross_attn.out_proj.weight = transfer_param(mod.cross_attn.out_proj.weight)
mod.cross_attn.out_proj.bias = transfer_param(mod.cross_attn.out_proj.bias, is_bias=True)
mod.norm3.weight = transfer_param(mod.norm3.weight)
mod.norm3.bias = transfer_param(mod.norm3.bias, is_bias=True)
mod.linear1.weight = transfer_param(mod.linear1.weight)
mod.linear1.bias = transfer_param(mod.linear1.bias, is_bias=True)
mod.linear2.weight = transfer_param(mod.linear2.weight)
mod.linear2.bias = transfer_param(mod.linear2.bias, is_bias=True)
self.weights = []
for idx, mod in enumerate(decoder.layers):
layer_weight = []
layer_weight.append(mod.norm1.weight)
layer_weight.append(mod.norm1.bias)
layer_weight.append(mod.self_attn.q_proj.weight)
layer_weight.append(mod.self_attn.q_proj.bias)
layer_weight.append(mod.self_attn.k_proj.weight)
layer_weight.append(mod.self_attn.k_proj.bias)
layer_weight.append(mod.self_attn.v_proj.weight)
layer_weight.append(mod.self_attn.v_proj.bias)
layer_weight.append(mod.self_attn.out_proj.weight)
layer_weight.append(mod.self_attn.out_proj.bias)
layer_weight.append(mod.norm2.weight)
layer_weight.append(mod.norm2.bias)
layer_weight.append(mod.cross_attn.q_proj.weight)
layer_weight.append(mod.cross_attn.q_proj.bias)
layer_weight.append(mod.cross_attn.k_proj.weight)
layer_weight.append(mod.cross_attn.k_proj.bias)
layer_weight.append(mod.cross_attn.v_proj.weight)
layer_weight.append(mod.cross_attn.v_proj.bias)
layer_weight.append(mod.cross_attn.out_proj.weight)
layer_weight.append(mod.cross_attn.out_proj.bias)
layer_weight.append(mod.norm3.weight)
layer_weight.append(mod.norm3.bias)
layer_weight.append(mod.linear1.weight)
layer_weight.append(mod.linear1.bias)
layer_weight.append(mod.linear2.weight)
layer_weight.append(mod.linear2.bias)
self.weights.append(layer_weight)
[docs]
def forward(
self,
from_tensor,
memory_tensor,
mem_seq_len,
self_cache_key,
self_cache_value,
mem_cache,
step,
memory_hidden_dim,
is_fuse_qkv,
):
decoder_output = from_tensor
self_caches_key = []
self_caches_value = []
mem_caches = []
if not self.use_batch_major_op_cache:
self_cache_key = paddle.concat(
[
self_cache_key,
paddle.zeros(
shape=[len(self.weights), 1, paddle.shape(memory_tensor)[0], self.n_head * self.size_per_head],
dtype=self_cache_key.dtype,
),
],
axis=1,
)
self_cache_value = paddle.concat(
[
self_cache_value,
paddle.zeros(
shape=[len(self.weights), 1, paddle.shape(memory_tensor)[0], self.n_head * self.size_per_head],
dtype=self_cache_value.dtype,
),
],
axis=1,
)
for idx in range(len(self.weights)):
weight = self.weights[idx]
decoder_output, new_self_cache_key, new_self_cache_value, new_mem_cache = infer_transformer_decoder(
from_tensor=decoder_output,
memory_tensor=memory_tensor,
mem_seq_len=mem_seq_len,
self_ln_weight=weight[0],
self_ln_bias=weight[1],
self_q_weight=weight[2],
self_q_bias=weight[3],
self_k_weight=weight[4],
self_k_bias=weight[5],
self_v_weight=weight[6],
self_v_bias=weight[7],
self_out_weight=weight[8],
self_out_bias=weight[9],
cross_ln_weight=weight[10],
cross_ln_bias=weight[11],
cross_q_weight=weight[12],
cross_q_bias=weight[13],
cross_k_weight=weight[14],
cross_k_bias=weight[15],
cross_v_weight=weight[16],
cross_v_bias=weight[17],
cross_out_weight=weight[18],
cross_out_bias=weight[19],
ffn_ln_weight=weight[20],
ffn_ln_bias=weight[21],
ffn_inter_weight=weight[22],
ffn_inter_bias=weight[23],
ffn_out_weight=weight[24],
ffn_out_bias=weight[25],
old_self_cache_key=self_cache_key[idx],
old_self_cache_value=self_cache_value[idx],
old_mem_cache=mem_cache[idx],
step=step,
n_head=self.n_head,
size_per_head=self.size_per_head,
memory_hidden_dim=memory_hidden_dim,
is_fuse_qkv=is_fuse_qkv,
)
self_caches_key.append(new_self_cache_key)
self_caches_value.append(new_self_cache_value)
mem_caches.append(new_mem_cache)
self_cache_key = paddle.stack(self_caches_key, axis=0)
self_cache_value = paddle.stack(self_caches_value, axis=0)
mem_cache = paddle.stack(mem_caches, axis=0)
return decoder_output, self_cache_key, self_cache_value, mem_cache
[docs]
class FasterDecoder(nn.Layer):
"""
FasterTransformer decoder for auto-regressive generation.
Args:
src_vocab_size (`int`):
The size of source vocabulary.
trg_vocab_size (`int`):
The size of target vocabulary.
max_length (`int`):
The maximum length of input sequences.
num_encoder_layers (`int`):
The number of sub-layers to be stacked in the encoder.
num_decoder_layers (`int`):
The number of sub-layers to be stacked in the decoder.
n_head (`int`):
The number of head used in multi-head attention.
d_model (`int`):
The dimension for word embeddings, which is also the last dimension of
the input and output of multi-head attention, position-wise feed-forward
networks, encoder and decoder.
d_inner_hid (`int`):
Size of the hidden layer in position-wise feed-forward networks.
dropout (`float`):
Dropout rates. Used for pre-process, activation and inside attention.
weight_sharing (`bool`):
Whether to use weight sharing.
bos_id (`int`, optional):
The start token id and also is used as padding id. Defaults to 0.
eos_id (`int`, optional):
The end token id. Defaults to 1.
max_out_len (int, optional):
The maximum output length. Defaults to 256.
decoder_lib (`str`):
The path to decoder_lib. Default to None.
use_fp16_decoder (`bool`):
Whether to use fp16 for decoder. Default to False.
"""
def __init__(
self,
src_vocab_size,
trg_vocab_size,
max_length,
num_encoder_layers,
num_decoder_layers,
n_head,
d_model,
d_inner_hid,
dropout,
weight_sharing,
bos_id=0,
eos_id=1,
max_out_len=256,
decoder_lib=None,
use_fp16_decoder=False,
use_batch_major_op_cache=False,
):
super().__init__()
self.trg_vocab_size = trg_vocab_size
self.n_head = n_head
self.emb_dim = d_model
self.bos_id = bos_id
self.eos_id = eos_id
self.dropout = dropout
self.max_out_len = max_out_len
self.max_length = max_length
self.use_fp16_decoder = use_fp16_decoder
self.num_decoder_layers = num_decoder_layers
self.d_model = d_model
self.size_per_head = d_model // n_head
self.use_batch_major_op_cache, self.x = get_op_cache_config(
use_batch_major_op_cache, self.size_per_head, use_fp16_decoder
)
self.src_word_embedding = WordEmbedding(vocab_size=src_vocab_size, emb_dim=d_model, bos_id=self.bos_id)
# print(self.src_word_embedding.word_embedding.weight)
self.src_pos_embedding = PositionalEmbedding(emb_dim=d_model, max_length=max_length)
if weight_sharing:
assert (
src_vocab_size == trg_vocab_size
), "Vocabularies in source and target should be same for weight sharing."
self.trg_word_embedding = self.src_word_embedding
self.trg_pos_embedding = self.src_pos_embedding
else:
self.trg_word_embedding = WordEmbedding(vocab_size=trg_vocab_size, emb_dim=d_model, bos_id=self.bos_id)
self.trg_pos_embedding = PositionalEmbedding(emb_dim=d_model, max_length=max_length)
self.transformer = paddle.nn.Transformer(
d_model=d_model,
nhead=n_head,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=d_inner_hid,
dropout=dropout,
activation="relu",
normalize_before=True,
)
self.decoder = InferTransformerDecoder(
decoder=self.transformer.decoder,
n_head=n_head,
size_per_head=self.size_per_head,
decoder_lib=decoder_lib,
use_fp16_decoder=use_fp16_decoder,
use_batch_major_op_cache=self.use_batch_major_op_cache,
)
if weight_sharing:
self.linear = lambda x: paddle.matmul(
x=x, y=self.trg_word_embedding.word_embedding.weight, transpose_y=True
)
else:
self.linear = nn.Linear(in_features=d_model, out_features=trg_vocab_size, bias_attr=False)
[docs]
def forward(self, src_word):
src_max_len = paddle.shape(src_word)[-1]
mem_seq_lens = paddle.sum(
paddle.cast(src_word != self.bos_id, dtype="int32"), axis=-1, keepdim=True, dtype="int32"
)
src_slf_attn_bias = (
paddle.cast(src_word == self.bos_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9
)
src_slf_attn_bias.stop_gradient = True
src_pos = paddle.cast(src_word != self.bos_id, dtype="int64") * paddle.arange(start=0, end=src_max_len)
src_emb = self.src_word_embedding(src_word)
src_pos_emb = self.src_pos_embedding(src_pos)
src_emb = src_emb + src_pos_emb
enc_input = F.dropout(src_emb, p=self.dropout, training=self.training) if self.dropout else src_emb
enc_output = self.transformer.encoder(enc_input, src_mask=src_slf_attn_bias)
batch_size, _, memory_hidden_dim = enc_output.shape
end_token_tensor = paddle.full(shape=[batch_size, 1], fill_value=self.eos_id, dtype="int64")
predict_ids = []
log_probs = paddle.full(shape=[batch_size, 1], fill_value=0, dtype="float32")
trg_word = paddle.full(shape=[batch_size, 1], fill_value=self.bos_id, dtype="int64")
if self.use_fp16_decoder:
enc_output = paddle.cast(enc_output, "float16")
# Init cache
if not self.use_batch_major_op_cache:
self_cache_key = paddle.zeros(
shape=[self.num_decoder_layers, 0, batch_size, self.d_model], dtype=enc_output.dtype
)
self_cache_value = paddle.zeros(
shape=[self.num_decoder_layers, 0, batch_size, self.d_model], dtype=enc_output.dtype
)
else:
self_cache_key = paddle.zeros(
shape=[
self.num_decoder_layers,
batch_size,
self.n_head,
self.size_per_head // self.x,
self.max_out_len,
self.x,
],
dtype=enc_output.dtype,
)
self_cache_value = paddle.zeros(
shape=[self.num_decoder_layers, batch_size, self.n_head, self.max_out_len, self.size_per_head],
dtype=enc_output.dtype,
)
mem_cache = paddle.zeros(
shape=[self.num_decoder_layers, 2, batch_size, src_max_len, self.d_model], dtype=enc_output.dtype
)
for i in range(self.max_out_len):
trg_pos = paddle.full(shape=trg_word.shape, fill_value=i, dtype="int64")
trg_emb = self.trg_word_embedding(trg_word)
trg_pos_emb = self.trg_pos_embedding(trg_pos)
trg_emb = trg_emb + trg_pos_emb
dec_input = F.dropout(trg_emb, p=self.dropout, training=self.training) if self.dropout else trg_emb
# TODO(gongenlei): do cast in op
if self.use_fp16_decoder:
dec_input = paddle.cast(dec_input, "float16")
dec_output, self_cache_key, self_cache_value, mem_cache = self.decoder(
from_tensor=dec_input,
memory_tensor=enc_output,
mem_seq_len=mem_seq_lens,
self_cache_key=self_cache_key,
self_cache_value=self_cache_value,
mem_cache=mem_cache,
step=i,
memory_hidden_dim=memory_hidden_dim,
is_fuse_qkv=False,
)
if self.use_fp16_decoder:
dec_output = paddle.cast(dec_output, "float32")
dec_output = paddle.reshape(dec_output, shape=[-1, dec_output.shape[-1]])
logits = self.linear(dec_output)
step_log_probs = paddle.log(F.softmax(logits, axis=-1))
log_probs = paddle.add(x=step_log_probs, y=log_probs)
scores = log_probs
topk_scores, topk_indices = paddle.topk(x=scores, k=1)
finished = paddle.equal(topk_indices, end_token_tensor)
trg_word = topk_indices
log_probs = topk_scores
predict_ids.append(topk_indices)
# TODO(gongenlei): support static graph
if paddle.all(finished).numpy():
break
predict_ids = paddle.stack(predict_ids, axis=0)
finished_seq = paddle.transpose(predict_ids, [1, 2, 0])
finished_scores = topk_scores
return finished_seq, finished_scores
def load(self, init_from_params):
# Load the trained model
assert init_from_params, "Please set init_from_params to load the infer model."
model_dict = paddle.load(init_from_params, return_numpy=True)
# To set weight[padding_idx] to 0.
model_dict["trg_word_embedding.word_embedding.weight"][self.bos_id] = [0] * self.d_model
# To avoid a longer length than training, reset the size of position
# encoding to max_length
model_dict["encoder.pos_encoder.weight"] = position_encoding_init(self.max_length, self.d_model)
model_dict["decoder.pos_encoder.weight"] = position_encoding_init(self.max_length, self.d_model)
if self.use_fp16_decoder:
for item in self.state_dict():
if "decoder.layers" in item:
model_dict[item] = np.float16(model_dict[item])
self.load_dict(model_dict)