# coding:utf-8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
from ..data import Pad, Stack, Tuple, Vocab
from ..transformers import ErnieModel, ErnieTokenizer, is_chinese_char
from .models import ErnieForCSC
from .task import Task
from .utils import static_mode_guard
usage = r"""
from paddlenlp import Taskflow
text_correction = Taskflow("text_correction")
text_correction('遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。')
'''
[{'source': '遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。',
'target': '遇到逆境时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。',
'errors': [{'position': 3, 'correction': {'竟': '境'}}]}
]
'''
text_correction(['遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。',
'人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。'])
'''
[{'source': '遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。',
'target': '遇到逆境时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。',
'errors': [{'position': 3, 'correction': {'竟': '境'}}]},
{'source': '人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。',
'target': '人生就是如此,经过磨练才能让自己更加茁壮,才能使自己更加乐观。',
'errors': [{'position': 18, 'correction': {'拙': '茁'}}]}
]
'''
"""
TASK_MODEL_MAP = {"ernie-csc": "ernie-1.0"}
[文档]
class CSCTask(Task):
"""
The text generation model to predict the question or chinese poetry.
Args:
task(string): The name of task.
model(string): The model name in the task.
kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
"""
resource_files_names = {"model_state": "model_state.pdparams", "pinyin_vocab": "pinyin_vocab.txt"}
resource_files_urls = {
"ernie-csc": {
"model_state": [
"https://bj.bcebos.com/paddlenlp/taskflow/text_correction/ernie-csc/model_state.pdparams",
"cdc53e7e3985ffc78fedcdf8e6dca6d2",
],
"pinyin_vocab": [
"https://bj.bcebos.com/paddlenlp/taskflow/text_correction/ernie-csc/pinyin_vocab.txt",
"5599a8116b6016af573d08f8e686b4b2",
],
}
}
def __init__(self, task, model, **kwargs):
super().__init__(task=task, model=model, **kwargs)
self._usage = usage
self._check_task_files()
self._construct_vocabs()
self._get_inference_model()
self._construct_tokenizer(model)
try:
import pypinyin
except ImportError:
raise ImportError("Please install the dependencies first, pip install pypinyin --upgrade")
self._pypinyin = pypinyin
self._batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=self._tokenizer.pad_token_id, dtype="int64"), # input
Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id, dtype="int64"), # segment
Pad(
axis=0, pad_val=self._pinyin_vocab.token_to_idx[self._pinyin_vocab.pad_token], dtype="int64"
), # pinyin
Stack(axis=0, dtype="int64"), # length
): [data for data in fn(samples)]
self._num_workers = self.kwargs["num_workers"] if "num_workers" in self.kwargs else 0
self._batch_size = self.kwargs["batch_size"] if "batch_size" in self.kwargs else 1
self._lazy_load = self.kwargs["lazy_load"] if "lazy_load" in self.kwargs else False
self._max_seq_len = self.kwargs["max_seq_len"] if "max_seq_len" in self.kwargs else 128
self._split_sentence = self.kwargs["split_sentence"] if "split_sentence" in self.kwargs else False
def _construct_input_spec(self):
"""
Construct the input spec for the convert dygraph model to static model.
"""
self._input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"),
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pinyin_ids"),
]
def _construct_vocabs(self):
pinyin_vocab_path = os.path.join(self._task_path, "pinyin_vocab.txt")
self._pinyin_vocab = Vocab.load_vocabulary(pinyin_vocab_path, unk_token="[UNK]", pad_token="[PAD]")
def _construct_model(self, model):
"""
Construct the inference model for the predictor.
"""
ernie = ErnieModel.from_pretrained(TASK_MODEL_MAP[model])
model_instance = ErnieForCSC(
ernie,
pinyin_vocab_size=len(self._pinyin_vocab),
pad_pinyin_id=self._pinyin_vocab[self._pinyin_vocab.pad_token],
)
# Load the model parameter for the predict
model_path = os.path.join(self._task_path, "model_state.pdparams")
state_dict = paddle.load(model_path)
model_instance.set_state_dict(state_dict)
self._model = model_instance
self._model.eval()
def _construct_tokenizer(self, model):
"""
Construct the tokenizer for the predictor.
"""
self._tokenizer = ErnieTokenizer.from_pretrained(TASK_MODEL_MAP[model])
def _preprocess(self, inputs, padding=True, add_special_tokens=True):
input_texts = self._check_input_text(inputs)
examples = []
texts = []
max_predict_len = self._max_seq_len - 2
short_input_texts, self.input_mapping = self._auto_splitter(
input_texts, max_predict_len, split_sentence=self._split_sentence
)
for text in short_input_texts:
if not (isinstance(text, str) and len(text) > 0):
continue
example = {"source": text.strip()}
input_ids, token_type_ids, pinyin_ids, length = self._convert_example(example)
examples.append((input_ids, token_type_ids, pinyin_ids, length))
texts.append(example["source"])
batch_examples = [examples[idx : idx + self._batch_size] for idx in range(0, len(examples), self._batch_size)]
batch_texts = [
short_input_texts[idx : idx + self._batch_size] for idx in range(0, len(examples), self._batch_size)
]
outputs = {}
outputs["batch_examples"] = batch_examples
outputs["batch_texts"] = batch_texts
return outputs
def _run_model(self, inputs):
"""
Run the task model from the outputs of the `_tokenize` function.
"""
results = []
with static_mode_guard():
for examples in inputs["batch_examples"]:
token_ids, token_type_ids, pinyin_ids, lengths = self._batchify_fn(examples)
self.input_handles[0].copy_from_cpu(token_ids)
self.input_handles[1].copy_from_cpu(pinyin_ids)
self.predictor.run()
det_preds = self.output_handle[0].copy_to_cpu()
char_preds = self.output_handle[1].copy_to_cpu()
batch_result = []
for i in range(len(lengths)):
batch_result.append((det_preds[i], char_preds[i], lengths[i]))
results.append(batch_result)
inputs["batch_results"] = results
return inputs
def _postprocess(self, inputs):
"""
The model output is the logits and probs, this function will convert the model output to raw text.
"""
results = []
for examples, texts, temp_results in zip(
inputs["batch_examples"], inputs["batch_texts"], inputs["batch_results"]
):
for i in range(len(examples)):
result = {}
det_pred, char_preds, length = temp_results[i]
pred_result = self._parse_decode(texts[i], char_preds, det_pred, length)
result["source"] = texts[i]
result["target"] = "".join(pred_result)
results.append(result)
results = self._auto_joiner(results, self.input_mapping, is_dict=True)
for result in results:
errors_result = []
for i, (source_token, target_token) in enumerate(zip(result["source"], result["target"])):
if source_token != target_token:
errors_result.append({"position": i, "correction": {source_token: target_token}})
result["errors"] = errors_result
return results
def _convert_example(self, example):
source = example["source"]
words = list(source)
length = len(words)
words = ["[CLS]"] + words + ["[SEP]"]
input_ids = self._tokenizer.convert_tokens_to_ids(words)
token_type_ids = [0] * len(input_ids)
# Use pad token in pinyin emb to map word emb [CLS], [SEP]
pinyins = self._pypinyin.lazy_pinyin(source, style=self._pypinyin.Style.TONE3, neutral_tone_with_five=True)
pinyin_ids = [0]
# Align pinyin and chinese char
pinyin_offset = 0
for i, word in enumerate(words[1:-1]):
pinyin = "[UNK]" if word != "[PAD]" else "[PAD]"
if len(word) == 1 and is_chinese_char(ord(word)):
while pinyin_offset < len(pinyins):
current_pinyin = pinyins[pinyin_offset][:-1]
pinyin_offset += 1
if current_pinyin in self._pinyin_vocab:
pinyin = current_pinyin
break
pinyin_ids.append(self._pinyin_vocab[pinyin])
pinyin_ids.append(0)
assert len(input_ids) == len(pinyin_ids), "length of input_ids must be equal to length of pinyin_ids"
return input_ids, token_type_ids, pinyin_ids, length
def _parse_decode(self, words, corr_preds, det_preds, lengths):
UNK = self._tokenizer.unk_token
UNK_id = self._tokenizer.convert_tokens_to_ids(UNK)
corr_pred = corr_preds[1 : 1 + lengths].tolist()
det_pred = det_preds[1 : 1 + lengths].tolist()
words = list(words)
rest_words = []
max_seq_length = self._max_seq_len - 2
if len(words) > max_seq_length:
rest_words = words[max_seq_length:]
words = words[:max_seq_length]
pred_result = ""
for j, word in enumerate(words):
candidates = self._tokenizer.convert_ids_to_tokens(
corr_pred[j] if corr_pred[j] < self._tokenizer.vocab_size else UNK_id
)
word_icc = is_chinese_char(ord(word))
cand_icc = is_chinese_char(ord(candidates)) if len(candidates) == 1 else False
if not word_icc or det_pred[j] == 0 or candidates in [UNK, "[PAD]"] or (word_icc and not cand_icc):
pred_result += word
else:
pred_result += candidates.lstrip("##")
pred_result += "".join(rest_words)
return pred_result