paddlenlp.transformers.gpt.modeling 源代码

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list

from .. import PretrainedModel, register_base_model

__all__ = [
    'GPTModel',
    'GPTPretrainedModel',
    'GPTForPretraining',
    'GPTPretrainingCriterion',
    'GPTForGreedyGeneration',
    'GPTLMHeadModel',
    'GPTForTokenClassification',
    'GPTForSequenceClassification',
    'GPTForCausalLM',
]


class MultiHeadAttention(nn.Layer):
    """
    Attention mapps queries and a set of key-value pairs to outputs, and
    Multi-Head Attention performs multiple parallel attention to jointly attending
    to information from different representation subspaces.

    """

    Cache = collections.namedtuple("Cache", ["k", "v"])
    StaticCache = collections.namedtuple("StaticCache", ["k", "v"])

    def __init__(self,
                 embed_dim,
                 num_heads,
                 dropout=0.,
                 kdim=None,
                 vdim=None,
                 need_weights=False,
                 weight_attr=None,
                 bias_attr=None,
                 topo=None,
                 fuse=False):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.need_weights = need_weights
        self.fuse = fuse

        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        if self.fuse:
            assert self.kdim == embed_dim
            assert self.vdim == embed_dim
            self.qkv_proj = nn.Linear(
                embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr)
        else:
            self.q_proj = nn.Linear(
                embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
            self.k_proj = nn.Linear(
                self.kdim, embed_dim, weight_attr, bias_attr=bias_attr)
            self.v_proj = nn.Linear(
                self.vdim, embed_dim, weight_attr, bias_attr=bias_attr)
        self.out_proj = nn.Linear(
            embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)

    def _fuse_prepare_qkv(self, query):
        mix_layer = self.qkv_proj(query)
        mix_layer = paddle.reshape_(mix_layer,
                                    [0, 0, self.num_heads, 3 * self.head_dim])
        mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3])
        q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1)
        return q, k, v

    def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
        r"""
        Prapares linear projected queries, keys and values for usage of subsequnt
        multiple parallel attention. If `cache` is not None, using cached results
        to reduce redundant calculations.

        """
        q = self.q_proj(query)
        q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
        q = tensor.transpose(x=q, perm=[0, 2, 1, 3])

        if isinstance(cache, self.StaticCache):
            # for encoder-decoder attention in inference and has cached
            k, v = cache.k, cache.v
        else:
            k, v = self.compute_kv(key, value)

        if isinstance(cache, self.Cache):
            # for decoder self-attention in inference
            k = tensor.concat([cache.k, k], axis=2)
            v = tensor.concat([cache.v, v], axis=2)
        if use_cache is True:
            cache = self.Cache(k, v)

        return (q, k, v) if use_cache is False else (q, k, v, cache)

    def compute_kv(self, key, value):
        r"""
        Applies linear projection on input keys and values, then splits heads
        (reshape and transpose) to get keys and values from different representation
        subspaces. The results are used as key-values pairs for subsequent multiple
        parallel attention.

        It is part of calculations in multi-head attention, and is provided as
        a method to pre-compute and prefetch these results, thus we can use them
        to construct cache for inference.

        """
        k = self.k_proj(key)
        v = self.v_proj(value)
        k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
        k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
        v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
        v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
        return k, v

    def gen_cache(self, key, value=None, type=Cache):
        """
        Generates cache for `forward` usage in inference accroding to arguments.
        The generated cache is an instance of `MultiHeadAttention.Cache` or an
        instance of `MultiHeadAttention.StaticCache`.
        """
        if type == MultiHeadAttention.StaticCache:  # static_kv
            k, v = self.compute_kv(key, value)
            return self.StaticCache(k, v)
        elif value is None:  # incremental_state
            k = layers.fill_constant_batch_size_like(
                input=key,
                shape=[-1, self.num_heads, 0, self.head_dim],
                dtype=key.dtype,
                value=0)
            v = layers.fill_constant_batch_size_like(
                input=key,
                shape=[-1, self.num_heads, 0, self.head_dim],
                dtype=key.dtype,
                value=0)
            return self.Cache(k, v)
        else:
            # incremental_state with initial value, mainly for usage like UniLM
            return self.Cache(key, value)

    def forward(self,
                query,
                key,
                value,
                attn_mask=None,
                use_cache=False,
                cache=None):
        r"""
        Applies multi-head attention to map queries and a set of key-value pairs
        to outputs.
        """
        key = query if key is None else key
        value = query if value is None else value
        # compute q ,k ,v
        if use_cache is False:
            if self.fuse:
                q, k, v = self._fuse_prepare_qkv(query)
            else:
                q, k, v = self._prepare_qkv(query, key, value, use_cache, cache)
        else:
            q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
                                               cache)
        # scale dot product attention
        product = layers.matmul(
            x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)

        if attn_mask is not None:
            product = product + attn_mask

        weights = F.softmax(product)
        if self.dropout:
            weights = F.dropout(
                weights,
                self.dropout,
                training=self.training,
                mode="upscale_in_train")

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

        outs = [out]
        if self.need_weights:
            outs.append(weights)
        if use_cache:
            outs.append(cache)
        return out if len(outs) == 1 else tuple(outs)


class TransformerDecoder(nn.Layer):
    """
    TransformerDecoder is a stack of N decoder layers.
    """

    def __init__(self,
                 decoder_layers,
                 num_layers,
                 norm=None,
                 hidden_size=None,
                 topo=None):
        super(TransformerDecoder, self).__init__()

        self.topo = topo
        self.num_layers = num_layers
        self.layers = decoder_layers
        self.norm = norm
        if norm == "LayerNorm":
            self.norm = nn.LayerNorm(hidden_size, epsilon=1e-5)
        elif norm is not None:
            raise ValueError("Only support LayerNorm")
        self.checkpoints = []

    def forward(self,
                tgt,
                memory,
                tgt_mask=None,
                memory_mask=None,
                use_cache=False,
                cache=None):
        r"""
        Applies a stack of N Transformer decoder layers on inputs. If `norm` is
        provided, also applies layer normalization on the output of last decoder
        layer.
        """
        output = tgt
        new_caches = []
        self.checkpoints = []

        for i, mod in enumerate(self.layers):
            if cache is None:
                if use_cache:
                    output, new_cache = mod(output,
                                            memory,
                                            tgt_mask=tgt_mask,
                                            use_cache=use_cache,
                                            cache=cache)
                    new_caches.append(new_cache)
                else:
                    output = mod(output,
                                 memory,
                                 tgt_mask=tgt_mask,
                                 use_cache=use_cache,
                                 cache=cache)

            else:
                output, new_cache = mod(output,
                                        memory,
                                        tgt_mask=tgt_mask,
                                        use_cache=use_cache,
                                        cache=cache[i])
                new_caches.append(new_cache)
            self.checkpoints.append(output.name)

        if self.norm is not None:
            output = self.norm(output)
        return output if use_cache is False else (output, new_caches)

    def gen_cache(self, memory, do_zip=False):
        r"""
        Generates cache for `forward` usage. The generated cache is a list, and
        each element in it is a tuple( :code:`(incremental_cache, static_cache)` )
        produced by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache`
        for more details. If `do_zip` is True, apply `zip` on these tuples to get
        a list with two elements.
       """
        cache = [layer.gen_cache(memory) for layer in self.layers]
        if do_zip:
            cache = list(zip(*cache))
        return cache


class TransformerDecoderLayer(nn.Layer):
    """
    The transformer decoder layer.

    It contains multiheadattention and some linear layers.
    """

    def __init__(self,
                 d_model,
                 nhead,
                 dim_feedforward,
                 dropout=0.1,
                 activation="gelu",
                 attn_dropout=None,
                 act_dropout=None,
                 normalize_before=True,
                 weight_attr=None,
                 bias_attr=None,
                 topo=None):
        self._config = locals()
        self._config.pop("self")
        self._config.pop("__class__", None)  # py3

        super(TransformerDecoderLayer, self).__init__()
        attn_dropout = dropout if attn_dropout is None else attn_dropout
        act_dropout = dropout if act_dropout is None else act_dropout
        self.normalize_before = normalize_before

        weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
        bias_attrs = _convert_param_attr_to_list(bias_attr, 3)

        self.self_attn = MultiHeadAttention(
            d_model,
            nhead,
            dropout=attn_dropout,
            weight_attr=weight_attrs[0],
            bias_attr=bias_attrs[0],
            topo=topo)
        self.linear1 = nn.Linear(
            d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2])
        self.linear2 = nn.Linear(
            dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2])

        self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
        self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
        self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train")
        self.activation = getattr(F, activation)

    def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
        residual = tgt

        if self.normalize_before:
            tgt = self.norm1(tgt)

        if use_cache is False:
            tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
        else:
            tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask,
                                                    use_cache, cache)
        tgt = residual + self.dropout1(tgt)
        if not self.normalize_before:
            tgt = self.norm1(tgt)

        residual = tgt
        if self.normalize_before:
            tgt = self.norm2(tgt)
        tgt = self.dropout2(
            self.linear2(F.gelu(
                self.linear1(tgt), approximate=True)))
        tgt = residual + tgt

        if not self.normalize_before:
            tgt = self.norm2(tgt)

        return tgt if use_cache is False else (tgt, incremental_cache)

    def gen_cache(self, memory):
        incremental_cache = self.self_attn.gen_cache(
            memory, type=self.self_attn.Cache)
        return incremental_cache


class GPTEmbeddings(nn.Layer):
    """
    Include embeddings from word and position embeddings.
    """

    def __init__(self,
                 vocab_size,
                 hidden_size=768,
                 hidden_dropout_prob=0.1,
                 max_position_embeddings=512,
                 type_vocab_size=16,
                 initializer_range=0.02,
                 topo=None):
        super(GPTEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(
            vocab_size,
            hidden_size,
            weight_attr=paddle.ParamAttr(initializer=nn.initializer.Normal(
                mean=0.0, std=initializer_range)))

        self.position_embeddings = nn.Embedding(
            max_position_embeddings,
            hidden_size,
            weight_attr=paddle.ParamAttr(initializer=nn.initializer.Normal(
                mean=0.0, std=initializer_range)))

        self.dropout = nn.Dropout(hidden_dropout_prob)

    def forward(self, input_ids, position_ids=None):
        if position_ids is None:
            ones = paddle.ones_like(input_ids, dtype="int64")
            seq_length = paddle.cumsum(ones, axis=-1)
            position_ids = seq_length - ones

        input_embedings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = input_embedings + position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings


[文档]class GPTPretrainedModel(PretrainedModel): """ An abstract class for pretrained GPT models. It provides GPT related `model_config_file`, `resource_files_names`, `pretrained_resource_files_map`, `pretrained_init_configuration`, `base_model_prefix` for downloading and loading pretrained models. See :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details. """ model_config_file = "model_config.json" pretrained_init_configuration = { "gpt-cpm-large-cn": { # 2.6B "vocab_size": 30000, "hidden_size": 2560, "num_hidden_layers": 32, "num_attention_heads": 32, "intermediate_size": 10240, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 1024, "type_vocab_size": 1, # no use "initializer_range": 0.02, "pad_token_id": 0, "eos_token_id": 7, "bos_token_id": 0, "eol_token_id": 3, }, "gpt-cpm-small-cn-distill": { # 109M "vocab_size": 30000, "hidden_size": 768, "num_hidden_layers": 12, "num_attention_heads": 12, "intermediate_size": 3072, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 1024, "type_vocab_size": 1, # no use "initializer_range": 0.02, "pad_token_id": 0, "eos_token_id": 7, "bos_token_id": 0, "eol_token_id": 3, }, "gpt3-13B-en": { # 13B "vocab_size": 50304, "hidden_size": 5120, "num_hidden_layers": 40, "num_attention_heads": 128, "intermediate_size": 20480, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 1024, "type_vocab_size": 1, # no use "initializer_range": 0.02, "eos_token_id": 50256, "eol_token_id": 198, }, "gpt3-1.3B-en": { # 1.3B "vocab_size": 50304, "hidden_size": 2048, "num_hidden_layers": 24, "num_attention_heads": 16, "intermediate_size": 8192, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 1024, "type_vocab_size": 1, # no use "initializer_range": 0.02, "eos_token_id": 50256, "eol_token_id": 198, }, "gpt2-xl-en": { # 1558M "vocab_size": 50257, "hidden_size": 1600, "num_hidden_layers": 48, "num_attention_heads": 25, "intermediate_size": 6400, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 1024, "type_vocab_size": 1, # no use "initializer_range": 0.02, "eos_token_id": 50256, "eol_token_id": 198, }, "gpt2-large-en": { # 774M "vocab_size": 50257, "hidden_size": 1280, "num_hidden_layers": 36, "num_attention_heads": 20, "intermediate_size": 5120, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 1024, "type_vocab_size": 1, # no use "initializer_range": 0.02, "eos_token_id": 50256, "eol_token_id": 198, }, "gpt2-medium-en": { #345M "vocab_size": 50304, "hidden_size": 1024, "num_hidden_layers": 24, "num_attention_heads": 16, "intermediate_size": 4096, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 1024, "type_vocab_size": 1, # no use "initializer_range": 0.02, "eos_token_id": 50256, "eol_token_id": 198, }, "gpt2-en": { #117M "vocab_size": 50257, "hidden_size": 768, "num_hidden_layers": 12, "num_attention_heads": 12, "intermediate_size": 3072, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 1024, "type_vocab_size": 1, # no use "initializer_range": 0.02, "eos_token_id": 50256, "eol_token_id": 198, }, "gpt2-small-en": { # config for CE "vocab_size": 50304, "hidden_size": 1024, "num_hidden_layers": 4, "num_attention_heads": 4, "intermediate_size": 4096, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 1024, "type_vocab_size": 1, # no use "initializer_range": 0.02, "eos_token_id": 50256, "eol_token_id": 198, }, } resource_files_names = {"model_state": "model_state.pdparams"} pretrained_resource_files_map = { "model_state": { "gpt-cpm-large-cn": "https://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt-cpm-large-cn.pdparams", "gpt-cpm-small-cn-distill": "https://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt-cpm-small-cn-distill.pdparams", "gpt2-en": "https://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt2-en.pdparams", "gpt2-medium-en": "https://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt2-medium-en.pdparams", "gpt2-large-en": "https://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt2-large-en.pdparams", "gpt2-xl-en": "https://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt2-xl-en.pdparams", } } base_model_prefix = "gpt"
[文档] def init_weights(self, layer): """ Initialization hook """ # no hook return if isinstance(layer, (nn.Linear, nn.Embedding)): # In the dygraph mode, use the `set_value` to reset the parameter directly, # and reset the `state_dict` to update parameter in static mode. if isinstance(layer.weight, paddle.Tensor): layer.weight.set_value( paddle.tensor.normal( mean=0.0, std=self.initializer_range if hasattr(self, "initializer_range") else self.gpt.config["initializer_range"], shape=layer.weight.shape))
[文档]@register_base_model class GPTModel(GPTPretrainedModel): r""" The bare GPT Model transformer outputting raw hidden-states. This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. Refer to the superclass documentation for the generic methods. This model is also a Paddle `paddle.nn.Layer <https://www.paddlepaddle.org.cn/documentation /docs/en/api/paddle/fluid/dygraph/layers/Layer_en.html>`__ subclass. Use it as a regular Paddle Layer and refer to the Paddle documentation for all matter related to general usage and behavior. Args: vocab_size (int): Vocabulary size of `inputs_ids` in `GPTModel`. Also is the vocab size of token embedding matrix. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling `GPTModel`. hidden_size (int, optional): Dimensionality of the embedding layer and decoder layer. Defaults to `768`. num_hidden_layers (int, optional): Number of hidden layers in the Transformer decoder. Defaults to `12`. num_attention_heads (int, optional): Number of attention heads for each attention layer in the Transformer decoder. Defaults to `12`. intermediate_size (int, optional): Dimensionality of the feed-forward (ff) layer in the decoder. Input tensors to ff layers are firstly projected from `hidden_size` to `intermediate_size`, and then projected back to `hidden_size`. Typically `intermediate_size` is larger than `hidden_size`. Defaults to `3072`. hidden_act (str, optional): The non-linear activation function in the feed-forward layer. ``"gelu"``, ``"relu"`` and any other paddle supported activation functions are supported. Defaults to `"gelu"`. hidden_dropout_prob (float, optional): The dropout probability for all fully connected layers in the embeddings and decoder. Defaults to `0.1`. attention_probs_dropout_prob (float, optional): The dropout probability used in MultiHeadAttention in all decoder layers to drop some attention target. Defaults to `0.1`. max_position_embeddings (int, optional): The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input sequence. Defaults to `512`. type_vocab_size (int, optional): The vocabulary size of the `token_type_ids`. Defaults to `16`. .. note:: Please NOT using `type_vocab_size`, for it will be obsolete in the future.. initializer_range (float, optional): The standard deviation of the normal initializer. Default to `0.02`. .. note:: A normal_initializer initializes weight matrices as normal distributions. See :meth:`GPTPretrainedModel._init_weights()` for how weights are initialized in `GPTModel`. pad_token_id(int, optional): The index of padding token in the token vocabulary. Defaults to `0`. """ def __init__(self, vocab_size, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=16, initializer_range=0.02, pad_token_id=0, eos_token_id=7, bos_token_id=0, eol_token_id=3, topo=None): super(GPTModel, self).__init__() self.pad_token_id = pad_token_id self.initializer_range = initializer_range self.topo = topo self.hidden_size = hidden_size self.vocab_size = vocab_size self.embeddings = GPTEmbeddings( vocab_size, hidden_size, hidden_dropout_prob, max_position_embeddings, type_vocab_size, self.initializer_range, topo) decoder_layers = nn.LayerList() for i in range(num_hidden_layers): decoder_layers.append( TransformerDecoderLayer( d_model=hidden_size, nhead=num_attention_heads, dim_feedforward=intermediate_size, dropout=hidden_dropout_prob, activation=hidden_act, attn_dropout=attention_probs_dropout_prob, act_dropout=hidden_dropout_prob, weight_attr=paddle.ParamAttr( initializer=nn.initializer.Normal( mean=0.0, std=self.initializer_range)), bias_attr=None, topo=topo)) self.decoder = TransformerDecoder( decoder_layers, num_hidden_layers, norm="LayerNorm", hidden_size=hidden_size, topo=topo) self.apply(self.init_weights) self.checkpoints = []
[文档] def forward(self, input_ids, position_ids=None, attention_mask=None, use_cache=False, cache=None): r''' The GPTModel forward method, overrides the `__call__()` special method. Args: input_ids (Tensor): Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. position_ids(Tensor, optional): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, max_position_embeddings - 1]``. Shape as `(batch_size, num_tokens)` and dtype as int64. Defaults to `None`. attention_mask (Tensor, optional): Mask used in self attention to avoid performing attention to some unwanted positions, usually the subsequent positions. It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`. It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`. For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length], [batch_size, num_attention_heads, sequence_length, sequence_length]. Its data type should be float32. The `masked` tokens have `-1e-9` values, and the `unmasked` tokens have `0` values. Defaults to `None`, which means nothing needed to be prevented attention to. use_cache (bool, optional): Whether or not to use cache. Defaults to `False`. If set to `True`, key value states will be returned and can be used to speed up decoding. cache (list, optional): It is a list, and each element in the list is a tuple `(incremental_cache, static_cache)`. See `TransformerDecoder.gen_cache <https://github.com/PaddlePaddle/Paddle/blob/release/2.1/python/paddle/nn/layer/transformer.py#L1060>`__ for more details. It is only used for inference and should be None for training. Default to `None`. Returns: Tensor: Returns tensor `encoder_output`, which is the output at the last layer of the model. Its data type should be float32 and has a shape of [batch_size, sequence_length, hidden_size]. Example: .. code-block:: import paddle from paddlenlp.transformers import GPTModel, GPTTokenizer tokenizer = GPTTokenizer.from_pretrained('gpt2-medium-en') model = GPTModel.from_pretrained('gpt2-medium-en') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!", return_token_type_ids=False) inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} output = model(**inputs) ''' self.checkpoints = [] if position_ids is None: past_length = 0 if cache is not None: past_length = paddle.shape(cache[0].k)[-2] position_ids = paddle.arange( past_length, paddle.shape(input_ids)[-1] + past_length, dtype=input_ids.dtype) position_ids = position_ids.unsqueeze(0) # .expand_as(input_ids) position_ids = paddle.fluid.layers.expand_as(position_ids, input_ids) embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids) # TODO, use registered buffer causal_mask = paddle.tensor.triu( paddle.ones((paddle.shape(input_ids)[-1], paddle.shape(input_ids)[-1])) * -1e4, diagonal=1) if attention_mask is not None: if len(attention_mask.shape) == 2: attention_mask = attention_mask[:, None, None, :] attention_mask = attention_mask + causal_mask else: attention_mask = causal_mask # The tensor returned by triu not in static graph. attention_mask.stop_gradient = True encoder_outputs = self.decoder( embedding_output, memory=None, tgt_mask=attention_mask, use_cache=use_cache, cache=cache) self.checkpoints.extend(self.decoder.checkpoints) return encoder_outputs
[文档]class GPTForPretraining(GPTPretrainedModel): """ GPT Model with pretraining tasks on top. Args: gpt (:class:`GPTModel`): An instance of :class:`GPTModel`. """ def __init__(self, gpt): super(GPTForPretraining, self).__init__() self.gpt = gpt self.apply(self.init_weights)
[文档] def forward(self, input_ids, position_ids=None, attention_mask=None, masked_positions=None, use_cache=False, cache=None): r""" Args: input_ids (Tensor): See :class:`GPTModel`. position_ids (Tensor, optional): See :class:`GPTModel`. attention_mask (Tensor, optional): See :class:`GPTModel`. use_cache (bool, optional): See :class:`GPTModel`. cache (Tensor, optional): See :class:`GPTModel`. Returns: Tensor or tuple: Returns tensor `logits` or tuple `(logits, cached_kvs)`. If `use_cache` is True, tuple (`logits, cached_kvs`) will be returned. Otherwise, tensor `logits` will be returned. `logits` is the output of the gpt model. `cache_kvs` is the cache output of gpt model if `use_cache` is True. Example: .. code-block:: import paddle from paddlenlp.transformers import GPTForPretraining, GPTTokenizer tokenizer = GPTTokenizer.from_pretrained('gpt2-medium-en') model = GPTForPretraining.from_pretrained('gpt2-medium-en') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!", return_token_type_ids=False) inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} output = model(**inputs,use_cache=True) logits = output[0] cached_kvs = output[1] """ outputs = self.gpt(input_ids, position_ids=position_ids, attention_mask=attention_mask, use_cache=use_cache, cache=cache) if use_cache: encoder_outputs, cached_kvs = outputs[:2] else: encoder_outputs = outputs logits = paddle.matmul( encoder_outputs, self.gpt.embeddings.word_embeddings.weight, transpose_y=True) if use_cache: return logits, cached_kvs else: return logits
[文档]class GPTPretrainingCriterion(paddle.nn.Layer): """ Criterion for GPT. It calculates the final loss. """ def __init__(self, topo=None): super(GPTPretrainingCriterion, self).__init__() self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none")
[文档] def forward(self, prediction_scores, masked_lm_labels, loss_mask): """ Args: prediction_scores(Tensor): The logits of masked token prediction. Its data type should be float32 and its shape is [batch_size, sequence_length, vocab_size]. masked_lm_labels(Tensor): The labels of the masked language modeling, the dimensionality of `masked_lm_labels` is equal to `prediction_scores`. Its data type should be int64 and its shape is [batch_size, sequence_length, 1]. loss_mask(Tensor): Mask used for calculating the loss of the masked language modeling to avoid calculating some unwanted tokens. Its data type should be float32 and its shape is [batch_size, sequence_length, 1]. Returns: Tensor: The pretraining loss. Its data type should be float32 and its shape is [1]. """ masked_lm_loss = self.loss_func(prediction_scores, masked_lm_labels.unsqueeze(2)) loss_mask = loss_mask.reshape([-1]) masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask) loss = masked_lm_loss / loss_mask.sum() return loss
[文档]class GPTForGreedyGeneration(GPTPretrainedModel): """ The generate model for GPT-2. It use the greedy strategy and generate the output sequence with highest probability. Args: gpt (:class:`GPTModel`): An instance of `paddlenlp.transformers.GPTModel`. max_predict_len(int): The max length of the prediction. """ def __init__(self, gpt, max_predict_len, eol_token_id=3): super(GPTForGreedyGeneration, self).__init__() self.gpt = gpt self.max_predict_len = max_predict_len self.eol_token_id = eol_token_id self.apply(self.init_weights)
[文档] def model(self, input_ids, position_ids=None, attention_mask=None, masked_positions=None, use_cache=False, cache=None): r""" Args: input_ids (Tensor): See :class:`GPTModel`. position_ids (Tensor, optional): See :class:`GPTModel`. attention_mask (Tensor, optional): See :class:`GPTModel`. use_cache (bool, optional): See :class:`GPTModel`. cache (Tensor, optional): See :class:`GPTModel`. Returns: Tensor or tuple: Returns tensor `logits` or tuple `(logits, cached_kvs)`. If `use_cache` is True, tuple (`logits, cached_kvs`) will be returned. Otherwise, tensor `logits` will be returned. `logits` is the output of the gpt model. `cache_kvs` is the cache output of gpt model if `use_cache` is True. """ outputs = self.gpt(input_ids, position_ids=position_ids, attention_mask=attention_mask, use_cache=use_cache, cache=cache) if use_cache: encoder_outputs, cached_kvs = outputs[:2] else: encoder_outputs = outputs logits = paddle.matmul( encoder_outputs, self.gpt.embeddings.word_embeddings.weight, transpose_y=True) if use_cache: return logits, cached_kvs else: return logits
[文档] def forward(self, input_ids): """ Args: input_ids(Tensor): See :class:`GPTModel`. Returns: Tensor: Returns tensor `src_ids`, which means the indices of output sequence tokens in the vocabulary. They are numerical representations of tokens that build the output sequence. """ output, cached_kvs = self.model(input_ids, use_cache=True, cache=None) src_ids = input_ids nid = paddle.argmax(output[:, -1, :], axis=-1).reshape([-1, 1]) src_ids = paddle.concat([src_ids, nid], axis=1) cur_len = 0 while (cur_len < self.max_predict_len): output, cached_kvs = self.model( nid, use_cache=True, cache=cached_kvs) nid = paddle.argmax(output[:, -1, :], axis=-1).reshape([-1, 1]) src_ids = paddle.concat([src_ids, nid], axis=1) cur_len += 1 if paddle.max(nid) == self.eol_token_id: break return src_ids
class GPTLMHead(nn.Layer): def __init__(self, hidden_size, vocab_size, embedding_weights=None): super(GPTLMHead, self).__init__() self.decoder_weight = self.create_parameter( shape=[vocab_size, hidden_size], dtype=paddle.get_default_dtype(), is_bias=True) if embedding_weights is None else embedding_weights def forward(self, hidden_states): logits = paddle.tensor.matmul( hidden_states, self.decoder_weight, transpose_y=True) return logits
[文档]class GPTLMHeadModel(GPTPretrainedModel): """ The GPT Model with a `language modeling` head on top. Args: gpt (:class:`GPTModel`): An instance of :class:`GPTModel`. """ def __init__(self, gpt): super(GPTLMHeadModel, self).__init__() self.gpt = gpt self.lm_head = GPTLMHead(self.gpt.config["hidden_size"], self.gpt.config["vocab_size"], self.gpt.embeddings.word_embeddings.weight) self.apply(self.init_weights)
[文档] def forward(self, input_ids, position_ids=None, attention_mask=None, use_cache=False, cache=None): r""" Args: input_ids (Tensor): See :class:`GPTModel`. position_ids (Tensor, optional): See :class:`GPTModel`. attention_mask (Tensor, optional): See :class:`GPTModel`. use_cache (bool, optional): See :class:`GPTModel`. cache (Tensor, optional): See :class:`GPTModel`. Returns: Tensor or tuple: Returns tensor `logits` or tuple `(logits, cached_kvs)`. If `use_cache` is True, tuple (`logits, cached_kvs`) will be returned. Otherwise, tensor `logits` will be returned. `logits` is the output of the gpt model. `cache_kvs` is the cache output of gpt model if `use_cache` is True. """ outputs = self.gpt(input_ids, position_ids=position_ids, attention_mask=attention_mask, use_cache=use_cache, cache=cache) if use_cache: encoder_outputs, cached_kvs = outputs[:2] else: encoder_outputs = outputs logits = self.lm_head(encoder_outputs) if use_cache: return logits, cached_kvs else: return logits
def prepare_faster_entry(self, kwargs): from paddlenlp.ops import FasterGPT use_fp16_decoding = kwargs.get('use_fp16_decoding', False) decode_strategy = kwargs.get('decode_strategy') if decode_strategy == "beam_search": raise AttributeError( "'beam_search' is not supported yet in the faster version of GPT" ) # Currently, FasterTransformer only support restricted size_per_head. size_per_head = self.gpt.config["hidden_size"] // self.gpt.config[ "num_attention_heads"] if size_per_head not in [32, 64, 80, 96, 128]: raise AttributeError( "'size_per_head = %d' is not supported yet in the faster version of GPT" % size_per_head) if kwargs['forced_bos_token_id'] is not None: # not support for min_length yet in the faster version raise AttributeError( "'forced_bos_token_id != None' is not supported yet in the faster version" ) if kwargs['min_length'] != 0: # not support for min_length yet in the faster version raise AttributeError( "'min_length != 0' is not supported yet in the faster version") self._faster_entry = FasterGPT( self, use_fp16_decoding=use_fp16_decoding).forward return self._faster_entry def prepare_inputs_for_generation(self, input_ids, use_cache=False, cache=None, **kwargs): # only last token for inputs_ids if cache is defined in kwargs position_ids = kwargs.get("position_ids", None) attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: if len(attention_mask.shape) == 4: attention_mask = attention_mask[:, -1, -1, :] if "int" in paddle.fluid.data_feeder.convert_dtype( attention_mask.dtype): attention_mask = (1.0 - attention_mask) * -1e4 if cache is not None: input_ids = input_ids[:, -1].unsqueeze(-1) if position_ids is not None: position_ids = position_ids[:, -1].unsqueeze(-1) return { "input_ids": input_ids, "position_ids": position_ids, "attention_mask": attention_mask, "use_cache": use_cache, "cache": cache } def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError as e: try: return getattr(getattr(self, self.base_model_prefix), name) except AttributeError: try: return getattr(self, self.base_model_prefix).config[name] except KeyError: raise e
[文档]class GPTForTokenClassification(GPTPretrainedModel): """ GPT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. Args: gpt (:class:`GPTModel`): An instance of GPTModel. num_classes (int, optional): The number of classes. Defaults to `2`. dropout (float, optional): The dropout probability for output of GPT. If None, use the same value as `hidden_dropout_prob` of `GPTModel` instance `gpt`. Defaults to None. """ def __init__(self, gpt, num_classes=2, dropout=None): super(GPTForTokenClassification, self).__init__() self.num_classes = num_classes self.gpt = gpt # allow gpt to be config self.dropout = nn.Dropout(dropout if dropout is not None else self.gpt.config["hidden_dropout_prob"]) self.classifier = nn.Linear(self.gpt.config["hidden_size"], num_classes) self.apply(self.init_weights)
[文档] def forward(self, input_ids, position_ids=None, attention_mask=None): r""" The GPTForTokenClassification forward method, overrides the __call__() special method. Args: input_ids (Tensor): See :class:`GPTModel`. position_ids(Tensor, optional): See :class:`GPTModel`. attention_mask (list, optional): See :class:`GPTModel`. Returns: Tensor: Returns tensor `logits`, a tensor of the input token classification logits. Shape as `[batch_size, sequence_length, num_classes]` and dtype as `float32`. Example: .. code-block:: import paddle from paddlenlp.transformers import GPTForTokenClassification, GPTTokenizer tokenizer = GPTTokenizer.from_pretrained('gpt2-medium-en') model = GPTForTokenClassification.from_pretrained('gpt2-medium-en') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!", return_token_type_ids=False) inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} logits = model(**inputs) """ sequence_output = self.gpt(input_ids, position_ids=position_ids, attention_mask=attention_mask) sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) return logits
[文档]class GPTForSequenceClassification(GPTPretrainedModel): """ GPT Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. Args: gpt (:class:`GPTModel`): An instance of GPTModel. num_classes (int, optional): The number of classes. Defaults to `2`. """ def __init__(self, gpt, num_classes=2): super(GPTForSequenceClassification, self).__init__() self.gpt = gpt # allow gpt to be config self.score = nn.Linear( self.gpt.config["hidden_size"], num_classes, bias_attr=False) self.apply(self.init_weights)
[文档] def forward(self, input_ids, position_ids=None, attention_mask=None): r""" The GPTForSequenceClassification forward method, overrides the __call__() special method. Args: input_ids (Tensor): See :class:`GPTModel`. position_ids(Tensor, optional): See :class:`GPTModel`. attention_mask (list, optional): See :class:`GPTModel`. Returns: Tensor: Returns tensor `logits`, a tensor of the input text classification logits. Shape as `[batch_size, num_classes]` and dtype as float32. Example: .. code-block:: import paddle from paddlenlp.transformers import GPTForSequenceClassification, GPTTokenizer tokenizer = GPTTokenizer.from_pretrained('gpt2-medium-en') model = GPTForSequenceClassification.from_pretrained('gpt2-medium-en') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!", return_token_type_ids=False) inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} logits = model(**inputs) """ # sequence_output shape [bs, seq_len, hidden_size] sequence_output = self.gpt(input_ids, position_ids=position_ids, attention_mask=attention_mask) # logits shape [bs, seq_len, num_class] logits = self.score(sequence_output) # padding index maybe 0 eos_token_id = self.gpt.config.get("eos_token_id", 0) # sequence_lengths shape [bs,] sequence_lengths = (input_ids != eos_token_id).astype("int64").sum( axis=-1) - 1 pooled_logits = logits.gather_nd( paddle.stack( [paddle.arange(sequence_output.shape[0]), sequence_lengths], axis=-1)) return pooled_logits
GPTForCausalLM = GPTLMHeadModel