Perplexity(name='Perplexity', *args, **kwargs)¶
Perplexity is calculated using cross entropy. It supports both padding data and no padding data.
If data is not padded, users should provide
Metricinitialization. If data is padded, your label should contain
seq_mask, which indicates the actual length of samples.
This Perplexity requires that the output of your network is prediction, label and sequence length (opitonal). If the Perplexity here doesn't meet your needs, you could override the
updatemethod for caculating Perplexity.
seq_len (int) -- Sequence length of each sample, it must be provided while data is not padded. Defaults to 20.
name (str) -- Name of
Metricinstance. Defaults to 'Perplexity'.
compute(pred, label, seq_mask=None)¶
Computes cross entropy loss.
pred (Tensor) -- Predictor tensor, and its dtype is float32 or float64, and has a shape of [batch_size, sequence_length, vocab_size].
label (Tensor) -- Label tensor, and its dtype is int64, and has a shape of [batch_size, sequence_length, 1] or [batch_size, sequence_length].
seq_mask (Tensor, optional) -- Sequence mask tensor, and its type could be float32, float64, int32 or int64, and has a shape of [batch_size, sequence_length]. It's used to calculate loss. Defaults to None.
Updates metric states.
Resets all metric states.
Calculates and returns the value of perplexity.
Returns name of the metric instance.
The name of the metric instance.