# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2021-2022 The Fairseq Authors and The Google Flax
# Team Authors And The HuggingFace Inc. team and & DALL·E Mini team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import math
import random
import re
from pathlib import Path
from paddle.utils import try_import
from ...transformers import AddedToken, GPTTokenizer
__all__ = ["DalleBartTokenizer"]
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"dalle-mini": 64,
"dalle-mega-v16": 64,
"dalle-mega-v26": 64,
"dalle-mega": 64,
}
# based on wiki word occurrence
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
temp_token = "xtokx" # avoid repeating chars
class HashtagProcessor:
# Adapted from wordninja library
# We use our wikipedia word count + a good heuristic to make it work
def __init__(self, wiki_word_frequency):
self._word_cost = (l.split()[0] for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines())
self._word_cost = {str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)}
self._max_word = max(len(x) for x in self._word_cost.keys())
self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")
def __call__(self, s):
"""Uses dynamic programming to infer the location of spaces in a string without spaces."""
l = [self._split(x) for x in self._SPLIT_RE.split(s)]
return " ".join([item for sublist in l for item in sublist])
def _split(self, s):
# Find the best match for the i first characters, assuming cost has
# been built for the i-1 first characters.
# Returns a pair (match_cost, match_length).
def best_match(i):
candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
return min((c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1) for k, c in candidates)
# Build the cost array
cost = [0]
for i in range(1, len(s) + 1):
c, k = best_match(i)
cost.append(c)
# Backtrack to recover the minimal-cost string.
out = []
i = len(s)
while i > 0:
c, k = best_match(i)
assert c == cost[i]
newToken = True
if not s[i - k : i] == "'": # ignore a lone apostrophe
if len(out) > 0:
# re-attach split 's and split digits
if out[-1] == "'s" or (s[i - 1].isdigit() and out[-1][0].isdigit()): # digit followed by digit
out[-1] = s[i - k : i] + out[-1] # combine current token with previous token
newToken = False
if newToken:
out.append(s[i - k : i])
i -= k
return reversed(out)
def replace_person_token(t):
"Used for CC12M"
t = re.sub(r"<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
while "<person>" in t:
t = t.replace("<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1)
return t
def fix_html(t):
# from OpenAI CLIP
return html.unescape(html.unescape(t))
def replace_punctuation_with_commas(t):
return re.sub(r"[()[\].,|:;?!=+~\-\/{}]", ",", t)
def simplify_quotes(t):
return re.sub("""['"`]""", ' " ', t)
def merge_quotes(t):
return re.sub(r'(\s*"+\s*)+', ' " ', t)
def remove_comma_numbers(t):
def _f(t):
return re.sub(r"(\d),(\d{3})", r"\1\2", t)
return _f(_f(t))
def pre_process_dot_numbers(t):
return re.sub(r"(\w)\.(\w)", rf"\1{temp_token}dot{temp_token}\2", t)
def post_process_dot_numbers(t):
return re.sub(f"{temp_token}dot{temp_token}", ".", t)
def pre_process_quotes(t):
# allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
return re.sub(r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", rf"{temp_token}quote{temp_token}", t)
def post_process_quotes(t):
return re.sub(f"{temp_token}quote{temp_token}", "'", t)
def pre_process_dates(t):
return re.sub(r"(\d)/(\d)", rf"\1{temp_token}slash{temp_token}\2", t)
def post_process_dates(t):
return re.sub(f"{temp_token}slash{temp_token}", "/", t)
def merge_commas(t):
return re.sub(r"(\s*,+\s*)+", ", ", t)
def add_space_after_commas(t):
return re.sub(",", ", ", t)
def handle_special_chars(t):
"Handle special characters"
# replace "-" with a space when between words without space
t = re.sub(r"(\w)-(\w)", r"\1 \2", t)
# always add space around some characters
return re.sub(r"([%&\/$*])", r" \1 ", t)
def expand_hashtags(t, hashtag_processor):
"Remove # and try to split words"
return re.sub(r"#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
_re_ignore_chars = r"[_#\\]"
def ignore_chars(t):
"Ignore useless characters"
return re.sub(_re_ignore_chars, " ", t)
def remove_extra_spaces(t):
"Remove extra spaces (including \t and \n)"
return re.sub(r"\s+", " ", t)
def remove_repeating_chars(t):
"If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
return re.sub(r"(\D)(\1{3,})", r"\1", t)
def remove_urls(t):
return re.sub(r"http\S+", "", t)
def remove_html_tags(t):
return re.sub("<[^<]+?>", " ", t)
def remove_first_last_commas(t):
t = t.strip()
t = t[:-1] if t and t[-1] == "," else t
t = t[1:] if t and t[0] == "," else t
return t.strip()
def remove_wiki_ref(t):
t = re.sub(r"\A\s*\[\d+\]", "", t)
return re.sub(r"\[\d+\]\s*\Z", "", t)
class TextNormalizer:
def __init__(self, wiki_word_frequency_file):
self._hashtag_processor = HashtagProcessor(wiki_word_frequency_file)
self.emoji = try_import("emoji")
self.ftfy = try_import("ftfy")
self.unidecode = try_import("unidecode")
def __call__(self, t):
# fix some characters
t = self.ftfy.fix_text(t)
# fix html
t = fix_html(t)
# decode emojis (would be removed by unidecode)
t = self.emoji.demojize(t)
# decode and simplify text: see unidecode library
t = self.unidecode.unidecode(t)
# lower case
t = t.lower()
# replace <PERSON> (for CC12M)
t = replace_person_token(t)
# remove wiki reference (for WIT)
t = remove_wiki_ref(t)
# remove html tags
t = remove_html_tags(t)
# remove urls
t = remove_urls(t)
# remove commas in numbers
t = remove_comma_numbers(t)
# handle dots in numbers and quotes - Part 1
t = pre_process_dot_numbers(t)
t = pre_process_quotes(t)
t = pre_process_dates(t)
# handle special characters
t = handle_special_chars(t)
# handle hashtags
t = expand_hashtags(t, self._hashtag_processor)
# ignore useless characters
t = ignore_chars(t)
# simplify quotes
t = simplify_quotes(t)
# all punctuation becomes commas
t = replace_punctuation_with_commas(t)
# handle dots in numbers and quotes - Part 2
t = post_process_dot_numbers(t)
t = post_process_quotes(t)
t = post_process_dates(t)
# handle repeating characters
t = remove_repeating_chars(t)
# merge quotes
t = merge_quotes(t)
# merge commas
t = merge_commas(t)
# remove multiple spaces
t = remove_extra_spaces(t)
# remove first and last comma
t = remove_first_last_commas(t)
# always start with a space
return f" {t}"
[docs]
class DalleBartTokenizer(GPTTokenizer):
r"""
Construct a DalleBart tokenizer based on byte-level Byte-Pair-Encoding.
This tokenizer inherits from :class:`~paddlenlp.transformers.gpt.tokenizer.GPTTokenizer`.
For more information regarding those methods, please refer to this superclass.
Args:
vocab_file (str):
Path to the vocabulary file.
The vocab file contains a mapping from vocabulary strings to indices.
merges_file (str):
Path to the merge file.
The merge file is used to split the input sentence into "subword" units.
The vocab file is then used to encode those units as intices.
wiki_word_frequency_file (str):
Path to the wiki_word_frequency file when we need normlize text.
errors (str):
Paradigm to follow when decoding bytes to UTF-8.
Defaults to `'replace'`.
max_len (int, optional):
The maximum value of the input sequence length.
Defaults to `None`.
bos_token (str, optional):
The beginning of sequence token that was used during pretraining. Can be
used a sequence classifier token.
Defaults to `"<s>"`.
eos_token (str, optional):
A special token representing the end of a sequence that was used during pretraining.
Defaults to `"</s>"`.
cls_token (str, optional):
A special token used for sequence classification. It is the last token
of the sequence when built with special tokens.
Defaults to `"<s>"`.
sep_token (str, optional):
A special token separating two different sentences in the same input.
Defaults to `"</s>"`.
unk_token (str, optional):
A special token representing the *unknown (out-of-vocabulary)* token.
An unknown token is set to be `unk_token` inorder to be converted to an ID.
Defaults to `"<unk>"`.
pad_token (str, optional):
A special token used to make arrays of tokens the same size for batching purposes.
Defaults to `"<pad>"`.
mask_token (str, optional):
A special token representing a masked token. This is the token used
in the masked language modeling task which the model tries to predict the original unmasked ones.
Defaults to `"<mask>"`.
Examples:
.. code-block::
from paddlenlp.transformers import DalleBartTokenizer
tokenizer = DalleBartTokenizer.from_pretrained('dalle-mini')
print(tokenizer('Donald Trump in Animal Crossing'))
# {'input_ids': [0, 7083, 3252, 91, 2203, 7807, 2]}
"""
resource_files_names = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
"wiki_word_frequency_file": "enwiki-words-frequency.txt",
}
pretrained_resource_files_map = {
"vocab_file": {
"dalle-mini": "https://bj.bcebos.com/paddlenlp/models/transformers/dallebart/dalle-mini/vocab.json",
"dalle-mega-v16": "https://bj.bcebos.com/paddlenlp/models/transformers/dallebart/dalle-mega-v16/vocab.json",
"dalle-mega-v26": "https://bj.bcebos.com/paddlenlp/models/transformers/dallebart/dalle-mega-v26/vocab.json",
"dalle-mega": "https://bj.bcebos.com/paddlenlp/models/transformers/dallebart/dalle-mega-v26/vocab.json",
},
"merges_file": {
"dalle-mini": "https://bj.bcebos.com/paddlenlp/models/transformers/dallebart/dalle-mini/merges.txt",
"dalle-mega-v16": "https://bj.bcebos.com/paddlenlp/models/transformers/dallebart/dalle-mega-v16/merges.txt",
"dalle-mega-v26": "https://bj.bcebos.com/paddlenlp/models/transformers/dallebart/dalle-mega-v26/merges.txt",
"dalle-mega": "https://bj.bcebos.com/paddlenlp/models/transformers/dallebart/dalle-mega-v26/merges.txt",
},
"wiki_word_frequency_file": {
"dalle-mini": "https://bj.bcebos.com/paddlenlp/models/transformers/dallebart/dalle-mini/enwiki-words-frequency.txt",
"dalle-mega-v16": "https://bj.bcebos.com/paddlenlp/models/transformers/dallebart/dalle-mega-v16/enwiki-words-frequency.txt",
"dalle-mega-v26": "https://bj.bcebos.com/paddlenlp/models/transformers/dallebart/dalle-mega-v26/enwiki-words-frequency.txt",
"dalle-mega": "https://bj.bcebos.com/paddlenlp/models/transformers/dallebart/dalle-mega-v26/enwiki-words-frequency.txt",
},
}
pretrained_init_configuration = {
"dalle-mini": {"normalize_text": True},
"dalle-mega-v16": {"normalize_text": True},
"dalle-mega-v26": {"normalize_text": True},
"dalle-mega": {"normalize_text": True},
}
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
merges_file,
wiki_word_frequency_file,
normalize_text=True,
errors="replace",
max_len=None,
bos_token="<s>",
eos_token="</s>",
cls_token="<s>",
sep_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
**kwargs
):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
self._build_special_tokens_map_extended(
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
unk_token=unk_token,
pad_token=pad_token,
mask_token=mask_token,
)
self.normalize_text = normalize_text
# in order to save wiki_word_frequency_file, we need set this attr
self._wiki_word_frequency_file = wiki_word_frequency_file
if self.normalize_text:
self.text_processor = TextNormalizer(wiki_word_frequency_file)
super().__init__(vocab_file, merges_file, errors, max_len, pad_token, eos_token, unk_token, **kwargs)
def _bpe_encode(self, text):
bpe_tokens = []
re = try_import("regex")
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
[docs]
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is
called when adding special tokens using the tokenizer ``encode`` methods.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
[docs]
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
def __call__(
self,
text,
text_pair=None,
max_length=64, # default
stride=0,
is_split_into_words=False,
padding="max_length", # default
truncation=True, # default
return_position_ids=False,
return_token_type_ids=False, # don't return token_type_ids
return_attention_mask=True, # default
return_length=False,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
return_dict=True,
return_offsets_mapping=False,
add_special_tokens=True,
pad_to_multiple_of=None,
return_tensors=None,
verbose: bool = True,
**kwargs
):
if self.normalize_text:
is_batched = isinstance(text, (list, tuple))
if is_batched:
text = [self.text_processor(t) for t in text]
if text_pair:
text_pair = [self.text_processor(t) for t in text_pair]
else:
text = self.text_processor(text)
if text_pair:
text_pair = self.text_processor(text_pair)
return super().__call__(
text,
text_pair,
max_length,
stride,
is_split_into_words,
padding,
truncation,
return_position_ids,
return_token_type_ids,
return_attention_mask,
return_length,
return_overflowing_tokens,
return_special_tokens_mask,
return_dict,
return_offsets_mapping,
add_special_tokens,
pad_to_multiple_of,
return_tensors,
verbose,
**kwargs,
)