# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
from paddlenlp.experimental import FasterPretrainedModel, FasterTokenizer
from paddlenlp.transformers.ernie.modeling import ErnieEmbeddings, ErniePooler
from paddlenlp.transformers.model_utils import register_base_model
__all__ = ["FasterErnieModel", "FasterErnieForSequenceClassification", "FasterErnieForTokenClassification"]
class FasterErniePretrainedModel(FasterPretrainedModel):
r"""
An abstract class for pretrained ERNIE models. It provides ERNIE related
`model_config_file`, `resource_files_names`, `pretrained_resource_files_map`,
`pretrained_init_configuration`, `base_model_prefix` for downloading and
loading pretrained models.
Refer to :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details.
"""
model_config_file = "model_config.json"
pretrained_init_configuration = {
"ernie-1.0": {
"attention_probs_dropout_prob": 0.1,
"hidden_act": "relu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"max_position_embeddings": 513,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"type_vocab_size": 2,
"vocab_size": 18000,
"pad_token_id": 0,
"do_lower_case": True,
},
"ernie-2.0-en": {
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"type_vocab_size": 4,
"vocab_size": 30522,
"pad_token_id": 0,
"do_lower_case": True,
},
"ernie-2.0-en-finetuned-squad": {
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"type_vocab_size": 4,
"vocab_size": 30522,
"pad_token_id": 0,
"do_lower_case": True,
},
"ernie-2.0-large-en": {
"attention_probs_dropout_prob": 0.1,
"intermediate_size": 4096, # special for ernie-2.0-large-en
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"max_position_embeddings": 512,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"type_vocab_size": 4,
"vocab_size": 30522,
"pad_token_id": 0,
"do_lower_case": True,
},
}
resource_files_names = {"model_state": "model_state.pdparams", "vocab_file": "vocab.txt"}
pretrained_resource_files_map = {
"model_state": {
"ernie-1.0": "https://bj.bcebos.com/paddlenlp/models/transformers/faster_ernie/faster_ernie_v1_chn_base.pdparams",
"ernie-2.0-en": "https://bj.bcebos.com/paddlenlp/models/transformers/faster_ernie_v2_base/faster_ernie_v2_eng_base.pdparams",
"ernie-2.0-en-finetuned-squad": "https://bj.bcebos.com/paddlenlp/models/transformers/faster_ernie_v2_base/faster_ernie_v2_eng_base_finetuned_squad.pdparams",
"ernie-2.0-large-en": "https://bj.bcebos.com/paddlenlp/models/transformers/faster_ernie_v2_large/faster_ernie_v2_eng_large.pdparams",
},
"vocab_file": {
"ernie-1.0": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie/vocab.txt",
"ernie-2.0-en": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_base/vocab.txt",
"ernie-2.0-en-finetuned-squad": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_base/vocab.txt",
"ernie-2.0-large-en": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_large/vocab.txt",
},
}
base_model_prefix = "ernie"
def init_weights(self, layer):
"""Initialization hook"""
if isinstance(layer, (nn.Linear, nn.Embedding)):
# only support dygraph, use truncated_normal and make it inplace
# and configurable later
if isinstance(layer.weight, paddle.Tensor):
layer.weight.set_value(
paddle.tensor.normal(
mean=0.0,
std=self.initializer_range
if hasattr(self, "initializer_range")
else self.ernie.config["initializer_range"],
shape=layer.weight.shape,
)
)
elif isinstance(layer, nn.LayerNorm):
layer._epsilon = 1e-12
[文档]@register_base_model
class FasterErnieModel(FasterErniePretrainedModel):
r"""
The bare ERNIE Model transformer outputting raw hidden-states.
This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`.
Refer to the superclass documentation for the generic methods.
This model is also a Paddle `paddle.nn.Layer <https://www.paddlepaddle.org.cn/documentation
/docs/en/api/paddle/fluid/dygraph/layers/Layer_en.html>`__ subclass. Use it as a regular Paddle Layer
and refer to the Paddle documentation for all matter related to general usage and behavior.
Args:
vocab_size (int):
Vocabulary size of `inputs_ids` in `ErnieModel`. Also is the vocab size of token embedding matrix.
Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling `ErnieModel`.
hidden_size (int, optional):
Dimensionality of the embedding layer, encoder layers and pooler layer. Defaults to `768`.
num_hidden_layers (int, optional):
Number of hidden layers in the Transformer encoder. Defaults to `12`.
num_attention_heads (int, optional):
Number of attention heads for each attention layer in the Transformer encoder.
Defaults to `12`.
intermediate_size (int, optional):
Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors
to ff layers are firstly projected from `hidden_size` to `intermediate_size`,
and then projected back to `hidden_size`. Typically `intermediate_size` is larger than `hidden_size`.
Defaults to `3072`.
hidden_act (str, optional):
The non-linear activation function in the feed-forward layer.
``"gelu"``, ``"relu"`` and any other paddle supported activation functions
are supported. Defaults to `"gelu"`.
hidden_dropout_prob (float, optional):
The dropout probability for all fully connected layers in the embeddings and encoder.
Defaults to `0.1`.
attention_probs_dropout_prob (float, optional):
The dropout probability used in MultiHeadAttention in all encoder layers to drop some attention target.
Defaults to `0.1`.
max_position_embeddings (int, optional):
The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input
sequence. Defaults to `512`.
type_vocab_size (int, optional):
The vocabulary size of the `token_type_ids`.
Defaults to `2`.
initializer_range (float, optional):
The standard deviation of the normal initializer for initializing all weight matrices.
Defaults to `0.02`.
.. note::
A normal_initializer initializes weight matrices as normal distributions.
See :meth:`ErniePretrainedModel._init_weights()` for how weights are initialized in `ErnieModel`.
pad_token_id(int, optional):
The index of padding token in the token vocabulary.
Defaults to `0`.
"""
def __init__(
self,
vocab_size,
vocab_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
pad_token_id=0,
do_lower_case=True,
is_split_into_words=False,
max_seq_len=512,
):
super(FasterErnieModel, self).__init__()
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the "
"vocabulary from a pretrained model please use "
"`model = FasterErnieModel.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
)
self.do_lower_case = do_lower_case
self.vocab = self.load_vocabulary(vocab_file)
self.max_seq_len = max_seq_len
self.tokenizer = FasterTokenizer(
self.vocab, do_lower_case=self.do_lower_case, is_split_into_words=is_split_into_words
)
self.pad_token_id = pad_token_id
self.initializer_range = initializer_range
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range))
self.embeddings = ErnieEmbeddings(
vocab_size,
hidden_size,
hidden_dropout_prob,
max_position_embeddings,
type_vocab_size,
pad_token_id,
weight_attr,
)
# Avoid import error in global scope when using paddle <= 2.2.0, therefore
# import FusedTransformerEncoderLayer in local scope.
# FusedTransformerEncoderLayer is supported by paddlepaddle since 2.2.0, please
# ensure the version >= 2.2.0
from paddle.incubate.nn import FusedTransformerEncoderLayer
encoder_layer = FusedTransformerEncoderLayer(
hidden_size,
num_attention_heads,
intermediate_size,
dropout_rate=hidden_dropout_prob,
activation=hidden_act,
attn_dropout_rate=attention_probs_dropout_prob,
act_dropout_rate=0,
weight_attr=weight_attr,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers)
self.pooler = ErniePooler(hidden_size, weight_attr)
self.apply(self.init_weights)
[文档] def forward(self, text, text_pair=None):
input_ids, token_type_ids = self.tokenizer(text=text, text_pair=text_pair, max_seq_len=self.max_seq_len)
attention_mask = paddle.unsqueeze(
(input_ids == self.pad_token_id).astype(self.pooler.dense.weight.dtype) * -1e4, axis=[1, 2]
)
embedding_output = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
encoder_outputs = self.encoder(embedding_output, attention_mask)
sequence_output = encoder_outputs
pooled_output = self.pooler(sequence_output)
return sequence_output, pooled_output
[文档]class FasterErnieForSequenceClassification(FasterErniePretrainedModel):
def __init__(self, ernie, num_classes=2, dropout=None):
super(FasterErnieForSequenceClassification, self).__init__()
self.num_classes = num_classes
self.ernie = ernie # allow ernie to be config
self.dropout = nn.Dropout(dropout if dropout is not None else self.ernie.config["hidden_dropout_prob"])
self.classifier = nn.Linear(self.ernie.config["hidden_size"], num_classes)
self.apply(self.init_weights)
[文档] def forward(self, text, text_pair=None):
_, pooled_output = self.ernie(text, text_pair)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
predictions = paddle.argmax(logits, axis=-1)
return logits, predictions
[文档]class FasterErnieForTokenClassification(FasterErniePretrainedModel):
def __init__(self, ernie, num_classes=2, dropout=None):
super(FasterErnieForTokenClassification, self).__init__()
self.num_classes = num_classes
self.ernie = ernie # allow ernie to be config
self.dropout = nn.Dropout(dropout if dropout is not None else self.ernie.config["hidden_dropout_prob"])
self.classifier = nn.Linear(self.ernie.config["hidden_size"], num_classes)
self.apply(self.init_weights)
[文档] def forward(self, text, text_pair=None):
sequence_output, _ = self.ernie(text, text_pair)
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
predictions = paddle.argmax(logits, axis=-1)
return logits, predictions