Source code for paddlenlp.transformers.ernie.modeling

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional, Tuple

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import Tensor

# TODO(guosheng): update this workaround import for in_declarative_mode
from paddle.nn.layer.layers import in_declarative_mode

from ...layers import Linear as TransposedLinear
from ...utils.env import CONFIG_NAME
from .. import PretrainedModel, register_base_model
from ..model_outputs import (
    BaseModelOutputWithPoolingAndCrossAttentions,
    MaskedLMOutput,
    ModelOutput,
    MultipleChoiceModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from .configuration import (
    ERNIE_PRETRAINED_INIT_CONFIGURATION,
    ERNIE_PRETRAINED_RESOURCE_FILES_MAP,
    ErnieConfig,
)

__all__ = [
    "ErnieModel",
    "ErniePretrainedModel",
    "ErnieForSequenceClassification",
    "ErnieForTokenClassification",
    "ErnieForQuestionAnswering",
    "ErnieForPretraining",
    "ErniePretrainingCriterion",
    "ErnieForMaskedLM",
    "ErnieForMultipleChoice",
    "UIE",
    "UTC",
]


class ErnieEmbeddings(nn.Layer):
    r"""
    Include embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config: ErnieConfig, weight_attr):
        super(ErnieEmbeddings, self).__init__()

        self.word_embeddings = nn.Embedding(
            config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id, weight_attr=weight_attr
        )
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size, weight_attr=weight_attr
        )
        self.type_vocab_size = config.type_vocab_size
        if self.type_vocab_size > 0:
            self.token_type_embeddings = nn.Embedding(
                config.type_vocab_size, config.hidden_size, weight_attr=weight_attr
            )
        self.use_task_id = config.use_task_id
        self.task_id = config.task_id
        if self.use_task_id:
            self.task_type_embeddings = nn.Embedding(
                config.task_type_vocab_size, config.hidden_size, weight_attr=weight_attr
            )
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        input_ids: Optional[Tensor] = None,
        token_type_ids: Optional[Tensor] = None,
        position_ids: Optional[Tensor] = None,
        task_type_ids: Optional[Tensor] = None,
        inputs_embeds: Optional[Tensor] = None,
        past_key_values_length: int = 0,
    ):

        if input_ids is not None:
            inputs_embeds = self.word_embeddings(input_ids)

        input_shape = inputs_embeds.shape[:-1] if in_declarative_mode() else paddle.shape(inputs_embeds)[:-1]

        if position_ids is None:
            # maybe need use shape op to unify static graph and dynamic graph
            ones = paddle.ones(input_shape, dtype="int64")
            seq_length = paddle.cumsum(ones, axis=1)
            position_ids = seq_length - ones

            if past_key_values_length > 0:
                position_ids = position_ids + past_key_values_length

            position_ids.stop_gradient = True

        position_embeddings = self.position_embeddings(position_ids)
        embeddings = inputs_embeds + position_embeddings

        if self.type_vocab_size > 0:
            if token_type_ids is None:
                token_type_ids = paddle.zeros(input_shape, dtype="int64")
            token_type_embeddings = self.token_type_embeddings(token_type_ids)
            embeddings = embeddings + token_type_embeddings

        if self.use_task_id:
            if task_type_ids is None:
                task_type_ids = paddle.ones(input_shape, dtype="int64") * self.task_id
            task_type_embeddings = self.task_type_embeddings(task_type_ids)
            embeddings = embeddings + task_type_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class ErniePooler(nn.Layer):
    def __init__(self, config: ErnieConfig, weight_attr):
        super(ErniePooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size, weight_attr=weight_attr)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


[docs]class ErniePretrainedModel(PretrainedModel): r""" An abstract class for pretrained ERNIE models. It provides ERNIE related `model_config_file`, `pretrained_init_configuration`, `resource_files_names`, `pretrained_resource_files_map`, `base_model_prefix` for downloading and loading pretrained models. Refer to :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details. """ model_config_file = CONFIG_NAME config_class = ErnieConfig resource_files_names = {"model_state": "model_state.pdparams"} base_model_prefix = "ernie" pretrained_init_configuration = ERNIE_PRETRAINED_INIT_CONFIGURATION pretrained_resource_files_map = ERNIE_PRETRAINED_RESOURCE_FILES_MAP def _init_weights(self, layer): """Initialization hook""" if isinstance(layer, (nn.Linear, nn.Embedding)): # only support dygraph, use truncated_normal and make it inplace # and configurable later if isinstance(layer.weight, paddle.Tensor): layer.weight.set_value( paddle.tensor.normal( mean=0.0, std=self.config.initializer_range, shape=layer.weight.shape, ) ) elif isinstance(layer, nn.LayerNorm): layer._epsilon = 1e-12
[docs]@register_base_model class ErnieModel(ErniePretrainedModel): r""" The bare ERNIE Model transformer outputting raw hidden-states. This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. Refer to the superclass documentation for the generic methods. This model is also a Paddle `paddle.nn.Layer <https://www.paddlepaddle.org.cn/documentation /docs/zh/api/paddle/nn/Layer_cn.html>`__ subclass. Use it as a regular Paddle Layer and refer to the Paddle documentation for all matter related to general usage and behavior. Args: config (:class:`ErnieConfig`): An instance of ErnieConfig used to construct ErnieModel """ def __init__(self, config: ErnieConfig): super(ErnieModel, self).__init__(config) self.pad_token_id = config.pad_token_id self.initializer_range = config.initializer_range weight_attr = paddle.ParamAttr( initializer=nn.initializer.TruncatedNormal(mean=0.0, std=self.initializer_range) ) self.embeddings = ErnieEmbeddings(config=config, weight_attr=weight_attr) encoder_layer = nn.TransformerEncoderLayer( config.hidden_size, config.num_attention_heads, config.intermediate_size, dropout=config.hidden_dropout_prob, activation=config.hidden_act, attn_dropout=config.attention_probs_dropout_prob, act_dropout=0, weight_attr=weight_attr, normalize_before=False, ) self.encoder = nn.TransformerEncoder(encoder_layer, config.num_hidden_layers) self.pooler = ErniePooler(config, weight_attr)
[docs] def get_input_embeddings(self): return self.embeddings.word_embeddings
[docs] def set_input_embeddings(self, value): self.embeddings.word_embeddings = value
[docs] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, task_type_ids: Optional[Tensor] = None, past_key_values: Optional[Tuple[Tuple[Tensor]]] = None, inputs_embeds: Optional[Tensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (Tensor): Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. It's data type should be `int64` and has a shape of [batch_size, sequence_length]. token_type_ids (Tensor, optional): Segment token indices to indicate different portions of the inputs. Selected in the range ``[0, type_vocab_size - 1]``. If `type_vocab_size` is 2, which means the inputs have two portions. Indices can either be 0 or 1: - 0 corresponds to a *sentence A* token, - 1 corresponds to a *sentence B* token. Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. Defaults to `None`, which means we don't add segment embeddings. position_ids (Tensor, optional): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, max_position_embeddings - 1]``. Shape as `[batch_size, num_tokens]` and dtype as int64. Defaults to `None`. attention_mask (Tensor, optional): Mask used in multi-head attention to avoid performing attention on to some unwanted positions, usually the paddings or the subsequent positions. Its data type can be int, float and bool. When the data type is bool, the `masked` tokens have `False` values and the others have `True` values. When the data type is int, the `masked` tokens have `0` values and the others have `1` values. When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values. It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`. For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length], [batch_size, num_attention_heads, sequence_length, sequence_length]. We use whole-word-mask in ERNIE, so the whole word will have the same value. For example, "使用" as a word, "使" and "用" will have the same value. Defaults to `None`, which means nothing needed to be prevented attention to. inputs_embeds (Tensor, optional): If you want to control how to convert `inputs_ids` indices into associated vectors, you can pass an embedded representation directly instead of passing `inputs_ids`. past_key_values (tuple(tuple(Tensor)), optional): The length of tuple equals to the number of layers, and each inner tuple haves 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`) which contains precomputed key and value hidden states of the attention blocks. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. use_cache (`bool`, optional): If set to `True`, `past_key_values` key value states are returned. Defaults to `None`. output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.model_outputs.ModelOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions`. Example: .. code-block:: import paddle from paddlenlp.transformers import ErnieModel, ErnieTokenizer tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') model = ErnieModel.from_pretrained('ernie-1.0') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} sequence_output, pooled_output = model(**inputs) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.") # init the default bool value output_attentions = output_attentions if output_attentions is not None else False output_hidden_states = output_hidden_states if output_hidden_states is not None else False return_dict = return_dict if return_dict is not None else False use_cache = use_cache if use_cache is not None else False past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] if attention_mask is None: attention_mask = paddle.unsqueeze( (input_ids == self.pad_token_id).astype(self.pooler.dense.weight.dtype) * -1e4, axis=[1, 2] ) if past_key_values is not None: batch_size = past_key_values[0][0].shape[0] past_mask = paddle.zeros([batch_size, 1, 1, past_key_values_length], dtype=attention_mask.dtype) attention_mask = paddle.concat([past_mask, attention_mask], axis=-1) # For 2D attention_mask from tokenizer elif attention_mask.ndim == 2: attention_mask = paddle.unsqueeze(attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype()) attention_mask = (1.0 - attention_mask) * -1e4 attention_mask.stop_gradient = True embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task_type_ids=task_type_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) self.encoder._use_cache = use_cache # To be consistent with HF encoder_outputs = self.encoder( embedding_output, src_mask=attention_mask, cache=past_key_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if isinstance(encoder_outputs, type(embedding_output)): sequence_output = encoder_outputs pooled_output = self.pooler(sequence_output) return (sequence_output, pooled_output) else: sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )
[docs]class ErnieForSequenceClassification(ErniePretrainedModel): r""" Ernie Model with a linear layer on top of the output layer, designed for sequence classification/regression tasks like GLUE tasks. Args: config (:class:`ErnieConfig`): An instance of ErnieConfig used to construct ErnieForSequenceClassification. """ def __init__(self, config): super(ErnieForSequenceClassification, self).__init__(config) self.ernie = ErnieModel(config) self.num_labels = config.num_labels self.dropout = nn.Dropout( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
[docs] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (Tensor): See :class:`ErnieModel`. token_type_ids (Tensor, optional): See :class:`ErnieModel`. position_ids (Tensor, optional): See :class:`ErnieModel`. attention_mask (Tensor, optional): See :class:`ErnieModel`. inputs_embeds(Tensor, optional): See :class:`ErnieModel`. labels (Tensor of shape `(batch_size,)`, optional): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., num_labels - 1]`. If `num_labels == 1` a regression loss is computed (Mean-Square loss), If `num_labels > 1` a classification loss is computed (Cross-Entropy). output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput`. Example: .. code-block:: import paddle from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') model = ErnieForSequenceClassification.from_pretrained('ernie-1.0') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} logits = model(**inputs) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.ernie( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == paddle.int64 or labels.dtype == paddle.int32): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = paddle.nn.MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = paddle.nn.CrossEntropyLoss() loss = loss_fct(logits.reshape((-1, self.num_labels)), labels.reshape((-1,))) elif self.config.problem_type == "multi_label_classification": loss_fct = paddle.nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else (output[0] if len(output) == 1 else output) return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
[docs]class ErnieForQuestionAnswering(ErniePretrainedModel): """ Ernie Model with a linear layer on top of the hidden-states output to compute `span_start_logits` and `span_end_logits`, designed for question-answering tasks like SQuAD. Args: config (:class:`ErnieConfig`): An instance of ErnieConfig used to construct ErnieForQuestionAnswering. """ def __init__(self, config): super(ErnieForQuestionAnswering, self).__init__(config) self.ernie = ErnieModel(config) self.classifier = nn.Linear(config.hidden_size, 2)
[docs] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, start_positions: Optional[Tensor] = None, end_positions: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (Tensor): See :class:`ErnieModel`. token_type_ids (Tensor, optional): See :class:`ErnieModel`. position_ids (Tensor, optional): See :class:`ErnieModel`. attention_mask (Tensor, optional): See :class:`ErnieModel`. inputs_embeds(Tensor, optional): See :class:`ErnieModel`. start_positions (Tensor of shape `(batch_size,)`, optional): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (Tensor of shape `(batch_size,)`, optional): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput`. Example: .. code-block:: import paddle from paddlenlp.transformers import ErnieForQuestionAnswering, ErnieTokenizer tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') model = ErnieForQuestionAnswering.from_pretrained('ernie-1.0') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} logits = model(**inputs) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.ernie( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.classifier(sequence_output) logits = paddle.transpose(logits, perm=[2, 0, 1]) start_logits, end_logits = paddle.unstack(x=logits, axis=0) total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if start_positions.ndim > 1: start_positions = start_positions.squeeze(-1) if start_positions.ndim > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = paddle.shape(start_logits)[1] start_positions = start_positions.clip(0, ignored_index) end_positions = end_positions.clip(0, ignored_index) loss_fct = paddle.nn.CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
[docs]class ErnieForTokenClassification(ErniePretrainedModel): r""" ERNIE Model with a linear layer on top of the hidden-states output layer, designed for token classification tasks like NER tasks. Args: config (:class:`ErnieConfig`): An instance of ErnieConfigused to construct ErnieForTokenClassification. """ def __init__(self, config: ErnieConfig): super(ErnieForTokenClassification, self).__init__(config) self.ernie = ErnieModel(config) self.num_labels = config.num_labels self.dropout = nn.Dropout( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
[docs] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (Tensor): See :class:`ErnieModel`. token_type_ids (Tensor, optional): See :class:`ErnieModel`. position_ids (Tensor, optional): See :class:`ErnieModel`. attention_mask (Tensor, optional): See :class:`ErnieModel`. inputs_embeds(Tensor, optional): See :class:`ErnieModel`. labels (Tensor of shape `(batch_size, sequence_length)`, optional): Labels for computing the token classification loss. Indices should be in `[0, ..., num_labels - 1]`. output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput`. Example: .. code-block:: import paddle from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') model = ErnieForTokenClassification.from_pretrained('ernie-1.0') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} logits = model(**inputs) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.ernie( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: loss_fct = paddle.nn.CrossEntropyLoss() loss = loss_fct(logits.reshape((-1, self.num_labels)), labels.reshape((-1,))) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else (output[0] if len(output) == 1 else output) return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
class ErnieLMPredictionHead(nn.Layer): r""" Ernie Model with a `language modeling` head on top. """ def __init__( self, config: ErnieConfig, weight_attr=None, ): super(ErnieLMPredictionHead, self).__init__() self.transform = nn.Linear(config.hidden_size, config.hidden_size, weight_attr=weight_attr) self.activation = getattr(nn.functional, config.hidden_act) self.layer_norm = nn.LayerNorm(config.hidden_size) self.decoder = TransposedLinear(config.hidden_size, config.vocab_size) # link bias to load pretrained weights self.decoder_bias = self.decoder.bias def forward(self, hidden_states, masked_positions=None): if masked_positions is not None: hidden_states = paddle.reshape(hidden_states, [-1, hidden_states.shape[-1]]) hidden_states = paddle.tensor.gather(hidden_states, masked_positions) # gather masked tokens might be more quick hidden_states = self.transform(hidden_states) hidden_states = self.activation(hidden_states) hidden_states = self.layer_norm(hidden_states) hidden_states = self.decoder(hidden_states) # hidden_states = paddle.tensor.matmul(hidden_states, self.decoder.weight, transpose_y=True) + self.decoder_bias return hidden_states class ErniePretrainingHeads(nn.Layer): def __init__( self, config: ErnieConfig, weight_attr=None, ): super(ErniePretrainingHeads, self).__init__() self.predictions = ErnieLMPredictionHead(config, weight_attr) self.seq_relationship = nn.Linear(config.hidden_size, 2, weight_attr=weight_attr) def forward(self, sequence_output, pooled_output, masked_positions=None): prediction_scores = self.predictions(sequence_output, masked_positions) seq_relationship_score = self.seq_relationship(pooled_output) return prediction_scores, seq_relationship_score @dataclass class ErnieForPreTrainingOutput(ModelOutput): """ Output type of [`ErnieForPreTraining`]. Args: loss (*optional*, returned when `labels` is provided, `paddle.Tensor` of shape `(1,)`): Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss. prediction_logits (`paddle.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). seq_relationship_logits (`paddle.Tensor` of shape `(batch_size, 2)`): Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `paddle.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ loss: Optional[paddle.Tensor] = None prediction_logits: paddle.Tensor = None seq_relationship_logits: paddle.Tensor = None hidden_states: Optional[Tuple[paddle.Tensor]] = None attentions: Optional[Tuple[paddle.Tensor]] = None
[docs]class ErnieForPretraining(ErniePretrainedModel): r""" Ernie Model with a `masked language modeling` head and a `sentence order prediction` head on top. """ def __init__(self, config: ErnieConfig): super(ErnieForPretraining, self).__init__(config) self.ernie = ErnieModel(config) weight_attr = paddle.ParamAttr( initializer=nn.initializer.TruncatedNormal(mean=0.0, std=self.ernie.initializer_range) ) self.cls = ErniePretrainingHeads( config=config, weight_attr=weight_attr, ) self.tie_weights()
[docs] def get_output_embeddings(self): return self.cls.predictions.decoder
[docs] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, masked_positions: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, labels: Optional[Tensor] = None, next_sentence_label: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (Tensor): See :class:`ErnieModel`. token_type_ids (Tensor, optional): See :class:`ErnieModel`. position_ids (Tensor, optional): See :class:`ErnieModel`. attention_mask (Tensor, optional): See :class:`ErnieModel`. inputs_embeds(Tensor, optional): See :class:`ErnieModel`. labels (Tensor of shape `(batch_size, sequence_length)`, optional): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., vocab_size]`. next_sentence_label (Tensor of shape `(batch_size,)`, optional): Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see `input_ids` docstring) Indices should be in `[0, 1]`: - 0 indicates sequence B is a continuation of sequence A, - 1 indicates sequence B is a random sequence. output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.bert.ErnieForPreTrainingOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.bert.ErnieForPreTrainingOutput` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.bert.ErnieForPreTrainingOutput`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict with paddle.static.amp.fp16_guard(): outputs = self.ernie( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output, pooled_output = outputs[:2] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output, masked_positions) total_loss = None if labels is not None and next_sentence_label is not None: loss_fct = paddle.nn.CrossEntropyLoss() masked_lm_loss = loss_fct( prediction_scores.reshape((-1, paddle.shape(prediction_scores)[-1])), labels.reshape((-1,)) ) next_sentence_loss = loss_fct( seq_relationship_score.reshape((-1, 2)), next_sentence_label.reshape((-1,)) ) total_loss = masked_lm_loss + next_sentence_loss if not return_dict: output = (prediction_scores, seq_relationship_score) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output return ErnieForPreTrainingOutput( loss=total_loss, prediction_logits=prediction_scores, seq_relationship_logits=seq_relationship_score, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
[docs]class ErniePretrainingCriterion(paddle.nn.Layer): r""" The loss output of Ernie Model during the pretraining: a `masked language modeling` head and a `next sentence prediction (classification)` head. """ def __init__(self, with_nsp_loss=True): super(ErniePretrainingCriterion, self).__init__() self.with_nsp_loss = with_nsp_loss # self.loss_fn = paddle.nn.loss.CrossEntropyLoss(ignore_index=-1)
[docs] def forward(self, prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels=None): """ Args: prediction_scores(Tensor): The scores of masked token prediction. Its data type should be float32. If `masked_positions` is None, its shape is [batch_size, sequence_length, vocab_size]. Otherwise, its shape is [batch_size, mask_token_num, vocab_size] seq_relationship_score(Tensor): The scores of next sentence prediction. Its data type should be float32 and its shape is [batch_size, 2] masked_lm_labels(Tensor): The labels of the masked language modeling, its dimensionality is equal to `prediction_scores`. Its data type should be int64. If `masked_positions` is None, its shape is [batch_size, sequence_length, 1]. Otherwise, its shape is [batch_size, mask_token_num, 1] next_sentence_labels(Tensor): The labels of the next sentence prediction task, the dimensionality of `next_sentence_labels` is equal to `seq_relation_labels`. Its data type should be int64 and its shape is [batch_size, 1] Returns: Tensor: The pretraining loss, equals to the sum of `masked_lm_loss` plus the mean of `next_sentence_loss`. Its data type should be float32 and its shape is [1]. """ with paddle.static.amp.fp16_guard(): masked_lm_loss = F.cross_entropy(prediction_scores, masked_lm_labels, ignore_index=-1, reduction="none") if not self.with_nsp_loss: return paddle.mean(masked_lm_loss) next_sentence_loss = F.cross_entropy(seq_relationship_score, next_sentence_labels, reduction="none") return paddle.mean(masked_lm_loss), paddle.mean(next_sentence_loss)
class ErnieOnlyMLMHead(nn.Layer): def __init__(self, config: ErnieConfig): super().__init__() self.predictions = ErnieLMPredictionHead(config=config) def forward(self, sequence_output, masked_positions=None): prediction_scores = self.predictions(sequence_output, masked_positions) return prediction_scores
[docs]class ErnieForMaskedLM(ErniePretrainedModel): """ Ernie Model with a `masked language modeling` head on top. Args: config (:class:`ErnieConfig`): An instance of ErnieConfig used to construct ErnieForMaskedLM. """ def __init__(self, config: ErnieConfig): super(ErnieForMaskedLM, self).__init__(config) self.ernie = ErnieModel(config) self.cls = ErnieOnlyMLMHead(config=config) self.tie_weights()
[docs] def get_output_embeddings(self): return self.cls.predictions.decoder
[docs] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, masked_positions: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (Tensor): See :class:`ErnieModel`. token_type_ids (Tensor, optional): See :class:`ErnieModel`. position_ids (Tensor, optional): See :class:`ErnieModel`. attention_mask (Tensor, optional): See :class:`ErnieModel`. masked_positions: masked positions of output. inputs_embeds(Tensor, optional): See :class:`ErnieModel`. labels (Tensor of shape `(batch_size, sequence_length)`, optional): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., vocab_size]` output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput`. Example: .. code-block:: import paddle from paddlenlp.transformers import ErnieForMaskedLM, ErnieTokenizer tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') model = ErnieForMaskedLM.from_pretrained('ernie-1.0') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} logits = model(**inputs) print(logits.shape) # [1, 17, 18000] """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.ernie( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] prediction_scores = self.cls(sequence_output, masked_positions=masked_positions) masked_lm_loss = None if labels is not None: loss_fct = paddle.nn.CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct( prediction_scores.reshape((-1, paddle.shape(prediction_scores)[-1])), labels.reshape((-1,)) ) if not return_dict: output = (prediction_scores,) + outputs[2:] return ( ((masked_lm_loss,) + output) if masked_lm_loss is not None else (output[0] if len(output) == 1 else output) ) return MaskedLMOutput( loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
[docs]class ErnieForMultipleChoice(ErniePretrainedModel): """ Ernie Model with a linear layer on top of the hidden-states output layer, designed for multiple choice tasks like RocStories/SWAG tasks. Args: config (:class:`ErnieConfig`): An instance of ErnieConfig used to construct ErnieForMultipleChoice """ def __init__(self, config: ErnieConfig): super(ErnieForMultipleChoice, self).__init__(config) self.ernie = ErnieModel(config) self.num_choices = config.num_choices if config.num_choices is not None else 2 self.dropout = nn.Dropout( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.classifier = nn.Linear(config.hidden_size, 1)
[docs] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, labels: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" The ErnieForMultipleChoice forward method, overrides the __call__() special method. Args: input_ids (Tensor): See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length]. token_type_ids(Tensor, optional): See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length]. position_ids(Tensor, optional): See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length]. attention_mask (list, optional): See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length]. inputs_embeds(Tensor, optional): See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length, hidden_size]. labels (Tensor of shape `(batch_size, )`, optional): Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. output_attentions (bool, optional): Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): Whether to return a :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: An instance of :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # input_ids: [bs, num_choice, seq_l] if input_ids is not None: input_ids = input_ids.reshape(shape=(-1, input_ids.shape[-1])) # flat_input_ids: [bs*num_choice,seq_l] if position_ids is not None: position_ids = position_ids.reshape(shape=(-1, position_ids.shape[-1])) if token_type_ids is not None: token_type_ids = token_type_ids.reshape(shape=(-1, token_type_ids.shape[-1])) if attention_mask is not None: attention_mask = attention_mask.reshape(shape=(-1, attention_mask.shape[-1])) if inputs_embeds is not None: inputs_embeds = inputs_embeds.reshape(shape=(-1, inputs_embeds.shape[-2], inputs_embeds.shape[-1])) outputs = self.ernie( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) # logits: (bs*num_choice,1) reshaped_logits = logits.reshape(shape=(-1, self.num_choices)) # logits: (bs, num_choice) loss = None if labels is not None: loss_fct = paddle.nn.CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) if not return_dict: output = (reshaped_logits,) + outputs[2:] return ((loss,) + output) if loss is not None else (output[0] if len(output) == 1 else output) return MultipleChoiceModelOutput( loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
[docs]class UIE(ErniePretrainedModel): """ Ernie Model with two linear layer on top of the hidden-states output to compute `start_prob` and `end_prob`, designed for Universal Information Extraction. Args: config (:class:`ErnieConfig`): An instance of ErnieConfig used to construct UIE """ def __init__(self, config: ErnieConfig): super(UIE, self).__init__(config) self.ernie = ErnieModel(config) self.linear_start = paddle.nn.Linear(config.hidden_size, 1) self.linear_end = paddle.nn.Linear(config.hidden_size, 1) self.sigmoid = nn.Sigmoid()
[docs] def forward( self, input_ids: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, return_dict: Optional[Tensor] = None, ): r""" Args: input_ids (Tensor): See :class:`ErnieModel`. token_type_ids (Tensor, optional): See :class:`ErnieModel`. position_ids (Tensor, optional): See :class:`ErnieModel`. attention_mask (Tensor, optional): See :class:`ErnieModel`. Example: .. code-block:: import paddle from paddlenlp.transformers import UIE, ErnieTokenizer tokenizer = ErnieTokenizer.from_pretrained('uie-base') model = UIE.from_pretrained('uie-base') inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} start_prob, end_prob = model(**inputs) """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict sequence_output, _ = self.ernie( input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, return_dict=return_dict, ) start_logits = self.linear_start(sequence_output) start_logits = paddle.squeeze(start_logits, -1) start_prob = self.sigmoid(start_logits) end_logits = self.linear_end(sequence_output) end_logits = paddle.squeeze(end_logits, -1) end_prob = self.sigmoid(end_logits) return start_prob, end_prob
[docs]class UTC(ErniePretrainedModel): """ Ernie Model with two linear layer on the top of the hidden-states output to compute probability of candidate labels, designed for Unified Tag Classification. """ def __init__(self, config: ErnieConfig): super(UTC, self).__init__(config) self.ernie = ErnieModel(config) self.predict_size = 64 self.linear_q = paddle.nn.Linear(config.hidden_size, self.predict_size) self.linear_k = paddle.nn.Linear(config.hidden_size, self.predict_size)
[docs] def forward( self, input_ids, token_type_ids, position_ids, attention_mask, omask_positions, cls_positions, inputs_embeds: Optional[Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (Tensor): See :class:`ErnieModel`. token_type_ids (Tensor): See :class:`ErnieModel`. position_ids (Tensor): See :class:`ErnieModel`. attention_mask (Tensor): See :class:`ErnieModel`. omask_positions (Tensor of shape `(batch_size, max_option)`): Masked positions of [O-MASK] tokens padded with 0. cls_positions (Tensor of shape `(batch_size)`): Masked positions of the second [CLS] token. labels (Tensor of shape `(num_labels_in_batch,)`, optional): Labels for computing classification loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.ernie( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] batch_size, seq_len, hidden_size = sequence_output.shape flat_sequence_output = paddle.reshape(sequence_output, [-1, hidden_size]) flat_length = paddle.arange(batch_size) * seq_len flat_length = flat_length.unsqueeze(axis=1).astype("int64") cls_output = paddle.tensor.gather(flat_sequence_output, cls_positions + flat_length.squeeze(1)) q = self.linear_q(cls_output) option_output = paddle.tensor.gather(flat_sequence_output, paddle.reshape(omask_positions + flat_length, [-1])) option_output = paddle.reshape(option_output, [batch_size, -1, hidden_size]) k = self.linear_k(option_output) option_logits = paddle.matmul(q.unsqueeze(1), k, transpose_y=True).squeeze(1) option_logits = option_logits / self.predict_size**0.5 if hasattr(paddle.framework, "_no_check_dy2st_diff"): # TODO(wanghuancoder): _no_check_dy2st_diff is used to turn off the checking of behavior # inconsistency between dynamic graph and static graph. _no_check_dy2st_diff should be # removed after static graphs support inplace and stride. with paddle.framework._no_check_dy2st_diff(): for index, logit in enumerate(option_logits): option_logits[index] -= (1 - (omask_positions[index] > 0).astype("float32")) * 1e12 else: for index, logit in enumerate(option_logits): option_logits[index] -= (1 - (omask_positions[index] > 0).astype("float32")) * 1e12 loss = None if not return_dict: output = (option_logits,) if output_hidden_states: output = output + (outputs.hidden_states,) if output_attentions: output = output + (output.attentions,) return ((loss,) + output) if loss is not None else (output[0] if len(output) == 1 else output) return MultipleChoiceModelOutput( loss=loss, logits=option_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )