tools#

static_params_to_dygraph(model, static_tensor_dict)[源代码]#

Simple tool for convert static paramters to dygraph paramters dict.

NOTE The model must both support static graph and dygraph mode.

参数:
  • model (nn.Layer) -- the model of a neural network.

  • static_tensor_dict (string) -- path of which locate the saved paramters in static mode. Usualy load by paddle.static.load_program_state.

返回:

a state dict the same as the dygraph mode.

返回类型:

[tensor dict]

dygraph_params_to_static(model, dygraph_tensor_dict, topo=None)[源代码]#

Simple tool for convert dygraph paramters to static paramters dict.

NOTE The model must both support static graph and dygraph mode.

参数:
  • model (nn.Layer) -- the model of a neural network.

  • dygraph_tensor_dict (string) -- path of which locate the saved paramters in static mode.

返回:

a state dict the same as the dygraph mode.

返回类型:

[tensor dict]

class TimeCostAverage[源代码]#

基类:object

Simple tool for calcluating time average cost in the process of training and inferencing.

reset()[源代码]#

Reset the recoder state, and reset the cnt to zero.

record(usetime)[源代码]#

Recoding the time cost in current step and accumulating the cnt.

get_average()[源代码]#

Returning the average time cost after the start of training.

get_env_device()[源代码]#

Return the device name of running environment.

compare_version(version, pair_version)[源代码]#
参数:
  • version (str) -- The first version string needed to be compared. The format of version string should be as follow : "xxx.yyy.zzz".

  • pair_version (str) -- The second version string needed to be compared. The format of version string should be as follow : "xxx.yyy.zzz".

返回:

The result of comparasion. 1 means version > pair_version; 0 means

version = pair_version; -1 means version < pair_version.

返回类型:

int

示例

>>> compare_version("2.2.1", "2.2.0")
>>> 1
>>> compare_version("2.2.0", "2.2.0")
>>> 0
>>> compare_version("2.2.0-rc0", "2.2.0")
>>> -1
>>> compare_version("2.3.0-rc0", "2.2.0")
>>> 1
get_bool_ids_greater_than(probs, limit=0.5, return_prob=False)[源代码]#

Get idx of the last dimension in probability arrays, which is greater than a limitation.

参数:
  • probs (List[List[float]]) -- The input probability arrays.

  • limit (float) -- The limitation for probability.

  • return_prob (bool) -- Whether to return the probability

返回:

The index of the last dimension meet the conditions.

返回类型:

List[List[int]]

get_span(start_ids, end_ids, with_prob=False)[源代码]#

Get span set from position start and end list.

参数:
  • start_ids (List[int]/List[tuple]) -- The start index list.

  • end_ids (List[int]/List[tuple]) -- The end index list.

  • with_prob (bool) -- If True, each element for start_ids and end_ids is a tuple aslike: (index, probability).

返回:

The span set without overlapping, every id can only be used once .

返回类型:

set

class DataConverter(label_studio_file, negative_ratio=5, prompt_prefix='情感倾向', options=['正向', '负向'], separator='##', layout_analysis=False, expand_to_a4_size=True, schema_lang='ch', ocr_lang='en', anno_type='text')[源代码]#

基类:object

DataConverter to convert data export from annotation platform

convert_cls_examples(raw_examples)[源代码]#

Convert labeled data for classification task.

convert_ext_examples(raw_examples, is_train=True)[源代码]#

Convert labeled data for extraction task.