diffsynth-engine 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. diffsynth_engine/__init__.py +28 -0
  2. diffsynth_engine/algorithm/__init__.py +0 -0
  3. diffsynth_engine/algorithm/noise_scheduler/__init__.py +21 -0
  4. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +10 -0
  5. diffsynth_engine/algorithm/noise_scheduler/flow_match/__init__.py +5 -0
  6. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_beta.py +28 -0
  7. diffsynth_engine/algorithm/noise_scheduler/flow_match/flow_ddim.py +25 -0
  8. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +50 -0
  9. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/__init__.py +0 -0
  10. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/beta.py +26 -0
  11. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/ddim.py +25 -0
  12. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/exponential.py +19 -0
  13. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/karras.py +21 -0
  14. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/linear.py +77 -0
  15. diffsynth_engine/algorithm/noise_scheduler/stable_diffusion/sgm_uniform.py +17 -0
  16. diffsynth_engine/algorithm/sampler/__init__.py +19 -0
  17. diffsynth_engine/algorithm/sampler/flow_match/__init__.py +0 -0
  18. diffsynth_engine/algorithm/sampler/flow_match/flow_match_euler.py +22 -0
  19. diffsynth_engine/algorithm/sampler/stable_diffusion/__init__.py +0 -0
  20. diffsynth_engine/algorithm/sampler/stable_diffusion/brownian_tree.py +54 -0
  21. diffsynth_engine/algorithm/sampler/stable_diffusion/ddpm.py +32 -0
  22. diffsynth_engine/algorithm/sampler/stable_diffusion/deis.py +125 -0
  23. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m.py +29 -0
  24. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_2m_sde.py +53 -0
  25. diffsynth_engine/algorithm/sampler/stable_diffusion/dpmpp_3m_sde.py +59 -0
  26. diffsynth_engine/algorithm/sampler/stable_diffusion/epsilon.py +29 -0
  27. diffsynth_engine/algorithm/sampler/stable_diffusion/euler.py +12 -0
  28. diffsynth_engine/algorithm/sampler/stable_diffusion/euler_ancestral.py +30 -0
  29. diffsynth_engine/conf/models/components/vae.json +254 -0
  30. diffsynth_engine/conf/models/flux/flux_dit.json +105 -0
  31. diffsynth_engine/conf/models/flux/flux_text_encoder.json +20 -0
  32. diffsynth_engine/conf/models/flux/flux_vae.json +250 -0
  33. diffsynth_engine/conf/models/sd/sd_text_encoder.json +220 -0
  34. diffsynth_engine/conf/models/sd/sd_unet.json +397 -0
  35. diffsynth_engine/conf/models/sd3/sd3_dit.json +908 -0
  36. diffsynth_engine/conf/models/sd3/sd3_text_encoder.json +756 -0
  37. diffsynth_engine/conf/models/sdxl/sdxl_text_encoder.json +455 -0
  38. diffsynth_engine/conf/models/sdxl/sdxl_unet.json +1056 -0
  39. diffsynth_engine/conf/models/wan/dit/1.3b-t2v.json +13 -0
  40. diffsynth_engine/conf/models/wan/dit/14b-i2v.json +13 -0
  41. diffsynth_engine/conf/models/wan/dit/14b-t2v.json +13 -0
  42. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/merges.txt +48895 -0
  43. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/special_tokens_map.json +30 -0
  44. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/tokenizer_config.json +30 -0
  45. diffsynth_engine/conf/tokenizers/flux/tokenizer_1/vocab.json +49410 -0
  46. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/special_tokens_map.json +125 -0
  47. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/spiece.model +0 -0
  48. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer.json +129428 -0
  49. diffsynth_engine/conf/tokenizers/flux/tokenizer_2/tokenizer_config.json +940 -0
  50. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/merges.txt +48895 -0
  51. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/special_tokens_map.json +24 -0
  52. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/tokenizer_config.json +30 -0
  53. diffsynth_engine/conf/tokenizers/sdxl/tokenizer/vocab.json +49410 -0
  54. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/merges.txt +40213 -0
  55. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/special_tokens_map.json +24 -0
  56. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/tokenizer_config.json +38 -0
  57. diffsynth_engine/conf/tokenizers/sdxl/tokenizer_2/vocab.json +49411 -0
  58. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/special_tokens_map.json +308 -0
  59. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/spiece.model +0 -0
  60. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json +1028026 -0
  61. diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json +2748 -0
  62. diffsynth_engine/kernels/__init__.py +0 -0
  63. diffsynth_engine/models/__init__.py +7 -0
  64. diffsynth_engine/models/base.py +64 -0
  65. diffsynth_engine/models/basic/__init__.py +0 -0
  66. diffsynth_engine/models/basic/attention.py +217 -0
  67. diffsynth_engine/models/basic/lora.py +293 -0
  68. diffsynth_engine/models/basic/relative_position_emb.py +56 -0
  69. diffsynth_engine/models/basic/timestep.py +81 -0
  70. diffsynth_engine/models/basic/transformer_helper.py +88 -0
  71. diffsynth_engine/models/basic/unet_helper.py +244 -0
  72. diffsynth_engine/models/components/__init__.py +0 -0
  73. diffsynth_engine/models/components/clip.py +56 -0
  74. diffsynth_engine/models/components/t5.py +222 -0
  75. diffsynth_engine/models/components/vae.py +392 -0
  76. diffsynth_engine/models/flux/__init__.py +14 -0
  77. diffsynth_engine/models/flux/flux_dit.py +476 -0
  78. diffsynth_engine/models/flux/flux_text_encoder.py +88 -0
  79. diffsynth_engine/models/flux/flux_vae.py +78 -0
  80. diffsynth_engine/models/sd/__init__.py +12 -0
  81. diffsynth_engine/models/sd/sd_text_encoder.py +142 -0
  82. diffsynth_engine/models/sd/sd_unet.py +293 -0
  83. diffsynth_engine/models/sd/sd_vae.py +38 -0
  84. diffsynth_engine/models/sd3/__init__.py +14 -0
  85. diffsynth_engine/models/sd3/sd3_dit.py +302 -0
  86. diffsynth_engine/models/sd3/sd3_text_encoder.py +163 -0
  87. diffsynth_engine/models/sd3/sd3_vae.py +43 -0
  88. diffsynth_engine/models/sdxl/__init__.py +13 -0
  89. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +307 -0
  90. diffsynth_engine/models/sdxl/sdxl_unet.py +306 -0
  91. diffsynth_engine/models/sdxl/sdxl_vae.py +38 -0
  92. diffsynth_engine/models/utils.py +54 -0
  93. diffsynth_engine/models/wan/__init__.py +0 -0
  94. diffsynth_engine/models/wan/wan_dit.py +497 -0
  95. diffsynth_engine/models/wan/wan_image_encoder.py +494 -0
  96. diffsynth_engine/models/wan/wan_text_encoder.py +297 -0
  97. diffsynth_engine/models/wan/wan_vae.py +771 -0
  98. diffsynth_engine/pipelines/__init__.py +18 -0
  99. diffsynth_engine/pipelines/base.py +253 -0
  100. diffsynth_engine/pipelines/flux_image.py +512 -0
  101. diffsynth_engine/pipelines/sd_image.py +352 -0
  102. diffsynth_engine/pipelines/sdxl_image.py +395 -0
  103. diffsynth_engine/pipelines/wan_video.py +524 -0
  104. diffsynth_engine/tokenizers/__init__.py +6 -0
  105. diffsynth_engine/tokenizers/base.py +157 -0
  106. diffsynth_engine/tokenizers/clip.py +288 -0
  107. diffsynth_engine/tokenizers/t5.py +194 -0
  108. diffsynth_engine/tokenizers/wan.py +74 -0
  109. diffsynth_engine/utils/__init__.py +0 -0
  110. diffsynth_engine/utils/constants.py +34 -0
  111. diffsynth_engine/utils/download.py +135 -0
  112. diffsynth_engine/utils/env.py +7 -0
  113. diffsynth_engine/utils/flag.py +46 -0
  114. diffsynth_engine/utils/fp8_linear.py +64 -0
  115. diffsynth_engine/utils/gguf.py +415 -0
  116. diffsynth_engine/utils/loader.py +17 -0
  117. diffsynth_engine/utils/lock.py +56 -0
  118. diffsynth_engine/utils/logging.py +12 -0
  119. diffsynth_engine/utils/offload.py +44 -0
  120. diffsynth_engine/utils/parallel.py +390 -0
  121. diffsynth_engine/utils/prompt.py +9 -0
  122. diffsynth_engine/utils/video.py +40 -0
  123. diffsynth_engine-0.0.0.dist-info/LICENSE +201 -0
  124. diffsynth_engine-0.0.0.dist-info/METADATA +236 -0
  125. diffsynth_engine-0.0.0.dist-info/RECORD +127 -0
  126. diffsynth_engine-0.0.0.dist-info/WHEEL +5 -0
  127. diffsynth_engine-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,288 @@
1
+ import os
2
+ import json
3
+ import ftfy
4
+ import regex as re
5
+ import torch
6
+ from functools import lru_cache
7
+ from typing import Dict, List, Union, Optional
8
+
9
+ from diffsynth_engine.tokenizers.base import BaseTokenizer, TOKENIZER_CONFIG_FILE
10
+
11
+
12
+ VOCAB_FILES_NAMES = {
13
+ "vocab_file": "vocab.json",
14
+ "merges_file": "merges.txt",
15
+ }
16
+
17
+ CLIP_DEFAULT_MAX_LENGTH = 77
18
+
19
+
20
+ @lru_cache()
21
+ def bytes_to_unicode():
22
+ """
23
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
24
+ characters the bpe code barfs on.
25
+
26
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
27
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
28
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
29
+ tables between utf-8 bytes and unicode strings.
30
+ """
31
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
32
+ cs = bs[:]
33
+ n = 0
34
+ for b in range(2**8):
35
+ if b not in bs:
36
+ bs.append(b)
37
+ cs.append(2**8 + n)
38
+ n += 1
39
+ cs = [chr(n) for n in cs]
40
+ return dict(zip(bs, cs))
41
+
42
+
43
+ def get_pairs(word):
44
+ """
45
+ Return set of symbol pairs in a word.
46
+
47
+ Word is represented as tuple of symbols (symbols being variable-length strings).
48
+ """
49
+ pairs = set()
50
+ prev_char = word[0]
51
+ for char in word[1:]:
52
+ pairs.add((prev_char, char))
53
+ prev_char = char
54
+ return pairs
55
+
56
+
57
+ def whitespace_clean(text):
58
+ text = re.sub(r"\s+", " ", text)
59
+ text = text.strip()
60
+ return text
61
+
62
+
63
+ # Modified from transformers.models.clip.tokenization_clip.CLIPTokenizer and open_clip.tokenizer.SimpleTokenizer
64
+ class CLIPTokenizer(BaseTokenizer):
65
+ """
66
+ Construct a CLIP tokenizer. Based on byte-level Byte-Pair-Encoding.
67
+
68
+ Args:
69
+ vocab_file (`str`):
70
+ Path to the vocabulary file.
71
+ merges_file (`str`):
72
+ Path to the merges file.
73
+ bos_token (`str`, *optional*, defaults to `"<|startoftext|>"`):
74
+ The beginning of sequence token.
75
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
76
+ The end of sequence token.
77
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
78
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
79
+ token instead.
80
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
81
+ The token used for padding, for example when batching sequences of different lengths.
82
+ """
83
+
84
+ vocab_files_names = VOCAB_FILES_NAMES
85
+ model_input_names = ["input_ids", "attention_mask"]
86
+
87
+ def __init__(
88
+ self,
89
+ vocab_file: str,
90
+ merges_file: str,
91
+ bos_token: Optional[str] = "<|startoftext|>",
92
+ eos_token: Optional[str] = "<|endoftext|>",
93
+ unk_token: Optional[str] = "<|endoftext|>",
94
+ pad_token: Optional[str] = "<|endoftext|>", # hack to enable padding
95
+ **kwargs,
96
+ ):
97
+ super().__init__(
98
+ unk_token=unk_token,
99
+ bos_token=bos_token,
100
+ eos_token=eos_token,
101
+ pad_token=pad_token,
102
+ **kwargs,
103
+ )
104
+
105
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
106
+ self.encoder = json.load(vocab_handle)
107
+ self.decoder = {v: k for k, v in self.encoder.items()}
108
+ self.byte_encoder = bytes_to_unicode()
109
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
110
+ with open(merges_file, encoding="utf-8") as merges_handle:
111
+ bpe_merges = merges_handle.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
112
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
113
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
114
+ self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
115
+
116
+ self.pat = re.compile(
117
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
118
+ re.IGNORECASE,
119
+ )
120
+
121
+ self.model_max_length = self.model_max_length if self.model_max_length else CLIP_DEFAULT_MAX_LENGTH
122
+
123
+ @classmethod
124
+ def from_pretrained(cls, pretrained_model_path: Union[str, os.PathLike], **kwargs):
125
+ tokenizer_config_file = os.path.join(pretrained_model_path, TOKENIZER_CONFIG_FILE)
126
+ with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
127
+ init_kwargs = json.load(tokenizer_config_handle)
128
+ init_kwargs.update(**kwargs)
129
+ vocab_file = os.path.join(pretrained_model_path, cls.vocab_files_names["vocab_file"])
130
+ merges_file = os.path.join(pretrained_model_path, cls.vocab_files_names["merges_file"])
131
+ return cls(vocab_file=vocab_file, merges_file=merges_file, **init_kwargs)
132
+
133
+ @property
134
+ def vocab_size(self):
135
+ return len(self.encoder)
136
+
137
+ def get_vocab(self):
138
+ return self.encoder
139
+
140
+ def bpe(self, token):
141
+ if token in self.cache:
142
+ return self.cache[token]
143
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
144
+ pairs = get_pairs(word)
145
+
146
+ if not pairs:
147
+ return token + "</w>"
148
+
149
+ while True:
150
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
151
+ if bigram not in self.bpe_ranks:
152
+ break
153
+ first, second = bigram
154
+ new_word = []
155
+ i = 0
156
+ while i < len(word):
157
+ try:
158
+ j = word.index(first, i)
159
+ except ValueError:
160
+ new_word.extend(word[i:])
161
+ break
162
+ else:
163
+ new_word.extend(word[i:j])
164
+ i = j
165
+
166
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
167
+ new_word.append(first + second)
168
+ i += 2
169
+ else:
170
+ new_word.append(word[i])
171
+ i += 1
172
+ new_word = tuple(new_word)
173
+ word = new_word
174
+ if len(word) == 1:
175
+ break
176
+ else:
177
+ pairs = get_pairs(word)
178
+ word = " ".join(word)
179
+ self.cache[token] = word
180
+ return word
181
+
182
+ def tokenize(self, texts: Union[str, List[str]]) -> Union[List[str], List[List[str]]]:
183
+ """Convert string to tokens."""
184
+ if isinstance(texts, str):
185
+ return self._tokenize(texts)
186
+
187
+ return [self._tokenize(text) for text in texts]
188
+
189
+ def _tokenize(self, text: str) -> List[str]:
190
+ bpe_tokens = []
191
+ text = whitespace_clean(ftfy.fix_text(text)).lower()
192
+
193
+ for token in re.findall(self.pat, text):
194
+ token = "".join(
195
+ self.byte_encoder[b] for b in token.encode("utf-8")
196
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
197
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
198
+ return bpe_tokens
199
+
200
+ def encode(self, texts: str) -> List[int]:
201
+ tokens = self.tokenize(texts)
202
+ return self.convert_tokens_to_ids(tokens)
203
+
204
+ def batch_encode(self, texts: List[str]) -> List[List[int]]:
205
+ return [self.encode(text) for text in texts]
206
+
207
+ def decode(
208
+ self, ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None
209
+ ) -> str:
210
+ tokens = self.convert_ids_to_tokens(ids, skip_special_tokens)
211
+ text = self.convert_tokens_to_string(tokens)
212
+
213
+ clean_up_tokenization_spaces = (
214
+ clean_up_tokenization_spaces
215
+ if clean_up_tokenization_spaces is not None
216
+ else self.clean_up_tokenization_spaces
217
+ )
218
+ if clean_up_tokenization_spaces:
219
+ text = self.clean_up_tokenization(text)
220
+ return text
221
+
222
+ def batch_decode(
223
+ self, ids: List[List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None
224
+ ) -> List[str]:
225
+ return [self.decode(index, skip_special_tokens, clean_up_tokenization_spaces) for index in ids]
226
+
227
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
228
+ if isinstance(tokens, str):
229
+ return self.encoder.get(tokens, self.encoder.get(self.unk_token))
230
+
231
+ return [self.encoder.get(token, self.encoder.get(self.unk_token)) for token in tokens]
232
+
233
+ def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]:
234
+ if isinstance(ids, int):
235
+ return self.decoder.get(ids)
236
+
237
+ tokens = []
238
+ for index in ids:
239
+ if skip_special_tokens and index in self.all_special_ids:
240
+ continue
241
+ tokens.append(self.decoder.get(index))
242
+ return tokens
243
+
244
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
245
+ text = "".join(tokens)
246
+ byte_array = bytearray([self.byte_decoder[c] for c in text])
247
+ text = byte_array.decode("utf-8", errors="replace").replace("</w>", " ").strip()
248
+ return text
249
+
250
+ def __call__(
251
+ self,
252
+ texts: Union[str, List[str]],
253
+ max_length: Optional[int] = None,
254
+ **kwargs,
255
+ ) -> Dict[str, "torch.Tensor"]:
256
+ """
257
+ Tokenize text and prepare for model inputs.
258
+
259
+ Args:
260
+ text (`str`, `List[str]`, *optional*):
261
+ The sequence or batch of sequences to be encoded.
262
+
263
+ max_length (`int`, *optional*):
264
+ Each encoded sequence will be truncated or padded to max_length.
265
+
266
+ Returns:
267
+ `Dict[str, "torch.Tensor"]`: tensor dict compatible with model_input_names.
268
+ """
269
+
270
+ if isinstance(texts, str):
271
+ texts = [texts]
272
+
273
+ max_length = max_length if max_length else self.model_max_length
274
+
275
+ encoded = torch.zeros(len(texts), max_length, dtype=torch.long)
276
+ encoded.fill_(self.pad_token_id)
277
+ attention_mask = torch.zeros(len(texts), max_length, dtype=torch.long)
278
+
279
+ for i, text in enumerate(texts):
280
+ tokens = self.tokenize(text)
281
+ ids = [self.bos_token_id] + self.convert_tokens_to_ids(tokens) + [self.eos_token_id]
282
+ if len(ids) > max_length:
283
+ ids = ids[:max_length]
284
+ ids[-1] = self.eos_token_id
285
+ encoded[i, : len(ids)] = torch.tensor(ids)
286
+ attention_mask[i, : len(ids)] = torch.ones((1, len(ids)))
287
+
288
+ return {"input_ids": encoded, "attention_mask": attention_mask}
@@ -0,0 +1,194 @@
1
+ import os
2
+ import json
3
+ import torch
4
+ from typing import Dict, List, Union, Optional
5
+ from tokenizers import Tokenizer as TokenizerFast
6
+
7
+ from diffsynth_engine.tokenizers.base import BaseTokenizer, TOKENIZER_CONFIG_FILE
8
+
9
+
10
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
11
+
12
+ T5_DEFAULT_MAX_LENGTH = 512
13
+
14
+
15
+ class T5TokenizerFast(BaseTokenizer):
16
+ """
17
+ Construct a "fast" T5 tokenizer (backed by HuggingFace's *tokenizers* library). Based on
18
+ [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).
19
+
20
+ Args:
21
+ vocab_file (`str`):
22
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
23
+ contains the vocabulary necessary to instantiate a tokenizer.
24
+ tokenizer_file (`str`):
25
+ Precompiled file for initializing a fast tokenizer.
26
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
27
+ The end of sequence token.
28
+
29
+ <Tip>
30
+
31
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
32
+ The token used is the `sep_token`.
33
+
34
+ </Tip>
35
+
36
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
37
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
38
+ token instead.
39
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
40
+ The token used for padding, for example when batching sequences of different lengths.
41
+ """
42
+
43
+ vocab_files_names = VOCAB_FILES_NAMES
44
+ model_input_names = ["input_ids", "attention_mask"]
45
+
46
+ prefix_tokens: List[int] = []
47
+
48
+ def __init__(
49
+ self,
50
+ vocab_file=None,
51
+ tokenizer_file=None,
52
+ eos_token="</s>",
53
+ unk_token="<unk>",
54
+ pad_token="<pad>",
55
+ **kwargs,
56
+ ):
57
+ super().__init__(
58
+ eos_token=eos_token,
59
+ unk_token=unk_token,
60
+ pad_token=pad_token,
61
+ **kwargs,
62
+ )
63
+
64
+ fast_tokenizer = TokenizerFast.from_file(tokenizer_file)
65
+ self._tokenizer = fast_tokenizer
66
+ # disable truncation and padding
67
+ self._tokenizer.no_truncation()
68
+ self._tokenizer.no_padding()
69
+
70
+ self.model_max_length = self.model_max_length if self.model_max_length else T5_DEFAULT_MAX_LENGTH
71
+
72
+ @classmethod
73
+ def from_pretrained(cls, pretrained_model_path: Union[str, os.PathLike], **kwargs):
74
+ tokenizer_config_file = os.path.join(pretrained_model_path, TOKENIZER_CONFIG_FILE)
75
+ with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
76
+ init_kwargs = json.load(tokenizer_config_handle)
77
+ init_kwargs.update(**kwargs)
78
+ vocab_file = os.path.join(pretrained_model_path, cls.vocab_files_names["vocab_file"])
79
+ tokenizer_file = os.path.join(pretrained_model_path, cls.vocab_files_names["tokenizer_file"])
80
+ return cls(vocab_file=vocab_file, tokenizer_file=tokenizer_file, **init_kwargs)
81
+
82
+ @property
83
+ def vocab_size(self):
84
+ return self._tokenizer.get_vocab_size(with_added_tokens=True)
85
+
86
+ def get_vocab(self):
87
+ return self._tokenizer.get_vocab(with_added_tokens=True)
88
+
89
+ def tokenize(self, texts: Union[str, List[str]]) -> Union[List[str], List[List[str]]]:
90
+ if isinstance(texts, str):
91
+ encoding = self._tokenizer.encode(texts)
92
+ return encoding.tokens
93
+
94
+ encodings = self._tokenizer.encode_batch(texts)
95
+ return [encoding.tokens for encoding in encodings]
96
+
97
+ def encode(self, texts: str) -> List[int]:
98
+ encoding = self._tokenizer.encode(texts, add_special_tokens=True)
99
+ return encoding.ids
100
+
101
+ def batch_encode(self, texts: List[str]) -> List[List[int]]:
102
+ encodings = self._tokenizer.encode_batch(texts, add_special_tokens=True)
103
+ return [encoding.ids for encoding in encodings]
104
+
105
+ def decode(
106
+ self, ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None
107
+ ) -> str:
108
+ text = self._tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)
109
+
110
+ clean_up_tokenization_spaces = (
111
+ clean_up_tokenization_spaces
112
+ if clean_up_tokenization_spaces is not None
113
+ else self.clean_up_tokenization_spaces
114
+ )
115
+ if clean_up_tokenization_spaces:
116
+ text = self.clean_up_tokenization(text)
117
+ return text
118
+
119
+ def batch_decode(
120
+ self, ids: List[List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None
121
+ ) -> List[str]:
122
+ texts = self._tokenizer.decode_batch(ids, skip_special_tokens=skip_special_tokens)
123
+
124
+ clean_up_tokenization_spaces = (
125
+ clean_up_tokenization_spaces
126
+ if clean_up_tokenization_spaces is not None
127
+ else self.clean_up_tokenization_spaces
128
+ )
129
+ if clean_up_tokenization_spaces:
130
+ texts = [self.clean_up_tokenization(text) for text in texts]
131
+ return texts
132
+
133
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
134
+ if isinstance(tokens, str):
135
+ index = self._tokenizer.token_to_id(tokens)
136
+ return index if index is not None else self._tokenizer.token_to_id(self.unk_token)
137
+
138
+ ids = [self._tokenizer.token_to_id(token) for token in tokens]
139
+ return [index if index is not None else self._tokenizer.token_to_id(self.unk_token) for index in ids]
140
+
141
+ def convert_ids_to_tokens(
142
+ self, ids: Union[int, List[int]], skip_special_tokens: bool = False
143
+ ) -> Union[str, List[str]]:
144
+ if isinstance(ids, int):
145
+ return self._tokenizer.id_to_token(ids)
146
+
147
+ tokens = []
148
+ for index in ids:
149
+ if skip_special_tokens and index in self.all_special_ids:
150
+ continue
151
+ tokens.append(self._tokenizer.id_to_token(index))
152
+ return tokens
153
+
154
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
155
+ return self._tokenizer.decode(tokens)
156
+
157
+ def __call__(
158
+ self,
159
+ texts: Union[str, List[str]],
160
+ max_length: Optional[int] = None,
161
+ **kwargs,
162
+ ) -> Dict[str, "torch.Tensor"]:
163
+ """
164
+ Tokenize text and prepare for model inputs.
165
+
166
+ Args:
167
+ text (`str`, `List[str]`, *optional*):
168
+ The sequence or batch of sequences to be encoded.
169
+
170
+ max_length (`int`, *optional*):
171
+ Each encoded sequence will be truncated or padded to max_length.
172
+
173
+ Returns:
174
+ `Dict[str, "torch.Tensor"]`: tensor dict compatible with model_input_names.
175
+ """
176
+
177
+ if isinstance(texts, str):
178
+ texts = [texts]
179
+
180
+ max_length = max_length if max_length else self.model_max_length
181
+
182
+ encoded = torch.zeros(len(texts), max_length, dtype=torch.long)
183
+ encoded.fill_(self.pad_token_id)
184
+ attention_mask = torch.zeros(len(texts), max_length, dtype=torch.long)
185
+
186
+ batch_ids = self.batch_encode(texts)
187
+ for i, ids in enumerate(batch_ids):
188
+ if len(ids) > max_length:
189
+ ids = ids[:max_length]
190
+ ids[-1] = self.eos_token_id
191
+ encoded[i, : len(ids)] = torch.tensor(ids)
192
+ attention_mask[i, : len(ids)] = torch.ones((1, len(ids)))
193
+
194
+ return {"input_ids": encoded, "attention_mask": attention_mask}
@@ -0,0 +1,74 @@
1
+ import html
2
+ import string
3
+
4
+ import ftfy
5
+ import regex as re
6
+ from .t5 import T5TokenizerFast
7
+
8
+
9
+ def basic_clean(text):
10
+ text = ftfy.fix_text(text)
11
+ text = html.unescape(html.unescape(text))
12
+ return text.strip()
13
+
14
+
15
+ def whitespace_clean(text):
16
+ text = re.sub(r"\s+", " ", text)
17
+ text = text.strip()
18
+ return text
19
+
20
+
21
+ def canonicalize(text, keep_punctuation_exact_string=None):
22
+ text = text.replace("_", " ")
23
+ if keep_punctuation_exact_string:
24
+ text = keep_punctuation_exact_string.join(
25
+ part.translate(str.maketrans("", "", string.punctuation))
26
+ for part in text.split(keep_punctuation_exact_string)
27
+ )
28
+ else:
29
+ text = text.translate(str.maketrans("", "", string.punctuation))
30
+ text = text.lower()
31
+ text = re.sub(r"\s+", " ", text)
32
+ return text.strip()
33
+
34
+
35
+ class WanT5Tokenizer:
36
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
37
+ assert clean in (None, "whitespace", "lower", "canonicalize")
38
+ self.name = name
39
+ self.seq_len = seq_len
40
+ self.clean = clean
41
+
42
+ # init tokenizer
43
+ self.tokenizer = T5TokenizerFast.from_pretrained(name, **kwargs)
44
+ self.vocab_size = self.tokenizer.vocab_size
45
+
46
+ def __call__(self, sequence, **kwargs):
47
+ return_mask = kwargs.pop("return_mask", False)
48
+
49
+ # arguments
50
+ _kwargs = {"return_tensors": "pt"}
51
+ if self.seq_len is not None:
52
+ _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
53
+ _kwargs.update(**kwargs)
54
+
55
+ # tokenization
56
+ if isinstance(sequence, str):
57
+ sequence = [sequence]
58
+ if self.clean:
59
+ sequence = [self._clean(u) for u in sequence]
60
+ ids = self.tokenizer(sequence, **_kwargs)
61
+ # output
62
+ if return_mask:
63
+ return ids["input_ids"], ids["attention_mask"]
64
+ else:
65
+ return ids["input_ids"]
66
+
67
+ def _clean(self, text):
68
+ if self.clean == "whitespace":
69
+ text = whitespace_clean(basic_clean(text))
70
+ elif self.clean == "lower":
71
+ text = whitespace_clean(basic_clean(text)).lower()
72
+ elif self.clean == "canonicalize":
73
+ text = canonicalize(basic_clean(text))
74
+ return text
File without changes
@@ -0,0 +1,34 @@
1
+ import os
2
+
3
+ PACKAGE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
4
+ REPO_ROOT = os.path.dirname(PACKAGE_ROOT)
5
+
6
+ # conf
7
+ CONF_PATH = os.path.join(PACKAGE_ROOT, "conf")
8
+ # tokenizers
9
+ FLUX_TOKENIZER_1_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "flux", "tokenizer_1")
10
+ FLUX_TOKENIZER_2_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "flux", "tokenizer_2")
11
+ SDXL_TOKENIZER_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "sdxl", "tokenizer")
12
+ SDXL_TOKENIZER_2_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "sdxl", "tokenizer_2")
13
+ WAN_TOKENIZER_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "wan", "umt5-xxl")
14
+ # models
15
+ VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "components", "vae.json")
16
+ FLUX_DIT_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_dit.json")
17
+ FLUX_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_text_encoder.json")
18
+ FLUX_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_vae.json")
19
+ SD_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd", "sd_text_encoder.json")
20
+ SD_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd", "sd_unet.json")
21
+ SD3_DIT_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd3", "sd3_dit.json")
22
+ SD3_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd3", "sd3_text_encoder.json")
23
+ SDXL_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_text_encoder.json")
24
+ SDXL_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_unet.json")
25
+
26
+ WAN_DIT_1_3B_T2V_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "1.3b-t2v.json")
27
+ WAN_DIT_14B_I2V_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "14b-i2v.json")
28
+ WAN_DIT_14B_T2V_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "14b-t2v.json")
29
+
30
+ # data size
31
+ KB = 1024
32
+ MB = 1024 * KB
33
+ GB = 1024 * MB
34
+ TB = 1024 * GB