# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import paddle
import paddle.nn as nn
from paddle import Tensor
from ...utils.env import CONFIG_NAME
from .. import PretrainedModel, register_base_model
from ..model_outputs import (
BaseModelOutputWithPoolingAndCrossAttentions,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
tuple_output,
)
from .configuration import (
ERNIE_GRAM_PRETRAINED_INIT_CONFIGURATION,
ERNIE_GRAM_PRETRAINED_RESOURCE_FILES_MAP,
ErnieGramConfig,
)
__all__ = [
"ErnieGramModel",
"ErnieGramPretrainedModel",
"ErnieGramForSequenceClassification",
"ErnieGramForTokenClassification",
"ErnieGramForQuestionAnswering",
]
class ErnieGramEmbeddings(nn.Layer):
r"""
Include embeddings from word, position and token_type embeddings.
"""
def __init__(self, config: ErnieGramConfig):
super(ErnieGramEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
if config.rel_pos_size and config.num_attention_heads:
self.rel_pos_embeddings = nn.Embedding(config.rel_pos_size, config.num_attention_heads)
self.layer_norm = nn.LayerNorm(config.embedding_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self,
input_ids: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
past_key_values_length: int = 0,
):
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
input_shape = paddle.shape(inputs_embeds)[:-1]
if position_ids is None:
# maybe need use shape op to unify static graph and dynamic graph
ones = paddle.ones(input_shape, dtype="int64")
seq_length = paddle.cumsum(ones, axis=1)
position_ids = seq_length - ones
if past_key_values_length > 0:
position_ids = position_ids + past_key_values_length
position_ids.stop_gradient = True
if token_type_ids is None:
token_type_ids_shape = input_shape
token_type_ids = paddle.zeros(token_type_ids_shape, dtype="int64")
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.layer_norm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class ErnieGramPooler(nn.Layer):
def __init__(self, config: ErnieGramConfig, weight_attr=None):
super(ErnieGramPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size, weight_attr=weight_attr)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
[文档]
class ErnieGramPretrainedModel(PretrainedModel):
r"""
An abstract class for pretrained ERNIE-Gram models. It provides ERNIE-Gram related
`model_config_file`, `resource_files_names`, `pretrained_resource_files_map`,
`pretrained_init_configuration`, `base_model_prefix` for downloading and
loading pretrained models.
See :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details.
"""
pretrained_init_configuration = ERNIE_GRAM_PRETRAINED_INIT_CONFIGURATION
pretrained_resource_files_map = ERNIE_GRAM_PRETRAINED_RESOURCE_FILES_MAP
base_model_prefix = "ernie_gram"
config_class = ErnieGramConfig
model_config_file = CONFIG_NAME
resource_files_names = {"model_state": "model_state.pdparams"}
def _init_weights(self, layer):
"""Initialization hook"""
if isinstance(layer, (nn.Linear, nn.Embedding)):
# only support dygraph, use truncated_normal and make it inplace
# and configurable later
if isinstance(layer.weight, paddle.Tensor):
layer.weight.set_value(
paddle.tensor.normal(
mean=0.0,
std=self.config.initializer_range,
shape=layer.weight.shape,
)
)
elif isinstance(layer, nn.LayerNorm):
layer._epsilon = 1e-5
[文档]
@register_base_model
class ErnieGramModel(ErnieGramPretrainedModel):
r"""
The bare ERNIE-Gram Model transformer outputting raw hidden-states.
This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`.
Refer to the superclass documentation for the generic methods.
This model is also a Paddle `paddle.nn.Layer <https://www.paddlepaddle.org.cn/documentation
/docs/zh/api/paddle/nn/Layer_cn.html>`__ subclass. Use it as a regular Paddle Layer
and refer to the Paddle documentation for all matter related to general usage and behavior.
Args:
config (:class:`ErnieGramConfig`):
An instance of ErnieGramConfig used to construct ErnieGramModel.
"""
def __init__(self, config: ErnieGramConfig):
super(ErnieGramModel, self).__init__(config)
self.config = config
self.pad_token_id = config.pad_token_id
self.initializer_range = config.initializer_range
self.embeddings = ErnieGramEmbeddings(config)
encoder_layer = nn.TransformerEncoderLayer(
config.hidden_size,
config.num_attention_heads,
config.intermediate_size,
dropout=config.hidden_dropout_prob,
activation=config.hidden_act,
attn_dropout=config.attention_probs_dropout_prob,
act_dropout=0,
)
self.encoder = nn.TransformerEncoder(encoder_layer, config.num_hidden_layers)
self.pooler = ErnieGramPooler(config)
[文档]
def forward(
self,
input_ids: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
past_key_values: Optional[Tuple[Tuple[Tensor]]] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Args:
input_ids (Tensor):
Indices of input sequence tokens in the vocabulary. They are
numerical representations of tokens that build the input sequence.
It's data type should be `int64` and has a shape of [batch_size, sequence_length].
token_type_ids (Tensor, optional):
Segment token indices to indicate first and second portions of the inputs.
Indices can be either 0 or 1:
- 0 corresponds to a **sentence A** token,
- 1 corresponds to a **sentence B** token.
It's data type should be `int64` and has a shape of [batch_size, sequence_length].
Defaults to None, which means no segment embeddings is added to token embeddings.
position_ids (Tensor, optional):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
Defaults to `None`. Shape as `(batch_sie, num_tokens)` and dtype as `int32` or `int64`.
attention_mask (Tensor, optional):
Mask used in multi-head attention to avoid performing attention on to some unwanted positions,
usually the paddings or the subsequent positions.
Its data type can be int, float and bool.
When the data type is bool, the `masked` tokens have `False` values and the others have `True` values.
When the data type is int, the `masked` tokens have `0` values and the others have `1` values.
When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values.
It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`.
For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length],
[batch_size, num_attention_heads, sequence_length, sequence_length].
We use whole-word-mask in ERNIE, so the whole word will have the same value. For example, "使用" as a word,
"使" and "用" will have the same value.
Defaults to `None`, which means nothing needed to be prevented attention to.
inputs_embeds (Tensor, optional):
If you want to control how to convert `inputs_ids` indices into associated vectors, you can
pass an embedded representation directly instead of passing `inputs_ids`.
past_key_values (tuple(tuple(Tensor)), optional):
The length of tuple equals to the number of layers, and each inner
tuple haves 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`)
which contains precomputed key and value hidden states of the attention blocks.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, optional):
If set to `True`, `past_key_values` key value states are returned.
Defaults to `None`.
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.ModelOutput` object. If `False`, the output
will be a tuple of tensors. Defaults to `False`.
Returns:
tuple: Returns tuple (``sequence_output``, ``pooled_output``).
With the fields:
- `sequence_output` (Tensor):
Sequence of hidden-states at the last layer of the model.
It's data type should be float32 and its shape is [batch_size, sequence_length, hidden_size].
- `pooled_output` (Tensor):
The output of first token (`[CLS]`) in sequence.
We "pool" the model by simply taking the hidden state corresponding to the first token.
Its data type should be float32 and its shape is [batch_size, hidden_size].
Example:
.. code-block::
import paddle
from paddlenlp.transformers import ErnieGramModel, ErnieGramTokenizer
tokenizer = ErnieGramTokenizer.from_pretrained('ernie-gram-zh')
model = ErnieGramModel.from_pretrained('ernie-gram-zh)
inputs = tokenizer("欢迎使用百度飞桨!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
sequence_output, pooled_output = model(**inputs)
"""
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.")
# init the default bool value
output_attentions = output_attentions if output_attentions is not None else False
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
return_dict = return_dict if return_dict is not None else False
use_cache = use_cache if use_cache is not None else False
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if attention_mask is None:
attention_mask = paddle.unsqueeze(
(input_ids == self.pad_token_id).astype(self.pooler.dense.weight.dtype) * -1e4, axis=[1, 2]
)
if past_key_values is not None:
batch_size = past_key_values[0][0].shape[0]
past_mask = paddle.zeros([batch_size, 1, 1, past_key_values_length], dtype=attention_mask.dtype)
attention_mask = paddle.concat([past_mask, attention_mask], axis=-1)
# For 2D attention_mask from tokenizer
elif attention_mask.ndim == 2:
attention_mask = paddle.unsqueeze(attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype())
attention_mask = (1.0 - attention_mask) * -1e4
attention_mask.stop_gradient = True
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
self.encoder._use_cache = use_cache # To be consistent with HF
encoder_outputs = self.encoder(
embedding_output,
attention_mask,
cache=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if isinstance(encoder_outputs, type(input_ids)):
sequence_output = encoder_outputs
pooled_output = self.pooler(sequence_output)
return (sequence_output, pooled_output)
else:
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
[文档]
class ErnieGramForTokenClassification(ErnieGramPretrainedModel):
r"""
ERNIE-Gram Model with a linear layer on top of the hidden-states output layer,
designed for token classification tasks like NER tasks.
Args:
config (:class:`ErnieGramConfig`):
An instance of ErnieGramConfig used to construct ErnieGramForTokenClassification.
"""
def __init__(self, config: ErnieGramConfig):
super(ErnieGramForTokenClassification, self).__init__(config)
self.config = config
self.num_labels = config.num_labels
self.ernie_gram = ErnieGramModel(config) # allow ernie_gram to be config
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(
config.hidden_size,
config.num_labels,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.TruncatedNormal(std=config.initializer_range)),
)
[文档]
def forward(
self,
input_ids: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Args:
input_ids (Tensor):
See :class:`ErnieGramModel`.
token_type_ids (Tensor, optional):
See :class:`ErnieGramModel`.
position_ids (Tensor, optional):
See :class:`ErnieGramModel`.
attention_mask (Tensor, optional):
See :class:`ErnieGramModel`.
labels (Tensor of shape `(batch_size, sequence_length)`, optional):
Labels for computing the token classification loss. Indices should be in `[0, ..., num_labels - 1]`.
inputs_embeds(Tensor, optional):
See :class:`ErnieGramModel`.
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If
`False`, the output will be a tuple of tensors. Defaults to `False`.
Returns:
Tensor: Returns tensor `logits`, a tensor of the input token classification logits.
Shape as `[batch_size, sequence_length, num_labels]` and dtype as `float32`.
Example:
.. code-block::
import paddle
from paddlenlp.transformers import ErnieGramForTokenClassification, ErnieGramTokenizer
tokenizer = ErnieGramTokenizer.from_pretrained('ernie-gram-zh')
model = ErnieGramForTokenClassification.from_pretrained('ernie-gram-zh')
inputs = tokenizer("欢迎使用百度飞桨!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
logits = model(**inputs)
"""
outputs = self.ernie_gram(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = paddle.nn.CrossEntropyLoss()
loss = loss_fct(logits.reshape((-1, self.num_labels)), labels.reshape((-1,)))
if not return_dict:
output = (logits,) + outputs[2:]
return tuple_output(output, loss)
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
[文档]
class ErnieGramForQuestionAnswering(ErnieGramPretrainedModel):
"""
ERNIE-Gram Model with a linear layer on top of the hidden-states
output to compute `span_start_logits` and `span_end_logits`,
designed for question-answering tasks like SQuAD..
Args:
config (:class:`ErnieGramConfig`):
An instance of ErnieGramConfig used to construct ErnieGramForQuestionAnswering.
"""
def __init__(self, config: ErnieGramConfig):
super(ErnieGramForQuestionAnswering, self).__init__(config)
self.config = config
self.ernie_gram = ErnieGramModel(config)
self.classifier = nn.Linear(config.hidden_size, 2)
[文档]
def forward(
self,
input_ids: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
start_positions: Optional[Tensor] = None,
end_positions: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Args:
input_ids (Tensor):
See :class:`ErnieGramModel`.
token_type_ids (Tensor, optional):
See :class:`ErnieGramModel`.
position_ids (Tensor, optional):
See :class:`ErnieGramModel`.
attention_mask (Tensor, optional):
See :class:`ErnieGramModel`.
inputs_embeds(Tensor, optional):
See :class:`ErnieGramModel`.
start_positions (Tensor of shape `(batch_size,)`, optional):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (Tensor of shape `(batch_size,)`, optional):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput` object. If
`False`, the output will be a tuple of tensors. Defaults to `False`.
Returns:
tuple: Returns tuple (`start_logits`, `end_logits`).
With the fields:
- `start_logits` (Tensor):
A tensor of the input token classification logits, indicates the start position of the labelled span.
Its data type should be float32 and its shape is [batch_size, sequence_length].
- `end_logits` (Tensor):
A tensor of the input token classification logits, indicates the end position of the labelled span.
Its data type should be float32 and its shape is [batch_size, sequence_length].
Example:
.. code-block::
import paddle
from paddlenlp.transformers import ErnieGramForQuestionAnswering, ErnieGramTokenizer
tokenizer = ErnieGramTokenizer.from_pretrained('ernie-gram-zh')
model = ErnieGramForQuestionAnswering.from_pretrained('ernie-gram-zh')
inputs = tokenizer("欢迎使用百度飞桨!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
logits = model(**inputs)
"""
outputs = self.ernie_gram(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.classifier(outputs[0])
logits = paddle.transpose(logits, perm=[2, 0, 1])
start_logits, end_logits = paddle.unstack(x=logits, axis=0)
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if start_positions.ndim > 1:
start_positions = start_positions.squeeze(-1)
if start_positions.ndim > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = paddle.shape(start_logits)[1]
start_positions = start_positions.clip(0, ignored_index)
end_positions = end_positions.clip(0, ignored_index)
loss_fct = paddle.nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return tuple_output(output, total_loss)
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
[文档]
class ErnieGramForSequenceClassification(ErnieGramPretrainedModel):
r"""
ERNIE-Gram Model with a linear layer on top of the output layer,
designed for sequence classification/regression tasks like GLUE tasks.
Args:
config (:class:`ErnieGramConfig`):
An instance of ErnieGramConfig used to construct ErnieGramForSequenceClassification.
"""
def __init__(self, config: ErnieGramConfig):
super(ErnieGramForSequenceClassification, self).__init__(config)
self.config = config
self.num_labels = config.num_labels
self.ernie_gram = ErnieGramModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
[文档]
def forward(
self,
input_ids: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Args:
input_ids (Tensor):
See :class:`ErnieGramModel`.
token_type_ids (Tensor, optional):
See :class:`ErnieGramModel`.
position_ids (Tensor, optional):
See :class:`ErnieGramModel`.
attention_mask (Tensor, optional):
See :class:`BertModel`.
labels (Tensor of shape `(batch_size,)`, optional):
Labels for computing the sequence classification/regression loss.
Indices should be in `[0, ..., num_labels - 1]`. If `num_labels == 1`
a regression loss is computed (Mean-Square loss), If `num_labels > 1`
a classification loss is computed (Cross-Entropy).
inputs_embeds(Tensor, optional):
See :class:`ErnieGramModel`.
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` object. If
`False`, the output will be a tuple of tensors. Defaults to `False`.
Returns:
Tensor: Returns tensor `logits`, a tensor of the input text classification logits.
Shape as `[batch_size, num_labels]` and dtype as float32.
Example:
.. code-block::
import paddle
from paddlenlp.transformers import ErnieGramForSequenceClassification, ErnieGramTokenizer
tokenizer = ErnieGramTokenizer.from_pretrained('ernie-gram-zh')
model = ErnieGramForSequenceClassification.from_pretrained('ernie-gram-zh')
inputs = tokenizer("欢迎使用百度飞桨!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
logits = model(**inputs)
"""
outputs = self.ernie_gram(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = self.dropout(outputs[1])
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == paddle.int64 or labels.dtype == paddle.int32):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = paddle.nn.MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = paddle.nn.CrossEntropyLoss()
loss = loss_fct(logits.reshape((-1, self.num_labels)), labels.reshape((-1,)))
elif self.config.problem_type == "multi_label_classification":
loss_fct = paddle.nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return tuple_output(output, loss)
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)