Source code for paddlenlp.ops.fast_transformer.transformer.decoder

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from paddlenlp.ops import transfer_param
from paddlenlp.ops.ext_utils import LOADED_EXT, load
from paddlenlp.transformers import (
    PositionalEmbedding,
    WordEmbedding,
    position_encoding_init,
)
from paddlenlp.utils.log import logger

from .decoding import run_custom


def infer_transformer_decoder(
    from_tensor,
    memory_tensor,
    mem_seq_len,
    self_ln_weight,
    self_ln_bias,
    self_q_weight,
    self_q_bias,
    self_k_weight,
    self_k_bias,
    self_v_weight,
    self_v_bias,
    self_out_weight,
    self_out_bias,
    cross_ln_weight,
    cross_ln_bias,
    cross_q_weight,
    cross_q_bias,
    cross_k_weight,
    cross_k_bias,
    cross_v_weight,
    cross_v_bias,
    cross_out_weight,
    cross_out_bias,
    ffn_ln_weight,
    ffn_ln_bias,
    ffn_inter_weight,
    ffn_inter_bias,
    ffn_out_weight,
    ffn_out_bias,
    old_self_cache_key,
    old_self_cache_value,
    old_mem_cache,
    step,
    n_head,
    size_per_head,
    memory_hidden_dim,
    is_fuse_qkv=False,
):
    inputs_names = [
        "FromTensor",
        "MemoryTensor",
        "MemSeqLen",
        "SelfLayernormWeight",
        "SelfLayernormBias",
        "SelfQueryWeight",
        "SelfQueryBias",
        "SelfKeyWeight",
        "SelfKeyBias",
        "SelfValueWeight",
        "SelfValueBias",
        "SelfOutWeight",
        "SelfOutBias",
        "CrossLayernormWeight",
        "CrossLayernormBias",
        "CrossQueryWeight",
        "CrossQueryBias",
        "CrossKeyWeight",
        "CrossKeyBias",
        "CrossValueWeight",
        "CrossValueBias",
        "CrossOutWeight",
        "CrossOutBias",
        "FFNLayernormWeight",
        "FFNLayernormBias",
        "FFNInterWeight",
        "FFNInterBias",
        "FFNOutWeight",
        "FFNOutBias",
        "OldSelfCacheKey",
        "OldSelfCacheValue",
    ]

    inputs_var = [
        from_tensor,
        memory_tensor,
        mem_seq_len,
        self_ln_weight,
        self_ln_bias,
        self_q_weight,
        self_q_bias,
        self_k_weight,
        self_k_bias,
        self_v_weight,
        self_v_bias,
        self_out_weight,
        self_out_bias,
        cross_ln_weight,
        cross_ln_bias,
        cross_q_weight,
        cross_q_bias,
        cross_k_weight,
        cross_k_bias,
        cross_v_weight,
        cross_v_bias,
        cross_out_weight,
        cross_out_bias,
        ffn_ln_weight,
        ffn_ln_bias,
        ffn_inter_weight,
        ffn_inter_bias,
        ffn_out_weight,
        ffn_out_bias,
        old_self_cache_key,
        old_self_cache_value,
        old_mem_cache,
    ]

    attrs_names = ["step", "n_head", "size_per_head", "memory_hidden_dim", "is_fuse_qkv"]

    attrs_val = [step, n_head, size_per_head, memory_hidden_dim, is_fuse_qkv]

    outputs_names = ["DecoderOutput", "NewSelfCacheKey", "NewSelfCacheValue", "NewMemCache"]

    outputs_dtype = [memory_tensor.dtype] * len(outputs_names)

    return run_custom("fusion_decoder", inputs_names, inputs_var, attrs_names, attrs_val, outputs_names, outputs_dtype)


def get_op_cache_config(use_batch_major_op_cache, size_per_head, is_fp16):
    x = 8 if is_fp16 else 4
    use_batch_major_op_cache = True if use_batch_major_op_cache is True and size_per_head % x == 0 else False
    x = x if use_batch_major_op_cache else 1
    return use_batch_major_op_cache, x


[docs]class InferTransformerDecoder(nn.Layer): """ FasterTransformer decoder block. Args: decoder (`TransformerDecoder`): Transformer decoder block. n_head (`int`): The number of head used in multi-head attention. size_per_head (`int`): The size of per head used in multi-head attention. decoder_lib (`str`): The path to decoder_lib. Default to None. use_fp16_decoder (`bool`): Whether to use fp16 for decoder. Default to False. """ def __init__( self, decoder, n_head, size_per_head, decoder_lib=None, use_fp16_decoder=False, use_batch_major_op_cache=False ): if decoder_lib is not None and os.path.isfile(decoder_lib): # Maybe it has been loadad by `ext_utils.load` if "FastGeneration" not in LOADED_EXT.keys(): ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(decoder_lib) LOADED_EXT["FastGeneration"] = ops else: if decoder_lib is not None: logger.warning("The specified decoder_lib does not exist, and it will be built automatically.") load("FastGeneration", verbose=True) super(InferTransformerDecoder, self).__init__() self.n_head = n_head self.size_per_head = size_per_head self.use_batch_major_op_cache = use_batch_major_op_cache if use_fp16_decoder: for idx, mod in enumerate(decoder.layers): mod.norm1.weight = transfer_param(mod.norm1.weight) mod.norm1.bias = transfer_param(mod.norm1.bias, is_bias=True) mod.self_attn.q_proj.weight = transfer_param(mod.self_attn.q_proj.weight) mod.self_attn.q_proj.bias = transfer_param(mod.self_attn.q_proj.bias, is_bias=True) mod.self_attn.k_proj.weight = transfer_param(mod.self_attn.k_proj.weight) mod.self_attn.k_proj.bias = transfer_param(mod.self_attn.k_proj.bias, is_bias=True) mod.self_attn.v_proj.weight = transfer_param(mod.self_attn.v_proj.weight) mod.self_attn.v_proj.bias = transfer_param(mod.self_attn.v_proj.bias, is_bias=True) mod.self_attn.out_proj.weight = transfer_param(mod.self_attn.out_proj.weight) mod.self_attn.out_proj.bias = transfer_param(mod.self_attn.out_proj.bias, is_bias=True) mod.norm2.weight = transfer_param(mod.norm2.weight) mod.norm2.bias = transfer_param(mod.norm2.bias, is_bias=True) mod.cross_attn.q_proj.weight = transfer_param(mod.cross_attn.q_proj.weight) mod.cross_attn.q_proj.bias = transfer_param(mod.cross_attn.q_proj.bias, is_bias=True) mod.cross_attn.k_proj.weight = transfer_param(mod.cross_attn.k_proj.weight) mod.cross_attn.k_proj.bias = transfer_param(mod.cross_attn.k_proj.bias, is_bias=True) mod.cross_attn.v_proj.weight = transfer_param(mod.cross_attn.v_proj.weight) mod.cross_attn.v_proj.bias = transfer_param(mod.cross_attn.v_proj.bias, is_bias=True) mod.cross_attn.out_proj.weight = transfer_param(mod.cross_attn.out_proj.weight) mod.cross_attn.out_proj.bias = transfer_param(mod.cross_attn.out_proj.bias, is_bias=True) mod.norm3.weight = transfer_param(mod.norm3.weight) mod.norm3.bias = transfer_param(mod.norm3.bias, is_bias=True) mod.linear1.weight = transfer_param(mod.linear1.weight) mod.linear1.bias = transfer_param(mod.linear1.bias, is_bias=True) mod.linear2.weight = transfer_param(mod.linear2.weight) mod.linear2.bias = transfer_param(mod.linear2.bias, is_bias=True) self.weights = [] for idx, mod in enumerate(decoder.layers): layer_weight = [] layer_weight.append(mod.norm1.weight) layer_weight.append(mod.norm1.bias) layer_weight.append(mod.self_attn.q_proj.weight) layer_weight.append(mod.self_attn.q_proj.bias) layer_weight.append(mod.self_attn.k_proj.weight) layer_weight.append(mod.self_attn.k_proj.bias) layer_weight.append(mod.self_attn.v_proj.weight) layer_weight.append(mod.self_attn.v_proj.bias) layer_weight.append(mod.self_attn.out_proj.weight) layer_weight.append(mod.self_attn.out_proj.bias) layer_weight.append(mod.norm2.weight) layer_weight.append(mod.norm2.bias) layer_weight.append(mod.cross_attn.q_proj.weight) layer_weight.append(mod.cross_attn.q_proj.bias) layer_weight.append(mod.cross_attn.k_proj.weight) layer_weight.append(mod.cross_attn.k_proj.bias) layer_weight.append(mod.cross_attn.v_proj.weight) layer_weight.append(mod.cross_attn.v_proj.bias) layer_weight.append(mod.cross_attn.out_proj.weight) layer_weight.append(mod.cross_attn.out_proj.bias) layer_weight.append(mod.norm3.weight) layer_weight.append(mod.norm3.bias) layer_weight.append(mod.linear1.weight) layer_weight.append(mod.linear1.bias) layer_weight.append(mod.linear2.weight) layer_weight.append(mod.linear2.bias) self.weights.append(layer_weight)
[docs] def forward( self, from_tensor, memory_tensor, mem_seq_len, self_cache_key, self_cache_value, mem_cache, step, memory_hidden_dim, is_fuse_qkv, ): decoder_output = from_tensor self_caches_key = [] self_caches_value = [] mem_caches = [] if not self.use_batch_major_op_cache: self_cache_key = paddle.concat( [ self_cache_key, paddle.zeros( shape=[len(self.weights), 1, paddle.shape(memory_tensor)[0], self.n_head * self.size_per_head], dtype=self_cache_key.dtype, ), ], axis=1, ) self_cache_value = paddle.concat( [ self_cache_value, paddle.zeros( shape=[len(self.weights), 1, paddle.shape(memory_tensor)[0], self.n_head * self.size_per_head], dtype=self_cache_value.dtype, ), ], axis=1, ) for idx in range(len(self.weights)): weight = self.weights[idx] decoder_output, new_self_cache_key, new_self_cache_value, new_mem_cache = infer_transformer_decoder( from_tensor=decoder_output, memory_tensor=memory_tensor, mem_seq_len=mem_seq_len, self_ln_weight=weight[0], self_ln_bias=weight[1], self_q_weight=weight[2], self_q_bias=weight[3], self_k_weight=weight[4], self_k_bias=weight[5], self_v_weight=weight[6], self_v_bias=weight[7], self_out_weight=weight[8], self_out_bias=weight[9], cross_ln_weight=weight[10], cross_ln_bias=weight[11], cross_q_weight=weight[12], cross_q_bias=weight[13], cross_k_weight=weight[14], cross_k_bias=weight[15], cross_v_weight=weight[16], cross_v_bias=weight[17], cross_out_weight=weight[18], cross_out_bias=weight[19], ffn_ln_weight=weight[20], ffn_ln_bias=weight[21], ffn_inter_weight=weight[22], ffn_inter_bias=weight[23], ffn_out_weight=weight[24], ffn_out_bias=weight[25], old_self_cache_key=self_cache_key[idx], old_self_cache_value=self_cache_value[idx], old_mem_cache=mem_cache[idx], step=step, n_head=self.n_head, size_per_head=self.size_per_head, memory_hidden_dim=memory_hidden_dim, is_fuse_qkv=is_fuse_qkv, ) self_caches_key.append(new_self_cache_key) self_caches_value.append(new_self_cache_value) mem_caches.append(new_mem_cache) self_cache_key = paddle.stack(self_caches_key, axis=0) self_cache_value = paddle.stack(self_caches_value, axis=0) mem_cache = paddle.stack(mem_caches, axis=0) return decoder_output, self_cache_key, self_cache_value, mem_cache
[docs]class FasterDecoder(nn.Layer): """ FasterTransformer decoder for auto-regressive generation. Args: src_vocab_size (`int`): The size of source vocabulary. trg_vocab_size (`int`): The size of target vocabulary. max_length (`int`): The maximum length of input sequences. num_encoder_layers (`int`): The number of sub-layers to be stacked in the encoder. num_decoder_layers (`int`): The number of sub-layers to be stacked in the decoder. n_head (`int`): The number of head used in multi-head attention. d_model (`int`): The dimension for word embeddings, which is also the last dimension of the input and output of multi-head attention, position-wise feed-forward networks, encoder and decoder. d_inner_hid (`int`): Size of the hidden layer in position-wise feed-forward networks. dropout (`float`): Dropout rates. Used for pre-process, activation and inside attention. weight_sharing (`bool`): Whether to use weight sharing. bos_id (`int`, optional): The start token id and also is used as padding id. Defaults to 0. eos_id (`int`, optional): The end token id. Defaults to 1. max_out_len (int, optional): The maximum output length. Defaults to 256. decoder_lib (`str`): The path to decoder_lib. Default to None. use_fp16_decoder (`bool`): Whether to use fp16 for decoder. Default to False. """ def __init__( self, src_vocab_size, trg_vocab_size, max_length, num_encoder_layers, num_decoder_layers, n_head, d_model, d_inner_hid, dropout, weight_sharing, bos_id=0, eos_id=1, max_out_len=256, decoder_lib=None, use_fp16_decoder=False, use_batch_major_op_cache=False, ): super().__init__() self.trg_vocab_size = trg_vocab_size self.n_head = n_head self.emb_dim = d_model self.bos_id = bos_id self.eos_id = eos_id self.dropout = dropout self.max_out_len = max_out_len self.max_length = max_length self.use_fp16_decoder = use_fp16_decoder self.num_decoder_layers = num_decoder_layers self.d_model = d_model self.size_per_head = d_model // n_head self.use_batch_major_op_cache, self.x = get_op_cache_config( use_batch_major_op_cache, self.size_per_head, use_fp16_decoder ) self.src_word_embedding = WordEmbedding(vocab_size=src_vocab_size, emb_dim=d_model, bos_id=self.bos_id) # print(self.src_word_embedding.word_embedding.weight) self.src_pos_embedding = PositionalEmbedding(emb_dim=d_model, max_length=max_length) if weight_sharing: assert ( src_vocab_size == trg_vocab_size ), "Vocabularies in source and target should be same for weight sharing." self.trg_word_embedding = self.src_word_embedding self.trg_pos_embedding = self.src_pos_embedding else: self.trg_word_embedding = WordEmbedding(vocab_size=trg_vocab_size, emb_dim=d_model, bos_id=self.bos_id) self.trg_pos_embedding = PositionalEmbedding(emb_dim=d_model, max_length=max_length) self.transformer = paddle.nn.Transformer( d_model=d_model, nhead=n_head, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=d_inner_hid, dropout=dropout, activation="relu", normalize_before=True, ) self.decoder = InferTransformerDecoder( decoder=self.transformer.decoder, n_head=n_head, size_per_head=self.size_per_head, decoder_lib=decoder_lib, use_fp16_decoder=use_fp16_decoder, use_batch_major_op_cache=self.use_batch_major_op_cache, ) if weight_sharing: self.linear = lambda x: paddle.matmul( x=x, y=self.trg_word_embedding.word_embedding.weight, transpose_y=True ) else: self.linear = nn.Linear(in_features=d_model, out_features=trg_vocab_size, bias_attr=False)
[docs] def forward(self, src_word): src_max_len = paddle.shape(src_word)[-1] mem_seq_lens = paddle.sum( paddle.cast(src_word != self.bos_id, dtype="int32"), axis=-1, keepdim=True, dtype="int32" ) src_slf_attn_bias = ( paddle.cast(src_word == self.bos_id, dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9 ) src_slf_attn_bias.stop_gradient = True src_pos = paddle.cast(src_word != self.bos_id, dtype="int64") * paddle.arange(start=0, end=src_max_len) src_emb = self.src_word_embedding(src_word) src_pos_emb = self.src_pos_embedding(src_pos) src_emb = src_emb + src_pos_emb enc_input = F.dropout(src_emb, p=self.dropout, training=self.training) if self.dropout else src_emb enc_output = self.transformer.encoder(enc_input, src_mask=src_slf_attn_bias) batch_size, _, memory_hidden_dim = enc_output.shape end_token_tensor = paddle.full(shape=[batch_size, 1], fill_value=self.eos_id, dtype="int64") predict_ids = [] log_probs = paddle.full(shape=[batch_size, 1], fill_value=0, dtype="float32") trg_word = paddle.full(shape=[batch_size, 1], fill_value=self.bos_id, dtype="int64") if self.use_fp16_decoder: enc_output = paddle.cast(enc_output, "float16") # Init cache if not self.use_batch_major_op_cache: self_cache_key = paddle.zeros( shape=[self.num_decoder_layers, 0, batch_size, self.d_model], dtype=enc_output.dtype ) self_cache_value = paddle.zeros( shape=[self.num_decoder_layers, 0, batch_size, self.d_model], dtype=enc_output.dtype ) else: self_cache_key = paddle.zeros( shape=[ self.num_decoder_layers, batch_size, self.n_head, self.size_per_head // self.x, self.max_out_len, self.x, ], dtype=enc_output.dtype, ) self_cache_value = paddle.zeros( shape=[self.num_decoder_layers, batch_size, self.n_head, self.max_out_len, self.size_per_head], dtype=enc_output.dtype, ) mem_cache = paddle.zeros( shape=[self.num_decoder_layers, 2, batch_size, src_max_len, self.d_model], dtype=enc_output.dtype ) for i in range(self.max_out_len): trg_pos = paddle.full(shape=trg_word.shape, fill_value=i, dtype="int64") trg_emb = self.trg_word_embedding(trg_word) trg_pos_emb = self.trg_pos_embedding(trg_pos) trg_emb = trg_emb + trg_pos_emb dec_input = F.dropout(trg_emb, p=self.dropout, training=self.training) if self.dropout else trg_emb # TODO(gongenlei): do cast in op if self.use_fp16_decoder: dec_input = paddle.cast(dec_input, "float16") dec_output, self_cache_key, self_cache_value, mem_cache = self.decoder( from_tensor=dec_input, memory_tensor=enc_output, mem_seq_len=mem_seq_lens, self_cache_key=self_cache_key, self_cache_value=self_cache_value, mem_cache=mem_cache, step=i, memory_hidden_dim=memory_hidden_dim, is_fuse_qkv=False, ) if self.use_fp16_decoder: dec_output = paddle.cast(dec_output, "float32") dec_output = paddle.reshape(dec_output, shape=[-1, dec_output.shape[-1]]) logits = self.linear(dec_output) step_log_probs = paddle.log(F.softmax(logits, axis=-1)) log_probs = paddle.add(x=step_log_probs, y=log_probs) scores = log_probs topk_scores, topk_indices = paddle.topk(x=scores, k=1) finished = paddle.equal(topk_indices, end_token_tensor) trg_word = topk_indices log_probs = topk_scores predict_ids.append(topk_indices) # TODO(gongenlei): support static graph if paddle.all(finished).numpy(): break predict_ids = paddle.stack(predict_ids, axis=0) finished_seq = paddle.transpose(predict_ids, [1, 2, 0]) finished_scores = topk_scores return finished_seq, finished_scores
def load(self, init_from_params): # Load the trained model assert init_from_params, "Please set init_from_params to load the infer model." model_dict = paddle.load(init_from_params, return_numpy=True) # To set weight[padding_idx] to 0. model_dict["trg_word_embedding.word_embedding.weight"][self.bos_id] = [0] * self.d_model # To avoid a longer length than training, reset the size of position # encoding to max_length model_dict["encoder.pos_encoder.weight"] = position_encoding_init(self.max_length, self.d_model) model_dict["decoder.pos_encoder.weight"] = position_encoding_init(self.max_length, self.d_model) if self.use_fp16_decoder: for item in self.state_dict(): if "decoder.layers" in item: model_dict[item] = np.float16(model_dict[item]) self.load_dict(model_dict)