modeling#
- class ErnieDocModel(config: ErnieDocConfig)[source]#
Bases:
ErnieDocPretrainedModel- get_input_embeddings()[source]#
get input embedding of model
- Returns:
embedding of model
- Return type:
nn.Embedding
- set_input_embeddings(value)[source]#
set new input embedding for model
- Parameters:
value (Embedding) – the new embedding of model
- Raises:
NotImplementedError – Model has not implement
set_input_embeddingsmethod
- forward(input_ids, memories, token_type_ids, position_ids, attn_mask)[source]#
The ErnieDocModel forward method, overrides the
__call__()special method.- Parameters:
input_ids (Tensor) – Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. It’s data type should be
int64and has a shape of [batch_size, sequence_length, 1].memories (List[Tensor]) – A list of length
n_layerswith each Tensor being a pre-computed hidden-state for each layer. Each Tensor has a dtypefloat32and a shape of [batch_size, sequence_length, hidden_size].token_type_ids (Tensor) –
Segment token indices to indicate first and second portions of the inputs. Indices can be either 0 or 1:
0 corresponds to a sentence A token,
1 corresponds to a sentence B token.
It’s data type should be
int64and has a shape of [batch_size, sequence_length, 1]. Defaults to None, which means no segment embeddings is added to token embeddings.position_ids (Tensor) – Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
[0, config.max_position_embeddings - 1]. Shape as(batch_sie, num_tokens)and dtype asint32orint64.attn_mask (Tensor) – Mask used in multi-head attention to avoid performing attention on to some unwanted positions, usually the paddings or the subsequent positions. Its data type can be int, float and bool. When the data type is bool, the
maskedtokens haveFalsevalues and the others haveTruevalues. When the data type is int, themaskedtokens have0values and the others have1values. When the data type is float, themaskedtokens have-INFvalues and the others have0values. It is a tensor with shape broadcasted to[batch_size, num_attention_heads, sequence_length, sequence_length]. For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length], [batch_size, num_attention_heads, sequence_length, sequence_length]. We use whole-word-mask in ERNIE, so the whole word will have the same value. For example, “使用” as a word, “使” and “用” will have the same value. Defaults toNone, which means nothing needed to be prevented attention to.
- Returns:
Returns tuple (
encoder_output,pooled_output,new_mem).With the fields:
encoder_output(Tensor):Sequence of hidden-states at the last layer of the model. It’s data type should be float32 and its shape is [batch_size, sequence_length, hidden_size].
pooled_output(Tensor):The output of first token (
[CLS]) in sequence. We “pool” the model by simply taking the hidden state corresponding to the first token. Its data type should be float32 and its shape is [batch_size, hidden_size].
new_mem(List[Tensor]):A list of pre-computed hidden-states. The length of the list is
n_layers. Each element in the list is a Tensor with dtypefloat32and shape as [batch_size, memory_length, hidden_size].
- Return type:
tuple
Example
import numpy as np import paddle from paddlenlp.transformers import ErnieDocModel from paddlenlp.transformers import ErnieDocTokenizer def get_related_pos(insts, seq_len, memory_len=128): beg = seq_len + seq_len + memory_len r_position = [list(range(beg - 1, seq_len - 1, -1)) + \ list(range(0, seq_len)) for i in range(len(insts))] return np.array(r_position).astype('int64').reshape([len(insts), beg, 1]) tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh') model = ErnieDocModel.from_pretrained('ernie-doc-base-zh') inputs = tokenizer("欢迎使用百度飞桨!") inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()} memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)] position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128)) attn_mask = paddle.ones([1, 128, 1]) inputs['memories'] = memories inputs['position_ids'] = position_ids inputs['attn_mask'] = attn_mask outputs = model(**inputs) encoder_output = outputs[0] pooled_output = outputs[1] new_mem = outputs[2]
- class ErnieDocPretrainedModel(*args, **kwargs)[source]#
Bases:
PretrainedModelAn abstract class for pretrained ErnieDoc models. It provides ErnieDoc related
model_config_file,pretrained_init_configuration,resource_files_names,pretrained_resource_files_map,base_model_prefixfor downloading and loading pretrained models. SeePretrainedModelfor more details.- config_class#
alias of
ErnieDocConfig
- base_model_class#
alias of
ErnieDocModel
- class ErnieDocForSequenceClassification(config: ErnieDocConfig)[source]#
Bases:
ErnieDocPretrainedModelErnieDoc Model with a linear layer on top of the output layer, designed for sequence classification/regression tasks like GLUE tasks.
- Parameters:
config (
ErnieDocConfig) – An instance of ErnieDocConfig used to construct ErnieDocForSequenceClassification.
- forward(input_ids, memories, token_type_ids, position_ids, attn_mask)[source]#
The ErnieDocForSequenceClassification forward method, overrides the
__call__()special method.- Parameters:
input_ids (Tensor) – See
ErnieDocModel.memories (List[Tensor]) – See
ErnieDocModel.token_type_ids (Tensor) – See
ErnieDocModel.position_ids (Tensor) – See
ErnieDocModel.attn_mask (Tensor) – See
ErnieDocModel.
- Returns:
Returns tuple (
logits,mem).With the fields:
logits(Tensor):A tensor containing the [CLS] of hidden-states of the model at the output of last layer. Each Tensor has a data type of
float32and has a shape of [batch_size, num_labels].
mem(List[Tensor]):A list of pre-computed hidden-states. The length of the list is
n_layers. Each element in the list is a Tensor with dtypefloat32and has a shape of [batch_size, memory_length, hidden_size].
- Return type:
tuple
Example
import numpy as np import paddle from paddlenlp.transformers import ErnieDocForSequenceClassification from paddlenlp.transformers import ErnieDocTokenizer def get_related_pos(insts, seq_len, memory_len=128): beg = seq_len + seq_len + memory_len r_position = [list(range(beg - 1, seq_len - 1, -1)) + \ list(range(0, seq_len)) for i in range(len(insts))] return np.array(r_position).astype('int64').reshape([len(insts), beg, 1]) tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh') model = ErnieDocForSequenceClassification.from_pretrained('ernie-doc-base-zh', num_labels=2) inputs = tokenizer("欢迎使用百度飞桨!") inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()} memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)] position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128)) attn_mask = paddle.ones([1, 128, 1]) inputs['memories'] = memories inputs['position_ids'] = position_ids inputs['attn_mask'] = attn_mask outputs = model(**inputs) logits = outputs[0] mem = outputs[1]
- class ErnieDocForTokenClassification(config: ErnieDocConfig)[source]#
Bases:
ErnieDocPretrainedModelErnieDoc Model with a linear layer on top of the hidden-states output layer, designed for token classification tasks like NER tasks.
- Parameters:
config (
ErnieDocConfig) – An instance of ErnieDocConfig used to construct ErnieDocForTokenClassification.
- forward(input_ids, memories, token_type_ids, position_ids, attn_mask)[source]#
The ErnieDocForTokenClassification forward method, overrides the
__call__()special method.- Parameters:
input_ids (Tensor) – See
ErnieDocModel.memories (List[Tensor]) – See
ErnieDocModel.token_type_ids (Tensor) – See
ErnieDocModel. Defaults to None, which means no segment embeddings is added to token embeddings.position_ids (Tensor) – See
ErnieDocModel.attn_mask (Tensor) – See
ErnieDocModel.
- Returns:
Returns tuple (
logits,mem).With the fields:
logits(Tensor):A tensor containing the hidden-states of the model at the output of last layer. Each Tensor has a data type of
float32and has a shape of [batch_size, sequence_length, num_labels].
mem(List[Tensor]):A list of pre-computed hidden-states. The length of the list is
n_layers. Each element in the list is a Tensor with dtypefloat32and has a shape of [batch_size, memory_length, hidden_size].
- Return type:
tuple
Example
import numpy as np import paddle from paddlenlp.transformers import ErnieDocForTokenClassification from paddlenlp.transformers import ErnieDocTokenizer def get_related_pos(insts, seq_len, memory_len=128): beg = seq_len + seq_len + memory_len r_position = [list(range(beg - 1, seq_len - 1, -1)) + \ list(range(0, seq_len)) for i in range(len(insts))] return np.array(r_position).astype('int64').reshape([len(insts), beg, 1]) tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh') model = ErnieDocForTokenClassification.from_pretrained('ernie-doc-base-zh', num_labels=2) inputs = tokenizer("欢迎使用百度飞桨!") inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()} memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)] position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128)) attn_mask = paddle.ones([1, 128, 1]) inputs['memories'] = memories inputs['position_ids'] = position_ids inputs['attn_mask'] = attn_mask outputs = model(**inputs) logits = outputs[0] mem = outputs[1]
- class ErnieDocForQuestionAnswering(config: ErnieDocConfig)[source]#
Bases:
ErnieDocPretrainedModelErnieDoc Model with a linear layer on top of the hidden-states output to compute
span_start_logitsandspan_end_logits, designed for question-answering tasks like SQuAD.- Parameters:
config (
ErnieDocConfig) – An instance of ErnieDocConfig used to construct ErnieDocForQuestionAnswering.
- forward(input_ids, memories, token_type_ids, position_ids, attn_mask)[source]#
The ErnieDocForQuestionAnswering forward method, overrides the
__call__()special method.- Parameters:
input_ids (Tensor) – See
ErnieDocModel.memories (List[Tensor]) – See
ErnieDocModel.token_type_ids (Tensor) – See
ErnieDocModel.position_ids (Tensor) – See
ErnieDocModel.attn_mask (Tensor) – See
ErnieDocModel.
- Returns:
Returns tuple (
start_logits,end_logits,mem).With the fields:
start_logits(Tensor):A tensor of the input token classification logits, indicates the start position of the labelled span. Its data type should be float32 and its shape is [batch_size, sequence_length].
end_logits(Tensor):A tensor of the input token classification logits, indicates the end position of the labelled span. Its data type should be float32 and its shape is [batch_size, sequence_length].
mem(List[Tensor]):A list of pre-computed hidden-states. The length of the list is
n_layers. Each element in the list is a Tensor with dtypefloat32and has a shape of [batch_size, memory_length, hidden_size].
- Return type:
tuple
Example
import numpy as np import paddle from paddlenlp.transformers import ErnieDocForQuestionAnswering from paddlenlp.transformers import ErnieDocTokenizer def get_related_pos(insts, seq_len, memory_len=128): beg = seq_len + seq_len + memory_len r_position = [list(range(beg - 1, seq_len - 1, -1)) + \ list(range(0, seq_len)) for i in range(len(insts))] return np.array(r_position).astype('int64').reshape([len(insts), beg, 1]) tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh') model = ErnieDocForQuestionAnswering.from_pretrained('ernie-doc-base-zh') inputs = tokenizer("欢迎使用百度飞桨!") inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()} memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)] position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128)) attn_mask = paddle.ones([1, 128, 1]) inputs['memories'] = memories inputs['position_ids'] = position_ids inputs['attn_mask'] = attn_mask outputs = model(**inputs) start_logits = outputs[0] end_logits = outputs[1] mem = outputs[2]