trainer_utils#

Utilities for the Trainer class.

class TrainOutput(global_step, training_loss, metrics)[源代码]#

基类:NamedTuple

global_step: int#

Alias for field number 0

training_loss: float#

Alias for field number 1

metrics: Dict[str, float]#

Alias for field number 2

class PredictionOutput(predictions, label_ids, metrics)[源代码]#

基类:NamedTuple

predictions: ndarray | Tuple[ndarray]#

Alias for field number 0

label_ids: ndarray | Tuple[ndarray] | None#

Alias for field number 1

metrics: Dict[str, float] | None#

Alias for field number 2

class EvalPrediction(predictions: ndarray | Tuple[ndarray], label_ids: ndarray | Tuple[ndarray])[源代码]#

基类:NamedTuple

Evaluation output (always contains labels), to be used to compute metrics.

参数:
  • predictions (np.ndarray) -- Predictions of the model.

  • label_ids (np.ndarray) -- Targets to be matched.

predictions: ndarray | Tuple[ndarray]#

Alias for field number 0

label_ids: ndarray | Tuple[ndarray]#

Alias for field number 1

class IntervalStrategy(value)[源代码]#

基类:ExplicitEnum

An enumeration.

class SchedulerType(value)[源代码]#

基类:ExplicitEnum

An enumeration.

speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_length=None)[源代码]#

Measure and return speed performance metrics.

This function requires a time snapshot start_time before the operation to be measured starts and this function should be run immediately after the operation to be measured has completed.

Args:

  • split: name to prefix metric (like train, eval, test...)

  • start_time: operation start time

  • num_samples: number of samples processed

get_scheduler(name: str | SchedulerType, learning_rate: float, num_warmup_steps: int | None = None, num_training_steps: int | None = None, num_cycles: float | None = 0.5, lr_end: float | None = 1e-07, power: float | None = 1.0)[源代码]#

Unified API to get any scheduler from its name. :param name: The name of the scheduler to use. :type name: str or SchedulerType :param learning_rate: The initial learning rate. It is a python float number. :type learning_rate: float :param num_warmup_steps: The number of warmup steps to do. This is not required by all schedulers (hence the argument being

optional), the function will raise an error if it's unset and the scheduler type requires it.

参数:
  • num_training_steps (int`, optional) -- The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it.

  • num_cycles (float, optional) -- The number of waves in the cosine scheduler (the defaults is to just decrease from the max value to 0 following a half-cosine). This is not required by all schedulers (hence the argument being optional)

  • lr_end (float, optional) -- The end LR in the polynomial scheduler. This is not required by all schedulers (hence the argument being optional).

  • power (float, optional) -- The power factor in the polynomial scheduler. This is not required by all schedulers (hence the argument being optional).