paddlenlp.utils.tools 源代码

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
from .log import logger


[文档]def static_params_to_dygraph(model, static_tensor_dict): """Simple tool for convert static paramters to dygraph paramters dict. **NOTE** The model must both support static graph and dygraph mode. Args: model (nn.Layer): the model of a neural network. static_tensor_dict (string): path of which locate the saved paramters in static mode. Usualy load by `paddle.static.load_program_state`. Returns: [tensor dict]: a state dict the same as the dygraph mode. """ state_dict = model.state_dict() # static_tensor_dict = paddle.static.load_program_state(static_params_path) ret_dict = dict() for n, p in state_dict.items(): if p.name not in static_tensor_dict: logger.info("%s paramter is missing from you state dict." % n) continue ret_dict[n] = static_tensor_dict[p.name] return ret_dict
[文档]def dygraph_params_to_static(model, dygraph_tensor_dict, topo=None): """Simple tool for convert dygraph paramters to static paramters dict. **NOTE** The model must both support static graph and dygraph mode. Args: model (nn.Layer): the model of a neural network. dygraph_tensor_dict (string): path of which locate the saved paramters in static mode. Returns: [tensor dict]: a state dict the same as the dygraph mode. """ state_dict = model.state_dict() ret_dict = dict() for name, parm in state_dict.items(): if name not in dygraph_tensor_dict: logger.info("%s paramter is missing from you state dict." % name) continue tensor = dygraph_tensor_dict[name] if parm.is_distributed: assert topo is not None for dim, v in enumerate(tensor.shape): if parm.shape[dim] != v: break splited = np.split( tensor, topo.mp_info.size, axis=dim)[topo.mp_info.rank] ret_dict[parm.name] = splited else: ret_dict[parm.name] = tensor return ret_dict
[文档]class TimeCostAverage(object): """ Simple tool for calcluating time average cost in the process of training and inferencing. """ def __init__(self): self.reset()
[文档] def reset(self): """ Reset the recoder state, and reset the `cnt` to zero. """ self.cnt = 0 self.total_time = 0
[文档] def record(self, usetime): """ Recoding the time cost in current step and accumulating the `cnt`. """ self.cnt += 1 self.total_time += usetime
[文档] def get_average(self): """ Returning the average time cost after the start of training. """ if self.cnt == 0: return 0 return self.total_time / self.cnt
[文档]def get_env_device(): """ Return the device name of running enviroment. """ if paddle.is_compiled_with_cuda(): return 'gpu' elif paddle.is_compiled_with_npu(): return 'npu' elif paddle.is_compiled_with_rocm(): return 'rocm' elif paddle.is_compiled_with_xpu(): return 'xpu' return 'cpu'
[文档]def compare_version(version, pair_version): """ Args: version (str): The first version string needed to be compared. The format of version string should be as follow : "xxx.yyy.zzz". pair_version (str): The second version string needed to be compared. The format of version string should be as follow : "xxx.yyy.zzz". Returns: int: The result of comparasion. 1 means version > pair_version; 0 means version = pair_version; -1 means version < pair_version. Examples: >>> compare_version("2.2.1", "2.2.0") >>> 1 >>> compare_version("2.2.0", "2.2.0") >>> 0 >>> compare_version("2.2.0-rc0", "2.2.0") >>> -1 >>> compare_version("2.3.0-rc0", "2.2.0") >>> 1 """ version = version.strip() pair_version = pair_version.strip() if version == pair_version: return 0 version_list = version.split(".") pair_version_list = pair_version.split(".") for version_code, pair_version_code in zip(version_list, pair_version_list): if not version_code.isnumeric(): return -1 if not pair_version_code.isnumeric(): return 1 if int(version_code) > int(pair_version_code): return 1 elif int(version_code) < int(pair_version_code): return -1 return 0
[文档]def get_bool_ids_greater_than(probs, limit=0.5, return_prob=False): """ Get idx of the last dimension in probability arrays, which is greater than a limitation. Args: probs (List[List[float]]): The input probability arrays. limit (float): The limitation for probability. return_prob (bool): Whether to return the probability Returns: List[List[int]]: The index of the last dimension meet the conditions. """ probs = np.array(probs) dim_len = len(probs.shape) if dim_len > 1: result = [] for p in probs: result.append(get_bool_ids_greater_than(p, limit, return_prob)) return result else: result = [] for i, p in enumerate(probs): if p > limit: if return_prob: result.append((i, p)) else: result.append(i) return result
[文档]def get_span(start_ids, end_ids, with_prob=False): """ Get span set from position start and end list. Args: start_ids (List[int]/List[tuple]): The start index list. end_ids (List[int]/List[tuple]): The end index list. with_prob (bool): If True, each element for start_ids and end_ids is a tuple aslike: (index, probability). Returns: set: The span set without overlapping, every id can only be used once . """ if with_prob: start_ids = sorted(start_ids, key=lambda x: x[0]) end_ids = sorted(end_ids, key=lambda x: x[0]) else: start_ids = sorted(start_ids) end_ids = sorted(end_ids) start_pointer = 0 end_pointer = 0 len_start = len(start_ids) len_end = len(end_ids) couple_dict = {} while start_pointer < len_start and end_pointer < len_end: if with_prob: start_id = start_ids[start_pointer][0] end_id = end_ids[end_pointer][0] else: start_id = start_ids[start_pointer] end_id = end_ids[end_pointer] if start_id == end_id: couple_dict[end_ids[end_pointer]] = start_ids[start_pointer] start_pointer += 1 end_pointer += 1 continue if start_id < end_id: couple_dict[end_ids[end_pointer]] = start_ids[start_pointer] start_pointer += 1 continue if start_id > end_id: end_pointer += 1 continue result = [(couple_dict[end], end) for end in couple_dict] result = set(result) return result