modeling#
- class ProphetNetModel(config: ProphetNetConfig)[source]#
Bases:
ProphetNetPretrainedModel
- get_input_embeddings()[source]#
get input embedding of model
- Returns:
embedding of model
- Return type:
nn.Embedding
- set_input_embeddings(value)[source]#
set new input embedding for model
- Parameters:
value (Embedding) – the new embedding of model
- Raises:
NotImplementedError – Model has not implement
set_input_embeddings
method
- forward(input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_output: Tuple | None = None, use_cache=True, past_key_values=None)[source]#
Defines the computation performed at every call. Should be overridden by all subclasses.
- Parameters:
*inputs (tuple) – unpacked tuple arguments
**kwargs (dict) – unpacked dict arguments
- class ProphetNetPretrainedModel(*args, **kwargs)[source]#
Bases:
PretrainedModel
An abstract class for pretrained Prophetnet models. It provides Prophetnet related
model_config_file
,pretrained_init_configuration
,resource_files_names
,pretrained_resource_files_map
,base_model_prefix
for downloading and loading pretrained models.- config_class#
alias of
ProphetNetConfig
- base_model_class#
alias of
ProphetNetModel
- class ProphetNetEncoder(word_embeddings, config: ProphetNetConfig)[source]#
Bases:
ProphetNetPretrainedModel
- word_embeddings (
paddle.nn.Embeddings
of shape(config.vocab_size, config.hidden_size)
,optional
): The word embedding parameters. This can be used to initialize
ProphetNetEncoder
with pre-defined word embeddings instead of randomly initialized word embeddings.
- word_embeddings (
- class ProphetNetDecoder(word_embeddings, config: ProphetNetConfig)[source]#
Bases:
ProphetNetPretrainedModel
- forward(input_ids=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=True)[source]#
Defines the computation performed at every call. Should be overridden by all subclasses.
- Parameters:
*inputs (tuple) – unpacked tuple arguments
**kwargs (dict) – unpacked dict arguments
- class ProphetNetForConditionalGeneration(config: ProphetNetConfig)[source]#
Bases:
ProphetNetPretrainedModel
- forward(input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_output=None, labels=None, use_cache=True, past_key_values=None)[source]#
Defines the computation performed at every call. Should be overridden by all subclasses.
- Parameters:
*inputs (tuple) – unpacked tuple arguments
**kwargs (dict) – unpacked dict arguments