Source code for paddlenlp.transformers.codegen.tokenizer

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2022 The Salesforce authors, The Open AI Team Authors and The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.utils import try_import
from .. import GPTTokenizer

__all__ = ["CodeGenTokenizer"]

VOCAB_FILES_NAMES = {
    "vocab_file": "vocab.json",
    "merges_file": "merges.txt",
}


[docs]class CodeGenTokenizer(GPTTokenizer): resource_files_names = {"vocab_file": "vocab.json", "merges_file": "merges.txt"} pretrained_resource_files_map = {"vocab_file": {}, "merges_file": {}} pretrained_init_configuration = {} def __init__( self, vocab_file, merges_file, errors="replace", max_len=None, pad_token="<|endoftext|>", eos_token="<|endoftext|>", unk_token="<|endoftext|>", eol_token="\u010a", **kwargs ): super().__init__( vocab_file=vocab_file, merges_file=merges_file, errors=errors, max_len=max_len, pad_token=pad_token, eos_token=eos_token, unk_token=unk_token, eol_token=eol_token, **kwargs, )
[docs] def decode( self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True, truncate_before_pattern=None, **kwargs ): """ Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special tokens and clean up tokenization spaces. Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. Args: token_ids (`Union[int, List[int], np.ndarray, paddle.Tensor]`): List of tokenized input ids. Can be obtained using the `__call__` method. skip_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not to remove special tokens in the decoding. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): Whether or not to clean up the tokenization spaces. truncate_before_pattern (`List[str]`, *optional*, defaults to `None`): A list of regular expression strings that will be used to truncate the returned string. This can be used to remove extra pieces of code (e.g. truncate if observing a comment symbol "#" at the beginning of a new line). An example pattern could be `["^#", re.escape("<|endoftext|>"), "^'''", "\n\n\n"]`. kwargs (additional keyword arguments, *optional*): Will be passed to the underlying model specific decode method. Returns: `str`: The decoded sentence. """ decoded_text = super()._decode( token_ids=token_ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs, ) if truncate_before_pattern is not None and len(truncate_before_pattern) > 0: decoded_text = self.truncate(decoded_text, truncate_before_pattern) return decoded_text
def truncate(self, completion, truncate_before_pattern): def find_re(string, pattern, start_pos): m = pattern.search(string, start_pos) return m.start() if m else -1 re = try_import("regex") terminals = [re.compile(pattern, re.MULTILINE) for pattern in truncate_before_pattern] prints = list(re.finditer("^print", completion, re.MULTILINE)) if len(prints) > 1: completion = completion[: prints[1].start()] defs = list(re.finditer("^def", completion, re.MULTILINE)) if len(defs) > 1: completion = completion[: defs[1].start()] start_pos = 0 terminals_pos = [ pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1 ] if len(terminals_pos) > 0: return completion[: min(terminals_pos)] else: return completion