modeling#

class ErnieDocModel(config: ErnieDocConfig)[源代码]#

基类:ErnieDocPretrainedModel

get_input_embeddings()[源代码]#

get input embedding of model

返回:

embedding of model

返回类型:

nn.Embedding

set_input_embeddings(value)[源代码]#

set new input embedding for model

参数:

value (Embedding) -- the new embedding of model

抛出:

NotImplementedError -- Model has not implement set_input_embeddings method

forward(input_ids, memories, token_type_ids, position_ids, attn_mask)[源代码]#

The ErnieDocModel forward method, overrides the __call__() special method.

参数:
  • input_ids (Tensor) -- Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. It's data type should be int64 and has a shape of [batch_size, sequence_length, 1].

  • memories (List[Tensor]) -- A list of length n_layers with each Tensor being a pre-computed hidden-state for each layer. Each Tensor has a dtype float32 and a shape of [batch_size, sequence_length, hidden_size].

  • token_type_ids (Tensor) --

    Segment token indices to indicate first and second portions of the inputs. Indices can be either 0 or 1:

    • 0 corresponds to a sentence A token,

    • 1 corresponds to a sentence B token.

    It's data type should be int64 and has a shape of [batch_size, sequence_length, 1]. Defaults to None, which means no segment embeddings is added to token embeddings.

  • position_ids (Tensor) -- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range [0, config.max_position_embeddings - 1]. Shape as (batch_sie, num_tokens) and dtype as int32 or int64.

  • attn_mask (Tensor) -- Mask used in multi-head attention to avoid performing attention on to some unwanted positions, usually the paddings or the subsequent positions. Its data type can be int, float and bool. When the data type is bool, the masked tokens have False values and the others have True values. When the data type is int, the masked tokens have 0 values and the others have 1 values. When the data type is float, the masked tokens have -INF values and the others have 0 values. It is a tensor with shape broadcasted to [batch_size, num_attention_heads, sequence_length, sequence_length]. For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length], [batch_size, num_attention_heads, sequence_length, sequence_length]. We use whole-word-mask in ERNIE, so the whole word will have the same value. For example, "使用" as a word, "使" and "用" will have the same value. Defaults to None, which means nothing needed to be prevented attention to.

返回:

Returns tuple (encoder_output, pooled_output, new_mem).

With the fields:

  • encoder_output (Tensor):

    Sequence of hidden-states at the last layer of the model. It's data type should be float32 and its shape is [batch_size, sequence_length, hidden_size].

  • pooled_output (Tensor):

    The output of first token ([CLS]) in sequence. We "pool" the model by simply taking the hidden state corresponding to the first token. Its data type should be float32 and its shape is [batch_size, hidden_size].

  • new_mem (List[Tensor]):

    A list of pre-computed hidden-states. The length of the list is n_layers. Each element in the list is a Tensor with dtype float32 and shape as [batch_size, memory_length, hidden_size].

返回类型:

tuple

示例

import numpy as np
import paddle
from paddlenlp.transformers import ErnieDocModel
from paddlenlp.transformers import ErnieDocTokenizer

def get_related_pos(insts, seq_len, memory_len=128):
    beg = seq_len + seq_len + memory_len
    r_position = [list(range(beg - 1, seq_len - 1, -1)) + \
                list(range(0, seq_len)) for i in range(len(insts))]
    return np.array(r_position).astype('int64').reshape([len(insts), beg, 1])

tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh')
model = ErnieDocModel.from_pretrained('ernie-doc-base-zh')

inputs = tokenizer("欢迎使用百度飞桨!")
inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()}

memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)]
position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128))
attn_mask = paddle.ones([1, 128, 1])

inputs['memories'] = memories
inputs['position_ids'] = position_ids
inputs['attn_mask'] = attn_mask

outputs = model(**inputs)

encoder_output = outputs[0]
pooled_output = outputs[1]
new_mem = outputs[2]
class ErnieDocPretrainedModel(*args, **kwargs)[源代码]#

基类:PretrainedModel

An abstract class for pretrained ErnieDoc models. It provides ErnieDoc related model_config_file, pretrained_init_configuration, resource_files_names, pretrained_resource_files_map, base_model_prefix for downloading and loading pretrained models. See PretrainedModel for more details.

config_class#

ErnieDocConfig 的别名

base_model_class#

ErnieDocModel 的别名

class ErnieDocForSequenceClassification(config: ErnieDocConfig)[源代码]#

基类:ErnieDocPretrainedModel

ErnieDoc Model with a linear layer on top of the output layer, designed for sequence classification/regression tasks like GLUE tasks.

参数:

config (ErnieDocConfig) -- An instance of ErnieDocConfig used to construct ErnieDocForSequenceClassification.

forward(input_ids, memories, token_type_ids, position_ids, attn_mask)[源代码]#

The ErnieDocForSequenceClassification forward method, overrides the __call__() special method.

参数:
返回:

Returns tuple (logits, mem).

With the fields:

  • logits (Tensor):

    A tensor containing the [CLS] of hidden-states of the model at the output of last layer. Each Tensor has a data type of float32 and has a shape of [batch_size, num_labels].

  • mem (List[Tensor]):

    A list of pre-computed hidden-states. The length of the list is n_layers. Each element in the list is a Tensor with dtype float32 and has a shape of [batch_size, memory_length, hidden_size].

返回类型:

tuple

示例

import numpy as np
import paddle
from paddlenlp.transformers import ErnieDocForSequenceClassification
from paddlenlp.transformers import ErnieDocTokenizer

def get_related_pos(insts, seq_len, memory_len=128):
    beg = seq_len + seq_len + memory_len
    r_position = [list(range(beg - 1, seq_len - 1, -1)) + \
                list(range(0, seq_len)) for i in range(len(insts))]
    return np.array(r_position).astype('int64').reshape([len(insts), beg, 1])

tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh')
model = ErnieDocForSequenceClassification.from_pretrained('ernie-doc-base-zh', num_labels=2)

inputs = tokenizer("欢迎使用百度飞桨!")
inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()}

memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)]
position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128))
attn_mask = paddle.ones([1, 128, 1])

inputs['memories'] = memories
inputs['position_ids'] = position_ids
inputs['attn_mask'] = attn_mask

outputs = model(**inputs)

logits = outputs[0]
mem = outputs[1]
class ErnieDocForTokenClassification(config: ErnieDocConfig)[源代码]#

基类:ErnieDocPretrainedModel

ErnieDoc Model with a linear layer on top of the hidden-states output layer, designed for token classification tasks like NER tasks.

参数:

config (ErnieDocConfig) -- An instance of ErnieDocConfig used to construct ErnieDocForTokenClassification.

forward(input_ids, memories, token_type_ids, position_ids, attn_mask)[源代码]#

The ErnieDocForTokenClassification forward method, overrides the __call__() special method.

参数:
返回:

Returns tuple (logits, mem).

With the fields:

  • logits (Tensor):

    A tensor containing the hidden-states of the model at the output of last layer. Each Tensor has a data type of float32 and has a shape of [batch_size, sequence_length, num_labels].

  • mem (List[Tensor]):

    A list of pre-computed hidden-states. The length of the list is n_layers. Each element in the list is a Tensor with dtype float32 and has a shape of [batch_size, memory_length, hidden_size].

返回类型:

tuple

示例

import numpy as np
import paddle
from paddlenlp.transformers import ErnieDocForTokenClassification
from paddlenlp.transformers import ErnieDocTokenizer

def get_related_pos(insts, seq_len, memory_len=128):
    beg = seq_len + seq_len + memory_len
    r_position = [list(range(beg - 1, seq_len - 1, -1)) + \
                list(range(0, seq_len)) for i in range(len(insts))]
    return np.array(r_position).astype('int64').reshape([len(insts), beg, 1])

tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh')
model = ErnieDocForTokenClassification.from_pretrained('ernie-doc-base-zh', num_labels=2)

inputs = tokenizer("欢迎使用百度飞桨!")
inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()}

memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)]
position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128))
attn_mask = paddle.ones([1, 128, 1])

inputs['memories'] = memories
inputs['position_ids'] = position_ids
inputs['attn_mask'] = attn_mask

outputs = model(**inputs)

logits = outputs[0]
mem = outputs[1]
class ErnieDocForQuestionAnswering(config: ErnieDocConfig)[源代码]#

基类:ErnieDocPretrainedModel

ErnieDoc Model with a linear layer on top of the hidden-states output to compute span_start_logits and span_end_logits, designed for question-answering tasks like SQuAD.

参数:

config (ErnieDocConfig) -- An instance of ErnieDocConfig used to construct ErnieDocForQuestionAnswering.

forward(input_ids, memories, token_type_ids, position_ids, attn_mask)[源代码]#

The ErnieDocForQuestionAnswering forward method, overrides the __call__() special method.

参数:
返回:

Returns tuple (start_logits, end_logits, mem).

With the fields:

  • start_logits (Tensor):

    A tensor of the input token classification logits, indicates the start position of the labelled span. Its data type should be float32 and its shape is [batch_size, sequence_length].

  • end_logits (Tensor):

    A tensor of the input token classification logits, indicates the end position of the labelled span. Its data type should be float32 and its shape is [batch_size, sequence_length].

  • mem (List[Tensor]):

    A list of pre-computed hidden-states. The length of the list is n_layers. Each element in the list is a Tensor with dtype float32 and has a shape of [batch_size, memory_length, hidden_size].

返回类型:

tuple

示例

import numpy as np
import paddle
from paddlenlp.transformers import ErnieDocForQuestionAnswering
from paddlenlp.transformers import ErnieDocTokenizer

def get_related_pos(insts, seq_len, memory_len=128):
    beg = seq_len + seq_len + memory_len
    r_position = [list(range(beg - 1, seq_len - 1, -1)) + \
                list(range(0, seq_len)) for i in range(len(insts))]
    return np.array(r_position).astype('int64').reshape([len(insts), beg, 1])

tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh')
model = ErnieDocForQuestionAnswering.from_pretrained('ernie-doc-base-zh')

inputs = tokenizer("欢迎使用百度飞桨!")
inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()}

memories = [paddle.zeros([1, 128, 768], dtype="float32") for _ in range(12)]
position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128))
attn_mask = paddle.ones([1, 128, 1])

inputs['memories'] = memories
inputs['position_ids'] = position_ids
inputs['attn_mask'] = attn_mask

outputs = model(**inputs)

start_logits = outputs[0]
end_logits = outputs[1]
mem = outputs[2]