paddlenlp.transformers.roberta.modeling 源代码

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import Optional, Tuple

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import Tensor

from ...layers import Linear as TransposedLinear
from ...utils.converter import StateDictNameMapping, init_name_mappings
from .. import PretrainedModel, register_base_model
from ..model_outputs import (
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from .configuration import PRETRAINED_INIT_CONFIGURATION, RobertaConfig

__all__ = [
    "RobertaModel",
    "RobertaPretrainedModel",
    "RobertaForSequenceClassification",
    "RobertaForTokenClassification",
    "RobertaForQuestionAnswering",
    "RobertaForMaskedLM",
    "RobertaForMultipleChoice",
    "RobertaForCausalLM",
]


def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length):
    """
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.
    Args:
        x: paddle.Tensor x:
    Returns: paddle.Tensor
    """
    if past_key_values_length is None:
        past_key_values_length = 0
    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
    mask = (input_ids != padding_idx).cast("int64")
    incremental_indices = (paddle.cumsum(mask, axis=1) + past_key_values_length) * mask
    return incremental_indices + padding_idx


class RobertaEmbeddings(nn.Layer):
    r"""
    Include embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config: RobertaConfig):
        super(RobertaEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.padding_idx = config.pad_token_id
        self.cls_token_id = config.cls_token_id

    def forward(
        self,
        input_ids: Optional[Tensor] = None,
        token_type_ids: Optional[Tensor] = None,
        position_ids: Optional[Tensor] = None,
        inputs_embeds: Optional[Tensor] = None,
        past_key_values_length: Optional[int] = None,
    ):

        if input_ids is not None:
            inputs_embeds = self.word_embeddings(input_ids)

        if position_ids is None:
            if input_ids is not None:
                position_ids = create_position_ids_from_input_ids(
                    input_ids, padding_idx=self.padding_idx, past_key_values_length=past_key_values_length
                )
            else:
                position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
            position_ids.stop_gradient = True

        if token_type_ids is None:
            input_shape = paddle.shape(inputs_embeds)[:-1]
            token_type_ids = paddle.zeros(input_shape, dtype="int64")

        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

    def create_position_ids_from_inputs_embeds(self, inputs_embeds):
        """
        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
        Args:
            input_shape: paddle.Tensor
        Returns: paddle.Tensor
        """
        input_shape = paddle.shape(inputs_embeds)[:-1]
        sequence_length = input_shape[1]

        position_ids = paddle.arange(self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype="int64")
        return position_ids.unsqueeze(0).expand(input_shape)


class RobertaPooler(nn.Layer):
    def __init__(self, hidden_size):
        super(RobertaPooler, self).__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


[文档]class RobertaPretrainedModel(PretrainedModel): r""" An abstract class for pretrained RoBerta models. It provides RoBerta related `model_config_file`, `pretrained_init_configuration`, `resource_files_names`, `pretrained_resource_files_map`, `base_model_prefix` for downloading and loading pretrained models. See :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details. """ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION config_class = RobertaConfig pretrained_resource_files_map = { "model_state": { "hfl/roberta-wwm-ext": "https://bj.bcebos.com/paddlenlp/models/transformers/roberta_base/roberta_chn_base.pdparams", "hfl/roberta-wwm-ext-large": "https://bj.bcebos.com/paddlenlp/models/transformers/roberta_large/roberta_chn_large.pdparams", "hfl/rbt6": "https://bj.bcebos.com/paddlenlp/models/transformers/rbt6/rbt6_chn_large.pdparams", "hfl/rbt4": "https://bj.bcebos.com/paddlenlp/models/transformers/rbt4/rbt4_chn_large.pdparams", "hfl/rbt3": "https://bj.bcebos.com/paddlenlp/models/transformers/rbt3/rbt3_chn_large.pdparams", "hfl/rbtl3": "https://bj.bcebos.com/paddlenlp/models/transformers/rbtl3/rbtl3_chn_large.pdparams", } } base_model_prefix = "roberta" @classmethod def _get_name_mappings(cls, config: RobertaConfig) -> list[StateDictNameMapping]: mappings = [ "embeddings.word_embeddings.weight", "embeddings.position_embeddings.weight", "embeddings.token_type_embeddings.weight", ["embeddings.LayerNorm.weight", "embeddings.layer_norm.weight"], ["embeddings.LayerNorm.bias", "embeddings.layer_norm.bias"], ] for layer_index in range(config.num_hidden_layers): layer_mappings = [ [ f"encoder.layer.{layer_index}.attention.self.query.weight", f"encoder.layers.{layer_index}.self_attn.q_proj.weight", "transpose", ], [ f"encoder.layer.{layer_index}.attention.self.query.bias", f"encoder.layers.{layer_index}.self_attn.q_proj.bias", ], [ f"encoder.layer.{layer_index}.attention.self.key.weight", f"encoder.layers.{layer_index}.self_attn.k_proj.weight", "transpose", ], [ f"encoder.layer.{layer_index}.attention.self.key.bias", f"encoder.layers.{layer_index}.self_attn.k_proj.bias", ], [ f"encoder.layer.{layer_index}.attention.self.value.weight", f"encoder.layers.{layer_index}.self_attn.v_proj.weight", "transpose", ], [ f"encoder.layer.{layer_index}.attention.self.value.bias", f"encoder.layers.{layer_index}.self_attn.v_proj.bias", ], [ f"encoder.layer.{layer_index}.attention.output.dense.weight", f"encoder.layers.{layer_index}.self_attn.out_proj.weight", "transpose", ], [ f"encoder.layer.{layer_index}.attention.output.dense.bias", f"encoder.layers.{layer_index}.self_attn.out_proj.bias", ], [ f"encoder.layer.{layer_index}.attention.output.LayerNorm.weight", f"encoder.layers.{layer_index}.norm1.weight", ], [ f"encoder.layer.{layer_index}.attention.output.LayerNorm.bias", f"encoder.layers.{layer_index}.norm1.bias", ], [ f"encoder.layer.{layer_index}.intermediate.dense.weight", f"encoder.layers.{layer_index}.linear1.weight", "transpose", ], [f"encoder.layer.{layer_index}.intermediate.dense.bias", f"encoder.layers.{layer_index}.linear1.bias"], [ f"encoder.layer.{layer_index}.output.dense.weight", f"encoder.layers.{layer_index}.linear2.weight", "transpose", ], [f"encoder.layer.{layer_index}.output.dense.bias", f"encoder.layers.{layer_index}.linear2.bias"], [f"encoder.layer.{layer_index}.output.LayerNorm.weight", f"encoder.layers.{layer_index}.norm2.weight"], [f"encoder.layer.{layer_index}.output.LayerNorm.bias", f"encoder.layers.{layer_index}.norm2.bias"], ] mappings.extend(layer_mappings) init_name_mappings(mappings) # Other than RobertaModel, other architectures will prepend model prefix if config.architectures is not None and "RobertaModel" not in config.architectures: for mapping in mappings: mapping[0] = "roberta." + mapping[0] if cls.__name__ != "RobertaModel": for mapping in mappings: mapping[1] = "roberta." + mapping[1] mappings.extend( [ ["pooler.dense.weight", "roberta.pooler.dense.weight", "transpose"], ["pooler.dense.bias", "roberta.pooler.dense.bias"], ] ) if config.architectures is not None: if "RobertaForSequenceClassification" in config.architectures: mappings.extend( [ ["classifier.out_proj.weight", None, "transpose"], "classifier.out_proj.bias", ["classifier.dense.weight", None, "transpose"], "classifier.dense.bias", ] ) if "RobertaForMaskedLM" in config.architectures: mappings.extend( [ "lm_head.bias", "lm_head.dense.weight", "lm_head.dense.bias", "lm_head.layer_norm.weight", "lm_head.layer_norm.bias", ] ) if ( "RobertaForTokenClassification" in config.architectures or "RobertaForMultipleChoice" in config.architectures ): mappings.extend( [ ["classifier.weight", None, "transpose"], "classifier.bias", ] ) if "RobertaForQuestionAnswering" in config.architectures: mappings.extend( [ ["qa_outputs.weight", "classifier.weight", "transpose"], ["qa_outputs.bias", "classifier.bias"], ] ) init_name_mappings(mappings) return [StateDictNameMapping(*mapping) for mapping in mappings] def _init_weights(self, layer): """Initialization hook""" if isinstance(layer, (nn.Linear, nn.Embedding)): # only support dygraph, use truncated_normal and make it inplace # and configurable later layer.weight.set_value( paddle.tensor.normal( mean=0.0, std=self.config.initializer_range, shape=layer.weight.shape, ) ) elif isinstance(layer, nn.LayerNorm): layer._epsilon = self.config.layer_norm_eps
[文档]@register_base_model class RobertaModel(RobertaPretrainedModel): r""" The bare Roberta Model outputting raw hidden-states. This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. Refer to the superclass documentation for the generic methods. This model is also a Paddle `paddle.nn.Layer <https://www.paddlepaddle.org.cn/documentation /docs/zh/api/paddle/nn/Layer_cn.html>`__ subclass. Use it as a regular Paddle Layer and refer to the Paddle documentation for all matter related to general usage and behavior. Args: vocab_size (int): Vocabulary size of `inputs_ids` in `RobertaModel`. Also is the vocab size of token embedding matrix. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling `RobertaModel`. hidden_size (int, optional): Dimensionality of the embedding layer, encoder layers and pooler layer. Defaults to `768`. num_hidden_layers (int, optional): Number of hidden layers in the Transformer encoder. Defaults to `12`. num_attention_heads (int, optional): Number of attention heads for each attention layer in the Transformer encoder. Defaults to `12`. intermediate_size (int, optional): Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors to ff layers are firstly projected from `hidden_size` to `intermediate_size`, and then projected back to `hidden_size`. Typically `intermediate_size` is larger than `hidden_size`. Defaults to `3072`. hidden_act (str, optional): The non-linear activation function in the feed-forward layer. ``"gelu"``, ``"relu"`` and any other paddle supported activation functions are supported. Defaults to ``"gelu"``. hidden_dropout_prob (float, optional): The dropout probability for all fully connected layers in the embeddings and encoder. Defaults to `0.1`. attention_probs_dropout_prob (float, optional): The dropout probability used in MultiHeadAttention in all encoder layers to drop some attention target. Defaults to `0.1`. max_position_embeddings (int, optional): The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input sequence. Defaults to `512`. type_vocab_size (int, optional): The vocabulary size of the `token_type_ids` passed when calling `~transformers.RobertaModel`. Defaults to `2`. initializer_range (float, optional): The standard deviation of the normal initializer. Defaults to 0.02. .. note:: A normal_initializer initializes weight matrices as normal distributions. See :meth:`RobertaPretrainedModel._init_weights()` for how weights are initialized in `RobertaModel`. pad_token_id(int, optional): The index of padding token in the token vocabulary. Defaults to `0`. cls_token_id(int, optional): The index of cls token in the token vocabulary. Defaults to `101`. """ def __init__(self, config: RobertaConfig, add_pooling_layer=True): super(RobertaModel, self).__init__(config) self.pad_token_id = config.pad_token_id self.initializer_range = config.initializer_range self.layer_norm_eps = config.layer_norm_eps self.embeddings = RobertaEmbeddings(config) encoder_layer = nn.TransformerEncoderLayer( config.hidden_size, config.num_attention_heads, config.intermediate_size, dropout=config.hidden_dropout_prob, activation=config.hidden_act, attn_dropout=config.attention_probs_dropout_prob, act_dropout=0, ) self.encoder = nn.TransformerEncoder(encoder_layer, config.num_hidden_layers) self.pooler = RobertaPooler(config.hidden_size) if add_pooling_layer else None
[文档] def get_input_embeddings(self): return self.embeddings.word_embeddings
[文档] def set_input_embeddings(self, value): self.embeddings.word_embeddings = value
[文档] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, past_key_values: Optional[Tuple[Tuple[Tensor]]] = None, inputs_embeds: Optional[Tensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (Tensor): Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. It's data type should be `int64` and has a shape of [batch_size, sequence_length]. token_type_ids (Tensor, optional): Segment token indices to indicate first and second portions of the inputs. Indices can be either 0 or 1: - 0 corresponds to a **sentence A** token, - 1 corresponds to a **sentence B** token. It's data type should be `int64` and has a shape of [batch_size, sequence_length]. Defaults to None, which means no segment embeddings is added to token embeddings. position_ids (Tensor, optional): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, max_position_embeddings - 1]``. It's data type should be `int64` and has a shape of [batch_size, sequence_length]. Defaults to `None`. attention_mask (Tensor, optional): Mask used in multi-head attention to avoid performing attention to some unwanted positions, usually the paddings or the subsequent positions. Its data type can be int, float and bool. When the data type is bool, the `masked` tokens have `False` values and the others have `True` values. When the data type is int, the `masked` tokens have `0` values and the others have `1` values. When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values. It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`. For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length], [batch_size, num_attention_heads, sequence_length, sequence_length]. Defaults to `None`, which means nothing needed to be prevented attention to. past_key_values (tuple(tuple(Tensor)), optional): The length of tuple equals to the number of layers, and each inner tuple haves 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`) which contains precomputed key and value hidden states of the attention blocks. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (Tensor, optional): If you want to control how to convert `inputs_ids` indices into associated vectors, you can pass an embedded representation directly instead of passing `inputs_ids`. use_cache (`bool`, optional): If set to `True`, `past_key_values` key value states are returned. Defaults to `None`. output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.model_outputs.ModelOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions`. Example: .. code-block:: import paddle from paddlenlp.transformers import RobertaModel, RobertaTokenizer tokenizer = RobertaTokenizer.from_pretrained('roberta-wwm-ext') model = RobertaModel.from_pretrained('roberta-wwm-ext') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} sequence_output, pooled_output = model(**inputs) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.") past_key_values_length = None if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] if attention_mask is None: attention_mask = paddle.unsqueeze( (input_ids == self.pad_token_id).astype(paddle.get_default_dtype()) * -1e4, axis=[1, 2] ) if past_key_values is not None: batch_size = past_key_values[0][0].shape[0] past_mask = paddle.zeros([batch_size, 1, 1, past_key_values_length], dtype=attention_mask.dtype) attention_mask = paddle.concat([past_mask, attention_mask], axis=-1) elif attention_mask.ndim == 2: attention_mask = paddle.unsqueeze(attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype()) attention_mask = (1.0 - attention_mask) * -1e4 embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) self.encoder._use_cache = use_cache # To be consistent with HF encoder_outputs = self.encoder( embedding_output, src_mask=attention_mask, cache=past_key_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if isinstance(encoder_outputs, type(embedding_output)): sequence_output = encoder_outputs pooled_output = self.pooler(sequence_output) if self.pooler is not None else None return (sequence_output, pooled_output) else: sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )
[文档]class RobertaForQuestionAnswering(RobertaPretrainedModel): r""" Roberta Model with a linear layer on top of the hidden-states output to compute `span_start_logits` and `span_end_logits`, designed for question-answering tasks like SQuAD. Args: roberta (:class:`RobertaModel`): An instance of RobertaModel. """ def __init__(self, config: RobertaConfig): super(RobertaForQuestionAnswering, self).__init__(config) self.roberta = RobertaModel(config, add_pooling_layer=False) self.classifier = nn.Linear(config.hidden_size, 2)
[文档] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, start_positions: Optional[Tensor] = None, end_positions: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (Tensor): See :class:`RobertaModel`. token_type_ids (Tensor, optional): See :class:`RobertaModel`. position_ids (Tensor, optional): See :class:`RobertaModel`. attention_mask (Tensor, optional): See :class:`RobertaModel`. inputs_embeds (Tensor, optional): See :class:`RobertaModel`. start_positions (Tensor of shape `(batch_size,)`, optional): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (Tensor of shape `(batch_size,)`, optional): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput`. Example: .. code-block:: import paddle from paddlenlp.transformers import RobertaForSequenceClassification, RobertaTokenizer tokenizer = RobertaTokenizer.from_pretrained('roberta-wwm-ext') model = RobertaForSequenceClassification.from_pretrained('roberta-wwm-ext') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} logits = model(**inputs) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.roberta( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.classifier(sequence_output) logits = paddle.transpose(logits, perm=[2, 0, 1]) start_logits, end_logits = paddle.unstack(x=logits, axis=0) total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if start_positions.ndim > 1: start_positions = start_positions.squeeze(-1) if start_positions.ndim > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = paddle.shape(start_logits)[1] start_positions = start_positions.clip(0, ignored_index) end_positions = end_positions.clip(0, ignored_index) loss_fct = paddle.nn.CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
class RobertaClassificationHead(nn.Layer): """Head for sentence-level classification tasks.""" def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, features, **kwargs): x = features[:, 0, :] # take <s> token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) x = paddle.tanh(x) x = self.dropout(x) x = self.out_proj(x) return x
[文档]class RobertaForSequenceClassification(RobertaPretrainedModel): r""" Roberta Model with a linear layer on top of the output layer, designed for sequence classification/regression tasks like GLUE tasks. Args: roberta (:class:`RobertaModel`): An instance of `RobertaModel`. num_classes (int, optional): The number of classes. Defaults to `2`. dropout (float, optional): The dropout probability for output of Roberta. If None, use the same value as `hidden_dropout_prob` of `RobertaModel` instance `roberta`. Defaults to `None`. """ def __init__(self, config: RobertaConfig): super(RobertaForSequenceClassification, self).__init__(config) self.roberta = RobertaModel(config, add_pooling_layer=False) self.dropout = nn.Dropout( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.classifier = RobertaClassificationHead(config)
[文档] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (Tensor): See :class:`RobertaModel`. token_type_ids (Tensor, optional): See :class:`RobertaModel`. position_ids (Tensor, optional): See :class:`RobertaModel`. attention_mask (Tensor, optional): See :class:`RobertaModel`. inputs_embeds (Tensor, optional): See :class:`RobertaModel`. labels (Tensor of shape `(batch_size,)`, optional): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., num_classes - 1]`. If `num_classes == 1` a regression loss is computed (Mean-Square loss), If `num_classes > 1` a classification loss is computed (Cross-Entropy). output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput`. Example: .. code-block:: import paddle from paddlenlp.transformers import RobertaForSequenceClassification, RobertaTokenizer tokenizer = RobertaTokenizer.from_pretrained('roberta-wwm-ext') model = RobertaForSequenceClassification.from_pretrained('roberta-wwm-ext') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} logits = model(**inputs) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.roberta( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: if self.config.num_labels == 1: loss_fct = paddle.nn.MSELoss() loss = loss_fct(logits, labels) elif labels.dtype == paddle.int64 or labels.dtype == paddle.int32: loss_fct = paddle.nn.CrossEntropyLoss() loss = loss_fct(logits.reshape((-1, self.config.num_labels)), labels.reshape((-1,))) else: loss_fct = paddle.nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else (output[0] if len(output) == 1 else output) return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
[文档]class RobertaForTokenClassification(RobertaPretrainedModel): r""" Roberta Model with a linear layer on top of the hidden-states output layer, designed for token classification tasks like NER tasks. Args: roberta (:class:`RobertaModel`): An instance of `RobertaModel`. num_classes (int, optional): The number of classes. Defaults to `2`. dropout (float, optional): The dropout probability for output of Roberta. If None, use the same value as `hidden_dropout_prob` of `RobertaModel` instance `roberta`. Defaults to `None`. """ def __init__(self, config: RobertaConfig): super(RobertaForTokenClassification, self).__init__(config) self.roberta = RobertaModel(config, add_pooling_layer=False) self.dropout = nn.Dropout( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
[文档] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (Tensor): See :class:`RobertaModel`. token_type_ids (Tensor, optional): See :class:`RobertaModel`. position_ids (Tensor, optional): See :class:`RobertaModel`. attention_mask (Tensor, optional): See :class:`RobertaModel`. inputs_embeds (Tensor, optional): See :class:`RobertaModel`. labels (Tensor of shape `(batch_size, sequence_length)`, optional): Labels for computing the token classification loss. Indices should be in `[0, ..., num_classes - 1]`. output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput`. Example: .. code-block:: import paddle from paddlenlp.transformers import RobertaForTokenClassification, RobertaTokenizer tokenizer = RobertaTokenizer.from_pretrained('roberta-wwm-ext') model = RobertaForTokenClassification.from_pretrained('roberta-wwm-ext') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} logits = model(**inputs) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.roberta( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: loss_fct = paddle.nn.CrossEntropyLoss() loss = loss_fct(logits.reshape((-1, self.config.num_labels)), labels.reshape((-1,))) if not return_dict: output = (logits,) + outputs[2:] if loss is not None: return (loss,) + output if len(output) == 1: return output[0] return output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
[文档]class RobertaForMultipleChoice(RobertaPretrainedModel): """ RoBerta Model with a linear layer on top of the hidden-states output layer, designed for multiple choice tasks like RocStories/SWAG tasks. Args: bert (:class:`RobertaModel`): An instance of RobertaModel. num_choices (int, optional): The number of choices. Defaults to `2`. dropout (float, optional): The dropout probability for output of Bert. If None, use the same value as `hidden_dropout_prob` of `RobertaModel` instance `bert`. Defaults to None. """ def __init__(self, config: RobertaConfig): super(RobertaForMultipleChoice, self).__init__(config) self.roberta = RobertaModel(config) self.dropout = nn.Dropout( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.classifier = nn.Linear(config.hidden_size, 1)
[文档] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" The RobertaForMultipleChoice forward method, overrides the __call__() special method. Args: input_ids (Tensor): See :class:`RobertaModel` and shape as [batch_size, num_choice, sequence_length]. token_type_ids(Tensor, optional): See :class:`RobertaModel` and shape as [batch_size, num_choice, sequence_length]. position_ids(Tensor, optional): See :class:`RobertaModel` and shape as [batch_size, num_choice, sequence_length]. attention_mask (list, optional): See :class:`RobertaModel` and shape as [batch_size, num_choice, sequence_length]. inputs_embeds (list, optional): See :class:`RobertaModel` and shape as [batch_size, num_choice, sequence_length]. labels (Tensor of shape `(batch_size, )`, optional): Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput`. Example: .. code-block:: import paddle from paddlenlp.transformers import BertForMultipleChoice, BertTokenizer from paddlenlp.data import Pad, Dict tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertForMultipleChoice.from_pretrained('bert-base-uncased', num_choices=2) data = [ { "question": "how do you turn on an ipad screen?", "answer1": "press the volume button.", "answer2": "press the lock button.", "label": 1, }, { "question": "how do you indent something?", "answer1": "leave a space before starting the writing", "answer2": "press the spacebar", "label": 0, }, ] text = [] text_pair = [] for d in data: text.append(d["question"]) text_pair.append(d["answer1"]) text.append(d["question"]) text_pair.append(d["answer2"]) inputs = tokenizer(text, text_pair) batchify_fn = lambda samples, fn=Dict( { "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id), # input_ids "token_type_ids": Pad( axis=0, pad_val=tokenizer.pad_token_type_id ), # token_type_ids } ): fn(samples) inputs = batchify_fn(inputs) reshaped_logits = model( input_ids=paddle.to_tensor(inputs[0], dtype="int64"), token_type_ids=paddle.to_tensor(inputs[1], dtype="int64"), ) print(reshaped_logits.shape) # [2, 2] """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None: num_choices = paddle.shape(input_ids)[1] elif inputs_embeds is not None: num_choices = paddle.shape(inputs_embeds)[1] input_ids = input_ids.reshape((-1, input_ids.shape[-1])) if input_ids is not None else None inputs_embeds = ( inputs_embeds.reshape((-1, inputs_embeds.shape[-2], inputs_embeds.shape[-1])) if inputs_embeds is not None else None ) position_ids = position_ids.reshape((-1, position_ids.shape[-1])) if position_ids is not None else None token_type_ids = token_type_ids.reshape((-1, token_type_ids.shape[-1])) if token_type_ids is not None else None attention_mask = attention_mask.reshape((-1, attention_mask.shape[-1])) if attention_mask is not None else None outputs = self.roberta( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) reshaped_logits = logits.reshape((-1, num_choices)) loss = None if labels is not None: loss_fct = paddle.nn.CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) if not return_dict: output = (reshaped_logits,) + outputs[2:] return ((loss,) + output) if loss is not None else (output[0] if len(output) == 1 else output) return MultipleChoiceModelOutput( loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
[文档]class RobertaForMaskedLM(RobertaPretrainedModel): """ Roberta Model with a `masked language modeling` head on top. Args: bert (:class:RobertaModel`): An instance of :class:`RobertaModel`. """ def __init__(self, config: RobertaConfig): super(RobertaForMaskedLM, self).__init__(config) self.roberta = RobertaModel(config, add_pooling_layer=False) self.lm_head = RobertaLMHead(config) self.tie_weights()
[文档] def get_output_embeddings(self): return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings
[文档] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (Tensor): See :class:`RobertaModel`. token_type_ids (Tensor, optional): See :class:`RobertaModel`. position_ids (Tensor, optional): See :class:`RobertaModel`. attention_mask (Tensor, optional): See :class:`RobertaModel`. inputs_embeds (Tensor, optional): See :class:`RobertaModel`. labels (Tensor of shape `(batch_size, sequence_length)`, optional): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., vocab_size]` output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput`. Example: .. code-block:: import paddle from paddlenlp.transformers import RobertaForMaskedLM, RobertaTokenizer tokenizer = RobertaTokenizer.from_pretrained('roberta-wwm-ext') model = RobertaForMaskedLM.from_pretrained('roberta-wwm-ext') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} logits = model(**inputs) print(logits.shape) # [1, 13, 30522] """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.roberta( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] prediction_scores = self.lm_head(sequence_output) masked_lm_loss = None if labels is not None: loss_fct = paddle.nn.CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct( prediction_scores.reshape((-1, prediction_scores.shape[-1])), labels.reshape((-1,)) ) if not return_dict: output = (prediction_scores,) + outputs[2:] return ( ((masked_lm_loss,) + output) if masked_lm_loss is not None else (output[0] if len(output) == 1 else output) ) return MaskedLMOutput( loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
class RobertaLMHead(nn.Layer): """Roberta Head for masked language modeling.""" def __init__(self, config: RobertaConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) self.decoder = TransposedLinear(config.hidden_size, config.vocab_size) # link bias to load pretrained weights self.bias = self.decoder.bias def forward(self, features, **kwargs): x = self.dense(features) x = F.gelu(x) x = self.layer_norm(x) # project back to size of vocabulary with bias x = self.decoder(x) return x
[文档]class RobertaForCausalLM(RobertaPretrainedModel): """ Roberta Model with a `Causal language modeling` head on top. Args: bert (:class:RobertaModel`): An instance of :class:`RobertaModel`. """ def __init__(self, config: RobertaConfig): super().__init__(config) self.roberta = RobertaModel(config, add_pooling_layer=False) self.lm_head = RobertaLMHead(config) self.tie_weights()
[文档] def get_output_embeddings(self): return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings
[文档] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, past_key_values: Optional[Tuple[Tuple[Tensor]]] = None, use_cache: Optional[bool] = None, labels: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (Tensor): See :class:`RobertaModel`. token_type_ids (Tensor, optional): See :class:`RobertaModel`. position_ids (Tensor, optional): See :class:`RobertaModel`. attention_mask (Tensor, optional): See :class:`RobertaModel`. inputs_embeds (Tensor, optional): See :class:`RobertaModel`. past_key_values (tuple(tuple(Tensor)), optional): See :class:`RobertaModel`. use_cache (Tensor, optional): See :class:`RobertaModel`. attention_mask (Tensor, optional): See :class:`RobertaModel`. labels (Tensor of shape `(batch_size, sequence_length)`, optional): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in `[-100, 0, ..., vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., vocab_size]`. output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithCrossAttentions` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithCrossAttentions` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithCrossAttentions`. Example: .. code-block:: import paddle from paddlenlp.transformers import RobertaForCausalLM, RobertaTokenizer tokenizer = RobertaTokenizer.from_pretrained('roberta-wwm-ext') model = RobertaForCausalLM.from_pretrained('roberta-wwm-ext') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} logits = model(**inputs) print(logits.shape) # [1, 13, 30522] """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: use_cache = False outputs = self.roberta( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] prediction_scores = self.lm_head(sequence_output) lm_loss = None if labels is not None: # we are doing next-token prediction; shift prediction scores and input ids by one shifted_prediction_scores = prediction_scores[:, :-1, :] labels = labels[:, 1:] loss_fct = paddle.nn.CrossEntropyLoss() lm_loss = loss_fct( shifted_prediction_scores.reshape((-1, prediction_scores.shape[-1])), labels.reshape((-1,)) ) if not return_dict: output = (prediction_scores,) + outputs[2:] return ((lm_loss,) + output) if lm_loss is not None else (output[0] if len(output) == 1 else output) return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): input_shape = input_ids.shape # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) # cut decoder_input_ids if past is used if past is not None: input_ids = input_ids[:, -1:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} def _reorder_cache(self, past, beam_idx): reordered_past = () for layer_past in past: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past