paddlenlp.taskflow.text_similarity 源代码

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from paddlenlp.transformers import AutoModel, AutoTokenizer

from ..data import Pad, Tuple
from ..transformers import ErnieCrossEncoder, ErnieTokenizer
from ..utils.log import logger
from .task import Task
from .utils import static_mode_guard

usage = r"""
         from paddlenlp import Taskflow

         similarity = Taskflow("text_similarity")
         similarity([["世界上什么东西最小", "世界上什么东西最小?"]])
         '''
         [{'text1': '世界上什么东西最小', 'text2': '世界上什么东西最小?', 'similarity': 0.992725}]
         '''

         similarity = Taskflow("text_similarity", batch_size=2)
         similarity([["光眼睛大就好看吗", "眼睛好看吗?"], ["小蝌蚪找妈妈怎么样", "小蝌蚪找妈妈是谁画的"]])
         '''
         [{'text1': '光眼睛大就好看吗', 'text2': '眼睛好看吗?', 'similarity': 0.74502707}, {'text1': '小蝌蚪找妈妈怎么样', 'text2': '小蝌蚪找妈妈是谁画的', 'similarity': 0.8192149}]
         '''
         """
MATCH_TYPE = {
    "rocketqa-zh-dureader-cross-encoder": "matching",
    "rocketqa-base-cross-encoder": "matching",
    "rocketqa-medium-cross-encoder": "matching",
    "rocketqa-mini-cross-encoder": "matching",
    "rocketqa-micro-cross-encoder": "matching",
    "rocketqa-nano-cross-encoder": "matching",
    "rocketqav2-en-marco-cross-encoder": "matching_v2",
    "ernie-search-large-cross-encoder-marco-en": "matching_v3",
}


[文档]class TextSimilarityTask(Task): """ Text similarity task using SimBERT to predict the similarity of sentence pair. Args: task(string): The name of task. model(string): The model name in the task. kwargs (dict, optional): Additional keyword arguments passed along to the specific task. """ resource_files_names = { "model_state": "model_state.pdparams", "model_config": "model_config.json", } resource_files_urls = { "simbert-base-chinese": { "model_state": [ "https://bj.bcebos.com/paddlenlp/taskflow/text_similarity/simbert-base-chinese/model_state.pdparams", "27d9ef240c2e8e736bdfefea52af2542", ], "model_config": [ "https://bj.bcebos.com/paddlenlp/taskflow/text_similarity/simbert-base-chinese/model_config.json", "1254bbd7598457a9dad0afcb2e24b70c", ], }, "rocketqa-zh-dureader-cross-encoder": { "model_state": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqa-zh-dureader-cross-encoder/model_state.pdparams", "88bc3e1a64992a1bdfe4044ecba13bc7", ], "model_config": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqa-zh-dureader-cross-encoder/model_config.json", "b69083c2895e8f68e1a10467b384daab", ], }, "rocketqa-base-cross-encoder": { "model_state": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqa-base-cross-encoder/model_state.pdparams", "6d845a492a2695e62f2be79f8017be92", ], "model_config": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqa-base-cross-encoder/model_config.json", "18ce260ede18bc3cb28dcb2e7df23b1a", ], }, "rocketqa-medium-cross-encoder": { "model_state": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqa-medium-cross-encoder/model_state.pdparams", "4b929f4fc11a1df8f59fdf2784e23fa7", ], "model_config": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqa-medium-cross-encoder/model_config.json", "10997db96bc86e29cd113e1bf58989d7", ], }, "rocketqa-mini-cross-encoder": { "model_state": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqa-mini-cross-encoder/model_state.pdparams", "c411111df990132fb88c070d8b8cf3f7", ], "model_config": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqa-mini-cross-encoder/model_config.json", "271e6d779acbe8e8acdd596b1c835546", ], }, "rocketqa-micro-cross-encoder": { "model_state": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqa-micro-cross-encoder/model_state.pdparams", "3d643ff7d6029c8ceab5653680167dc0", ], "model_config": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqa-micro-cross-encoder/model_config.json", "b32d1a932d8c367fab2a6216459dd0a7", ], }, "rocketqa-nano-cross-encoder": { "model_state": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqa-nano-cross-encoder/model_state.pdparams", "4c1d36e5e94f5af09f665fc7ad0be140", ], "model_config": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqa-nano-cross-encoder/model_config.json", "dcff14cd671e1064be2c5d63734098bb", ], }, "rocketqav2-en-marco-cross-encoder": { "model_state": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqav2-en-marco-cross-encoder/model_state.pdparams", "a5afc77b6a63fc32a1beca3010f40f32", ], "model_config": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/rocketqav2-en-marco-cross-encoder/config.json", "8f5d5c71c8a891b68d0402a13e38b6f9", ], }, "ernie-search-large-cross-encoder-marco-en": { "model_state": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/ernie-search-large-cross-encoder-marco-en/model_state.pdparams", "fdf29f7de0f7fe570740d343c96165e5", ], "model_config": [ "https://paddlenlp.bj.bcebos.com/taskflow/text_similarity/ernie-search-large-cross-encoder-marco-en/config.json", "28bad2c7b36fa148fa75a8dc5b690485", ], }, "__internal_testing__/tiny-random-bert": { "model_state": [ "https://bj.bcebos.com/paddlenlp/models/community/__internal_testing__/tiny-random-bert/model_state.pdparams", "8d8814d589c21bf083fdb35de6c11a57", ], "model_config": [ "https://bj.bcebos.com/paddlenlp/models/community/__internal_testing__/tiny-random-bert/config.json", "37e28e2359f330f64fc82beff1967a1e", ], }, } def __init__(self, task, model, batch_size=1, max_length=384, **kwargs): super().__init__(task=task, model=model, **kwargs) self._static_mode = True self._check_predictor_type() if not self.from_hf_hub: self._check_task_files() if self._static_mode: self._get_inference_model() else: self._construct_model(model) self._construct_tokenizer(model) self._batch_size = batch_size self._max_length = max_length self._usage = usage self.model_name = model def _construct_input_spec(self): """ Construct the input spec for the convert dygraph model to static model. """ self._input_spec = [ paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"), paddle.static.InputSpec(shape=[None, None], dtype="int64", name="token_type_ids"), ] def _construct_model(self, model): """ Construct the inference model for the predictor. """ if "rocketqav2-en" in model or "ernie-search" in model: self._model = ErnieCrossEncoder(self._task_path, num_classes=1, reinitialize=True) elif "rocketqa" in model: self._model = ErnieCrossEncoder(self._task_path, num_classes=2) else: self._model = AutoModel.from_pretrained(self._task_path, pool_act="linear") self._model.eval() def _construct_tokenizer(self, model): """ Construct the tokenizer for the predictor. """ if "rocketqa" in model or "ernie-search" in model: self._tokenizer = ErnieTokenizer.from_pretrained(model) else: self._tokenizer = AutoTokenizer.from_pretrained(model) def _check_input_text(self, inputs): inputs = inputs[0] if not all([isinstance(i, list) and i and all(i) and len(i) == 2 for i in inputs]): raise TypeError("Invalid input format.") return inputs def _preprocess(self, inputs): """ Transform the raw text to the model inputs, two steps involved: 1) Transform the raw text to token ids. 2) Generate the other model inputs from the raw text and token ids. """ inputs = self._check_input_text(inputs) examples = [] for data in inputs: text1, text2 = data[0], data[1] if "rocketqa" in self.model_name or "ernie-search" in self.model_name: # Todo: wugaosheng, Add erine-search encoding support encoded_inputs = self._tokenizer(text=text1, text_pair=text2, max_length=self._max_length) ids = encoded_inputs["input_ids"] segment_ids = encoded_inputs["token_type_ids"] examples.append((ids, segment_ids)) else: text1_encoded_inputs = self._tokenizer(text=text1, max_length=self._max_length) text1_input_ids = text1_encoded_inputs["input_ids"] text1_token_type_ids = text1_encoded_inputs["token_type_ids"] text2_encoded_inputs = self._tokenizer(text=text2, max_length=self._max_length) text2_input_ids = text2_encoded_inputs["input_ids"] text2_token_type_ids = text2_encoded_inputs["token_type_ids"] examples.append((text1_input_ids, text1_token_type_ids, text2_input_ids, text2_token_type_ids)) batches = [examples[idx : idx + self._batch_size] for idx in range(0, len(examples), self._batch_size)] if "rocketqa" in self.model_name or "ernie-search" in self.model_name: batchify_fn = lambda samples, fn=Tuple( # noqa: E731 Pad(axis=0, pad_val=self._tokenizer.pad_token_id, dtype="int64"), # input ids Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id, dtype="int64"), # token type ids ): [data for data in fn(samples)] else: batchify_fn = lambda samples, fn=Tuple( # noqa: E731 Pad(axis=0, pad_val=self._tokenizer.pad_token_id, dtype="int64"), # text1_input_ids Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id, dtype="int64"), # text1_token_type_ids Pad(axis=0, pad_val=self._tokenizer.pad_token_id, dtype="int64"), # text2_input_ids Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id, dtype="int64"), # text2_token_type_ids ): [data for data in fn(samples)] outputs = {} outputs["data_loader"] = batches outputs["text"] = inputs self._batchify_fn = batchify_fn return outputs def _run_model(self, inputs): """ Run the task model from the outputs of the `_tokenize` function. """ results = [] if "rocketqa" in self.model_name or "ernie-search" in self.model_name: with static_mode_guard(): for batch in inputs["data_loader"]: if self._predictor_type == "paddle-inference": input_ids, segment_ids = self._batchify_fn(batch) self.input_handles[0].copy_from_cpu(input_ids) self.input_handles[1].copy_from_cpu(segment_ids) self.predictor.run() scores = self.output_handle[0].copy_to_cpu().tolist() results.extend(scores) else: # onnx mode input_dict = {} input_ids, segment_ids = self._batchify_fn(batch) input_dict["input_ids"] = input_ids input_dict["token_type_ids"] = segment_ids scores = self.predictor.run(None, input_dict)[0].tolist() results.extend(scores) else: with static_mode_guard(): for batch in inputs["data_loader"]: text1_ids, text1_segment_ids, text2_ids, text2_segment_ids = self._batchify_fn(batch) self.input_handles[0].copy_from_cpu(text1_ids) self.input_handles[1].copy_from_cpu(text1_segment_ids) self.predictor.run() vecs_text1 = self.output_handle[1].copy_to_cpu() self.input_handles[0].copy_from_cpu(text2_ids) self.input_handles[1].copy_from_cpu(text2_segment_ids) self.predictor.run() vecs_text2 = self.output_handle[1].copy_to_cpu() vecs_text1 = vecs_text1 / (vecs_text1**2).sum(axis=1, keepdims=True) ** 0.5 vecs_text2 = vecs_text2 / (vecs_text2**2).sum(axis=1, keepdims=True) ** 0.5 similarity = (vecs_text1 * vecs_text2).sum(axis=1) results.extend(similarity) inputs["result"] = results return inputs def _postprocess(self, inputs): """ The model output is tag ids, this function will convert the model output to raw text. """ final_results = [] for text, similarity in zip(inputs["text"], inputs["result"]): result = {} result["text1"] = text[0] result["text2"] = text[1] # The numpy.float32 can not be converted to the json format if isinstance(similarity, list): result["similarity"] = float(similarity[0]) else: result["similarity"] = float(similarity) final_results.append(result) return final_results def _convert_dygraph_to_static(self): """ Convert the dygraph model to static model. """ assert ( self._model is not None ), "The dygraph model must be created before converting the dygraph model to static model." assert ( self._input_spec is not None ), "The input spec must be created before converting the dygraph model to static model." logger.info("Converting to the inference model cost a little time.") if self.model in MATCH_TYPE: if MATCH_TYPE[self.model] == "matching": static_model = paddle.jit.to_static(self._model.matching, input_spec=self._input_spec) elif MATCH_TYPE[self.model] == "matching_v2": static_model = paddle.jit.to_static(self._model.matching_v2, input_spec=self._input_spec) elif MATCH_TYPE[self.model] == "matching_v3": static_model = paddle.jit.to_static(self._model.matching_v3, input_spec=self._input_spec) else: static_model = paddle.jit.to_static(self._model, input_spec=self._input_spec) paddle.jit.save(static_model, self.inference_model_path) logger.info("The inference model save in the path:{}".format(self.inference_model_path))