# coding:utf-8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import abc
import math
from abc import abstractmethod
import paddle
from paddle.dataset.common import md5file
from ..utils.env import PPNLP_HOME
from ..utils.log import logger
from .utils import download_check, static_mode_guard, dygraph_mode_guard, download_file, cut_chinese_sent
[文档]class Task(metaclass=abc.ABCMeta):
"""
The meta classs of task in Taskflow. The meta class has the five abstract function,
the subclass need to inherit from the meta class.
Args:
task(string): The name of task.
model(string): The model name in the task.
kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
"""
def __init__(self, model, task, priority_path=None, **kwargs):
self.model = model
self.task = task
self.kwargs = kwargs
self._priority_path = priority_path
self._usage = ""
# The dygraph model instantce
self._model = None
# The static model instantce
self._input_spec = None
self._config = None
self._custom_model = False
self._param_updated = False
# The root directory for storing Taskflow related files, default to ~/.paddlenlp.
self._home_path = self.kwargs[
'home_path'] if 'home_path' in self.kwargs else PPNLP_HOME
self._task_flag = self.kwargs[
'task_flag'] if 'task_flag' in self.kwargs else self.model
if 'task_path' in self.kwargs:
self._task_path = self.kwargs['task_path']
self._custom_model = True
elif self._priority_path:
self._task_path = os.path.join(self._home_path, "taskflow",
self._priority_path)
else:
self._task_path = os.path.join(self._home_path, "taskflow",
self.task, self.model)
download_check(self._task_flag)
@abstractmethod
def _construct_model(self, model):
"""
Construct the inference model for the predictor.
"""
@abstractmethod
def _construct_tokenizer(self, model):
"""
Construct the tokenizer for the predictor.
"""
@abstractmethod
def _preprocess(self, inputs, padding=True, add_special_tokens=True):
"""
Transform the raw text to the model inputs, two steps involved:
1) Transform the raw text to token ids.
2) Generate the other model inputs from the raw text and token ids.
"""
@abstractmethod
def _run_model(self, inputs):
"""
Run the task model from the outputs of the `_tokenize` function.
"""
@abstractmethod
def _postprocess(self, inputs):
"""
The model output is the logits and pros, this function will convert the model output to raw text.
"""
@abstractmethod
def _construct_input_spec(self):
"""
Construct the input spec for the convert dygraph model to static model.
"""
def _check_task_files(self):
"""
Check files required by the task.
"""
for file_id, file_name in self.resource_files_names.items():
path = os.path.join(self._task_path, file_name)
url = self.resource_files_urls[self.model][file_id][0]
md5 = self.resource_files_urls[self.model][file_id][1]
downloaded = True
if not os.path.exists(path):
downloaded = False
else:
if not self._custom_model:
if os.path.exists(path):
# Check whether the file is updated
if not md5file(path) == md5:
downloaded = False
if file_id == "model_state":
self._param_updated = True
else:
downloaded = False
if not downloaded:
download_file(self._task_path, file_name, url, md5)
def _prepare_static_mode(self):
"""
Construct the input data and predictor in the PaddlePaddele static mode.
"""
place = paddle.get_device()
if place == 'cpu':
self._config.disable_gpu()
else:
self._config.enable_use_gpu(100, self.kwargs['device_id'])
# TODO(linjieccc): enable embedding_eltwise_layernorm_fuse_pass after fixed
self._config.delete_pass("embedding_eltwise_layernorm_fuse_pass")
self._config.switch_use_feed_fetch_ops(False)
self._config.disable_glog_info()
self._config.enable_memory_optim()
self.predictor = paddle.inference.create_predictor(self._config)
self.input_handles = [
self.predictor.get_input_handle(name)
for name in self.predictor.get_input_names()
]
self.output_handle = [
self.predictor.get_output_handle(name)
for name in self.predictor.get_output_names()
]
def _get_inference_model(self):
"""
Return the inference program, inputs and outputs in static mode.
"""
inference_model_path = os.path.join(self._task_path, "static",
"inference")
if not os.path.exists(inference_model_path +
".pdiparams") or self._param_updated:
with dygraph_mode_guard():
self._construct_model(self.model)
self._construct_input_spec()
self._convert_dygraph_to_static()
model_file = inference_model_path + ".pdmodel"
params_file = inference_model_path + ".pdiparams"
self._config = paddle.inference.Config(model_file, params_file)
self._prepare_static_mode()
def _convert_dygraph_to_static(self):
"""
Convert the dygraph model to static model.
"""
assert self._model is not None, 'The dygraph model must be created before converting the dygraph model to static model.'
assert self._input_spec is not None, 'The input spec must be created before converting the dygraph model to static model.'
logger.info("Converting to the inference model cost a little time.")
static_model = paddle.jit.to_static(
self._model, input_spec=self._input_spec)
save_path = os.path.join(self._task_path, "static", "inference")
paddle.jit.save(static_model, save_path)
logger.info("The inference model save in the path:{}".format(save_path))
def _check_input_text(self, inputs):
inputs = inputs[0]
if isinstance(inputs, str):
if len(inputs) == 0:
raise ValueError(
"Invalid inputs, input text should not be empty text, please check your input.".
format(type(inputs)))
inputs = [inputs]
elif isinstance(inputs, list):
if not (isinstance(inputs[0], str) and len(inputs[0].strip()) > 0):
raise TypeError(
"Invalid inputs, input text should be list of str, and first element of list should not be empty text.".
format(type(inputs[0])))
else:
raise TypeError(
"Invalid inputs, input text should be str or list of str, but type of {} found!".
format(type(inputs)))
return inputs
def _auto_splitter(self, input_texts, max_text_len, split_sentence=False):
'''
Split the raw texts automatically for model inference.
Args:
input_texts (List[str]): input raw texts.
max_text_len (int): cutting length.
split_sentence (bool): If True, sentence-level split will be performed.
return:
short_input_texts (List[str]): the short input texts for model inference.
input_mapping (dict): mapping between raw text and short input texts.
'''
input_mapping = {}
short_input_texts = []
cnt_org = 0
cnt_short = 0
for text in input_texts:
if not split_sentence:
sens = [text]
else:
sens = cut_chinese_sent(text)
for sen in sens:
lens = len(sen)
if lens <= max_text_len:
short_input_texts.append(sen)
if cnt_org not in input_mapping.keys():
input_mapping[cnt_org] = [cnt_short]
else:
input_mapping[cnt_org].append(cnt_short)
cnt_short += 1
else:
temp_text_list = [
sen[i:i + max_text_len]
for i in range(0, lens, max_text_len)
]
short_input_texts.extend(temp_text_list)
short_idx = cnt_short
cnt_short += math.ceil(lens / max_text_len)
temp_text_id = [
short_idx + i for i in range(cnt_short - short_idx)
]
if cnt_org not in input_mapping.keys():
input_mapping[cnt_org] = temp_text_id
else:
input_mapping[cnt_org].extend(temp_text_id)
cnt_org += 1
return short_input_texts, input_mapping
def _auto_joiner(self, short_results, input_mapping, is_dict=False):
'''
Join the short results automatically and generate the final results to match with the user inputs.
Args:
short_results (List[dict] / List[List[str]] / List[str]): input raw texts.
input_mapping (dict): cutting length.
is_dict (bool): whether the element type is dict, default to False.
return:
short_input_texts (List[str]): the short input texts for model inference.
'''
concat_results = []
elem_type = {} if is_dict else []
for k, vs in input_mapping.items():
single_results = elem_type
for v in vs:
if len(single_results) == 0:
single_results = short_results[v]
elif isinstance(elem_type, list):
single_results.extend(short_results[v])
elif isinstance(elem_type, dict):
for sk in single_results.keys():
if isinstance(single_results[sk], str):
single_results[sk] += short_results[v][sk]
else:
single_results[sk].extend(short_results[v][sk])
else:
raise ValueError(
"Invalid element type, the type of results "
"for each element should be list of dict, "
"but {} received.".format(type(single_results)))
concat_results.append(single_results)
return concat_results
[文档] def help(self):
"""
Return the usage message of the current task.
"""
print("Examples:\n{}".format(self._usage))
def __call__(self, *args):
inputs = self._preprocess(*args)
outputs = self._run_model(inputs)
results = self._postprocess(outputs)
return results