diffsynth-engine 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +28 -0
- diffsynth_engine/algorithm/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +50 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +25 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
- diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +17 -0
- diffsynth_engine/algorithm/sampler/__init__.py +19 -0
- diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
- diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
- diffsynth_engine/conf/models/components/vae.json +254 -0
- diffsynth_engine/conf/models/flux/flux_dit.json +105 -0
- diffsynth_engine/conf/models/flux/flux_text_encoder.json +20 -0
- diffsynth_engine/conf/models/flux/flux_vae.json +250 -0
- diffsynth_engine/conf/models/sd/sd_text_encoder.json +220 -0
- diffsynth_engine/conf/models/sd/sd_unet.json +397 -0
- diffsynth_engine/conf/models/sd3/sd3_dit.json +908 -0
- diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +756 -0
- diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +455 -0
- diffsynth_engine/conf/models/sdxl/sdxl_unet.json +1056 -0
- diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +13 -0
- diffsynth_engine/conf/models/wan/dit/14b-i2v.json +13 -0
- diffsynth_engine/conf/models/wan/dit/14b-t2v.json +13 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
- diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
- diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
- diffsynth_engine/kernels/__init__.py +0 -0
- diffsynth_engine/models/__init__.py +7 -0
- diffsynth_engine/models/base.py +64 -0
- diffsynth_engine/models/basic/__init__.py +0 -0
- diffsynth_engine/models/basic/attention.py +217 -0
- diffsynth_engine/models/basic/lora.py +293 -0
- diffsynth_engine/models/basic/relative_position_emb.py +56 -0
- diffsynth_engine/models/basic/timestep.py +81 -0
- diffsynth_engine/models/basic/transformer_helper.py +88 -0
- diffsynth_engine/models/basic/unet_helper.py +244 -0
- diffsynth_engine/models/components/__init__.py +0 -0
- diffsynth_engine/models/components/clip.py +56 -0
- diffsynth_engine/models/components/t5.py +222 -0
- diffsynth_engine/models/components/vae.py +392 -0
- diffsynth_engine/models/flux/__init__.py +14 -0
- diffsynth_engine/models/flux/flux_dit.py +476 -0
- diffsynth_engine/models/flux/flux_text_encoder.py +88 -0
- diffsynth_engine/models/flux/flux_vae.py +78 -0
- diffsynth_engine/models/sd/__init__.py +12 -0
- diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
- diffsynth_engine/models/sd/sd_unet.py +293 -0
- diffsynth_engine/models/sd/sd_vae.py +38 -0
- diffsynth_engine/models/sd3/__init__.py +14 -0
- diffsynth_engine/models/sd3/sd3_dit.py +302 -0
- diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
- diffsynth_engine/models/sd3/sd3_vae.py +43 -0
- diffsynth_engine/models/sdxl/__init__.py +13 -0
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
- diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
- diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
- diffsynth_engine/models/utils.py +54 -0
- diffsynth_engine/models/wan/__init__.py +0 -0
- diffsynth_engine/models/wan/wan_dit.py +497 -0
- diffsynth_engine/models/wan/wan_image_encoder.py +494 -0
- diffsynth_engine/models/wan/wan_text_encoder.py +297 -0
- diffsynth_engine/models/wan/wan_vae.py +771 -0
- diffsynth_engine/pipelines/__init__.py +18 -0
- diffsynth_engine/pipelines/base.py +253 -0
- diffsynth_engine/pipelines/flux_image.py +512 -0
- diffsynth_engine/pipelines/sd_image.py +352 -0
- diffsynth_engine/pipelines/sdxl_image.py +395 -0
- diffsynth_engine/pipelines/wan_video.py +524 -0
- diffsynth_engine/tokenizers/__init__.py +6 -0
- diffsynth_engine/tokenizers/base.py +157 -0
- diffsynth_engine/tokenizers/clip.py +288 -0
- diffsynth_engine/tokenizers/t5.py +194 -0
- diffsynth_engine/tokenizers/wan.py +74 -0
- diffsynth_engine/utils/__init__.py +0 -0
- diffsynth_engine/utils/constants.py +34 -0
- diffsynth_engine/utils/download.py +135 -0
- diffsynth_engine/utils/env.py +7 -0
- diffsynth_engine/utils/flag.py +46 -0
- diffsynth_engine/utils/fp8_linear.py +64 -0
- diffsynth_engine/utils/gguf.py +415 -0
- diffsynth_engine/utils/loader.py +17 -0
- diffsynth_engine/utils/lock.py +56 -0
- diffsynth_engine/utils/logging.py +12 -0
- diffsynth_engine/utils/offload.py +44 -0
- diffsynth_engine/utils/parallel.py +390 -0
- diffsynth_engine/utils/prompt.py +9 -0
- diffsynth_engine/utils/video.py +40 -0
- diffsynth_engine-0.0.0.dist-info/LICENSE +201 -0
- diffsynth_engine-0.0.0.dist-info/METADATA +236 -0
- diffsynth_engine-0.0.0.dist-info/RECORD +127 -0
- diffsynth_engine-0.0.0.dist-info/WHEEL +5 -0
- diffsynth_engine-0.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import ftfy
|
|
4
|
+
import regex as re
|
|
5
|
+
import torch
|
|
6
|
+
from functools import lru_cache
|
|
7
|
+
from typing import Dict, List, Union, Optional
|
|
8
|
+
|
|
9
|
+
from diffsynth_engine.tokenizers.base import BaseTokenizer, TOKENIZER_CONFIG_FILE
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
VOCAB_FILES_NAMES = {
|
|
13
|
+
"vocab_file": "vocab.json",
|
|
14
|
+
"merges_file": "merges.txt",
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
CLIP_DEFAULT_MAX_LENGTH = 77
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@lru_cache()
|
|
21
|
+
def bytes_to_unicode():
|
|
22
|
+
"""
|
|
23
|
+
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
|
24
|
+
characters the bpe code barfs on.
|
|
25
|
+
|
|
26
|
+
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
|
27
|
+
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
|
28
|
+
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
|
29
|
+
tables between utf-8 bytes and unicode strings.
|
|
30
|
+
"""
|
|
31
|
+
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
|
32
|
+
cs = bs[:]
|
|
33
|
+
n = 0
|
|
34
|
+
for b in range(2**8):
|
|
35
|
+
if b not in bs:
|
|
36
|
+
bs.append(b)
|
|
37
|
+
cs.append(2**8 + n)
|
|
38
|
+
n += 1
|
|
39
|
+
cs = [chr(n) for n in cs]
|
|
40
|
+
return dict(zip(bs, cs))
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_pairs(word):
|
|
44
|
+
"""
|
|
45
|
+
Return set of symbol pairs in a word.
|
|
46
|
+
|
|
47
|
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
|
48
|
+
"""
|
|
49
|
+
pairs = set()
|
|
50
|
+
prev_char = word[0]
|
|
51
|
+
for char in word[1:]:
|
|
52
|
+
pairs.add((prev_char, char))
|
|
53
|
+
prev_char = char
|
|
54
|
+
return pairs
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def whitespace_clean(text):
|
|
58
|
+
text = re.sub(r"\s+", " ", text)
|
|
59
|
+
text = text.strip()
|
|
60
|
+
return text
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# Modified from transformers.models.clip.tokenization_clip.CLIPTokenizer and open_clip.tokenizer.SimpleTokenizer
|
|
64
|
+
class CLIPTokenizer(BaseTokenizer):
|
|
65
|
+
"""
|
|
66
|
+
Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
vocab_file (`str`):
|
|
70
|
+
Path to the vocabulary file.
|
|
71
|
+
merges_file (`str`):
|
|
72
|
+
Path to the merges file.
|
|
73
|
+
bos_token (`str`, *optional*, defaults to `"<|startoftext|>"`):
|
|
74
|
+
The beginning of sequence token.
|
|
75
|
+
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
|
76
|
+
The end of sequence token.
|
|
77
|
+
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
|
78
|
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
|
79
|
+
token instead.
|
|
80
|
+
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
|
81
|
+
The token used for padding, for example when batching sequences of different lengths.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
vocab_files_names = VOCAB_FILES_NAMES
|
|
85
|
+
model_input_names = ["input_ids", "attention_mask"]
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
vocab_file: str,
|
|
90
|
+
merges_file: str,
|
|
91
|
+
bos_token: Optional[str] = "<|startoftext|>",
|
|
92
|
+
eos_token: Optional[str] = "<|endoftext|>",
|
|
93
|
+
unk_token: Optional[str] = "<|endoftext|>",
|
|
94
|
+
pad_token: Optional[str] = "<|endoftext|>", # hack to enable padding
|
|
95
|
+
**kwargs,
|
|
96
|
+
):
|
|
97
|
+
super().__init__(
|
|
98
|
+
unk_token=unk_token,
|
|
99
|
+
bos_token=bos_token,
|
|
100
|
+
eos_token=eos_token,
|
|
101
|
+
pad_token=pad_token,
|
|
102
|
+
**kwargs,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
|
106
|
+
self.encoder = json.load(vocab_handle)
|
|
107
|
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
|
108
|
+
self.byte_encoder = bytes_to_unicode()
|
|
109
|
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
|
110
|
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
|
111
|
+
bpe_merges = merges_handle.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
|
112
|
+
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
|
|
113
|
+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
|
114
|
+
self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
|
|
115
|
+
|
|
116
|
+
self.pat = re.compile(
|
|
117
|
+
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
|
118
|
+
re.IGNORECASE,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
self.model_max_length = self.model_max_length if self.model_max_length else CLIP_DEFAULT_MAX_LENGTH
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def from_pretrained(cls, pretrained_model_path: Union[str, os.PathLike], **kwargs):
|
|
125
|
+
tokenizer_config_file = os.path.join(pretrained_model_path, TOKENIZER_CONFIG_FILE)
|
|
126
|
+
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
|
|
127
|
+
init_kwargs = json.load(tokenizer_config_handle)
|
|
128
|
+
init_kwargs.update(**kwargs)
|
|
129
|
+
vocab_file = os.path.join(pretrained_model_path, cls.vocab_files_names["vocab_file"])
|
|
130
|
+
merges_file = os.path.join(pretrained_model_path, cls.vocab_files_names["merges_file"])
|
|
131
|
+
return cls(vocab_file=vocab_file, merges_file=merges_file, **init_kwargs)
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def vocab_size(self):
|
|
135
|
+
return len(self.encoder)
|
|
136
|
+
|
|
137
|
+
def get_vocab(self):
|
|
138
|
+
return self.encoder
|
|
139
|
+
|
|
140
|
+
def bpe(self, token):
|
|
141
|
+
if token in self.cache:
|
|
142
|
+
return self.cache[token]
|
|
143
|
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
|
144
|
+
pairs = get_pairs(word)
|
|
145
|
+
|
|
146
|
+
if not pairs:
|
|
147
|
+
return token + "</w>"
|
|
148
|
+
|
|
149
|
+
while True:
|
|
150
|
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
|
151
|
+
if bigram not in self.bpe_ranks:
|
|
152
|
+
break
|
|
153
|
+
first, second = bigram
|
|
154
|
+
new_word = []
|
|
155
|
+
i = 0
|
|
156
|
+
while i < len(word):
|
|
157
|
+
try:
|
|
158
|
+
j = word.index(first, i)
|
|
159
|
+
except ValueError:
|
|
160
|
+
new_word.extend(word[i:])
|
|
161
|
+
break
|
|
162
|
+
else:
|
|
163
|
+
new_word.extend(word[i:j])
|
|
164
|
+
i = j
|
|
165
|
+
|
|
166
|
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
|
167
|
+
new_word.append(first + second)
|
|
168
|
+
i += 2
|
|
169
|
+
else:
|
|
170
|
+
new_word.append(word[i])
|
|
171
|
+
i += 1
|
|
172
|
+
new_word = tuple(new_word)
|
|
173
|
+
word = new_word
|
|
174
|
+
if len(word) == 1:
|
|
175
|
+
break
|
|
176
|
+
else:
|
|
177
|
+
pairs = get_pairs(word)
|
|
178
|
+
word = " ".join(word)
|
|
179
|
+
self.cache[token] = word
|
|
180
|
+
return word
|
|
181
|
+
|
|
182
|
+
def tokenize(self, texts: Union[str, List[str]]) -> Union[List[str], List[List[str]]]:
|
|
183
|
+
"""Convert string to tokens."""
|
|
184
|
+
if isinstance(texts, str):
|
|
185
|
+
return self._tokenize(texts)
|
|
186
|
+
|
|
187
|
+
return [self._tokenize(text) for text in texts]
|
|
188
|
+
|
|
189
|
+
def _tokenize(self, text: str) -> List[str]:
|
|
190
|
+
bpe_tokens = []
|
|
191
|
+
text = whitespace_clean(ftfy.fix_text(text)).lower()
|
|
192
|
+
|
|
193
|
+
for token in re.findall(self.pat, text):
|
|
194
|
+
token = "".join(
|
|
195
|
+
self.byte_encoder[b] for b in token.encode("utf-8")
|
|
196
|
+
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
|
|
197
|
+
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
|
198
|
+
return bpe_tokens
|
|
199
|
+
|
|
200
|
+
def encode(self, texts: str) -> List[int]:
|
|
201
|
+
tokens = self.tokenize(texts)
|
|
202
|
+
return self.convert_tokens_to_ids(tokens)
|
|
203
|
+
|
|
204
|
+
def batch_encode(self, texts: List[str]) -> List[List[int]]:
|
|
205
|
+
return [self.encode(text) for text in texts]
|
|
206
|
+
|
|
207
|
+
def decode(
|
|
208
|
+
self, ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None
|
|
209
|
+
) -> str:
|
|
210
|
+
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens)
|
|
211
|
+
text = self.convert_tokens_to_string(tokens)
|
|
212
|
+
|
|
213
|
+
clean_up_tokenization_spaces = (
|
|
214
|
+
clean_up_tokenization_spaces
|
|
215
|
+
if clean_up_tokenization_spaces is not None
|
|
216
|
+
else self.clean_up_tokenization_spaces
|
|
217
|
+
)
|
|
218
|
+
if clean_up_tokenization_spaces:
|
|
219
|
+
text = self.clean_up_tokenization(text)
|
|
220
|
+
return text
|
|
221
|
+
|
|
222
|
+
def batch_decode(
|
|
223
|
+
self, ids: List[List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None
|
|
224
|
+
) -> List[str]:
|
|
225
|
+
return [self.decode(index, skip_special_tokens, clean_up_tokenization_spaces) for index in ids]
|
|
226
|
+
|
|
227
|
+
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
|
|
228
|
+
if isinstance(tokens, str):
|
|
229
|
+
return self.encoder.get(tokens, self.encoder.get(self.unk_token))
|
|
230
|
+
|
|
231
|
+
return [self.encoder.get(token, self.encoder.get(self.unk_token)) for token in tokens]
|
|
232
|
+
|
|
233
|
+
def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]:
|
|
234
|
+
if isinstance(ids, int):
|
|
235
|
+
return self.decoder.get(ids)
|
|
236
|
+
|
|
237
|
+
tokens = []
|
|
238
|
+
for index in ids:
|
|
239
|
+
if skip_special_tokens and index in self.all_special_ids:
|
|
240
|
+
continue
|
|
241
|
+
tokens.append(self.decoder.get(index))
|
|
242
|
+
return tokens
|
|
243
|
+
|
|
244
|
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
|
245
|
+
text = "".join(tokens)
|
|
246
|
+
byte_array = bytearray([self.byte_decoder[c] for c in text])
|
|
247
|
+
text = byte_array.decode("utf-8", errors="replace").replace("</w>", " ").strip()
|
|
248
|
+
return text
|
|
249
|
+
|
|
250
|
+
def __call__(
|
|
251
|
+
self,
|
|
252
|
+
texts: Union[str, List[str]],
|
|
253
|
+
max_length: Optional[int] = None,
|
|
254
|
+
**kwargs,
|
|
255
|
+
) -> Dict[str, "torch.Tensor"]:
|
|
256
|
+
"""
|
|
257
|
+
Tokenize text and prepare for model inputs.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
text (`str`, `List[str]`, *optional*):
|
|
261
|
+
The sequence or batch of sequences to be encoded.
|
|
262
|
+
|
|
263
|
+
max_length (`int`, *optional*):
|
|
264
|
+
Each encoded sequence will be truncated or padded to max_length.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
`Dict[str, "torch.Tensor"]`: tensor dict compatible with model_input_names.
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
if isinstance(texts, str):
|
|
271
|
+
texts = [texts]
|
|
272
|
+
|
|
273
|
+
max_length = max_length if max_length else self.model_max_length
|
|
274
|
+
|
|
275
|
+
encoded = torch.zeros(len(texts), max_length, dtype=torch.long)
|
|
276
|
+
encoded.fill_(self.pad_token_id)
|
|
277
|
+
attention_mask = torch.zeros(len(texts), max_length, dtype=torch.long)
|
|
278
|
+
|
|
279
|
+
for i, text in enumerate(texts):
|
|
280
|
+
tokens = self.tokenize(text)
|
|
281
|
+
ids = [self.bos_token_id] + self.convert_tokens_to_ids(tokens) + [self.eos_token_id]
|
|
282
|
+
if len(ids) > max_length:
|
|
283
|
+
ids = ids[:max_length]
|
|
284
|
+
ids[-1] = self.eos_token_id
|
|
285
|
+
encoded[i, : len(ids)] = torch.tensor(ids)
|
|
286
|
+
attention_mask[i, : len(ids)] = torch.ones((1, len(ids)))
|
|
287
|
+
|
|
288
|
+
return {"input_ids": encoded, "attention_mask": attention_mask}
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Dict, List, Union, Optional
|
|
5
|
+
from tokenizers import Tokenizer as TokenizerFast
|
|
6
|
+
|
|
7
|
+
from diffsynth_engine.tokenizers.base import BaseTokenizer, TOKENIZER_CONFIG_FILE
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
|
|
11
|
+
|
|
12
|
+
T5_DEFAULT_MAX_LENGTH = 512
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class T5TokenizerFast(BaseTokenizer):
|
|
16
|
+
"""
|
|
17
|
+
Construct a "fast" T5 tokenizer (backed by HuggingFace's *tokenizers* library). Based on
|
|
18
|
+
[Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
vocab_file (`str`):
|
|
22
|
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
|
23
|
+
contains the vocabulary necessary to instantiate a tokenizer.
|
|
24
|
+
tokenizer_file (`str`):
|
|
25
|
+
Precompiled file for initializing a fast tokenizer.
|
|
26
|
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
|
27
|
+
The end of sequence token.
|
|
28
|
+
|
|
29
|
+
<Tip>
|
|
30
|
+
|
|
31
|
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
|
32
|
+
The token used is the `sep_token`.
|
|
33
|
+
|
|
34
|
+
</Tip>
|
|
35
|
+
|
|
36
|
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
|
37
|
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
|
38
|
+
token instead.
|
|
39
|
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
|
40
|
+
The token used for padding, for example when batching sequences of different lengths.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
vocab_files_names = VOCAB_FILES_NAMES
|
|
44
|
+
model_input_names = ["input_ids", "attention_mask"]
|
|
45
|
+
|
|
46
|
+
prefix_tokens: List[int] = []
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
vocab_file=None,
|
|
51
|
+
tokenizer_file=None,
|
|
52
|
+
eos_token="</s>",
|
|
53
|
+
unk_token="<unk>",
|
|
54
|
+
pad_token="<pad>",
|
|
55
|
+
**kwargs,
|
|
56
|
+
):
|
|
57
|
+
super().__init__(
|
|
58
|
+
eos_token=eos_token,
|
|
59
|
+
unk_token=unk_token,
|
|
60
|
+
pad_token=pad_token,
|
|
61
|
+
**kwargs,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
fast_tokenizer = TokenizerFast.from_file(tokenizer_file)
|
|
65
|
+
self._tokenizer = fast_tokenizer
|
|
66
|
+
# disable truncation and padding
|
|
67
|
+
self._tokenizer.no_truncation()
|
|
68
|
+
self._tokenizer.no_padding()
|
|
69
|
+
|
|
70
|
+
self.model_max_length = self.model_max_length if self.model_max_length else T5_DEFAULT_MAX_LENGTH
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def from_pretrained(cls, pretrained_model_path: Union[str, os.PathLike], **kwargs):
|
|
74
|
+
tokenizer_config_file = os.path.join(pretrained_model_path, TOKENIZER_CONFIG_FILE)
|
|
75
|
+
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
|
|
76
|
+
init_kwargs = json.load(tokenizer_config_handle)
|
|
77
|
+
init_kwargs.update(**kwargs)
|
|
78
|
+
vocab_file = os.path.join(pretrained_model_path, cls.vocab_files_names["vocab_file"])
|
|
79
|
+
tokenizer_file = os.path.join(pretrained_model_path, cls.vocab_files_names["tokenizer_file"])
|
|
80
|
+
return cls(vocab_file=vocab_file, tokenizer_file=tokenizer_file, **init_kwargs)
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def vocab_size(self):
|
|
84
|
+
return self._tokenizer.get_vocab_size(with_added_tokens=True)
|
|
85
|
+
|
|
86
|
+
def get_vocab(self):
|
|
87
|
+
return self._tokenizer.get_vocab(with_added_tokens=True)
|
|
88
|
+
|
|
89
|
+
def tokenize(self, texts: Union[str, List[str]]) -> Union[List[str], List[List[str]]]:
|
|
90
|
+
if isinstance(texts, str):
|
|
91
|
+
encoding = self._tokenizer.encode(texts)
|
|
92
|
+
return encoding.tokens
|
|
93
|
+
|
|
94
|
+
encodings = self._tokenizer.encode_batch(texts)
|
|
95
|
+
return [encoding.tokens for encoding in encodings]
|
|
96
|
+
|
|
97
|
+
def encode(self, texts: str) -> List[int]:
|
|
98
|
+
encoding = self._tokenizer.encode(texts, add_special_tokens=True)
|
|
99
|
+
return encoding.ids
|
|
100
|
+
|
|
101
|
+
def batch_encode(self, texts: List[str]) -> List[List[int]]:
|
|
102
|
+
encodings = self._tokenizer.encode_batch(texts, add_special_tokens=True)
|
|
103
|
+
return [encoding.ids for encoding in encodings]
|
|
104
|
+
|
|
105
|
+
def decode(
|
|
106
|
+
self, ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None
|
|
107
|
+
) -> str:
|
|
108
|
+
text = self._tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)
|
|
109
|
+
|
|
110
|
+
clean_up_tokenization_spaces = (
|
|
111
|
+
clean_up_tokenization_spaces
|
|
112
|
+
if clean_up_tokenization_spaces is not None
|
|
113
|
+
else self.clean_up_tokenization_spaces
|
|
114
|
+
)
|
|
115
|
+
if clean_up_tokenization_spaces:
|
|
116
|
+
text = self.clean_up_tokenization(text)
|
|
117
|
+
return text
|
|
118
|
+
|
|
119
|
+
def batch_decode(
|
|
120
|
+
self, ids: List[List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None
|
|
121
|
+
) -> List[str]:
|
|
122
|
+
texts = self._tokenizer.decode_batch(ids, skip_special_tokens=skip_special_tokens)
|
|
123
|
+
|
|
124
|
+
clean_up_tokenization_spaces = (
|
|
125
|
+
clean_up_tokenization_spaces
|
|
126
|
+
if clean_up_tokenization_spaces is not None
|
|
127
|
+
else self.clean_up_tokenization_spaces
|
|
128
|
+
)
|
|
129
|
+
if clean_up_tokenization_spaces:
|
|
130
|
+
texts = [self.clean_up_tokenization(text) for text in texts]
|
|
131
|
+
return texts
|
|
132
|
+
|
|
133
|
+
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
|
|
134
|
+
if isinstance(tokens, str):
|
|
135
|
+
index = self._tokenizer.token_to_id(tokens)
|
|
136
|
+
return index if index is not None else self._tokenizer.token_to_id(self.unk_token)
|
|
137
|
+
|
|
138
|
+
ids = [self._tokenizer.token_to_id(token) for token in tokens]
|
|
139
|
+
return [index if index is not None else self._tokenizer.token_to_id(self.unk_token) for index in ids]
|
|
140
|
+
|
|
141
|
+
def convert_ids_to_tokens(
|
|
142
|
+
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
|
|
143
|
+
) -> Union[str, List[str]]:
|
|
144
|
+
if isinstance(ids, int):
|
|
145
|
+
return self._tokenizer.id_to_token(ids)
|
|
146
|
+
|
|
147
|
+
tokens = []
|
|
148
|
+
for index in ids:
|
|
149
|
+
if skip_special_tokens and index in self.all_special_ids:
|
|
150
|
+
continue
|
|
151
|
+
tokens.append(self._tokenizer.id_to_token(index))
|
|
152
|
+
return tokens
|
|
153
|
+
|
|
154
|
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
|
155
|
+
return self._tokenizer.decode(tokens)
|
|
156
|
+
|
|
157
|
+
def __call__(
|
|
158
|
+
self,
|
|
159
|
+
texts: Union[str, List[str]],
|
|
160
|
+
max_length: Optional[int] = None,
|
|
161
|
+
**kwargs,
|
|
162
|
+
) -> Dict[str, "torch.Tensor"]:
|
|
163
|
+
"""
|
|
164
|
+
Tokenize text and prepare for model inputs.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
text (`str`, `List[str]`, *optional*):
|
|
168
|
+
The sequence or batch of sequences to be encoded.
|
|
169
|
+
|
|
170
|
+
max_length (`int`, *optional*):
|
|
171
|
+
Each encoded sequence will be truncated or padded to max_length.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
`Dict[str, "torch.Tensor"]`: tensor dict compatible with model_input_names.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
if isinstance(texts, str):
|
|
178
|
+
texts = [texts]
|
|
179
|
+
|
|
180
|
+
max_length = max_length if max_length else self.model_max_length
|
|
181
|
+
|
|
182
|
+
encoded = torch.zeros(len(texts), max_length, dtype=torch.long)
|
|
183
|
+
encoded.fill_(self.pad_token_id)
|
|
184
|
+
attention_mask = torch.zeros(len(texts), max_length, dtype=torch.long)
|
|
185
|
+
|
|
186
|
+
batch_ids = self.batch_encode(texts)
|
|
187
|
+
for i, ids in enumerate(batch_ids):
|
|
188
|
+
if len(ids) > max_length:
|
|
189
|
+
ids = ids[:max_length]
|
|
190
|
+
ids[-1] = self.eos_token_id
|
|
191
|
+
encoded[i, : len(ids)] = torch.tensor(ids)
|
|
192
|
+
attention_mask[i, : len(ids)] = torch.ones((1, len(ids)))
|
|
193
|
+
|
|
194
|
+
return {"input_ids": encoded, "attention_mask": attention_mask}
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import html
|
|
2
|
+
import string
|
|
3
|
+
|
|
4
|
+
import ftfy
|
|
5
|
+
import regex as re
|
|
6
|
+
from .t5 import T5TokenizerFast
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def basic_clean(text):
|
|
10
|
+
text = ftfy.fix_text(text)
|
|
11
|
+
text = html.unescape(html.unescape(text))
|
|
12
|
+
return text.strip()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def whitespace_clean(text):
|
|
16
|
+
text = re.sub(r"\s+", " ", text)
|
|
17
|
+
text = text.strip()
|
|
18
|
+
return text
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def canonicalize(text, keep_punctuation_exact_string=None):
|
|
22
|
+
text = text.replace("_", " ")
|
|
23
|
+
if keep_punctuation_exact_string:
|
|
24
|
+
text = keep_punctuation_exact_string.join(
|
|
25
|
+
part.translate(str.maketrans("", "", string.punctuation))
|
|
26
|
+
for part in text.split(keep_punctuation_exact_string)
|
|
27
|
+
)
|
|
28
|
+
else:
|
|
29
|
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
|
30
|
+
text = text.lower()
|
|
31
|
+
text = re.sub(r"\s+", " ", text)
|
|
32
|
+
return text.strip()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class WanT5Tokenizer:
|
|
36
|
+
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
|
37
|
+
assert clean in (None, "whitespace", "lower", "canonicalize")
|
|
38
|
+
self.name = name
|
|
39
|
+
self.seq_len = seq_len
|
|
40
|
+
self.clean = clean
|
|
41
|
+
|
|
42
|
+
# init tokenizer
|
|
43
|
+
self.tokenizer = T5TokenizerFast.from_pretrained(name, **kwargs)
|
|
44
|
+
self.vocab_size = self.tokenizer.vocab_size
|
|
45
|
+
|
|
46
|
+
def __call__(self, sequence, **kwargs):
|
|
47
|
+
return_mask = kwargs.pop("return_mask", False)
|
|
48
|
+
|
|
49
|
+
# arguments
|
|
50
|
+
_kwargs = {"return_tensors": "pt"}
|
|
51
|
+
if self.seq_len is not None:
|
|
52
|
+
_kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
|
|
53
|
+
_kwargs.update(**kwargs)
|
|
54
|
+
|
|
55
|
+
# tokenization
|
|
56
|
+
if isinstance(sequence, str):
|
|
57
|
+
sequence = [sequence]
|
|
58
|
+
if self.clean:
|
|
59
|
+
sequence = [self._clean(u) for u in sequence]
|
|
60
|
+
ids = self.tokenizer(sequence, **_kwargs)
|
|
61
|
+
# output
|
|
62
|
+
if return_mask:
|
|
63
|
+
return ids["input_ids"], ids["attention_mask"]
|
|
64
|
+
else:
|
|
65
|
+
return ids["input_ids"]
|
|
66
|
+
|
|
67
|
+
def _clean(self, text):
|
|
68
|
+
if self.clean == "whitespace":
|
|
69
|
+
text = whitespace_clean(basic_clean(text))
|
|
70
|
+
elif self.clean == "lower":
|
|
71
|
+
text = whitespace_clean(basic_clean(text)).lower()
|
|
72
|
+
elif self.clean == "canonicalize":
|
|
73
|
+
text = canonicalize(basic_clean(text))
|
|
74
|
+
return text
|
|
File without changes
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
PACKAGE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
4
|
+
REPO_ROOT = os.path.dirname(PACKAGE_ROOT)
|
|
5
|
+
|
|
6
|
+
# conf
|
|
7
|
+
CONF_PATH = os.path.join(PACKAGE_ROOT, "conf")
|
|
8
|
+
# tokenizers
|
|
9
|
+
FLUX_TOKENIZER_1_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "flux", "tokenizer_1")
|
|
10
|
+
FLUX_TOKENIZER_2_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "flux", "tokenizer_2")
|
|
11
|
+
SDXL_TOKENIZER_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "sdxl", "tokenizer")
|
|
12
|
+
SDXL_TOKENIZER_2_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "sdxl", "tokenizer_2")
|
|
13
|
+
WAN_TOKENIZER_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "wan", "umt5-xxl")
|
|
14
|
+
# models
|
|
15
|
+
VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "components", "vae.json")
|
|
16
|
+
FLUX_DIT_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_dit.json")
|
|
17
|
+
FLUX_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_text_encoder.json")
|
|
18
|
+
FLUX_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_vae.json")
|
|
19
|
+
SD_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd", "sd_text_encoder.json")
|
|
20
|
+
SD_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd", "sd_unet.json")
|
|
21
|
+
SD3_DIT_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd3", "sd3_dit.json")
|
|
22
|
+
SD3_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd3", "sd3_text_encoder.json")
|
|
23
|
+
SDXL_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_text_encoder.json")
|
|
24
|
+
SDXL_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_unet.json")
|
|
25
|
+
|
|
26
|
+
WAN_DIT_1_3B_T2V_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "1.3b-t2v.json")
|
|
27
|
+
WAN_DIT_14B_I2V_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "14b-i2v.json")
|
|
28
|
+
WAN_DIT_14B_T2V_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "14b-t2v.json")
|
|
29
|
+
|
|
30
|
+
# data size
|
|
31
|
+
KB = 1024
|
|
32
|
+
MB = 1024 * KB
|
|
33
|
+
GB = 1024 * MB
|
|
34
|
+
TB = 1024 * GB
|