paddlenlp.taskflow.pos_tagging 源代码

# coding:utf-8
# Copyright (c) 2021  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .lexical_analysis import LacTask

usage = r"""
           from paddlenlp import Taskflow

           pos = Taskflow("pos_tagging")
           pos("第十四届全运会在西安举办")
           '''
           [('第十四届', 'm'), ('全运会', 'nz'), ('在', 'p'), ('西安', 'LOC'), ('举办', 'v')]
           '''

           pos(["第十四届全运会在西安举办", "三亚是一个美丽的城市"])
           '''
           [[('第十四届', 'm'), ('全运会', 'nz'), ('在', 'p'), ('西安', 'LOC'), ('举办', 'v')], [('三亚', 'LOC'), ('是', 'v'), ('一个', 'm'), ('美丽', 'a'), ('的', 'u'), ('城市', 'n')]]
           '''
         """


[文档]class POSTaggingTask(LacTask): """ Part-of-speech tagging task for the raw text. Args: task(string): The name of task. model(string): The model name in the task. kwargs (dict, optional): Additional keyword arguments passed along to the specific task. """ def __init__(self, task, model, **kwargs): super().__init__(task=task, model=model, **kwargs) def _postprocess(self, inputs): """ The model output is the tag ids, this function will convert the model output to raw text. """ lengths = inputs["lens"] preds = inputs["result"] sents = inputs["text"] final_results = [] for sent_index in range(len(lengths)): tags = [self._id2tag_dict[str(index)] for index in preds[sent_index][: lengths[sent_index]]] sent = sents[sent_index] if self._custom: self._custom.parse_customization(sent, tags) sent_out = [] tags_out = [] parital_word = "" for ind, tag in enumerate(tags): if parital_word == "": parital_word = sent[ind] tags_out.append(tag.split("-")[0]) continue if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"): sent_out.append(parital_word) tags_out.append(tag.split("-")[0]) parital_word = sent[ind] continue parital_word += sent[ind] if len(sent_out) < len(tags_out): sent_out.append(parital_word) result = list(zip(sent_out, tags_out)) final_results.append(result) final_results = self._auto_joiner(final_results, self.input_mapping) final_results = final_results if len(final_results) > 1 else final_results[0] return final_results