# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2022 The Salesforce authors, The Open AI Team Authors and The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.utils import try_import
from .. import GPTTokenizer
__all__ = ["CodeGenTokenizer"]
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
[docs]class CodeGenTokenizer(GPTTokenizer):
resource_files_names = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
pretrained_resource_files_map = {"vocab_file": {}, "merges_file": {}}
pretrained_init_configuration = {}
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
max_len=None,
pad_token="<|endoftext|>",
eos_token="<|endoftext|>",
unk_token="<|endoftext|>",
eol_token="\u010a",
**kwargs
):
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
errors=errors,
max_len=max_len,
pad_token=pad_token,
eos_token=eos_token,
unk_token=unk_token,
eol_token=eol_token,
**kwargs,
)
[docs] def decode(
self,
token_ids,
skip_special_tokens=False,
clean_up_tokenization_spaces=True,
truncate_before_pattern=None,
**kwargs
):
"""
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
tokens and clean up tokenization spaces.
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
Args:
token_ids (`Union[int, List[int], np.ndarray, paddle.Tensor]`):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the tokenization spaces.
truncate_before_pattern (`List[str]`, *optional*, defaults to `None`):
A list of regular expression strings that will be used to truncate the returned string. This can be
used to remove extra pieces of code (e.g. truncate if observing a comment symbol "#" at the beginning
of a new line). An example pattern could be `["^#", re.escape("<|endoftext|>"), "^'''", "\n\n\n"]`.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns:
`str`: The decoded sentence.
"""
decoded_text = super()._decode(
token_ids=token_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
if truncate_before_pattern is not None and len(truncate_before_pattern) > 0:
decoded_text = self.truncate(decoded_text, truncate_before_pattern)
return decoded_text
def truncate(self, completion, truncate_before_pattern):
def find_re(string, pattern, start_pos):
m = pattern.search(string, start_pos)
return m.start() if m else -1
re = try_import("regex")
terminals = [re.compile(pattern, re.MULTILINE) for pattern in truncate_before_pattern]
prints = list(re.finditer("^print", completion, re.MULTILINE))
if len(prints) > 1:
completion = completion[: prints[1].start()]
defs = list(re.finditer("^def", completion, re.MULTILINE))
if len(defs) > 1:
completion = completion[: defs[1].start()]
start_pos = 0
terminals_pos = [
pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1
]
if len(terminals_pos) > 0:
return completion[: min(terminals_pos)]
else:
return completion