# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .. import PretrainedModel, register_base_model
from ..attention_utils import _convert_param_attr_to_list
from .configuration import (
ERNIE_DOC_PRETRAINED_INIT_CONFIGURATION,
ERNIE_DOC_PRETRAINED_RESOURCE_FILES_MAP,
ErnieDocConfig,
)
__all__ = [
"ErnieDocModel",
"ErnieDocPretrainedModel",
"ErnieDocForSequenceClassification",
"ErnieDocForTokenClassification",
"ErnieDocForQuestionAnswering",
]
class PointwiseFFN(nn.Layer):
def __init__(self, d_inner_hid, d_hid, dropout_rate, hidden_act, weight_attr=None, bias_attr=None):
super(PointwiseFFN, self).__init__()
self.linear1 = nn.Linear(d_hid, d_inner_hid, weight_attr, bias_attr=bias_attr)
self.dropout = nn.Dropout(dropout_rate, mode="upscale_in_train")
self.linear2 = nn.Linear(d_inner_hid, d_hid, weight_attr, bias_attr=bias_attr)
self.activation = getattr(F, hidden_act)
def forward(self, x):
return self.linear2(self.dropout(self.activation(self.linear1(x))))
class MultiHeadAttention(nn.Layer):
def __init__(
self,
d_key,
d_value,
d_model,
n_head=1,
r_w_bias=None,
r_r_bias=None,
r_t_bias=None,
dropout_rate=0.0,
weight_attr=None,
bias_attr=None,
):
super(MultiHeadAttention, self).__init__()
self.d_key = d_key
self.d_value = d_value
self.d_model = d_model
self.n_head = n_head
assert d_key * n_head == d_model, "d_model must be divisible by n_head"
self.q_proj = nn.Linear(d_model, d_key * n_head, weight_attr=weight_attr, bias_attr=bias_attr)
self.k_proj = nn.Linear(d_model, d_key * n_head, weight_attr=weight_attr, bias_attr=bias_attr)
self.v_proj = nn.Linear(d_model, d_value * n_head, weight_attr=weight_attr, bias_attr=bias_attr)
self.r_proj = nn.Linear(d_model, d_key * n_head, weight_attr=weight_attr, bias_attr=bias_attr)
self.t_proj = nn.Linear(d_model, d_key * n_head, weight_attr=weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(d_model, d_model, weight_attr=weight_attr, bias_attr=bias_attr)
self.r_w_bias = r_w_bias
self.r_r_bias = r_r_bias
self.r_t_bias = r_t_bias
self.dropout = nn.Dropout(dropout_rate, mode="upscale_in_train") if dropout_rate else None
def __compute_qkv(self, queries, keys, values, rel_pos, rel_task):
q = self.q_proj(queries)
k = self.k_proj(keys)
v = self.v_proj(values)
r = self.r_proj(rel_pos)
t = self.t_proj(rel_task)
return q, k, v, r, t
def __split_heads(self, x, d_model, n_head):
# x shape: [B, T, H]
x = x.reshape(shape=[0, 0, n_head, d_model // n_head])
# shape: [B, N, T, HH]
return paddle.transpose(x=x, perm=[0, 2, 1, 3])
def __rel_shift(self, x, klen=-1):
"""
To perform relative attention, it should relatively shift the attention score matrix
See more details on: https://github.com/kimiyoung/transformer-xl/issues/8#issuecomment-454458852
"""
# input shape: [B, N, T, 2 * T + M]
x_shape = x.shape
x = x.reshape([x_shape[0], x_shape[1], x_shape[3], x_shape[2]])
x = x[:, :, 1:, :]
x = x.reshape([x_shape[0], x_shape[1], x_shape[2], x_shape[3] - 1])
# output shape: [B, N, T, T + M]
return x[:, :, :, :klen]
def __scaled_dot_product_attention(self, q, k, v, r, t, attn_mask):
q_w, q_r, q_t = q
score_w = paddle.matmul(q_w, k, transpose_y=True)
score_r = paddle.matmul(q_r, r, transpose_y=True)
score_r = self.__rel_shift(score_r, k.shape[2])
score_t = paddle.matmul(q_t, t, transpose_y=True)
score = score_w + score_r + score_t
score = score * (self.d_key**-0.5)
if attn_mask is not None:
score += attn_mask
weights = F.softmax(score)
if self.dropout:
weights = self.dropout(weights)
out = paddle.matmul(weights, v)
return out
def __combine_heads(self, x):
if len(x.shape) == 3:
return x
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
# x shape: [B, N, T, HH]
x = paddle.transpose(x, [0, 2, 1, 3])
# target shape:[B, T, H]
return x.reshape([0, 0, x.shape[2] * x.shape[3]])
def forward(self, queries, keys, values, rel_pos, rel_task, memory, attn_mask):
if memory is not None and len(memory.shape) > 1:
cat = paddle.concat([memory, queries], 1)
else:
cat = queries
keys, values = cat, cat
if not (
len(queries.shape)
== len(keys.shape)
== len(values.shape)
== len(rel_pos.shape)
== len(rel_task.shape)
== 3
):
raise ValueError("Inputs: quries, keys, values, rel_pos and rel_task should all be 3-D tensors.")
q, k, v, r, t = self.__compute_qkv(queries, keys, values, rel_pos, rel_task)
q_w, q_r, q_t = list(map(lambda x: q + x.unsqueeze([0, 1]), [self.r_w_bias, self.r_r_bias, self.r_t_bias]))
q_w, q_r, q_t = list(map(lambda x: self.__split_heads(x, self.d_model, self.n_head), [q_w, q_r, q_t]))
k, v, r, t = list(map(lambda x: self.__split_heads(x, self.d_model, self.n_head), [k, v, r, t]))
ctx_multiheads = self.__scaled_dot_product_attention([q_w, q_r, q_t], k, v, r, t, attn_mask)
out = self.__combine_heads(ctx_multiheads)
out = self.out_proj(out)
return out
class ErnieDocEncoderLayer(nn.Layer):
def __init__(
self,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
normalize_before=False,
epsilon=1e-5,
rel_pos_params_sharing=False,
r_w_bias=None,
r_r_bias=None,
r_t_bias=None,
weight_attr=None,
bias_attr=None,
):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
super(ErnieDocEncoderLayer, self).__init__()
if not rel_pos_params_sharing:
r_w_bias, r_r_bias, r_t_bias = list(
map(
lambda x: self.create_parameter(shape=[n_head * d_key], dtype="float32"),
["r_w_bias", "r_r_bias", "r_t_bias"],
)
)
weight_attrs = _convert_param_attr_to_list(weight_attr, 2)
bias_attrs = _convert_param_attr_to_list(bias_attr, 2)
self.attn = MultiHeadAttention(
d_key,
d_value,
d_model,
n_head,
r_w_bias,
r_r_bias,
r_t_bias,
attention_dropout,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0],
)
self.ffn = PointwiseFFN(
d_inner_hid, d_model, relu_dropout, hidden_act, weight_attr=weight_attrs[1], bias_attr=bias_attrs[1]
)
self.norm1 = nn.LayerNorm(d_model, epsilon=epsilon)
self.norm2 = nn.LayerNorm(d_model, epsilon=epsilon)
self.dropout1 = nn.Dropout(prepostprocess_dropout, mode="upscale_in_train")
self.dropout2 = nn.Dropout(prepostprocess_dropout, mode="upscale_in_train")
self.d_model = d_model
self.epsilon = epsilon
self.normalize_before = normalize_before
def forward(self, enc_input, memory, rel_pos, rel_task, attn_mask):
residual = enc_input
if self.normalize_before:
enc_input = self.norm1(enc_input)
attn_output = self.attn(enc_input, enc_input, enc_input, rel_pos, rel_task, memory, attn_mask)
attn_output = residual + self.dropout1(attn_output)
if not self.normalize_before:
attn_output = self.norm1(attn_output)
residual = attn_output
if self.normalize_before:
attn_output = self.norm2(attn_output)
ffn_output = self.ffn(attn_output)
output = residual + self.dropout2(ffn_output)
if not self.normalize_before:
output = self.norm2(output)
return output
class ErnieDocEncoder(nn.Layer):
def __init__(self, num_layers, encoder_layer, mem_len):
super(ErnieDocEncoder, self).__init__()
self.layers = nn.LayerList(
[(encoder_layer if i == 0 else type(encoder_layer)(**encoder_layer._config)) for i in range(num_layers)]
)
self.num_layers = num_layers
self.normalize_before = self.layers[0].normalize_before
self.mem_len = mem_len
def _cache_mem(self, curr_out, prev_mem):
if self.mem_len is None or self.mem_len == 0:
return None
if prev_mem is None:
new_mem = curr_out[:, -self.mem_len :, :]
else:
new_mem = paddle.concat([prev_mem, curr_out], 1)[:, -self.mem_len :, :]
new_mem.stop_gradient = True
return new_mem
def forward(self, enc_input, memories, rel_pos, rel_task, attn_mask):
# no need to normalize enc_input, cause it's already normalized outside.
new_mem = []
for i, encoder_layer in enumerate(self.layers):
enc_input = encoder_layer(enc_input, memories[i], rel_pos, rel_task, attn_mask)
new_mem += [self._cache_mem(enc_input, memories[i])]
# free the old memories explicitly to save gpu memory
memories[i] = None
return enc_input, new_mem
[docs]
class ErnieDocPretrainedModel(PretrainedModel):
"""
An abstract class for pretrained ErnieDoc models. It provides ErnieDoc related
`model_config_file`, `pretrained_init_configuration`, `resource_files_names`,
`pretrained_resource_files_map`, `base_model_prefix` for downloading
and loading pretrained models.
See :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details.
"""
base_model_prefix = "ernie_doc"
config_class = ErnieDocConfig
resource_files_names = {"model_state": "model_state.pdparams"}
pretrained_init_configuration = ERNIE_DOC_PRETRAINED_INIT_CONFIGURATION
pretrained_resource_files_map = ERNIE_DOC_PRETRAINED_RESOURCE_FILES_MAP
def _init_weights(self, layer):
# Initialization hook
if isinstance(layer, (nn.Linear, nn.Embedding)):
# In the dygraph mode, use the `set_value` to reset the parameter directly,
# and reset the `state_dict` to update parameter in static mode.
if isinstance(layer.weight, paddle.Tensor):
layer.weight.set_value(
paddle.tensor.normal(
mean=0.0,
std=self.config.initializer_range,
shape=layer.weight.shape,
)
)
class ErnieDocEmbeddings(nn.Layer):
def __init__(self, config: ErnieDocConfig):
super(ErnieDocEmbeddings, self).__init__()
self.word_emb = nn.Embedding(config.vocab_size, config.hidden_size)
self.pos_emb = nn.Embedding(config.max_position_embeddings * 2 + config.memory_len, config.hidden_size)
self.token_type_emb = nn.Embedding(config.task_type_vocab_size, config.hidden_size)
self.memory_len = config.memory_len
self.dropouts = nn.LayerList([nn.Dropout(config.hidden_dropout_prob) for i in range(3)])
self.norms = nn.LayerList([nn.LayerNorm(config.hidden_size) for i in range(3)])
def forward(self, input_ids, token_type_ids, position_ids):
# input_embeddings: [B, T, H]
input_embeddings = self.word_emb(input_ids.squeeze(-1))
# position_embeddings: [B, 2 * T + M, H]
position_embeddings = self.pos_emb(position_ids.squeeze(-1))
batch_size = input_ids.shape[0]
token_type_ids = paddle.concat(
[
paddle.zeros(shape=[batch_size, self.memory_len, 1], dtype="int64") + token_type_ids[0, 0, 0],
token_type_ids,
],
axis=1,
)
token_type_ids.stop_gradient = True
# token_type_embeddings: [B, M + T, H]
token_type_embeddings = self.token_type_emb(token_type_ids.squeeze(-1))
embs = [input_embeddings, position_embeddings, token_type_embeddings]
for i in range(len(embs)):
embs[i] = self.dropouts[i](self.norms[i](embs[i]))
return embs
class ErnieDocPooler(nn.Layer):
"""
get pool output
"""
def __init__(self, config: ErnieDocConfig):
super(ErnieDocPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
self.cls_token_idx = config.cls_token_idx
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the last token.
cls_token_tensor = hidden_states[:, self.cls_token_idx]
pooled_output = self.dense(cls_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
[docs]
@register_base_model
class ErnieDocModel(ErnieDocPretrainedModel):
def __init__(self, config: ErnieDocConfig):
super(ErnieDocModel, self).__init__(config)
r_w_bias, r_r_bias, r_t_bias = None, None, None
if config.rel_pos_params_sharing:
r_w_bias, r_r_bias, r_t_bias = list(
map(
lambda x: self.create_parameter(shape=[config.num_attention_heads * d_key], dtype="float32"),
["r_w_bias", "r_r_bias", "r_t_bias"],
)
)
d_key = config.hidden_size // config.num_attention_heads
d_value = config.hidden_size // config.num_attention_heads
d_inner_hid = config.hidden_size * 4
encoder_layer = ErnieDocEncoderLayer(
config.num_attention_heads,
d_key,
d_value,
config.hidden_size,
d_inner_hid,
config.hidden_dropout_prob,
config.attention_dropout_prob,
config.relu_dropout,
config.hidden_act,
normalize_before=config.normalize_before,
epsilon=config.epsilon,
rel_pos_params_sharing=config.rel_pos_params_sharing,
r_w_bias=r_w_bias,
r_r_bias=r_r_bias,
r_t_bias=r_t_bias,
)
self.initializer_range = config.initializer_range
self.n_head = config.num_attention_heads
self.hidden_size = config.hidden_size
self.memory_len = config.memory_len
self.encoder = ErnieDocEncoder(config.num_hidden_layers, encoder_layer, config.memory_len)
self.pad_token_id = config.pad_token_id
self.embeddings = ErnieDocEmbeddings(config)
self.pooler = ErnieDocPooler(config)
def _create_n_head_attn_mask(self, attn_mask, batch_size):
# attn_mask shape: [B, T, 1]
# concat an data_mask, shape: [B, M + T, 1]
data_mask = paddle.concat(
[paddle.ones(shape=[batch_size, self.memory_len, 1], dtype=attn_mask.dtype), attn_mask], axis=1
)
data_mask.stop_gradient = True
# create a self_attn_mask, shape: [B, T, M + T]
self_attn_mask = paddle.matmul(attn_mask, data_mask, transpose_y=True)
self_attn_mask = (self_attn_mask - 1) * 1e8
n_head_self_attn_mask = paddle.stack([self_attn_mask] * self.n_head, axis=1)
n_head_self_attn_mask.stop_gradient = True
return n_head_self_attn_mask
[docs]
def forward(self, input_ids, memories, token_type_ids, position_ids, attn_mask):
r"""
The ErnieDocModel forward method, overrides the `__call__()` special method.
Args:
input_ids (Tensor):
Indices of input sequence tokens in the vocabulary. They are
numerical representations of tokens that build the input sequence.
It's data type should be `int64` and has a shape of [batch_size, sequence_length, 1].
memories (List[Tensor]):
A list of length `n_layers` with each Tensor being a pre-computed hidden-state for each layer.
Each Tensor has a dtype `float32` and a shape of [batch_size, sequence_length, hidden_size].
token_type_ids (Tensor):
Segment token indices to indicate first and second portions of the inputs.
Indices can be either 0 or 1:
- 0 corresponds to a **sentence A** token,
- 1 corresponds to a **sentence B** token.
It's data type should be `int64` and has a shape of [batch_size, sequence_length, 1].
Defaults to None, which means no segment embeddings is added to token embeddings.
position_ids (Tensor):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``. Shape as `(batch_sie, num_tokens)` and dtype as `int32` or `int64`.
attn_mask (Tensor):
Mask used in multi-head attention to avoid performing attention on to some unwanted positions,
usually the paddings or the subsequent positions.
Its data type can be int, float and bool.
When the data type is bool, the `masked` tokens have `False` values and the others have `True` values.
When the data type is int, the `masked` tokens have `0` values and the others have `1` values.
When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values.
It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`.
For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length],
[batch_size, num_attention_heads, sequence_length, sequence_length].
We use whole-word-mask in ERNIE, so the whole word will have the same value. For example, "使用" as a word,
"使" and "用" will have the same value.
Defaults to `None`, which means nothing needed to be prevented attention to.
Returns:
tuple : Returns tuple (``encoder_output``, ``pooled_output``, ``new_mem``).
With the fields:
- `encoder_output` (Tensor):
Sequence of hidden-states at the last layer of the model.
It's data type should be float32 and its shape is [batch_size, sequence_length, hidden_size].
- `pooled_output` (Tensor):
The output of first token (`[CLS]`) in sequence.
We "pool" the model by simply taking the hidden state corresponding to the first token.
Its data type should be float32 and its shape is [batch_size, hidden_size].
- `new_mem` (List[Tensor]):
A list of pre-computed hidden-states. The length of the list is `n_layers`.
Each element in the list is a Tensor with dtype `float32` and shape as [batch_size, memory_length, hidden_size].
Example:
.. code-block::
import numpy as np
import paddle
from paddlenlp.transformers import ErnieDocModel
from paddlenlp.transformers import ErnieDocTokenizer
def get_related_pos(insts, seq_len, memory_len=128):
beg = seq_len + seq_len + memory_len
r_position = [list(range(beg - 1, seq_len - 1, -1)) + \
list(range(0, seq_len)) for i in range(len(insts))]
return np.array(r_position).astype('int64').reshape([len(insts), beg, 1])
tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh')
model = ErnieDocModel.from_pretrained('ernie-doc-base-zh')
inputs = tokenizer("欢迎使用百度飞桨!")
inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()}
memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)]
position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128))
attn_mask = paddle.ones([1, 128, 1])
inputs['memories'] = memories
inputs['position_ids'] = position_ids
inputs['attn_mask'] = attn_mask
outputs = model(**inputs)
encoder_output = outputs[0]
pooled_output = outputs[1]
new_mem = outputs[2]
"""
input_embeddings, position_embeddings, token_embeddings = self.embeddings(
input_ids, token_type_ids, position_ids
)
batch_size = input_embeddings.shape[0]
# [B, N, T, M + T]
n_head_self_attn_mask = self._create_n_head_attn_mask(attn_mask, batch_size)
# memories contains n_layer memory whose shape is [B, M, H]
encoder_output, new_mem = self.encoder(
enc_input=input_embeddings,
memories=memories,
rel_pos=position_embeddings,
rel_task=token_embeddings,
attn_mask=n_head_self_attn_mask,
)
pooled_output = self.pooler(encoder_output)
return encoder_output, pooled_output, new_mem
[docs]
class ErnieDocForSequenceClassification(ErnieDocPretrainedModel):
"""
ErnieDoc Model with a linear layer on top of the output layer,
designed for sequence classification/regression tasks like GLUE tasks.
Args:
config (:class:`ErnieDocConfig`):
An instance of ErnieDocConfig used to construct ErnieDocForSequenceClassification.
"""
def __init__(self, config: ErnieDocConfig):
super(ErnieDocForSequenceClassification, self).__init__(config)
self.ernie_doc = ErnieDocModel(config)
self.num_labels = config.num_labels
self.dropout = nn.Dropout(
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob,
mode="upscale_in_train",
)
self.linear = nn.Linear(config.hidden_size, config.num_labels)
[docs]
def forward(self, input_ids, memories, token_type_ids, position_ids, attn_mask):
r"""
The ErnieDocForSequenceClassification forward method, overrides the `__call__()` special method.
Args:
input_ids (Tensor):
See :class:`ErnieDocModel`.
memories (List[Tensor]):
See :class:`ErnieDocModel`.
token_type_ids (Tensor):
See :class:`ErnieDocModel`.
position_ids (Tensor):
See :class:`ErnieDocModel`.
attn_mask (Tensor):
See :class:`ErnieDocModel`.
Returns:
tuple : Returns tuple (`logits`, `mem`).
With the fields:
- `logits` (Tensor):
A tensor containing the [CLS] of hidden-states of the model at the output of last layer.
Each Tensor has a data type of `float32` and has a shape of [batch_size, num_labels].
- `mem` (List[Tensor]):
A list of pre-computed hidden-states. The length of the list is `n_layers`.
Each element in the list is a Tensor with dtype `float32` and has a shape of
[batch_size, memory_length, hidden_size].
Example:
.. code-block::
import numpy as np
import paddle
from paddlenlp.transformers import ErnieDocForSequenceClassification
from paddlenlp.transformers import ErnieDocTokenizer
def get_related_pos(insts, seq_len, memory_len=128):
beg = seq_len + seq_len + memory_len
r_position = [list(range(beg - 1, seq_len - 1, -1)) + \
list(range(0, seq_len)) for i in range(len(insts))]
return np.array(r_position).astype('int64').reshape([len(insts), beg, 1])
tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh')
model = ErnieDocForSequenceClassification.from_pretrained('ernie-doc-base-zh', num_labels=2)
inputs = tokenizer("欢迎使用百度飞桨!")
inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()}
memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)]
position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128))
attn_mask = paddle.ones([1, 128, 1])
inputs['memories'] = memories
inputs['position_ids'] = position_ids
inputs['attn_mask'] = attn_mask
outputs = model(**inputs)
logits = outputs[0]
mem = outputs[1]
"""
_, pooled_output, mem = self.ernie_doc(input_ids, memories, token_type_ids, position_ids, attn_mask)
pooled_output = self.dropout(pooled_output)
logits = self.linear(pooled_output)
return logits, mem
[docs]
class ErnieDocForTokenClassification(ErnieDocPretrainedModel):
"""
ErnieDoc Model with a linear layer on top of the hidden-states output layer,
designed for token classification tasks like NER tasks.
Args:
config (:class:`ErnieDocConfig`):
An instance of ErnieDocConfig used to construct ErnieDocForTokenClassification.
"""
def __init__(self, config: ErnieDocConfig):
super(ErnieDocForTokenClassification, self).__init__(config)
self.num_labels = config.num_labels
self.ernie_doc = ErnieDocModel(config)
self.dropout = nn.Dropout(
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob,
mode="upscale_in_train",
)
self.linear = nn.Linear(config.hidden_size, self.num_labels)
[docs]
def forward(self, input_ids, memories, token_type_ids, position_ids, attn_mask):
r"""
The ErnieDocForTokenClassification forward method, overrides the `__call__()` special method.
Args:
input_ids (Tensor):
See :class:`ErnieDocModel`.
memories (List[Tensor]):
See :class:`ErnieDocModel`.
token_type_ids (Tensor):
See :class:`ErnieDocModel`.
Defaults to None, which means no segment embeddings is added to token embeddings.
position_ids (Tensor):
See :class:`ErnieDocModel`.
attn_mask (Tensor):
See :class:`ErnieDocModel`.
Returns:
tuple : Returns tuple (`logits`, `mem`).
With the fields:
- `logits` (Tensor):
A tensor containing the hidden-states of the model at the output of last layer.
Each Tensor has a data type of `float32` and has a shape of [batch_size, sequence_length, num_labels].
- `mem` (List[Tensor]):
A list of pre-computed hidden-states. The length of the list is `n_layers`.
Each element in the list is a Tensor with dtype `float32` and has a shape of
[batch_size, memory_length, hidden_size].
Example:
.. code-block::
import numpy as np
import paddle
from paddlenlp.transformers import ErnieDocForTokenClassification
from paddlenlp.transformers import ErnieDocTokenizer
def get_related_pos(insts, seq_len, memory_len=128):
beg = seq_len + seq_len + memory_len
r_position = [list(range(beg - 1, seq_len - 1, -1)) + \
list(range(0, seq_len)) for i in range(len(insts))]
return np.array(r_position).astype('int64').reshape([len(insts), beg, 1])
tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh')
model = ErnieDocForTokenClassification.from_pretrained('ernie-doc-base-zh', num_labels=2)
inputs = tokenizer("欢迎使用百度飞桨!")
inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()}
memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)]
position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128))
attn_mask = paddle.ones([1, 128, 1])
inputs['memories'] = memories
inputs['position_ids'] = position_ids
inputs['attn_mask'] = attn_mask
outputs = model(**inputs)
logits = outputs[0]
mem = outputs[1]
"""
sequence_output, _, mem = self.ernie_doc(input_ids, memories, token_type_ids, position_ids, attn_mask)
sequence_output = self.dropout(sequence_output)
logits = self.linear(sequence_output)
return logits, mem
[docs]
class ErnieDocForQuestionAnswering(ErnieDocPretrainedModel):
"""
ErnieDoc Model with a linear layer on top of the hidden-states
output to compute `span_start_logits` and `span_end_logits`,
designed for question-answering tasks like SQuAD.
Args:
config (:class:`ErnieDocConfig`):
An instance of ErnieDocConfig used to construct ErnieDocForQuestionAnswering.
"""
def __init__(self, config: ErnieDocConfig):
super(ErnieDocForQuestionAnswering, self).__init__(config)
self.ernie_doc = ErnieDocModel(config)
self.dropout = nn.Dropout(
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob,
mode="upscale_in_train",
)
self.linear = nn.Linear(config.hidden_size, 2)
[docs]
def forward(self, input_ids, memories, token_type_ids, position_ids, attn_mask):
r"""
The ErnieDocForQuestionAnswering forward method, overrides the `__call__()` special method.
Args:
input_ids (Tensor):
See :class:`ErnieDocModel`.
memories (List[Tensor]):
See :class:`ErnieDocModel`.
token_type_ids (Tensor):
See :class:`ErnieDocModel`.
position_ids (Tensor):
See :class:`ErnieDocModel`.
attn_mask (Tensor):
See :class:`ErnieDocModel`.
Returns:
tuple : Returns tuple (`start_logits`, `end_logits`, `mem`).
With the fields:
- `start_logits` (Tensor):
A tensor of the input token classification logits, indicates the start position of the labelled span.
Its data type should be float32 and its shape is [batch_size, sequence_length].
- `end_logits` (Tensor):
A tensor of the input token classification logits, indicates the end position of the labelled span.
Its data type should be float32 and its shape is [batch_size, sequence_length].
- `mem` (List[Tensor]):
A list of pre-computed hidden-states. The length of the list is `n_layers`.
Each element in the list is a Tensor with dtype `float32` and has a shape of
[batch_size, memory_length, hidden_size].
Example:
.. code-block::
import numpy as np
import paddle
from paddlenlp.transformers import ErnieDocForQuestionAnswering
from paddlenlp.transformers import ErnieDocTokenizer
def get_related_pos(insts, seq_len, memory_len=128):
beg = seq_len + seq_len + memory_len
r_position = [list(range(beg - 1, seq_len - 1, -1)) + \
list(range(0, seq_len)) for i in range(len(insts))]
return np.array(r_position).astype('int64').reshape([len(insts), beg, 1])
tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh')
model = ErnieDocForQuestionAnswering.from_pretrained('ernie-doc-base-zh')
inputs = tokenizer("欢迎使用百度飞桨!")
inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()}
memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)]
position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128))
attn_mask = paddle.ones([1, 128, 1])
inputs['memories'] = memories
inputs['position_ids'] = position_ids
inputs['attn_mask'] = attn_mask
outputs = model(**inputs)
start_logits = outputs[0]
end_logits = outputs[1]
mem = outputs[2]
"""
sequence_output, _, mem = self.ernie_doc(input_ids, memories, token_type_ids, position_ids, attn_mask)
sequence_output = self.dropout(sequence_output)
logits = self.linear(sequence_output)
start_logits, end_logits = paddle.transpose(logits, perm=[2, 0, 1])
return start_logits, end_logits, mem