modeling¶
-
class
ErnieDocModel
(num_hidden_layers, num_attention_heads, hidden_size, hidden_dropout_prob, attention_dropout_prob, relu_dropout, hidden_act, memory_len, vocab_size, max_position_embeddings, task_type_vocab_size=3, normalize_before=False, epsilon=1e-05, rel_pos_params_sharing=False, initializer_range=0.02, pad_token_id=0, cls_token_idx=- 1)[源代码]¶ 基类:
paddlenlp.transformers.ernie_doc.modeling.ErnieDocPretrainedModel
The bare ERNIE-Doc Model outputting raw hidden-states.
This model inherits from
PretrainedModel
. Refer to the superclass documentation for the generic methods.This model is also a paddle.nn.Layer subclass. Use it as a regular Paddle Layer and refer to the Paddle documentation for all matter related to general usage and behavior.
- 参数
num_hidden_layers (int) -- The number of hidden layers in the Transformer encoder.
num_attention_heads (int) -- Number of attention heads for each attention layer in the Transformer encoder.
hidden_size (int) -- Dimensionality of the embedding layers, encoder layers and pooler layer.
hidden_dropout_prob (int) -- The dropout probability for all fully connected layers in the embeddings and encoder.
attention_dropout_prob (int) -- The dropout probability used in MultiHeadAttention in all encoder layers to drop some attention target.
relu_dropout (int) -- The dropout probability of FFN.
hidden_act (str) -- The non-linear activation function of FFN.
memory_len (int) -- The number of tokens to cache. If not 0, the last
memory_len
hidden states in each layer will be cached into memory.vocab_size (int) -- Vocabulary size of
inputs_ids
inErnieDocModel
. Also is the vocab size of token embedding matrix. Defines the number of different tokens that can be represented by theinputs_ids
passed when callingErnieDocModel
.max_position_embeddings (int) -- The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input sequence. Defaults to
512
.task_type_vocab_size (int, optional) -- The vocabulary size of the
token_type_ids
. Defaults to3
.normalize_before (bool, optional) -- Indicate whether to put layer normalization into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer normalization and post-precess includes dropout, residual connection. Otherwise, no pre-process and post-precess includes dropout, residual connection, layer normalization. Defaults to
False
.epsilon (float, optional) -- The
epsilon
parameter used inpaddle.nn.LayerNorm
for initializing layer normalization layers. Defaults to1e-5
.rel_pos_params_sharing (bool, optional) -- Whether to share the relative position parameters. Defaults to
False
.initializer_range (float, optional) -- The standard deviation of the normal initializer for initializing all weight matrices. Defaults to
0.02
.pad_token_id (int, optional) -- The token id of [PAD] token whose parameters won't be updated when training. Defaults to
0
.cls_token_idx (int, optional) -- The token id of [CLS] token. Defaults to
-1
.
-
forward
(input_ids, memories, token_type_ids, position_ids, attn_mask)[源代码]¶ The ErnieDocModel forward method, overrides the
__call__()
special method.- 参数
input_ids (Tensor) -- Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. It's data type should be
int64
and has a shape of [batch_size, sequence_length, 1].memories (List[Tensor]) -- A list of length
n_layers
with each Tensor being a pre-computed hidden-state for each layer. Each Tensor has a dtypefloat32
and a shape of [batch_size, sequence_length, hidden_size].token_type_ids (Tensor) --
Segment token indices to indicate first and second portions of the inputs. Indices can be either 0 or 1:
0 corresponds to a sentence A token,
1 corresponds to a sentence B token.
It's data type should be
int64
and has a shape of [batch_size, sequence_length, 1]. Defaults to None, which means no segment embeddings is added to token embeddings.position_ids (Tensor) -- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
[0, config.max_position_embeddings - 1]
. Shape as(batch_sie, num_tokens)
and dtype asint32
orint64
.attn_mask (Tensor) -- Mask used in multi-head attention to avoid performing attention on to some unwanted positions, usually the paddings or the subsequent positions. Its data type can be int, float and bool. When the data type is bool, the
masked
tokens haveFalse
values and the others haveTrue
values. When the data type is int, themasked
tokens have0
values and the others have1
values. When the data type is float, themasked
tokens have-INF
values and the others have0
values. It is a tensor with shape broadcasted to[batch_size, num_attention_heads, sequence_length, sequence_length]
. For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length], [batch_size, num_attention_heads, sequence_length, sequence_length]. We use whole-word-mask in ERNIE, so the whole word will have the same value. For example, "使用" as a word, "使" and "用" will have the same value. Defaults toNone
, which means nothing needed to be prevented attention to.
- 返回
Returns tuple (
encoder_output
,pooled_output
,new_mem
).With the fields:
encoder_output
(Tensor):Sequence of hidden-states at the last layer of the model. It's data type should be float32 and its shape is [batch_size, sequence_length, hidden_size].
pooled_output
(Tensor):The output of first token (
[CLS]
) in sequence. We "pool" the model by simply taking the hidden state corresponding to the first token. Its data type should be float32 and its shape is [batch_size, hidden_size].
new_mem
(List[Tensor]):A list of pre-computed hidden-states. The length of the list is
n_layers
. Each element in the list is a Tensor with dtypefloat32
and shape as [batch_size, memory_length, hidden_size].
- 返回类型
tuple
示例
import numpy as np import paddle from paddlenlp.transformers import ErnieDocModel from paddlenlp.transformers import ErnieDocTokenizer def get_related_pos(insts, seq_len, memory_len=128): beg = seq_len + seq_len + memory_len r_position = [list(range(beg - 1, seq_len - 1, -1)) + \ list(range(0, seq_len)) for i in range(len(insts))] return np.array(r_position).astype('int64').reshape([len(insts), beg, 1]) tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh') model = ErnieDocModel.from_pretrained('ernie-doc-base-zh') inputs = tokenizer("欢迎使用百度飞桨!") inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()} memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)] position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128)) attn_mask = paddle.ones([1, 128, 1]) inputs['memories'] = memories inputs['position_ids'] = position_ids inputs['attn_mask'] = attn_mask outputs = model(**inputs) encoder_output = outputs[0] pooled_output = outputs[1] new_mem = outputs[2]
-
class
ErnieDocPretrainedModel
(*args, **kwargs)[源代码]¶ 基类:
paddlenlp.transformers.model_utils.PretrainedModel
An abstract class for pretrained ErnieDoc models. It provides ErnieDoc related
model_config_file
,pretrained_init_configuration
,resource_files_names
,pretrained_resource_files_map
,base_model_prefix
for downloading and loading pretrained models. SeePretrainedModel
for more details.-
base_model_class
¶ alias of
paddlenlp.transformers.ernie_doc.modeling.ErnieDocModel
-
-
class
ErnieDocForSequenceClassification
(ernie_doc, num_classes=2, dropout=0.1)[源代码]¶ 基类:
paddlenlp.transformers.ernie_doc.modeling.ErnieDocPretrainedModel
ErnieDoc Model with a linear layer on top of the output layer, designed for sequence classification/regression tasks like GLUE tasks.
- 参数
ernie_doc (
ErnieDocModel
) -- An instance ofErnieDocModel
.num_classes (int) -- The number of classes.
dropout (float, optional) -- The dropout ratio of last output. Default to
0.1
.
-
forward
(input_ids, memories, token_type_ids, position_ids, attn_mask)[源代码]¶ The ErnieDocForSequenceClassification forward method, overrides the
__call__()
special method.- 参数
input_ids (Tensor) -- See
ErnieDocModel
.memories (List[Tensor]) -- See
ErnieDocModel
.token_type_ids (Tensor) -- See
ErnieDocModel
.position_ids (Tensor) -- See
ErnieDocModel
.attn_mask (Tensor) -- See
ErnieDocModel
.
- 返回
Returns tuple (
logits
,mem
).With the fields:
logits
(Tensor):A tensor containing the [CLS] of hidden-states of the model at the output of last layer. Each Tensor has a data type of
float32
and has a shape of [batch_size, num_classes].
mem
(List[Tensor]):A list of pre-computed hidden-states. The length of the list is
n_layers
. Each element in the list is a Tensor with dtypefloat32
and has a shape of [batch_size, memory_length, hidden_size].
- 返回类型
tuple
示例
import numpy as np import paddle from paddlenlp.transformers import ErnieDocForSequenceClassification from paddlenlp.transformers import ErnieDocTokenizer def get_related_pos(insts, seq_len, memory_len=128): beg = seq_len + seq_len + memory_len r_position = [list(range(beg - 1, seq_len - 1, -1)) + \ list(range(0, seq_len)) for i in range(len(insts))] return np.array(r_position).astype('int64').reshape([len(insts), beg, 1]) tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh') model = ErnieDocForSequenceClassification.from_pretrained('ernie-doc-base-zh', num_classes=2) inputs = tokenizer("欢迎使用百度飞桨!") inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()} memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)] position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128)) attn_mask = paddle.ones([1, 128, 1]) inputs['memories'] = memories inputs['position_ids'] = position_ids inputs['attn_mask'] = attn_mask outputs = model(**inputs) logits = outputs[0] mem = outputs[1]
-
class
ErnieDocForTokenClassification
(ernie_doc, num_classes=2, dropout=0.1)[源代码]¶ 基类:
paddlenlp.transformers.ernie_doc.modeling.ErnieDocPretrainedModel
ErnieDoc Model with a linear layer on top of the hidden-states output layer, designed for token classification tasks like NER tasks.
- 参数
ernie_doc (
ErnieDocModel
) -- An instance ofErnieDocModel
.num_classes (int) -- The number of classes.
dropout (float, optional) -- The dropout ratio of last output. Default to 0.1.
-
forward
(input_ids, memories, token_type_ids, position_ids, attn_mask)[源代码]¶ The ErnieDocForTokenClassification forward method, overrides the
__call__()
special method.- 参数
input_ids (Tensor) -- See
ErnieDocModel
.memories (List[Tensor]) -- See
ErnieDocModel
.token_type_ids (Tensor) -- See
ErnieDocModel
. Defaults to None, which means no segment embeddings is added to token embeddings.position_ids (Tensor) -- See
ErnieDocModel
.attn_mask (Tensor) -- See
ErnieDocModel
.
- 返回
Returns tuple (
logits
,mem
).With the fields:
logits
(Tensor):A tensor containing the hidden-states of the model at the output of last layer. Each Tensor has a data type of
float32
and has a shape of [batch_size, sequence_length, num_classes].
mem
(List[Tensor]):A list of pre-computed hidden-states. The length of the list is
n_layers
. Each element in the list is a Tensor with dtypefloat32
and has a shape of [batch_size, memory_length, hidden_size].
- 返回类型
tuple
示例
import numpy as np import paddle from paddlenlp.transformers import ErnieDocForTokenClassification from paddlenlp.transformers import ErnieDocTokenizer def get_related_pos(insts, seq_len, memory_len=128): beg = seq_len + seq_len + memory_len r_position = [list(range(beg - 1, seq_len - 1, -1)) + \ list(range(0, seq_len)) for i in range(len(insts))] return np.array(r_position).astype('int64').reshape([len(insts), beg, 1]) tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh') model = ErnieDocForTokenClassification.from_pretrained('ernie-doc-base-zh', num_classes=2) inputs = tokenizer("欢迎使用百度飞桨!") inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()} memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)] position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128)) attn_mask = paddle.ones([1, 128, 1]) inputs['memories'] = memories inputs['position_ids'] = position_ids inputs['attn_mask'] = attn_mask outputs = model(**inputs) logits = outputs[0] mem = outputs[1]
-
class
ErnieDocForQuestionAnswering
(ernie_doc, dropout=0.1)[源代码]¶ 基类:
paddlenlp.transformers.ernie_doc.modeling.ErnieDocPretrainedModel
ErnieDoc Model with a linear layer on top of the hidden-states output to compute
span_start_logits
andspan_end_logits
, designed for question-answering tasks like SQuAD.- 参数
ernie_doc (
ErnieDocModel
) -- An instance ofErnieDocModel
.dropout (float, optional) -- The dropout ratio of last output. Default to 0.1.
-
forward
(input_ids, memories, token_type_ids, position_ids, attn_mask)[源代码]¶ The ErnieDocForQuestionAnswering forward method, overrides the
__call__()
special method.- 参数
input_ids (Tensor) -- See
ErnieDocModel
.memories (List[Tensor]) -- See
ErnieDocModel
.token_type_ids (Tensor) -- See
ErnieDocModel
.position_ids (Tensor) -- See
ErnieDocModel
.attn_mask (Tensor) -- See
ErnieDocModel
.
- 返回
Returns tuple (
start_logits
,end_logits
,mem
).With the fields:
start_logits
(Tensor):A tensor of the input token classification logits, indicates the start position of the labelled span. Its data type should be float32 and its shape is [batch_size, sequence_length].
end_logits
(Tensor):A tensor of the input token classification logits, indicates the end position of the labelled span. Its data type should be float32 and its shape is [batch_size, sequence_length].
mem
(List[Tensor]):A list of pre-computed hidden-states. The length of the list is
n_layers
. Each element in the list is a Tensor with dtypefloat32
and has a shape of [batch_size, memory_length, hidden_size].
- 返回类型
tuple
示例
import numpy as np import paddle from paddlenlp.transformers import ErnieDocForQuestionAnswering from paddlenlp.transformers import ErnieDocTokenizer def get_related_pos(insts, seq_len, memory_len=128): beg = seq_len + seq_len + memory_len r_position = [list(range(beg - 1, seq_len - 1, -1)) + \ list(range(0, seq_len)) for i in range(len(insts))] return np.array(r_position).astype('int64').reshape([len(insts), beg, 1]) tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh') model = ErnieDocForQuestionAnswering.from_pretrained('ernie-doc-base-zh') inputs = tokenizer("欢迎使用百度飞桨!") inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()} memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)] position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128)) attn_mask = paddle.ones([1, 128, 1]) inputs['memories'] = memories inputs['position_ids'] = position_ids inputs['attn_mask'] = attn_mask outputs = model(**inputs) start_logits = outputs[0] end_logits = outputs[1] mem = outputs[2]