trainer_utils#
Utilities for the Trainer class.
- class TrainOutput(global_step, training_loss, metrics)[源代码]#
基类:
NamedTuple
- global_step: int#
Alias for field number 0
- training_loss: float#
Alias for field number 1
- metrics: Dict[str, float]#
Alias for field number 2
- class PredictionOutput(predictions, label_ids, metrics)[源代码]#
基类:
NamedTuple
- predictions: ndarray | Tuple[ndarray]#
Alias for field number 0
- label_ids: ndarray | Tuple[ndarray] | None#
Alias for field number 1
- metrics: Dict[str, float] | None#
Alias for field number 2
- class EvalPrediction(predictions: ndarray | Tuple[ndarray], label_ids: ndarray | Tuple[ndarray])[源代码]#
基类:
NamedTuple
Evaluation output (always contains labels), to be used to compute metrics.
- 参数:
predictions (
np.ndarray
) -- Predictions of the model.label_ids (
np.ndarray
) -- Targets to be matched.
- predictions: ndarray | Tuple[ndarray]#
Alias for field number 0
- label_ids: ndarray | Tuple[ndarray]#
Alias for field number 1
- speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_length=None)[源代码]#
Measure and return speed performance metrics.
This function requires a time snapshot
start_time
before the operation to be measured starts and this function should be run immediately after the operation to be measured has completed.Args:
split: name to prefix metric (like train, eval, test...)
start_time: operation start time
num_samples: number of samples processed
- get_scheduler(name: str | SchedulerType, learning_rate: float, num_warmup_steps: int | None = None, num_training_steps: int | None = None, num_cycles: float | None = 0.5, lr_end: float | None = 1e-07, power: float | None = 1.0)[源代码]#
Unified API to get any scheduler from its name. :param name: The name of the scheduler to use. :type name:
str
orSchedulerType
:param learning_rate: The initial learning rate. It is a python float number. :type learning_rate: float :param num_warmup_steps: The number of warmup steps to do. This is not required by all schedulers (hence the argument beingoptional), the function will raise an error if it's unset and the scheduler type requires it.
- 参数:
num_training_steps (
int`
, optional) -- The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it.num_cycles (
float
, optional) -- The number of waves in the cosine scheduler (the defaults is to just decrease from the max value to 0 following a half-cosine). This is not required by all schedulers (hence the argument being optional)lr_end (
float
, optional) -- The end LR in the polynomial scheduler. This is not required by all schedulers (hence the argument being optional).power (
float
, optional) -- The power factor in the polynomial scheduler. This is not required by all schedulers (hence the argument being optional).