paddlenlp.taskflow.text_correction 源代码

# coding:utf-8
# Copyright (c) 2021  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import json
import math
import os
import copy
import itertools

import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ..transformers import ErnieTokenizer, ErnieModel
from ..transformers import is_chinese_char
from ..datasets import load_dataset
from ..data import Stack, Pad, Tuple, Vocab
from .utils import download_file, add_docstrings, static_mode_guard
from .models import ErnieForCSC
from .task import Task

usage = r"""
           from paddlenlp import Taskflow

           text_correction = Taskflow("text_correction")
           text_correction('遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。')
           '''
           [{'source': '遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。',
             'target': '遇到逆境时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。',
             'errors': [{'position': 3, 'correction': {'竟': '境'}}]}
           ]
           '''

           text_correction(['遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。',
                            '人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。'])
           '''
           [{'source': '遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。', 
             'target': '遇到逆境时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。', 
             'errors': [{'position': 3, 'correction': {'竟': '境'}}]}, 
            {'source': '人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。', 
             'target': '人生就是如此,经过磨练才能让自己更加茁壮,才能使自己更加乐观。', 
             'errors': [{'position': 18, 'correction': {'拙': '茁'}}]}
           ]
           '''

         """

TASK_MODEL_MAP = {"ernie-csc": "ernie-1.0"}


[文档]class CSCTask(Task): """ The text generation model to predict the question or chinese poetry. Args: task(string): The name of task. model(string): The model name in the task. kwargs (dict, optional): Additional keyword arguments passed along to the specific task. """ resource_files_names = { "model_state": "model_state.pdparams", "pinyin_vocab": "pinyin_vocab.txt" } resource_files_urls = { "ernie-csc": { "model_state": [ "https://bj.bcebos.com/paddlenlp/taskflow/text_correction/ernie-csc/model_state.pdparams", "cdc53e7e3985ffc78fedcdf8e6dca6d2" ], "pinyin_vocab": [ "https://bj.bcebos.com/paddlenlp/taskflow/text_correction/ernie-csc/pinyin_vocab.txt", "5599a8116b6016af573d08f8e686b4b2" ], } } def __init__(self, task, model, **kwargs): super().__init__(task=task, model=model, **kwargs) self._usage = usage self._check_task_files() self._construct_vocabs() self._get_inference_model() self._construct_tokenizer(model) try: import pypinyin except: raise ImportError( "Please install the dependencies first, pip install pypinyin --upgrade" ) self._pypinyin = pypinyin self._batchify_fn = lambda samples, fn=Tuple( Pad(axis=0, pad_val=self._tokenizer.pad_token_id, dtype='int64'), # input Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id, dtype='int64'), # segment Pad(axis=0, pad_val=self._pinyin_vocab.token_to_idx[self._pinyin_vocab.pad_token], dtype='int64'), # pinyin Stack(axis=0, dtype='int64'), # length ): [data for data in fn(samples)] self._num_workers = self.kwargs[ 'num_workers'] if 'num_workers' in self.kwargs else 0 self._batch_size = self.kwargs[ 'batch_size'] if 'batch_size' in self.kwargs else 1 self._lazy_load = self.kwargs[ 'lazy_load'] if 'lazy_load' in self.kwargs else False self._max_seq_len = self.kwargs[ 'max_seq_len'] if 'max_seq_len' in self.kwargs else 128 self._split_sentence = self.kwargs[ 'split_sentence'] if 'split_sentence' in self.kwargs else False def _construct_input_spec(self): """ Construct the input spec for the convert dygraph model to static model. """ self._input_spec = [ paddle.static.InputSpec( shape=[None, None], dtype="int64", name='input_ids'), paddle.static.InputSpec( shape=[None, None], dtype="int64", name='pinyin_ids'), ] def _construct_vocabs(self): pinyin_vocab_path = os.path.join(self._task_path, "pinyin_vocab.txt") self._pinyin_vocab = Vocab.load_vocabulary( pinyin_vocab_path, unk_token='[UNK]', pad_token='[PAD]') def _construct_model(self, model): """ Construct the inference model for the predictor. """ ernie = ErnieModel.from_pretrained(TASK_MODEL_MAP[model]) model_instance = ErnieForCSC( ernie, pinyin_vocab_size=len(self._pinyin_vocab), pad_pinyin_id=self._pinyin_vocab[self._pinyin_vocab.pad_token]) # Load the model parameter for the predict model_path = os.path.join(self._task_path, "model_state.pdparams") state_dict = paddle.load(model_path) model_instance.set_state_dict(state_dict) self._model = model_instance self._model.eval() def _construct_tokenizer(self, model): """ Construct the tokenizer for the predictor. """ self._tokenizer = ErnieTokenizer.from_pretrained(TASK_MODEL_MAP[model]) def _preprocess(self, inputs, padding=True, add_special_tokens=True): input_texts = self._check_input_text(inputs) examples = [] texts = [] max_predict_len = self._max_seq_len - 2 short_input_texts, self.input_mapping = self._auto_splitter( input_texts, max_predict_len, split_sentence=self._split_sentence) for text in short_input_texts: if not (isinstance(text, str) and len(text) > 0): continue example = {"source": text.strip()} input_ids, token_type_ids, pinyin_ids, length = self._convert_example( example) examples.append((input_ids, token_type_ids, pinyin_ids, length)) texts.append(example["source"]) batch_examples = [ examples[idx:idx + self._batch_size] for idx in range(0, len(examples), self._batch_size) ] batch_texts = [ short_input_texts[idx:idx + self._batch_size] for idx in range(0, len(examples), self._batch_size) ] outputs = {} outputs['batch_examples'] = batch_examples outputs['batch_texts'] = batch_texts return outputs def _run_model(self, inputs): """ Run the task model from the outputs of the `_tokenize` function. """ results = [] with static_mode_guard(): for examples in inputs['batch_examples']: token_ids, token_type_ids, pinyin_ids, lengths = self._batchify_fn( examples) self.input_handles[0].copy_from_cpu(token_ids) self.input_handles[1].copy_from_cpu(pinyin_ids) self.predictor.run() det_preds = self.output_handle[0].copy_to_cpu() char_preds = self.output_handle[1].copy_to_cpu() batch_result = [] for i in range(len(lengths)): batch_result.append( (det_preds[i], char_preds[i], lengths[i])) results.append(batch_result) inputs['batch_results'] = results return inputs def _postprocess(self, inputs): """ The model output is the logits and probs, this function will convert the model output to raw text. """ results = [] for examples, texts, temp_results in zip(inputs['batch_examples'], inputs['batch_texts'], inputs['batch_results']): for i in range(len(examples)): result = {} det_pred, char_preds, length = temp_results[i] pred_result = self._parse_decode(texts[i], char_preds, det_pred, length) result['source'] = texts[i] result['target'] = ''.join(pred_result) results.append(result) results = self._auto_joiner(results, self.input_mapping, is_dict=True) for result in results: errors_result = [] for i, (source_token, target_token ) in enumerate(zip(result['source'], result['target'])): if source_token != target_token: errors_result.append({ 'position': i, 'correction': { source_token: target_token } }) result['errors'] = errors_result return results def _convert_example(self, example): source = example["source"] words = list(source) length = len(words) words = ['[CLS]'] + words + ['[SEP]'] input_ids = self._tokenizer.convert_tokens_to_ids(words) token_type_ids = [0] * len(input_ids) # Use pad token in pinyin emb to map word emb [CLS], [SEP] pinyins = self._pypinyin.lazy_pinyin( source, style=self._pypinyin.Style.TONE3, neutral_tone_with_five=True) pinyin_ids = [0] # Align pinyin and chinese char pinyin_offset = 0 for i, word in enumerate(words[1:-1]): pinyin = '[UNK]' if word != '[PAD]' else '[PAD]' if len(word) == 1 and is_chinese_char(ord(word)): while pinyin_offset < len(pinyins): current_pinyin = pinyins[pinyin_offset][:-1] pinyin_offset += 1 if current_pinyin in self._pinyin_vocab: pinyin = current_pinyin break pinyin_ids.append(self._pinyin_vocab[pinyin]) pinyin_ids.append(0) assert len(input_ids) == len( pinyin_ids ), "length of input_ids must be equal to length of pinyin_ids" return input_ids, token_type_ids, pinyin_ids, length def _parse_decode(self, words, corr_preds, det_preds, lengths): UNK = self._tokenizer.unk_token UNK_id = self._tokenizer.convert_tokens_to_ids(UNK) corr_pred = corr_preds[1:1 + lengths].tolist() det_pred = det_preds[1:1 + lengths].tolist() words = list(words) rest_words = [] max_seq_length = self._max_seq_len - 2 if len(words) > max_seq_length: rest_words = words[max_seq_length:] words = words[:max_seq_length] pred_result = "" for j, word in enumerate(words): candidates = self._tokenizer.convert_ids_to_tokens(corr_pred[ j] if corr_pred[j] < self._tokenizer.vocab_size else UNK_id) word_icc = is_chinese_char(ord(word)) cand_icc = is_chinese_char(ord(candidates)) if len( candidates) == 1 else False if not word_icc or det_pred[j] == 0\ or candidates in [UNK, '[PAD]']\ or (word_icc and not cand_icc): pred_result += word else: pred_result += candidates.lstrip("##") pred_result += ''.join(rest_words) return pred_result