modeling

class ErnieGenPretrainedModel(*args, **kwargs)[源代码]

基类:object

An abstract class for pretrained ErnieGen models. It provides ErnieGen related model_config_file, pretrained_init_configuration, resource_files_names, pretrained_resource_files_map, base_model_prefix for downloading and loading pretrained models. See PretrainedModel for more details.

save_pretrained(save_directory)[源代码]

Save model configuration and related resources (model state) to files under save_directory. :param save_directory: Directory to save files into. :type save_directory: str

class ErnieForGeneration(cfg, name=None)[源代码]

基类:paddlenlp.transformers.ernie_gen.modeling.ErnieModel

Ernie Model for sequence to sequence generation.

This model inherits from ErnieModel. Refer to the superclass documentation for the generic methods.

forward(*args, **kwargs)[源代码]
参数
  • tgt_labels (Tensor, optional) -- The ground truth target sequence id (hard label) or distribution (soft label). It's data type should be int64 and has a shape of [batch_size, sequence_length] or [batch_size, sequence_length, sequence_length].

  • tgt_pos (Tensor, optional) -- Index of tgt_labels in src_ids. It's data type should be int64 and has a shape of [n_targets, 2]).

  • encode_only (bool, optional) -- Whether the model will output the logits or only encode the inputs. If encode_only is True, loss and logits_2d will not be returned.

返回

Returns tuple (None, None, info) if encode_only is True, returns (output_ids, logits, info) if tgt_labels or tgt_pos is None, else, returns (loss, logits_2d, info).

With the fields:

  • `info`(dict):

    Middle level info, includes all hidden stats and k/v caches.

  • `output_ids`(Tensor):

    The output index. Its data type should be float32 and its shape is [batch_size]. If encode_only, returns None.

  • `logits`(Tensor):

    Logits for every targets. Its data type should be float32 and its shape is [batch_size, sequence_length]. If encode_only, returns None.

  • `loss`(Tensor):

    Cross entropy loss mean over every target label. If encode_only, returns None.

  • `logits_2d`(Tensor):

    Logits for every targets if tgt_labels or tgt_pos is not None . Its data type should be float32 and its shape is [batch_size, sequence_length].

返回类型

tuple