diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,69 @@
1
+ from .base_prompter import BasePrompter
2
+ from ..models.model_manager import ModelManager
3
+ from ..models import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
4
+ from transformers import BertTokenizer, AutoTokenizer
5
+ import warnings, os
6
+
7
+
8
+ class HunyuanDiTPrompter(BasePrompter):
9
+ def __init__(
10
+ self,
11
+ tokenizer_path=None,
12
+ tokenizer_t5_path=None
13
+ ):
14
+ if tokenizer_path is None:
15
+ base_path = os.path.dirname(os.path.dirname(__file__))
16
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
17
+ if tokenizer_t5_path is None:
18
+ base_path = os.path.dirname(os.path.dirname(__file__))
19
+ tokenizer_t5_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer_t5")
20
+ super().__init__()
21
+ self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
22
+ with warnings.catch_warnings():
23
+ warnings.simplefilter("ignore")
24
+ self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path)
25
+ self.text_encoder: HunyuanDiTCLIPTextEncoder = None
26
+ self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
27
+
28
+
29
+ def fetch_models(self, text_encoder: HunyuanDiTCLIPTextEncoder = None, text_encoder_t5: HunyuanDiTT5TextEncoder = None):
30
+ self.text_encoder = text_encoder
31
+ self.text_encoder_t5 = text_encoder_t5
32
+
33
+
34
+ def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
35
+ text_inputs = tokenizer(
36
+ prompt,
37
+ padding="max_length",
38
+ max_length=max_length,
39
+ truncation=True,
40
+ return_attention_mask=True,
41
+ return_tensors="pt",
42
+ )
43
+ text_input_ids = text_inputs.input_ids
44
+ attention_mask = text_inputs.attention_mask.to(device)
45
+ prompt_embeds = text_encoder(
46
+ text_input_ids.to(device),
47
+ attention_mask=attention_mask,
48
+ clip_skip=clip_skip
49
+ )
50
+ return prompt_embeds, attention_mask
51
+
52
+
53
+ def encode_prompt(
54
+ self,
55
+ prompt,
56
+ clip_skip=1,
57
+ clip_skip_2=1,
58
+ positive=True,
59
+ device="cuda"
60
+ ):
61
+ prompt = self.process_prompt(prompt, positive=positive)
62
+
63
+ # CLIP
64
+ prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, self.text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
65
+
66
+ # T5
67
+ prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, self.text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
68
+
69
+ return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5
@@ -0,0 +1,353 @@
1
+ from .base_prompter import BasePrompter
2
+ from ..models.model_manager import ModelManager
3
+ import json, os, re
4
+ from typing import List, Optional, Union, Dict
5
+ from sentencepiece import SentencePieceProcessor
6
+ from transformers import PreTrainedTokenizer
7
+ from transformers.utils import PaddingStrategy
8
+ from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
9
+ from ..models.kolors_text_encoder import ChatGLMModel
10
+
11
+
12
+ class SPTokenizer:
13
+ def __init__(self, model_path: str):
14
+ # reload tokenizer
15
+ assert os.path.isfile(model_path), model_path
16
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
17
+
18
+ # BOS / EOS token IDs
19
+ self.n_words: int = self.sp_model.vocab_size()
20
+ self.bos_id: int = self.sp_model.bos_id()
21
+ self.eos_id: int = self.sp_model.eos_id()
22
+ self.pad_id: int = self.sp_model.unk_id()
23
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
24
+
25
+ role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
26
+ special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
27
+ self.special_tokens = {}
28
+ self.index_special_tokens = {}
29
+ for token in special_tokens:
30
+ self.special_tokens[token] = self.n_words
31
+ self.index_special_tokens[self.n_words] = token
32
+ self.n_words += 1
33
+ self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
34
+
35
+ def tokenize(self, s: str, encode_special_tokens=False):
36
+ if encode_special_tokens:
37
+ last_index = 0
38
+ t = []
39
+ for match in re.finditer(self.role_special_token_expression, s):
40
+ if last_index < match.start():
41
+ t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
42
+ t.append(s[match.start():match.end()])
43
+ last_index = match.end()
44
+ if last_index < len(s):
45
+ t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
46
+ return t
47
+ else:
48
+ return self.sp_model.EncodeAsPieces(s)
49
+
50
+ def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
51
+ assert type(s) is str
52
+ t = self.sp_model.encode(s)
53
+ if bos:
54
+ t = [self.bos_id] + t
55
+ if eos:
56
+ t = t + [self.eos_id]
57
+ return t
58
+
59
+ def decode(self, t: List[int]) -> str:
60
+ text, buffer = "", []
61
+ for token in t:
62
+ if token in self.index_special_tokens:
63
+ if buffer:
64
+ text += self.sp_model.decode(buffer)
65
+ buffer = []
66
+ text += self.index_special_tokens[token]
67
+ else:
68
+ buffer.append(token)
69
+ if buffer:
70
+ text += self.sp_model.decode(buffer)
71
+ return text
72
+
73
+ def decode_tokens(self, tokens: List[str]) -> str:
74
+ text = self.sp_model.DecodePieces(tokens)
75
+ return text
76
+
77
+ def convert_token_to_id(self, token):
78
+ """ Converts a token (str) in an id using the vocab. """
79
+ if token in self.special_tokens:
80
+ return self.special_tokens[token]
81
+ return self.sp_model.PieceToId(token)
82
+
83
+ def convert_id_to_token(self, index):
84
+ """Converts an index (integer) in a token (str) using the vocab."""
85
+ if index in self.index_special_tokens:
86
+ return self.index_special_tokens[index]
87
+ if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
88
+ return ""
89
+ return self.sp_model.IdToPiece(index)
90
+
91
+
92
+
93
+ class ChatGLMTokenizer(PreTrainedTokenizer):
94
+ vocab_files_names = {"vocab_file": "tokenizer.model"}
95
+
96
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
97
+
98
+ def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
99
+ **kwargs):
100
+ self.name = "GLMTokenizer"
101
+
102
+ self.vocab_file = vocab_file
103
+ self.tokenizer = SPTokenizer(vocab_file)
104
+ self.special_tokens = {
105
+ "<bos>": self.tokenizer.bos_id,
106
+ "<eos>": self.tokenizer.eos_id,
107
+ "<pad>": self.tokenizer.pad_id
108
+ }
109
+ self.encode_special_tokens = encode_special_tokens
110
+ super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
111
+ encode_special_tokens=encode_special_tokens,
112
+ **kwargs)
113
+
114
+ def get_command(self, token):
115
+ if token in self.special_tokens:
116
+ return self.special_tokens[token]
117
+ assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
118
+ return self.tokenizer.special_tokens[token]
119
+
120
+ @property
121
+ def unk_token(self) -> str:
122
+ return "<unk>"
123
+
124
+ @property
125
+ def pad_token(self) -> str:
126
+ return "<unk>"
127
+
128
+ @property
129
+ def pad_token_id(self):
130
+ return self.get_command("<pad>")
131
+
132
+ @property
133
+ def eos_token(self) -> str:
134
+ return "</s>"
135
+
136
+ @property
137
+ def eos_token_id(self):
138
+ return self.get_command("<eos>")
139
+
140
+ @property
141
+ def vocab_size(self):
142
+ return self.tokenizer.n_words
143
+
144
+ def get_vocab(self):
145
+ """ Returns vocab as a dict """
146
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
147
+ vocab.update(self.added_tokens_encoder)
148
+ return vocab
149
+
150
+ def _tokenize(self, text, **kwargs):
151
+ return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
152
+
153
+ def _convert_token_to_id(self, token):
154
+ """ Converts a token (str) in an id using the vocab. """
155
+ return self.tokenizer.convert_token_to_id(token)
156
+
157
+ def _convert_id_to_token(self, index):
158
+ """Converts an index (integer) in a token (str) using the vocab."""
159
+ return self.tokenizer.convert_id_to_token(index)
160
+
161
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
162
+ return self.tokenizer.decode_tokens(tokens)
163
+
164
+ def save_vocabulary(self, save_directory, filename_prefix=None):
165
+ """
166
+ Save the vocabulary and special tokens file to a directory.
167
+
168
+ Args:
169
+ save_directory (`str`):
170
+ The directory in which to save the vocabulary.
171
+ filename_prefix (`str`, *optional*):
172
+ An optional prefix to add to the named of the saved files.
173
+
174
+ Returns:
175
+ `Tuple(str)`: Paths to the files saved.
176
+ """
177
+ if os.path.isdir(save_directory):
178
+ vocab_file = os.path.join(
179
+ save_directory, self.vocab_files_names["vocab_file"]
180
+ )
181
+ else:
182
+ vocab_file = save_directory
183
+
184
+ with open(self.vocab_file, 'rb') as fin:
185
+ proto_str = fin.read()
186
+
187
+ with open(vocab_file, "wb") as writer:
188
+ writer.write(proto_str)
189
+
190
+ return (vocab_file,)
191
+
192
+ def get_prefix_tokens(self):
193
+ prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
194
+ return prefix_tokens
195
+
196
+ def build_single_message(self, role, metadata, message):
197
+ assert role in ["system", "user", "assistant", "observation"], role
198
+ role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
199
+ message_tokens = self.tokenizer.encode(message)
200
+ tokens = role_tokens + message_tokens
201
+ return tokens
202
+
203
+ def build_chat_input(self, query, history=None, role="user"):
204
+ if history is None:
205
+ history = []
206
+ input_ids = []
207
+ for item in history:
208
+ content = item["content"]
209
+ if item["role"] == "system" and "tools" in item:
210
+ content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
211
+ input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
212
+ input_ids.extend(self.build_single_message(role, "", query))
213
+ input_ids.extend([self.get_command("<|assistant|>")])
214
+ return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
215
+
216
+ def build_inputs_with_special_tokens(
217
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
218
+ ) -> List[int]:
219
+ """
220
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
221
+ adding special tokens. A BERT sequence has the following format:
222
+
223
+ - single sequence: `[CLS] X [SEP]`
224
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
225
+
226
+ Args:
227
+ token_ids_0 (`List[int]`):
228
+ List of IDs to which the special tokens will be added.
229
+ token_ids_1 (`List[int]`, *optional*):
230
+ Optional second list of IDs for sequence pairs.
231
+
232
+ Returns:
233
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
234
+ """
235
+ prefix_tokens = self.get_prefix_tokens()
236
+ token_ids_0 = prefix_tokens + token_ids_0
237
+ if token_ids_1 is not None:
238
+ token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
239
+ return token_ids_0
240
+
241
+ def _pad(
242
+ self,
243
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
244
+ max_length: Optional[int] = None,
245
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
246
+ pad_to_multiple_of: Optional[int] = None,
247
+ return_attention_mask: Optional[bool] = None,
248
+ ) -> dict:
249
+ """
250
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
251
+
252
+ Args:
253
+ encoded_inputs:
254
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
255
+ max_length: maximum length of the returned list and optionally padding length (see below).
256
+ Will truncate by taking into account the special tokens.
257
+ padding_strategy: PaddingStrategy to use for padding.
258
+
259
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
260
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
261
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
262
+ The tokenizer padding sides are defined in self.padding_side:
263
+
264
+ - 'left': pads on the left of the sequences
265
+ - 'right': pads on the right of the sequences
266
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
267
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
268
+ `>= 7.5` (Volta).
269
+ return_attention_mask:
270
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
271
+ """
272
+ # Load from model defaults
273
+ assert self.padding_side == "left"
274
+
275
+ required_input = encoded_inputs[self.model_input_names[0]]
276
+ seq_length = len(required_input)
277
+
278
+ if padding_strategy == PaddingStrategy.LONGEST:
279
+ max_length = len(required_input)
280
+
281
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
282
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
283
+
284
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
285
+
286
+ # Initialize attention mask if not present.
287
+ if "attention_mask" not in encoded_inputs:
288
+ encoded_inputs["attention_mask"] = [1] * seq_length
289
+
290
+ if "position_ids" not in encoded_inputs:
291
+ encoded_inputs["position_ids"] = list(range(seq_length))
292
+
293
+ if needs_to_be_padded:
294
+ difference = max_length - len(required_input)
295
+
296
+ if "attention_mask" in encoded_inputs:
297
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
298
+ if "position_ids" in encoded_inputs:
299
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
300
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
301
+
302
+ return encoded_inputs
303
+
304
+
305
+
306
+ class KolorsPrompter(BasePrompter):
307
+ def __init__(
308
+ self,
309
+ tokenizer_path=None
310
+ ):
311
+ if tokenizer_path is None:
312
+ base_path = os.path.dirname(os.path.dirname(__file__))
313
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/kolors/tokenizer")
314
+ super().__init__()
315
+ self.tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
316
+ self.text_encoder: ChatGLMModel = None
317
+
318
+
319
+ def fetch_models(self, text_encoder: ChatGLMModel = None):
320
+ self.text_encoder = text_encoder
321
+
322
+
323
+ def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
324
+ text_inputs = tokenizer(
325
+ prompt,
326
+ padding="max_length",
327
+ max_length=max_length,
328
+ truncation=True,
329
+ return_tensors="pt",
330
+ ).to(device)
331
+ output = text_encoder(
332
+ input_ids=text_inputs['input_ids'] ,
333
+ attention_mask=text_inputs['attention_mask'],
334
+ position_ids=text_inputs['position_ids'],
335
+ output_hidden_states=True
336
+ )
337
+ prompt_emb = output.hidden_states[-clip_skip].permute(1, 0, 2).clone()
338
+ pooled_prompt_emb = output.hidden_states[-1][-1, :, :].clone()
339
+ return prompt_emb, pooled_prompt_emb
340
+
341
+
342
+ def encode_prompt(
343
+ self,
344
+ prompt,
345
+ clip_skip=1,
346
+ clip_skip_2=2,
347
+ positive=True,
348
+ device="cuda"
349
+ ):
350
+ prompt = self.process_prompt(prompt, positive=positive)
351
+ prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, self.text_encoder, self.tokenizer, 256, clip_skip_2, device)
352
+
353
+ return pooled_prompt_emb, prompt_emb
@@ -0,0 +1,77 @@
1
+ from transformers import AutoTokenizer
2
+ from ..models.model_manager import ModelManager
3
+ import torch
4
+
5
+
6
+
7
+ class BeautifulPrompt(torch.nn.Module):
8
+ def __init__(self, tokenizer_path=None, model=None, template=""):
9
+ super().__init__()
10
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
11
+ self.model = model
12
+ self.template = template
13
+
14
+
15
+ @staticmethod
16
+ def from_model_manager(model_nameger: ModelManager):
17
+ model, model_path = model_nameger.fetch_model("beautiful_prompt", require_model_path=True)
18
+ template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
19
+ if model_path.endswith("v2"):
20
+ template = """Converts a simple image description into a prompt. \
21
+ Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
22
+ or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
23
+ but make sure there is a correlation between the input and output.\n\
24
+ ### Input: {raw_prompt}\n### Output:"""
25
+ beautiful_prompt = BeautifulPrompt(
26
+ tokenizer_path=model_path,
27
+ model=model,
28
+ template=template
29
+ )
30
+ return beautiful_prompt
31
+
32
+
33
+ def __call__(self, raw_prompt, positive=True, **kwargs):
34
+ if positive:
35
+ model_input = self.template.format(raw_prompt=raw_prompt)
36
+ input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
37
+ outputs = self.model.generate(
38
+ input_ids,
39
+ max_new_tokens=384,
40
+ do_sample=True,
41
+ temperature=0.9,
42
+ top_k=50,
43
+ top_p=0.95,
44
+ repetition_penalty=1.1,
45
+ num_return_sequences=1
46
+ )
47
+ prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
48
+ outputs[:, input_ids.size(1):],
49
+ skip_special_tokens=True
50
+ )[0].strip()
51
+ print(f"Your prompt is refined by BeautifulPrompt: {prompt}")
52
+ return prompt
53
+ else:
54
+ return raw_prompt
55
+
56
+
57
+
58
+ class Translator(torch.nn.Module):
59
+ def __init__(self, tokenizer_path=None, model=None):
60
+ super().__init__()
61
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
62
+ self.model = model
63
+
64
+
65
+ @staticmethod
66
+ def from_model_manager(model_nameger: ModelManager):
67
+ model, model_path = model_nameger.fetch_model("translator", require_model_path=True)
68
+ translator = Translator(tokenizer_path=model_path, model=model)
69
+ return translator
70
+
71
+
72
+ def __call__(self, prompt, **kwargs):
73
+ input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
74
+ output_ids = self.model.generate(input_ids)
75
+ prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
76
+ print(f"Your prompt is translated: {prompt}")
77
+ return prompt
@@ -0,0 +1,92 @@
1
+ from .base_prompter import BasePrompter
2
+ from ..models.model_manager import ModelManager
3
+ from ..models import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
4
+ from transformers import CLIPTokenizer, T5TokenizerFast
5
+ import os, torch
6
+
7
+
8
+ class SD3Prompter(BasePrompter):
9
+ def __init__(
10
+ self,
11
+ tokenizer_1_path=None,
12
+ tokenizer_2_path=None,
13
+ tokenizer_3_path=None
14
+ ):
15
+ if tokenizer_1_path is None:
16
+ base_path = os.path.dirname(os.path.dirname(__file__))
17
+ tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_1")
18
+ if tokenizer_2_path is None:
19
+ base_path = os.path.dirname(os.path.dirname(__file__))
20
+ tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_2")
21
+ if tokenizer_3_path is None:
22
+ base_path = os.path.dirname(os.path.dirname(__file__))
23
+ tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
24
+ super().__init__()
25
+ self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
26
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
27
+ self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
28
+ self.text_encoder_1: SD3TextEncoder1 = None
29
+ self.text_encoder_2: SD3TextEncoder2 = None
30
+ self.text_encoder_3: SD3TextEncoder3 = None
31
+
32
+
33
+ def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: SD3TextEncoder2 = None, text_encoder_3: SD3TextEncoder3 = None):
34
+ self.text_encoder_1 = text_encoder_1
35
+ self.text_encoder_2 = text_encoder_2
36
+ self.text_encoder_3 = text_encoder_3
37
+
38
+
39
+ def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
40
+ input_ids = tokenizer(
41
+ prompt,
42
+ return_tensors="pt",
43
+ padding="max_length",
44
+ max_length=max_length,
45
+ truncation=True
46
+ ).input_ids.to(device)
47
+ pooled_prompt_emb, prompt_emb = text_encoder(input_ids)
48
+ return pooled_prompt_emb, prompt_emb
49
+
50
+
51
+ def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
52
+ input_ids = tokenizer(
53
+ prompt,
54
+ return_tensors="pt",
55
+ padding="max_length",
56
+ max_length=max_length,
57
+ truncation=True,
58
+ add_special_tokens=True,
59
+ ).input_ids.to(device)
60
+ prompt_emb = text_encoder(input_ids)
61
+ prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
62
+
63
+ return prompt_emb
64
+
65
+
66
+ def encode_prompt(
67
+ self,
68
+ prompt,
69
+ positive=True,
70
+ device="cuda"
71
+ ):
72
+ prompt = self.process_prompt(prompt, positive=positive)
73
+
74
+ # CLIP
75
+ pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
76
+ pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, self.text_encoder_2, self.tokenizer_2, 77, device)
77
+
78
+ # T5
79
+ if self.text_encoder_3 is None:
80
+ prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
81
+ else:
82
+ prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, 256, device)
83
+ prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
84
+
85
+ # Merge
86
+ prompt_emb = torch.cat([
87
+ torch.nn.functional.pad(torch.cat([prompt_emb_1, prompt_emb_2], dim=-1), (0, 4096 - 768 - 1280)),
88
+ prompt_emb_3
89
+ ], dim=-2)
90
+ pooled_prompt_emb = torch.cat([pooled_prompt_emb_1, pooled_prompt_emb_2], dim=-1)
91
+
92
+ return prompt_emb, pooled_prompt_emb
@@ -0,0 +1,73 @@
1
+ from .base_prompter import BasePrompter, tokenize_long_prompt
2
+ from ..models.model_manager import ModelManager, load_state_dict, search_for_embeddings
3
+ from ..models import SDTextEncoder
4
+ from transformers import CLIPTokenizer
5
+ import torch, os
6
+
7
+
8
+
9
+ class SDPrompter(BasePrompter):
10
+ def __init__(self, tokenizer_path=None):
11
+ if tokenizer_path is None:
12
+ base_path = os.path.dirname(os.path.dirname(__file__))
13
+ tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
14
+ super().__init__()
15
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
16
+ self.text_encoder: SDTextEncoder = None
17
+ self.textual_inversion_dict = {}
18
+ self.keyword_dict = {}
19
+
20
+
21
+ def fetch_models(self, text_encoder: SDTextEncoder = None):
22
+ self.text_encoder = text_encoder
23
+
24
+
25
+ def add_textual_inversions_to_model(self, textual_inversion_dict, text_encoder):
26
+ dtype = next(iter(text_encoder.parameters())).dtype
27
+ state_dict = text_encoder.token_embedding.state_dict()
28
+ token_embeddings = [state_dict["weight"]]
29
+ for keyword in textual_inversion_dict:
30
+ _, embeddings = textual_inversion_dict[keyword]
31
+ token_embeddings.append(embeddings.to(dtype=dtype, device=token_embeddings[0].device))
32
+ token_embeddings = torch.concat(token_embeddings, dim=0)
33
+ state_dict["weight"] = token_embeddings
34
+ text_encoder.token_embedding = torch.nn.Embedding(token_embeddings.shape[0], token_embeddings.shape[1])
35
+ text_encoder.token_embedding = text_encoder.token_embedding.to(dtype=dtype, device=token_embeddings[0].device)
36
+ text_encoder.token_embedding.load_state_dict(state_dict)
37
+
38
+
39
+ def add_textual_inversions_to_tokenizer(self, textual_inversion_dict, tokenizer):
40
+ additional_tokens = []
41
+ for keyword in textual_inversion_dict:
42
+ tokens, _ = textual_inversion_dict[keyword]
43
+ additional_tokens += tokens
44
+ self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
45
+ tokenizer.add_tokens(additional_tokens)
46
+
47
+
48
+ def load_textual_inversions(self, model_paths):
49
+ for model_path in model_paths:
50
+ keyword = os.path.splitext(os.path.split(model_path)[-1])[0]
51
+ state_dict = load_state_dict(model_path)
52
+
53
+ # Search for embeddings
54
+ for embeddings in search_for_embeddings(state_dict):
55
+ if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
56
+ tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
57
+ self.textual_inversion_dict[keyword] = (tokens, embeddings)
58
+
59
+ self.add_textual_inversions_to_model(self.textual_inversion_dict, self.text_encoder)
60
+ self.add_textual_inversions_to_tokenizer(self.textual_inversion_dict, self.tokenizer)
61
+
62
+
63
+ def encode_prompt(self, prompt, clip_skip=1, device="cuda", positive=True):
64
+ prompt = self.process_prompt(prompt, positive=positive)
65
+ for keyword in self.keyword_dict:
66
+ if keyword in prompt:
67
+ print(f"Textual inversion {keyword} is enabled.")
68
+ prompt = prompt.replace(keyword, self.keyword_dict[keyword])
69
+ input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
70
+ prompt_emb = self.text_encoder(input_ids, clip_skip=clip_skip)
71
+ prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
72
+
73
+ return prompt_emb