decoding

class InferTransformerDecoding(decoder, word_embedding, positional_embedding, linear, num_decoder_layers, n_head, d_model, bos_id=0, eos_id=1, decoding_strategy='beam_search', beam_size=4, topk=1, topp=0.0, max_out_len=256, diversity_rate=0.0, decoding_lib=None, use_fp16_decoding=False, rel_len=False, alpha=0.6)[源代码]
forward(enc_output, memory_seq_lens, trg_word=None)[源代码]

Defines the computation performed at every call. Should be overridden by all subclasses.

参数
  • *inputs (tuple) -- unpacked tuple arguments

  • **kwargs (dict) -- unpacked dict arguments

class InferGptDecoding(model, decoding_lib=None, use_fp16_decoding=False)[源代码]
forward(input_ids, mem_seq_len, attention_mask=None, topk=4, topp=0.0, bos_token_id=None, eos_token_id=None, pad_token_id=None, max_out_len=256, temperature=1)[源代码]

Defines the computation performed at every call. Should be overridden by all subclasses.

参数
  • *inputs (tuple) -- unpacked tuple arguments

  • **kwargs (dict) -- unpacked dict arguments

class InferUnifiedDecoding(model, decoding_strategy='topk_sampling', decoding_lib=None, use_fp16_decoding=False, logits_mask=None, n_head=8, hidden_dims=512, size_per_head=64, n_layer=6, unk_id=0, mask_id=30000, normalize_before=True, hidden_act='gelu')[源代码]
forward(cache_k, cache_v, memory_seq_lens, decoding_type_id, beam_size=4, topk=4, topp=0.0, max_out_len=256, bos_token_id=None, eos_token_id=None, pad_token_id=None, temperature=1.0, length_penalty=1.0, diversity_rate=0.0, pos_bias=True, rel_len=True)[源代码]

Defines the computation performed at every call. Should be overridden by all subclasses.

参数
  • *inputs (tuple) -- unpacked tuple arguments

  • **kwargs (dict) -- unpacked dict arguments

class InferBartDecoding(model, decoding_strategy='beam_search_v2', decoding_lib=None, use_fp16_decoding=False)[源代码]
forward(enc_output, memory_seq_lens, beam_size=4, top_k=1, top_p=0.0, max_out_len=256, diversity_rate=0.0, rel_len=False, bos_token_id=None, eos_token_id=None, pad_token_id=None, alpha=0.6)[源代码]

Defines the computation performed at every call. Should be overridden by all subclasses.

参数
  • *inputs (tuple) -- unpacked tuple arguments

  • **kwargs (dict) -- unpacked dict arguments