```# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Team Authors And The HuggingFace Inc. team and & DALL·E Mini team.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import html
import math
import random
import re
from pathlib import Path

__all__ = ["DalleBartTokenizer"]

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"dalle-mini": 64,
"dalle-mega-v16": 64,
"dalle-mega-v26": 64,
"dalle-mega": 64,
}

# based on wiki word occurrence
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
temp_token = "xtokx"  # avoid repeating chars

class HashtagProcessor:
# We use our wikipedia word count + a good heuristic to make it work
def __init__(self, wiki_word_frequency):
self._word_cost = (l.split()[0] for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines())
self._word_cost = {str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)}
self._max_word = max(len(x) for x in self._word_cost.keys())
self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")

def __call__(self, s):
"""Uses dynamic programming to infer the location of spaces in a string without spaces."""
l = [self._split(x) for x in self._SPLIT_RE.split(s)]
return " ".join([item for sublist in l for item in sublist])

def _split(self, s):
# Find the best match for the i first characters, assuming cost has
# been built for the i-1 first characters.
# Returns a pair (match_cost, match_length).
def best_match(i):
candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
return min((c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1) for k, c in candidates)

# Build the cost array
cost = [0]
for i in range(1, len(s) + 1):
c, k = best_match(i)
cost.append(c)

# Backtrack to recover the minimal-cost string.
out = []
i = len(s)
while i > 0:
c, k = best_match(i)
assert c == cost[i]
newToken = True
if not s[i - k : i] == "'":  # ignore a lone apostrophe
if len(out) > 0:
# re-attach split 's and split digits
if out[-1] == "'s" or (s[i - 1].isdigit() and out[-1][0].isdigit()):  # digit followed by digit
out[-1] = s[i - k : i] + out[-1]  # combine current token with previous token
newToken = False

if newToken:
out.append(s[i - k : i])

i -= k

return reversed(out)

def replace_person_token(t):
"Used for CC12M"
t = re.sub(r"<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
while "<person>" in t:
t = t.replace("<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1)
return t

def fix_html(t):
# from OpenAI CLIP
return html.unescape(html.unescape(t))

def replace_punctuation_with_commas(t):
return re.sub(r"[()[\].,|:;?!=+~\-\/{}]", ",", t)

def simplify_quotes(t):
return re.sub("""['"`]""", ' " ', t)

def merge_quotes(t):
return re.sub(r'(\s*"+\s*)+', ' " ', t)

def remove_comma_numbers(t):
def _f(t):
return re.sub(r"(\d),(\d{3})", r"\1\2", t)

return _f(_f(t))

def pre_process_dot_numbers(t):
return re.sub(r"(\w)\.(\w)", rf"\1{temp_token}dot{temp_token}\2", t)

def post_process_dot_numbers(t):
return re.sub(f"{temp_token}dot{temp_token}", ".", t)

def pre_process_quotes(t):
# allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
return re.sub(r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", rf"{temp_token}quote{temp_token}", t)

def post_process_quotes(t):
return re.sub(f"{temp_token}quote{temp_token}", "'", t)

def pre_process_dates(t):
return re.sub(r"(\d)/(\d)", rf"\1{temp_token}slash{temp_token}\2", t)

def post_process_dates(t):
return re.sub(f"{temp_token}slash{temp_token}", "/", t)

def merge_commas(t):
return re.sub(r"(\s*,+\s*)+", ", ", t)

return re.sub(",", ", ", t)

def handle_special_chars(t):
"Handle special characters"
# replace "-" with a space when between words without space
t = re.sub(r"(\w)-(\w)", r"\1 \2", t)
# always add space around some characters
return re.sub(r"([%&\/\$*])", r" \1 ", t)

def expand_hashtags(t, hashtag_processor):
"Remove # and try to split words"
return re.sub(r"#(\w+)", lambda m: hashtag_processor(m.group(1)), t)

_re_ignore_chars = r"[_#\\]"

def ignore_chars(t):
"Ignore useless characters"
return re.sub(_re_ignore_chars, " ", t)

def remove_extra_spaces(t):
"Remove extra spaces (including \t and \n)"
return re.sub(r"\s+", " ", t)

def remove_repeating_chars(t):
"If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
return re.sub(r"(\D)(\1{3,})", r"\1", t)

def remove_urls(t):
return re.sub(r"http\S+", "", t)

def remove_html_tags(t):
return re.sub("<[^<]+?>", " ", t)

def remove_first_last_commas(t):
t = t.strip()
t = t[:-1] if t and t[-1] == "," else t
t = t[1:] if t and t[0] == "," else t
return t.strip()

def remove_wiki_ref(t):
t = re.sub(r"\A\s*\[\d+\]", "", t)
return re.sub(r"\[\d+\]\s*\Z", "", t)

class TextNormalizer:
def __init__(self, wiki_word_frequency_file):
self._hashtag_processor = HashtagProcessor(wiki_word_frequency_file)
self.emoji = try_import("emoji")
self.ftfy = try_import("ftfy")
self.unidecode = try_import("unidecode")

def __call__(self, t):
# fix some characters
t = self.ftfy.fix_text(t)
# fix html
t = fix_html(t)
# decode emojis (would be removed by unidecode)
t = self.emoji.demojize(t)
# decode and simplify text: see unidecode library
t = self.unidecode.unidecode(t)
# lower case
t = t.lower()
# replace <PERSON> (for CC12M)
t = replace_person_token(t)
# remove wiki reference (for WIT)
t = remove_wiki_ref(t)
# remove html tags
t = remove_html_tags(t)
# remove urls
t = remove_urls(t)
# remove commas in numbers
t = remove_comma_numbers(t)
# handle dots in numbers and quotes - Part 1
t = pre_process_dot_numbers(t)
t = pre_process_quotes(t)
t = pre_process_dates(t)
# handle special characters
t = handle_special_chars(t)
# handle hashtags
t = expand_hashtags(t, self._hashtag_processor)
# ignore useless characters
t = ignore_chars(t)
# simplify quotes
t = simplify_quotes(t)
# all punctuation becomes commas
t = replace_punctuation_with_commas(t)
# handle dots in numbers and quotes - Part 2
t = post_process_dot_numbers(t)
t = post_process_quotes(t)
t = post_process_dates(t)
# handle repeating characters
t = remove_repeating_chars(t)
# merge quotes
t = merge_quotes(t)
# merge commas
t = merge_commas(t)
# remove multiple spaces
t = remove_extra_spaces(t)
# remove first and last comma
t = remove_first_last_commas(t)
return f" {t}"

[文档]class DalleBartTokenizer(GPTTokenizer):
r"""
Construct a DalleBart tokenizer based on byte-level Byte-Pair-Encoding.

Args:
vocab_file (str):
Path to the vocabulary file.
The vocab file contains a mapping from vocabulary strings to indices.
merges_file (str):
Path to the merge file.
The merge file is used to split the input sentence into "subword" units.
The vocab file is then used to encode those units as intices.
wiki_word_frequency_file (str):
Path to the wiki_word_frequency file when we need normlize text.
errors (str):
Defaults to `'replace'`.
max_len (int, optional):
The maximum value of the input sequence length.
Defaults to `None`.
bos_token (str, optional):
The beginning of sequence token that was used during pretraining. Can be
used a sequence classifier token.
Defaults to `"<s>"`.
eos_token (str, optional):
A special token representing the end of a sequence that was used during pretraining.
Defaults to `"</s>"`.
cls_token (str, optional):
A special token used for sequence classification. It is the last token
of the sequence when built with special tokens.
Defaults to `"<s>"`.
sep_token (str, optional):
A special token separating two different sentences in the same input.
Defaults to `"</s>"`.
unk_token (str, optional):
A special token representing the *unknown (out-of-vocabulary)* token.
An unknown token is set to be `unk_token` inorder to be converted to an ID.
Defaults to `"<unk>"`.
A special token used to make arrays of tokens the same size for batching purposes.
A special token representing a masked token. This is the token used

Examples:
.. code-block::

tokenizer = DalleBartTokenizer.from_pretrained('dalle-mini')
print(tokenizer('Donald Trump in Animal Crossing'))

# {'input_ids': [0, 7083, 3252, 91, 2203, 7807, 2]}

"""
resource_files_names = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
"wiki_word_frequency_file": "enwiki-words-frequency.txt",
}
pretrained_resource_files_map = {
"vocab_file": {
},
"merges_file": {
},
"wiki_word_frequency_file": {
},
}
pretrained_init_configuration = {
"dalle-mini": {"normalize_text": True},
"dalle-mega-v16": {"normalize_text": True},
"dalle-mega-v26": {"normalize_text": True},
"dalle-mega": {"normalize_text": True},
}
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

def __init__(
self,
vocab_file,
merges_file,
wiki_word_frequency_file,
normalize_text=True,
errors="replace",
max_len=None,
bos_token="<s>",
eos_token="</s>",
cls_token="<s>",
sep_token="</s>",
unk_token="<unk>",
**kwargs
):

bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token

# Mask token behave like a normal word, i.e. include the space before it

self._build_special_tokens_map_extended(
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
unk_token=unk_token,
)
self.normalize_text = normalize_text
# in order to save wiki_word_frequency_file, we need set this attr
self._wiki_word_frequency_file = wiki_word_frequency_file
if self.normalize_text:
self.text_processor = TextNormalizer(wiki_word_frequency_file)
super().__init__(vocab_file, merges_file, errors, max_len, pad_token, eos_token, unk_token, **kwargs)

def _bpe_encode(self, text):
bpe_tokens = []
re = try_import("regex")
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens

[文档]    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification
"""
_cls = [self.cls_token_id]
_sep = [self.sep_token_id]
if token_ids_1 is None:
return _cls + token_ids_0 + _sep
return _cls + token_ids_0 + _sep + _sep + token_ids_1 + _sep

"""
Retrieves sequence ids from a token list that has no special tokens added. This method is
called when adding special tokens using the tokenizer ``encode`` methods.
"""
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]

[文档]    def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]

if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

def __call__(
self,
text,
text_pair=None,
max_length=64,  # default
stride=0,
is_split_into_words=False,
truncation=True,  # default
return_position_ids=False,
return_length=False,
return_overflowing_tokens=False,
return_dict=True,
return_offsets_mapping=False,
return_tensors=None,
verbose: bool = True,
**kwargs
):
if self.normalize_text:
is_batched = isinstance(text, (list, tuple))
if is_batched:
text = [self.text_processor(t) for t in text]
if text_pair:
text_pair = [self.text_processor(t) for t in text_pair]
else:
text = self.text_processor(text)
if text_pair:
text_pair = self.text_processor(text_pair)

return super().__call__(
text,
text_pair,
max_length,
stride,
is_split_into_words,
truncation,
return_position_ids,
return_token_type_ids,