modeling#
Modeling classes for UnifiedTransformer model.
- class UnifiedTransformerPretrainedModel(*args, **kwargs)[源代码]#
-
An abstract class for pretrained UnifiedTransformer models. It provides UnifiedTransformer related
model_config_file
,resource_files_names
,pretrained_resource_files_map
,pretrained_init_configuration
,base_model_prefix
for downloading and loading pretrained models. SeePretrainedModel
for more details.- config_class#
UnifiedTransformerConfig
的别名
- base_model_class#
- class UnifiedTransformerModel(config: UnifiedTransformerConfig)[源代码]#
基类:
UnifiedTransformerPretrainedModel
The bare UnifiedTransformer Model outputting raw hidden-states.
This model inherits from
PretrainedModel
. Refer to the superclass documentation for the generic methods.This model is also a paddle.nn.Layer subclass. Use it as a regular Paddle Layer and refer to the Paddle documentation for all matter related to general usage and behavior.
- set_input_embeddings(value)[源代码]#
set new input embedding for model
- 参数:
value (Embedding) -- the new embedding of model
- 抛出:
NotImplementedError -- Model has not implement
set_input_embeddings
method
- forward(input_ids: Tensor | None = None, token_type_ids: Tensor | None = None, position_ids: Tensor | None = None, attention_mask: Tensor | None = None, use_cache: bool | None = None, cache: Tuple[Tensor] | None = None, role_ids: Tensor | None = None, inputs_embeds: Tensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None)[源代码]#
The UnifiedTransformerModel forward method, overrides the special
__call__()
method.- 参数:
input_ids (Tensor, optional) -- Indices of input sequence tokens in the vocabulary. They are numerical representations of tokens that build the input sequence. It's data type should be
int64
and has a shape of [batch_size, sequence_length].token_type_ids (Tensor) --
Segment token indices to indicate first and second portions of the inputs. Indices can be either 0 or 1:
0 corresponds to a sentence A token,
1 corresponds to a sentence B token.
It's data type should be
int64
and has a shape of [batch_size, sequence_length].position_ids (Tensor) -- The position indices of input sequence tokens. It's data type should be
int64
and has a shape of [batch_size, sequence_length].attention_mask (Tensor) --
A tensor used in multi-head attention to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape broadcasted to [batch_size, n_head, sequence_length, sequence_length].
When the data type is bool, the unwanted positions have
False
values and the others haveTrue
values.When the data type is int, the unwanted positions have 0 values and the others have 1 values.
When the data type is float, the unwanted positions have
-INF
values and the others have 0 values.
use_cache -- (bool, optional): Whether or not use the model cache to speed up decoding. Defaults to False.
cache (list, optional) -- It is a list, and each element in the list is
incremental_cache
produced bypaddle.nn.TransformerEncoderLayer.gen_cache()
method. Seepaddle.nn.TransformerEncoder.gen_cache()
method for more details. It is only used for inference and should be None for training. Defaults to None.role_ids (Tensor, optional) --
- Indices of role ids indicated different roles.
It's data type should be
int64
and has a shape of
[batch_size, sequence_length]. Defaults to None.
inputs_embeds (Tensor, optional) -- Optionally, instead of passing
input_ids
you can choose to directly pass an embedded representation of shape(batch_size, sequence_length, hidden_size)
. This is useful if you want more control over how to convertinput_ids
indices into associated vectors than the model's internal embedding lookup matrix. Default to None.output_attentions (bool, optional) -- Whether or not to return the attentions tensors of all attention layers. See
attentions
under returned tensors for more detail. Defaults toFalse
.output_hidden_states (bool, optional) -- Whether or not to return the hidden states of all layers. See
hidden_states
under returned tensors for more detail. Defaults toFalse
.return_dict (bool, optional) -- Whether to return a
BaseModelOutputWithPastAndCrossAttentions
object. IfFalse
, the output will be a tuple of tensors. Defaults toFalse
.
- 返回:
An instance of
BaseModelOutputWithPastAndCrossAttentions
ifreturn_dict=True
. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields ofBaseModelOutputWithPastAndCrossAttentions
. Especially, Whenreturn_dict=output_hidden_states=output_attentions=False
andcache=None
, returns a tensor representing the output ofUnifiedTransformerModel
, with shape [batch_size, sequence_length, hidden_size]. The data type is float32 or float64.
示例
from paddlenlp.transformers import UnifiedTransformerModel from paddlenlp.transformers import UnifiedTransformerTokenizer model = UnifiedTransformerModel.from_pretrained('plato-mini') tokenizer = UnifiedTransformerTokenizer.from_pretrained('plato-mini') history = '我爱祖国' inputs = tokenizer.dialogue_encode( history, return_tensors=True, is_split_into_words=False) outputs = model(**inputs)
- class UnifiedTransformerLMHeadModel(config: UnifiedTransformerConfig)[源代码]#
基类:
UnifiedTransformerPretrainedModel
The UnifiedTransformer Model with a language modeling head on top for generation tasks.
- 参数:
unified_transformer (
UnifiedTransformerModel
) -- An instance ofUnifiedTransformerModel
.
- forward(input_ids: Tensor | None = None, token_type_ids: Tensor | None = None, position_ids: Tensor | None = None, attention_mask: Tensor | None = None, masked_positions: Tensor | None = None, use_cache: bool | None = None, cache: Tuple[Tensor] | None = None, role_ids: Tensor | None = None, labels: Tensor | None = None, inputs_embeds: Tensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None)[源代码]#
The UnifiedTransformerLMHeadModel forward method, overrides the special
__call__()
method.- 参数:
input_ids (Tensor, optional) -- See
UnifiedTransformerModel
.token_type_ids (Tensor) -- See
UnifiedTransformerModel
.position_ids (Tensor) -- See
UnifiedTransformerModel
.attention_mask (Tensor) -- See
UnifiedTransformerModel
.use_cache -- (bool, optional): See
UnifiedTransformerModel
.cache (list, optional) -- See
UnifiedTransformerModel
.role_ids -- (Tensor, optional): See
UnifiedTransformerModel
.labels -- (Tensor, optional): Labels for computing the left-to-right language modeling loss. Indices should be in
[-100, 0, ..., vocab_size]
(seeinput_ids
docstring) Tokens with indices set to-100
are ignored (masked), the loss is only computed for the tokens with labels n[0, ..., vocab_size]
inputs_embeds (Tensor, optional) -- See
UnifiedTransformerModel
.output_attentions (bool, optional) -- See :class:
UnifiedTransformerModel
output_hidden_states (bool, optional) -- See :class:
UnifiedTransformerModel
return_dict (bool, optional) -- See :class:
UnifiedTransformerModel
- 返回:
An instance of
CausalLMOutputWithCrossAttentions
ifreturn_dict=True
. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields ofCausalLMOutputWithCrossAttentions
. Especially, Whenreturn_dict=output_hidden_states=output_attentions=False
andcache=labels=None
, returns a tensor representing the output ofUnifiedTransformerLMHeadModel
, with shape [batch_size, sequence_length, vocab_size]. The data type is float32 or float64.
示例
from paddlenlp.transformers import UnifiedTransformerLMHeadModel from paddlenlp.transformers import UnifiedTransformerTokenizer model = UnifiedTransformerLMHeadModel.from_pretrained('plato-mini') tokenizer = UnifiedTransformerTokenizer.from_pretrained('plato-mini') history = '我爱祖国' inputs = tokenizer.dialogue_encode( history, return_tensors=True, is_split_into_words=False) logits = model(**inputs)
- UnifiedTransformerForMaskedLM#