diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +12 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
- diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
- diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/__init__.py +16 -1
- diffsynth_engine/configs/controlnet.py +13 -0
- diffsynth_engine/configs/pipeline.py +37 -11
- diffsynth_engine/models/base.py +1 -1
- diffsynth_engine/models/basic/attention.py +105 -43
- diffsynth_engine/models/basic/transformer_helper.py +36 -2
- diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
- diffsynth_engine/models/flux/flux_controlnet.py +16 -30
- diffsynth_engine/models/flux/flux_dit.py +49 -62
- diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
- diffsynth_engine/models/flux/flux_vae.py +20 -2
- diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
- diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
- diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
- diffsynth_engine/models/sd/sd_unet.py +1 -1
- diffsynth_engine/models/sd3/sd3_dit.py +1 -1
- diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
- diffsynth_engine/models/vae/vae.py +1 -1
- diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
- diffsynth_engine/models/wan/wan_dit.py +65 -28
- diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
- diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
- diffsynth_engine/models/wan/wan_vae.py +2 -2
- diffsynth_engine/pipelines/base.py +73 -7
- diffsynth_engine/pipelines/flux_image.py +139 -120
- diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
- diffsynth_engine/pipelines/qwen_image.py +272 -87
- diffsynth_engine/pipelines/sdxl_image.py +1 -1
- diffsynth_engine/pipelines/utils.py +52 -0
- diffsynth_engine/pipelines/wan_s2v.py +25 -14
- diffsynth_engine/pipelines/wan_video.py +43 -19
- diffsynth_engine/tokenizers/base.py +6 -0
- diffsynth_engine/tokenizers/qwen2.py +12 -4
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/download.py +4 -2
- diffsynth_engine/utils/env.py +2 -0
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/loader.py +25 -6
- diffsynth_engine/utils/parallel.py +62 -29
- diffsynth_engine/utils/video.py +3 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
|
@@ -2,12 +2,17 @@ import json
|
|
|
2
2
|
import torch
|
|
3
3
|
import torch.distributed as dist
|
|
4
4
|
import math
|
|
5
|
-
from typing import Callable, List, Tuple, Optional, Union
|
|
5
|
+
from typing import Callable, List, Dict, Tuple, Optional, Union
|
|
6
6
|
from tqdm import tqdm
|
|
7
7
|
from einops import rearrange
|
|
8
8
|
from PIL import Image
|
|
9
9
|
|
|
10
|
-
from diffsynth_engine.configs import
|
|
10
|
+
from diffsynth_engine.configs import (
|
|
11
|
+
QwenImagePipelineConfig,
|
|
12
|
+
QwenImageStateDicts,
|
|
13
|
+
QwenImageControlNetParams,
|
|
14
|
+
QwenImageControlType,
|
|
15
|
+
)
|
|
11
16
|
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
12
17
|
from diffsynth_engine.models.qwen_image import (
|
|
13
18
|
QwenImageDiT,
|
|
@@ -19,7 +24,7 @@ from diffsynth_engine.models.qwen_image import (
|
|
|
19
24
|
from diffsynth_engine.models.qwen_image import QwenImageVAE
|
|
20
25
|
from diffsynth_engine.tokenizers import Qwen2TokenizerFast, Qwen2VLProcessor
|
|
21
26
|
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
22
|
-
from diffsynth_engine.pipelines.utils import calculate_shift
|
|
27
|
+
from diffsynth_engine.pipelines.utils import calculate_shift, pad_and_concat
|
|
23
28
|
from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
|
|
24
29
|
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
|
|
25
30
|
from diffsynth_engine.utils.constants import (
|
|
@@ -71,6 +76,39 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
|
|
|
71
76
|
lora_args["alpha"] = alpha
|
|
72
77
|
|
|
73
78
|
key = key.replace(f".{lora_a_suffix}", "")
|
|
79
|
+
key = key.replace("base_model.model.", "")
|
|
80
|
+
|
|
81
|
+
if key.startswith("transformer") and "attn.to_out.0" in key:
|
|
82
|
+
key = key.replace("attn.to_out.0", "attn.to_out")
|
|
83
|
+
dit_dict[key] = lora_args
|
|
84
|
+
return {"dit": dit_dict}
|
|
85
|
+
|
|
86
|
+
def _from_diffusers(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
87
|
+
dit_dict = {}
|
|
88
|
+
for key, param in lora_state_dict.items():
|
|
89
|
+
origin_key = key
|
|
90
|
+
lora_a_suffix = None
|
|
91
|
+
if "lora_A.weight" in key:
|
|
92
|
+
lora_a_suffix = "lora_A.weight"
|
|
93
|
+
lora_b_suffix = "lora_B.weight"
|
|
94
|
+
|
|
95
|
+
if lora_a_suffix is None:
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
lora_args = {}
|
|
99
|
+
lora_args["down"] = param
|
|
100
|
+
lora_args["up"] = lora_state_dict[origin_key.replace(lora_a_suffix, lora_b_suffix)]
|
|
101
|
+
lora_args["rank"] = lora_args["up"].shape[1]
|
|
102
|
+
alpha_key = origin_key.replace(lora_a_suffix, "alpha")
|
|
103
|
+
|
|
104
|
+
if alpha_key in lora_state_dict:
|
|
105
|
+
alpha = lora_state_dict[alpha_key]
|
|
106
|
+
else:
|
|
107
|
+
alpha = lora_args["rank"]
|
|
108
|
+
lora_args["alpha"] = alpha
|
|
109
|
+
|
|
110
|
+
key = key.replace(f".{lora_a_suffix}", "")
|
|
111
|
+
key = key.replace("diffusion_model.", "")
|
|
74
112
|
|
|
75
113
|
if key.startswith("transformer") and "attn.to_out.0" in key:
|
|
76
114
|
key = key.replace("attn.to_out.0", "attn.to_out")
|
|
@@ -78,7 +116,11 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
|
|
|
78
116
|
return {"dit": dit_dict}
|
|
79
117
|
|
|
80
118
|
def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
81
|
-
|
|
119
|
+
key = list(lora_state_dict.keys())[0]
|
|
120
|
+
if key.startswith("diffusion_model."):
|
|
121
|
+
return self._from_diffusers(lora_state_dict)
|
|
122
|
+
else:
|
|
123
|
+
return self._from_diffsynth(lora_state_dict)
|
|
82
124
|
|
|
83
125
|
|
|
84
126
|
class QwenImagePipeline(BasePipeline):
|
|
@@ -101,11 +143,25 @@ class QwenImagePipeline(BasePipeline):
|
|
|
101
143
|
dtype=config.model_dtype,
|
|
102
144
|
)
|
|
103
145
|
self.config = config
|
|
146
|
+
# qwen image
|
|
104
147
|
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
|
105
148
|
self.prompt_template_encode_start_idx = 34
|
|
149
|
+
# qwen image edit
|
|
150
|
+
self.edit_system_prompt = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
|
|
151
|
+
self.edit_prompt_template_encode = (
|
|
152
|
+
"<|im_start|>system\n"
|
|
153
|
+
+ self.edit_system_prompt
|
|
154
|
+
+ "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
|
155
|
+
)
|
|
156
|
+
# qwen image edit plus
|
|
157
|
+
self.edit_plus_prompt_template_encode = (
|
|
158
|
+
"<|im_start|>system\n"
|
|
159
|
+
+ self.edit_system_prompt
|
|
160
|
+
+ "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
|
161
|
+
)
|
|
106
162
|
|
|
107
|
-
self.edit_prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
|
108
163
|
self.edit_prompt_template_encode_start_idx = 64
|
|
164
|
+
|
|
109
165
|
# sampler
|
|
110
166
|
self.noise_scheduler = RecifitedFlowScheduler(shift=3.0, use_dynamic_shifting=True)
|
|
111
167
|
self.sampler = FlowMatchEulerSampler()
|
|
@@ -138,6 +194,7 @@ class QwenImagePipeline(BasePipeline):
|
|
|
138
194
|
logger.info(f"loading state dict from {config.vae_path} ...")
|
|
139
195
|
vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
|
|
140
196
|
|
|
197
|
+
encoder_state_dict = None
|
|
141
198
|
if config.encoder_path is None:
|
|
142
199
|
config.encoder_path = fetch_model(
|
|
143
200
|
"MusePublic/Qwen-image",
|
|
@@ -149,8 +206,11 @@ class QwenImagePipeline(BasePipeline):
|
|
|
149
206
|
"text_encoder/model-00004-of-00004.safetensors",
|
|
150
207
|
],
|
|
151
208
|
)
|
|
152
|
-
|
|
153
|
-
|
|
209
|
+
if config.load_encoder:
|
|
210
|
+
logger.info(f"loading state dict from {config.encoder_path} ...")
|
|
211
|
+
encoder_state_dict = cls.load_model_checkpoint(
|
|
212
|
+
config.encoder_path, device="cpu", dtype=config.encoder_dtype
|
|
213
|
+
)
|
|
154
214
|
|
|
155
215
|
state_dicts = QwenImageStateDicts(
|
|
156
216
|
model=model_state_dict,
|
|
@@ -177,50 +237,44 @@ class QwenImagePipeline(BasePipeline):
|
|
|
177
237
|
@classmethod
|
|
178
238
|
def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipelineConfig) -> "QwenImagePipeline":
|
|
179
239
|
init_device = "cpu" if config.offload_mode is not None else config.device
|
|
180
|
-
tokenizer =
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
240
|
+
tokenizer, processor, encoder = None, None, None
|
|
241
|
+
if config.load_encoder:
|
|
242
|
+
tokenizer = Qwen2TokenizerFast.from_pretrained(QWEN_IMAGE_TOKENIZER_CONF_PATH)
|
|
243
|
+
processor = Qwen2VLProcessor.from_pretrained(
|
|
244
|
+
tokenizer_config_path=QWEN_IMAGE_TOKENIZER_CONF_PATH,
|
|
245
|
+
image_processor_config_path=QWEN_IMAGE_PROCESSOR_CONFIG_FILE,
|
|
246
|
+
)
|
|
247
|
+
with open(QWEN_IMAGE_VISION_CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
248
|
+
vision_config = Qwen2_5_VLVisionConfig(**json.load(f))
|
|
249
|
+
with open(QWEN_IMAGE_CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
250
|
+
text_config = Qwen2_5_VLConfig(**json.load(f))
|
|
251
|
+
encoder = Qwen2_5_VLForConditionalGeneration.from_state_dict(
|
|
252
|
+
state_dicts.encoder,
|
|
253
|
+
vision_config=vision_config,
|
|
254
|
+
config=text_config,
|
|
255
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
256
|
+
dtype=config.encoder_dtype,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
with open(QWEN_IMAGE_VAE_CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
197
260
|
vae_config = json.load(f)
|
|
198
261
|
vae = QwenImageVAE.from_state_dict(
|
|
199
262
|
state_dicts.vae, config=vae_config, device=init_device, dtype=config.vae_dtype
|
|
200
263
|
)
|
|
201
264
|
|
|
202
265
|
with LoRAContext():
|
|
203
|
-
attn_kwargs = {
|
|
204
|
-
"attn_impl": config.dit_attn_impl,
|
|
205
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
206
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
207
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
208
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
209
|
-
}
|
|
210
266
|
if config.use_fbcache:
|
|
211
267
|
dit = QwenImageDiTFBCache.from_state_dict(
|
|
212
268
|
state_dicts.model,
|
|
213
|
-
device=init_device,
|
|
269
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
214
270
|
dtype=config.model_dtype,
|
|
215
|
-
attn_kwargs=attn_kwargs,
|
|
216
271
|
relative_l1_threshold=config.fbcache_relative_l1_threshold,
|
|
217
272
|
)
|
|
218
273
|
else:
|
|
219
274
|
dit = QwenImageDiT.from_state_dict(
|
|
220
275
|
state_dicts.model,
|
|
221
|
-
device=init_device,
|
|
276
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
222
277
|
dtype=config.model_dtype,
|
|
223
|
-
attn_kwargs=attn_kwargs,
|
|
224
278
|
)
|
|
225
279
|
if config.use_fp8_linear:
|
|
226
280
|
enable_fp8_linear(dit)
|
|
@@ -254,8 +308,13 @@ class QwenImagePipeline(BasePipeline):
|
|
|
254
308
|
pipe.compile()
|
|
255
309
|
return pipe
|
|
256
310
|
|
|
311
|
+
def update_weights(self, state_dicts: QwenImageStateDicts) -> None:
|
|
312
|
+
self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
|
|
313
|
+
self.update_component(self.encoder, state_dicts.encoder, self.config.device, self.config.encoder_dtype)
|
|
314
|
+
self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
|
|
315
|
+
|
|
257
316
|
def compile(self):
|
|
258
|
-
self.dit.compile_repeated_blocks(
|
|
317
|
+
self.dit.compile_repeated_blocks()
|
|
259
318
|
|
|
260
319
|
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
|
|
261
320
|
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
|
|
@@ -270,6 +329,10 @@ class QwenImagePipeline(BasePipeline):
|
|
|
270
329
|
|
|
271
330
|
def unload_loras(self):
|
|
272
331
|
self.dit.unload_loras()
|
|
332
|
+
self.noise_scheduler.restore_config()
|
|
333
|
+
|
|
334
|
+
def apply_scheduler_config(self, scheduler_config: Dict):
|
|
335
|
+
self.noise_scheduler.update_config(scheduler_config)
|
|
273
336
|
|
|
274
337
|
def prepare_latents(
|
|
275
338
|
self,
|
|
@@ -307,32 +370,43 @@ class QwenImagePipeline(BasePipeline):
|
|
|
307
370
|
input_ids, attention_mask = outputs["input_ids"].to(self.device), outputs["attention_mask"].to(self.device)
|
|
308
371
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
|
309
372
|
hidden_states = outputs["hidden_states"]
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
seq_len =
|
|
373
|
+
prompt_emb = hidden_states[:, drop_idx:]
|
|
374
|
+
prompt_emb_mask = attention_mask[:, drop_idx:]
|
|
375
|
+
seq_len = prompt_emb.shape[1]
|
|
313
376
|
|
|
314
377
|
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
|
315
|
-
|
|
316
|
-
|
|
378
|
+
prompt_emb = prompt_emb.repeat(1, num_images_per_prompt, 1)
|
|
379
|
+
prompt_emb = prompt_emb.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
317
380
|
|
|
318
|
-
|
|
319
|
-
|
|
381
|
+
prompt_emb_mask = prompt_emb_mask.repeat(1, num_images_per_prompt, 1)
|
|
382
|
+
prompt_emb_mask = prompt_emb_mask.view(batch_size * num_images_per_prompt, seq_len)
|
|
320
383
|
|
|
321
|
-
return
|
|
384
|
+
return prompt_emb, prompt_emb_mask
|
|
322
385
|
|
|
323
386
|
def encode_prompt_with_image(
|
|
324
387
|
self,
|
|
325
388
|
prompt: Union[str, List[str]],
|
|
326
|
-
|
|
389
|
+
vae_image: List[torch.Tensor],
|
|
390
|
+
condition_image: List[torch.Tensor], # edit plus
|
|
327
391
|
num_images_per_prompt: int = 1,
|
|
328
392
|
max_sequence_length: int = 1024,
|
|
393
|
+
is_edit_plus: bool = True,
|
|
329
394
|
):
|
|
330
395
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
331
396
|
|
|
332
397
|
batch_size = len(prompt)
|
|
333
398
|
template = self.edit_prompt_template_encode
|
|
334
399
|
drop_idx = self.edit_prompt_template_encode_start_idx
|
|
335
|
-
|
|
400
|
+
if not is_edit_plus:
|
|
401
|
+
template = self.edit_prompt_template_encode
|
|
402
|
+
texts = [template.format(txt) for txt in prompt]
|
|
403
|
+
image = vae_image
|
|
404
|
+
else:
|
|
405
|
+
template = self.edit_plus_prompt_template_encode
|
|
406
|
+
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
|
|
407
|
+
img_prompt = "".join([img_prompt_template.format(i + 1) for i in range(len(condition_image))])
|
|
408
|
+
texts = [template.format(img_prompt + e) for e in prompt]
|
|
409
|
+
image = condition_image
|
|
336
410
|
|
|
337
411
|
model_inputs = self.processor(text=texts, images=image, max_length=max_sequence_length + drop_idx)
|
|
338
412
|
input_ids, attention_mask, pixel_values, image_grid_thw = (
|
|
@@ -348,18 +422,18 @@ class QwenImagePipeline(BasePipeline):
|
|
|
348
422
|
image_grid_thw=image_grid_thw,
|
|
349
423
|
)
|
|
350
424
|
hidden_states = outputs["hidden_states"]
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
seq_len =
|
|
425
|
+
prompt_emb = hidden_states[:, drop_idx:]
|
|
426
|
+
prompt_emb_mask = attention_mask[:, drop_idx:]
|
|
427
|
+
seq_len = prompt_emb.shape[1]
|
|
354
428
|
|
|
355
429
|
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
|
356
|
-
|
|
357
|
-
|
|
430
|
+
prompt_emb = prompt_emb.repeat(1, num_images_per_prompt, 1)
|
|
431
|
+
prompt_emb = prompt_emb.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
358
432
|
|
|
359
|
-
|
|
360
|
-
|
|
433
|
+
prompt_emb_mask = prompt_emb_mask.repeat(1, num_images_per_prompt, 1)
|
|
434
|
+
prompt_emb_mask = prompt_emb_mask.view(batch_size * num_images_per_prompt, seq_len)
|
|
361
435
|
|
|
362
|
-
return
|
|
436
|
+
return prompt_emb, prompt_emb_mask
|
|
363
437
|
|
|
364
438
|
def predict_noise_with_cfg(
|
|
365
439
|
self,
|
|
@@ -368,9 +442,17 @@ class QwenImagePipeline(BasePipeline):
|
|
|
368
442
|
timestep: torch.Tensor,
|
|
369
443
|
prompt_emb: torch.Tensor,
|
|
370
444
|
negative_prompt_emb: torch.Tensor,
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
445
|
+
prompt_emb_mask: torch.Tensor,
|
|
446
|
+
negative_prompt_emb_mask: torch.Tensor,
|
|
447
|
+
# in_context
|
|
448
|
+
context_latents: torch.Tensor = None,
|
|
449
|
+
# eligen
|
|
450
|
+
entity_prompt_embs: Optional[List[torch.Tensor]] = None,
|
|
451
|
+
entity_prompt_emb_masks: Optional[List[torch.Tensor]] = None,
|
|
452
|
+
negative_entity_prompt_embs: Optional[List[torch.Tensor]] = None,
|
|
453
|
+
negative_entity_prompt_emb_masks: Optional[List[torch.Tensor]] = None,
|
|
454
|
+
entity_masks: Optional[List[torch.Tensor]] = None,
|
|
455
|
+
cfg_scale: float = 1.0,
|
|
374
456
|
batch_cfg: bool = False,
|
|
375
457
|
):
|
|
376
458
|
if cfg_scale <= 1.0 or negative_prompt_emb is None:
|
|
@@ -379,7 +461,11 @@ class QwenImagePipeline(BasePipeline):
|
|
|
379
461
|
image_latents,
|
|
380
462
|
timestep,
|
|
381
463
|
prompt_emb,
|
|
382
|
-
|
|
464
|
+
prompt_emb_mask,
|
|
465
|
+
context_latents=context_latents,
|
|
466
|
+
entity_prompt_embs=entity_prompt_embs,
|
|
467
|
+
entity_prompt_emb_masks=entity_prompt_emb_masks,
|
|
468
|
+
entity_masks=entity_masks,
|
|
383
469
|
)
|
|
384
470
|
if not batch_cfg:
|
|
385
471
|
# cfg by predict noise one by one
|
|
@@ -389,14 +475,22 @@ class QwenImagePipeline(BasePipeline):
|
|
|
389
475
|
image_latents,
|
|
390
476
|
timestep,
|
|
391
477
|
prompt_emb,
|
|
392
|
-
|
|
478
|
+
prompt_emb_mask,
|
|
479
|
+
context_latents=context_latents,
|
|
480
|
+
entity_prompt_embs=entity_prompt_embs,
|
|
481
|
+
entity_prompt_emb_masks=entity_prompt_emb_masks,
|
|
482
|
+
entity_masks=entity_masks,
|
|
393
483
|
)
|
|
394
484
|
negative_noise_pred = self.predict_noise(
|
|
395
485
|
latents,
|
|
396
486
|
image_latents,
|
|
397
487
|
timestep,
|
|
398
488
|
negative_prompt_emb,
|
|
399
|
-
|
|
489
|
+
negative_prompt_emb_mask,
|
|
490
|
+
context_latents=context_latents,
|
|
491
|
+
entity_prompt_embs=negative_entity_prompt_embs,
|
|
492
|
+
entity_prompt_emb_masks=negative_entity_prompt_emb_masks,
|
|
493
|
+
entity_masks=entity_masks,
|
|
400
494
|
)
|
|
401
495
|
comb_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
402
496
|
cond_norm = torch.norm(self.dit.patchify(positive_noise_pred), dim=-1, keepdim=True)
|
|
@@ -406,18 +500,32 @@ class QwenImagePipeline(BasePipeline):
|
|
|
406
500
|
else:
|
|
407
501
|
# cfg by predict noise in one batch
|
|
408
502
|
bs, _, h, w = latents.shape
|
|
409
|
-
prompt_emb =
|
|
410
|
-
|
|
503
|
+
prompt_emb = pad_and_concat(prompt_emb, negative_prompt_emb)
|
|
504
|
+
prompt_emb_mask = pad_and_concat(prompt_emb_mask, negative_prompt_emb_mask)
|
|
505
|
+
if entity_prompt_embs is not None:
|
|
506
|
+
entity_prompt_embs = [
|
|
507
|
+
torch.cat([x, y], dim=0) for x, y in zip(entity_prompt_embs, negative_entity_prompt_embs)
|
|
508
|
+
]
|
|
509
|
+
entity_prompt_emb_masks = [
|
|
510
|
+
torch.cat([x, y], dim=0) for x, y in zip(entity_prompt_emb_masks, negative_entity_prompt_emb_masks)
|
|
511
|
+
]
|
|
512
|
+
entity_masks = [torch.cat([mask, mask], dim=0) for mask in entity_masks]
|
|
411
513
|
latents = torch.cat([latents, latents], dim=0)
|
|
412
514
|
if image_latents is not None:
|
|
413
|
-
image_latents = torch.cat([
|
|
515
|
+
image_latents = [torch.cat([image_latent, image_latent], dim=0) for image_latent in image_latents]
|
|
516
|
+
if context_latents is not None:
|
|
517
|
+
context_latents = torch.cat([context_latents, context_latents], dim=0)
|
|
414
518
|
timestep = torch.cat([timestep, timestep], dim=0)
|
|
415
519
|
noise_pred = self.predict_noise(
|
|
416
520
|
latents,
|
|
417
521
|
image_latents,
|
|
418
522
|
timestep,
|
|
419
523
|
prompt_emb,
|
|
420
|
-
|
|
524
|
+
prompt_emb_mask,
|
|
525
|
+
context_latents=context_latents,
|
|
526
|
+
entity_prompt_embs=entity_prompt_embs,
|
|
527
|
+
entity_prompt_emb_masks=entity_prompt_emb_masks,
|
|
528
|
+
entity_masks=entity_masks,
|
|
421
529
|
)
|
|
422
530
|
positive_noise_pred, negative_noise_pred = noise_pred[:bs], noise_pred[bs:]
|
|
423
531
|
comb_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
@@ -432,15 +540,27 @@ class QwenImagePipeline(BasePipeline):
|
|
|
432
540
|
image_latents: torch.Tensor,
|
|
433
541
|
timestep: torch.Tensor,
|
|
434
542
|
prompt_emb: torch.Tensor,
|
|
435
|
-
|
|
543
|
+
prompt_emb_mask: torch.Tensor,
|
|
544
|
+
# in_context
|
|
545
|
+
context_latents: torch.Tensor = None,
|
|
546
|
+
# eligen
|
|
547
|
+
entity_prompt_embs: Optional[List[torch.Tensor]] = None,
|
|
548
|
+
entity_prompt_emb_masks: Optional[List[torch.Tensor]] = None,
|
|
549
|
+
entity_masks: Optional[List[torch.Tensor]] = None,
|
|
436
550
|
):
|
|
437
551
|
self.load_models_to_device(["dit"])
|
|
552
|
+
attn_kwargs = self.get_attn_kwargs(latents)
|
|
438
553
|
noise_pred = self.dit(
|
|
439
554
|
image=latents,
|
|
440
555
|
edit=image_latents,
|
|
441
|
-
text=prompt_emb,
|
|
442
556
|
timestep=timestep,
|
|
443
|
-
|
|
557
|
+
text=prompt_emb,
|
|
558
|
+
text_seq_lens=prompt_emb_mask.sum(dim=1),
|
|
559
|
+
context_latents=context_latents,
|
|
560
|
+
entity_text=entity_prompt_embs,
|
|
561
|
+
entity_seq_lens=[mask.sum(dim=1) for mask in entity_prompt_emb_masks] if entity_prompt_emb_masks else None,
|
|
562
|
+
entity_masks=entity_masks,
|
|
563
|
+
attn_kwargs=attn_kwargs,
|
|
444
564
|
)
|
|
445
565
|
return noise_pred
|
|
446
566
|
|
|
@@ -457,6 +577,20 @@ class QwenImagePipeline(BasePipeline):
|
|
|
457
577
|
image_latents = image_latents.squeeze(2).to(device=self.device)
|
|
458
578
|
return image_latents
|
|
459
579
|
|
|
580
|
+
def prepare_eligen(self, entity_prompts, entity_masks, width, height):
|
|
581
|
+
entity_masks = [mask.resize((width // 8, height // 8), resample=Image.NEAREST) for mask in entity_masks]
|
|
582
|
+
entity_masks = [self.preprocess_image(mask).mean(dim=1, keepdim=True) > 0 for mask in entity_masks]
|
|
583
|
+
entity_masks = [mask.to(device=self.device, dtype=self.dtype) for mask in entity_masks]
|
|
584
|
+
prompt_embs, prompt_emb_masks = [], []
|
|
585
|
+
negative_prompt_embs, negative_prompt_emb_masks = [], []
|
|
586
|
+
for entity_prompt in entity_prompts:
|
|
587
|
+
prompt_emb, prompt_emb_mask = self.encode_prompt(entity_prompt, 1, 512)
|
|
588
|
+
prompt_embs.append(prompt_emb)
|
|
589
|
+
prompt_emb_masks.append(prompt_emb_mask)
|
|
590
|
+
negative_prompt_embs.append(torch.zeros_like(prompt_emb))
|
|
591
|
+
negative_prompt_emb_masks.append(torch.zeros_like(prompt_emb_mask))
|
|
592
|
+
return prompt_embs, prompt_emb_masks, negative_prompt_embs, negative_prompt_emb_masks, entity_masks
|
|
593
|
+
|
|
460
594
|
def calculate_dimensions(self, target_area, ratio):
|
|
461
595
|
width = math.sqrt(target_area * ratio)
|
|
462
596
|
height = width / ratio
|
|
@@ -469,21 +603,51 @@ class QwenImagePipeline(BasePipeline):
|
|
|
469
603
|
self,
|
|
470
604
|
prompt: str,
|
|
471
605
|
negative_prompt: str = "",
|
|
472
|
-
|
|
606
|
+
# single image for edit, list for edit plus(QwenImageEdit2509)
|
|
607
|
+
input_image: List[Image.Image] | Image.Image | None = None,
|
|
473
608
|
cfg_scale: float = 4.0, # true cfg
|
|
474
|
-
height: int =
|
|
475
|
-
width: int =
|
|
609
|
+
height: Optional[int] = None,
|
|
610
|
+
width: Optional[int] = None,
|
|
476
611
|
num_inference_steps: int = 50,
|
|
477
612
|
seed: int | None = None,
|
|
613
|
+
controlnet_params: List[QwenImageControlNetParams] | QwenImageControlNetParams = [],
|
|
478
614
|
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
|
|
615
|
+
# eligen
|
|
616
|
+
entity_prompts: Optional[List[str]] = None,
|
|
617
|
+
entity_masks: Optional[List[Image.Image]] = None,
|
|
479
618
|
):
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
width, height = self.calculate_dimensions(1024 * 1024, width / height)
|
|
483
|
-
input_image = input_image.resize((width, height), Image.LANCZOS)
|
|
619
|
+
assert (height is None) == (width is None), "height and width should be set together"
|
|
620
|
+
is_edit_plus = isinstance(input_image, list)
|
|
484
621
|
|
|
622
|
+
if input_image is not None:
|
|
623
|
+
if not isinstance(input_image, list):
|
|
624
|
+
input_image = [input_image]
|
|
625
|
+
condition_images = []
|
|
626
|
+
vae_images = []
|
|
627
|
+
for img in input_image:
|
|
628
|
+
img_width, img_height = img.size
|
|
629
|
+
condition_width, condition_height = self.calculate_dimensions(384 * 384, img_width / img_height)
|
|
630
|
+
vae_width, vae_height = self.calculate_dimensions(1024 * 1024, img_width / img_height)
|
|
631
|
+
condition_images.append(img.resize((condition_width, condition_height), Image.LANCZOS))
|
|
632
|
+
vae_images.append(img.resize((vae_width, vae_height), Image.LANCZOS))
|
|
633
|
+
if width is None and height is None:
|
|
634
|
+
width, height = vae_images[-1].size
|
|
635
|
+
|
|
636
|
+
if width is None and height is None:
|
|
637
|
+
width, height = 1328, 1328
|
|
485
638
|
self.validate_image_size(height, width, minimum=64, multiple_of=16)
|
|
486
639
|
|
|
640
|
+
if not isinstance(controlnet_params, list):
|
|
641
|
+
controlnet_params = [controlnet_params]
|
|
642
|
+
|
|
643
|
+
context_latents = None
|
|
644
|
+
for param in controlnet_params:
|
|
645
|
+
self.load_lora(param.model, param.scale, fused=False, save_original_weight=False)
|
|
646
|
+
if param.control_type == QwenImageControlType.in_context:
|
|
647
|
+
width, height = param.image.size
|
|
648
|
+
self.validate_image_size(height, width, minimum=64, multiple_of=16)
|
|
649
|
+
context_latents = self.prepare_image_latents(param.image.resize((width, height), Image.LANCZOS))
|
|
650
|
+
|
|
487
651
|
noise = self.generate_noise((1, 16, height // 8, width // 8), seed=seed, device="cpu", dtype=self.dtype).to(
|
|
488
652
|
device=self.device
|
|
489
653
|
)
|
|
@@ -496,39 +660,60 @@ class QwenImagePipeline(BasePipeline):
|
|
|
496
660
|
|
|
497
661
|
self.load_models_to_device(["vae"])
|
|
498
662
|
if input_image:
|
|
499
|
-
image_latents = self.prepare_image_latents(
|
|
663
|
+
image_latents = [self.prepare_image_latents(img) for img in vae_images]
|
|
500
664
|
else:
|
|
501
665
|
image_latents = None
|
|
502
666
|
|
|
503
667
|
self.load_models_to_device(["encoder"])
|
|
504
668
|
if image_latents is not None:
|
|
505
|
-
|
|
669
|
+
prompt_emb, prompt_emb_mask = self.encode_prompt_with_image(
|
|
670
|
+
prompt, vae_images, condition_images, 1, 4096, is_edit_plus
|
|
671
|
+
)
|
|
506
672
|
if cfg_scale > 1.0 and negative_prompt != "":
|
|
507
|
-
|
|
508
|
-
negative_prompt,
|
|
673
|
+
negative_prompt_emb, negative_prompt_emb_mask = self.encode_prompt_with_image(
|
|
674
|
+
negative_prompt, vae_images, condition_images, 1, 4096, is_edit_plus
|
|
509
675
|
)
|
|
510
676
|
else:
|
|
511
|
-
|
|
677
|
+
negative_prompt_emb, negative_prompt_emb_mask = None, None
|
|
512
678
|
else:
|
|
513
|
-
|
|
679
|
+
prompt_emb, prompt_emb_mask = self.encode_prompt(prompt, 1, 4096)
|
|
514
680
|
if cfg_scale > 1.0 and negative_prompt != "":
|
|
515
|
-
|
|
681
|
+
negative_prompt_emb, negative_prompt_emb_mask = self.encode_prompt(negative_prompt, 1, 4096)
|
|
516
682
|
else:
|
|
517
|
-
|
|
683
|
+
negative_prompt_emb, negative_prompt_emb_mask = None, None
|
|
684
|
+
|
|
685
|
+
entity_prompt_embs, entity_prompt_emb_masks = None, None
|
|
686
|
+
negative_entity_prompt_embs, negative_entity_prompt_emb_masks = None, None
|
|
687
|
+
if entity_prompts is not None and entity_masks is not None:
|
|
688
|
+
assert len(entity_prompts) == len(entity_masks), "entity_prompts and entity_masks must have the same length"
|
|
689
|
+
(
|
|
690
|
+
entity_prompt_embs,
|
|
691
|
+
entity_prompt_emb_masks,
|
|
692
|
+
negative_entity_prompt_embs,
|
|
693
|
+
negative_entity_prompt_emb_masks,
|
|
694
|
+
entity_masks,
|
|
695
|
+
) = self.prepare_eligen(entity_prompts, entity_masks, width, height)
|
|
696
|
+
|
|
518
697
|
self.model_lifecycle_finish(["encoder"])
|
|
519
698
|
|
|
699
|
+
self.load_models_to_device(["dit"])
|
|
520
700
|
hide_progress = dist.is_initialized() and dist.get_rank() != 0
|
|
521
|
-
|
|
522
701
|
for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
|
|
523
702
|
timestep = timestep.unsqueeze(0).to(dtype=self.dtype)
|
|
524
703
|
noise_pred = self.predict_noise_with_cfg(
|
|
525
704
|
latents=latents,
|
|
526
705
|
image_latents=image_latents,
|
|
527
706
|
timestep=timestep,
|
|
528
|
-
prompt_emb=
|
|
529
|
-
negative_prompt_emb=
|
|
530
|
-
|
|
531
|
-
|
|
707
|
+
prompt_emb=prompt_emb,
|
|
708
|
+
negative_prompt_emb=negative_prompt_emb,
|
|
709
|
+
prompt_emb_mask=prompt_emb_mask,
|
|
710
|
+
negative_prompt_emb_mask=negative_prompt_emb_mask,
|
|
711
|
+
context_latents=context_latents,
|
|
712
|
+
entity_prompt_embs=entity_prompt_embs,
|
|
713
|
+
entity_prompt_emb_masks=entity_prompt_emb_masks,
|
|
714
|
+
negative_entity_prompt_embs=negative_entity_prompt_embs,
|
|
715
|
+
negative_entity_prompt_emb_masks=negative_entity_prompt_emb_masks,
|
|
716
|
+
entity_masks=entity_masks,
|
|
532
717
|
cfg_scale=cfg_scale,
|
|
533
718
|
batch_cfg=self.config.batch_cfg,
|
|
534
719
|
)
|
|
@@ -181,7 +181,7 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
181
181
|
|
|
182
182
|
@classmethod
|
|
183
183
|
def from_state_dict(cls, state_dicts: SDXLStateDicts, config: SDXLPipelineConfig) -> "SDXLImagePipeline":
|
|
184
|
-
init_device = "cpu" if config.offload_mode else config.device
|
|
184
|
+
init_device = "cpu" if config.offload_mode is not None else config.device
|
|
185
185
|
tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
|
|
186
186
|
tokenizer_2 = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_2_CONF_PATH)
|
|
187
187
|
with LoRAContext():
|
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
|
|
1
5
|
def accumulate(result, new_item):
|
|
2
6
|
if result is None:
|
|
3
7
|
return new_item
|
|
@@ -17,3 +21,51 @@ def calculate_shift(
|
|
|
17
21
|
b = base_shift - m * base_seq_len
|
|
18
22
|
mu = image_seq_len * m + b
|
|
19
23
|
return mu
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def pad_and_concat(
|
|
27
|
+
tensor1: torch.Tensor,
|
|
28
|
+
tensor2: torch.Tensor,
|
|
29
|
+
concat_dim: int = 0,
|
|
30
|
+
pad_dim: int = 1,
|
|
31
|
+
) -> torch.Tensor:
|
|
32
|
+
"""
|
|
33
|
+
Concatenate two tensors along a specified dimension after padding along another dimension.
|
|
34
|
+
|
|
35
|
+
Assumes input tensors have shape (b, s, d), where:
|
|
36
|
+
- b: batch dimension
|
|
37
|
+
- s: sequence dimension (may differ)
|
|
38
|
+
- d: feature dimension
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
tensor1: First tensor with shape (b1, s1, d)
|
|
42
|
+
tensor2: Second tensor with shape (b2, s2, d)
|
|
43
|
+
concat_dim: Dimension to concatenate along, default is 0 (batch dimension)
|
|
44
|
+
pad_dim: Dimension to pad along, default is 1 (sequence dimension)
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Concatenated tensor, shape depends on concat_dim and pad_dim choices
|
|
48
|
+
"""
|
|
49
|
+
assert tensor1.dim() == tensor2.dim(), "Both tensors must have the same number of dimensions"
|
|
50
|
+
assert concat_dim != pad_dim, "concat_dim and pad_dim cannot be the same"
|
|
51
|
+
|
|
52
|
+
len1, len2 = tensor1.shape[pad_dim], tensor2.shape[pad_dim]
|
|
53
|
+
max_len = max(len1, len2)
|
|
54
|
+
|
|
55
|
+
# Calculate the position of pad_dim in the padding list
|
|
56
|
+
# Padding format: from the last dimension, each pair represents (dim_n_left, dim_n_right, ..., dim_0_left, dim_0_right)
|
|
57
|
+
ndim = tensor1.dim()
|
|
58
|
+
padding = [0] * (2 * ndim)
|
|
59
|
+
pad_right_idx = -2 * pad_dim - 1
|
|
60
|
+
|
|
61
|
+
if len1 < max_len:
|
|
62
|
+
pad_len = max_len - len1
|
|
63
|
+
padding[pad_right_idx] = pad_len
|
|
64
|
+
tensor1 = F.pad(tensor1, padding, mode="constant", value=0)
|
|
65
|
+
elif len2 < max_len:
|
|
66
|
+
pad_len = max_len - len2
|
|
67
|
+
padding[pad_right_idx] = pad_len
|
|
68
|
+
tensor2 = F.pad(tensor2, padding, mode="constant", value=0)
|
|
69
|
+
|
|
70
|
+
# Concatenate along the specified dimension
|
|
71
|
+
return torch.cat([tensor1, tensor2], dim=concat_dim)
|