diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. diffsynth_engine/__init__.py +12 -0
  2. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
  3. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
  4. diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
  5. diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
  6. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  7. diffsynth_engine/configs/__init__.py +16 -1
  8. diffsynth_engine/configs/controlnet.py +13 -0
  9. diffsynth_engine/configs/pipeline.py +37 -11
  10. diffsynth_engine/models/base.py +1 -1
  11. diffsynth_engine/models/basic/attention.py +105 -43
  12. diffsynth_engine/models/basic/transformer_helper.py +36 -2
  13. diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
  14. diffsynth_engine/models/flux/flux_controlnet.py +16 -30
  15. diffsynth_engine/models/flux/flux_dit.py +49 -62
  16. diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
  17. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  18. diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
  19. diffsynth_engine/models/flux/flux_vae.py +20 -2
  20. diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
  21. diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
  22. diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
  23. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  24. diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
  25. diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
  26. diffsynth_engine/models/sd/sd_unet.py +1 -1
  27. diffsynth_engine/models/sd3/sd3_dit.py +1 -1
  28. diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
  29. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
  30. diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
  31. diffsynth_engine/models/vae/vae.py +1 -1
  32. diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
  33. diffsynth_engine/models/wan/wan_dit.py +65 -28
  34. diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
  35. diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
  36. diffsynth_engine/models/wan/wan_vae.py +2 -2
  37. diffsynth_engine/pipelines/base.py +73 -7
  38. diffsynth_engine/pipelines/flux_image.py +139 -120
  39. diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
  40. diffsynth_engine/pipelines/qwen_image.py +272 -87
  41. diffsynth_engine/pipelines/sdxl_image.py +1 -1
  42. diffsynth_engine/pipelines/utils.py +52 -0
  43. diffsynth_engine/pipelines/wan_s2v.py +25 -14
  44. diffsynth_engine/pipelines/wan_video.py +43 -19
  45. diffsynth_engine/tokenizers/base.py +6 -0
  46. diffsynth_engine/tokenizers/qwen2.py +12 -4
  47. diffsynth_engine/utils/constants.py +13 -12
  48. diffsynth_engine/utils/download.py +4 -2
  49. diffsynth_engine/utils/env.py +2 -0
  50. diffsynth_engine/utils/flag.py +6 -0
  51. diffsynth_engine/utils/loader.py +25 -6
  52. diffsynth_engine/utils/parallel.py +62 -29
  53. diffsynth_engine/utils/video.py +3 -1
  54. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
  55. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
  56. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  57. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  58. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  59. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  60. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  61. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  62. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  63. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  64. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  65. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  66. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  67. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
  68. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
  69. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
@@ -2,12 +2,17 @@ import json
2
2
  import torch
3
3
  import torch.distributed as dist
4
4
  import math
5
- from typing import Callable, List, Tuple, Optional, Union, Dict
5
+ from typing import Callable, List, Dict, Tuple, Optional, Union
6
6
  from tqdm import tqdm
7
7
  from einops import rearrange
8
8
  from PIL import Image
9
9
 
10
- from diffsynth_engine.configs import QwenImagePipelineConfig, QwenImageStateDicts
10
+ from diffsynth_engine.configs import (
11
+ QwenImagePipelineConfig,
12
+ QwenImageStateDicts,
13
+ QwenImageControlNetParams,
14
+ QwenImageControlType,
15
+ )
11
16
  from diffsynth_engine.models.basic.lora import LoRAContext
12
17
  from diffsynth_engine.models.qwen_image import (
13
18
  QwenImageDiT,
@@ -19,7 +24,7 @@ from diffsynth_engine.models.qwen_image import (
19
24
  from diffsynth_engine.models.qwen_image import QwenImageVAE
20
25
  from diffsynth_engine.tokenizers import Qwen2TokenizerFast, Qwen2VLProcessor
21
26
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
22
- from diffsynth_engine.pipelines.utils import calculate_shift
27
+ from diffsynth_engine.pipelines.utils import calculate_shift, pad_and_concat
23
28
  from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
24
29
  from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
25
30
  from diffsynth_engine.utils.constants import (
@@ -71,6 +76,39 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
71
76
  lora_args["alpha"] = alpha
72
77
 
73
78
  key = key.replace(f".{lora_a_suffix}", "")
79
+ key = key.replace("base_model.model.", "")
80
+
81
+ if key.startswith("transformer") and "attn.to_out.0" in key:
82
+ key = key.replace("attn.to_out.0", "attn.to_out")
83
+ dit_dict[key] = lora_args
84
+ return {"dit": dit_dict}
85
+
86
+ def _from_diffusers(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
87
+ dit_dict = {}
88
+ for key, param in lora_state_dict.items():
89
+ origin_key = key
90
+ lora_a_suffix = None
91
+ if "lora_A.weight" in key:
92
+ lora_a_suffix = "lora_A.weight"
93
+ lora_b_suffix = "lora_B.weight"
94
+
95
+ if lora_a_suffix is None:
96
+ continue
97
+
98
+ lora_args = {}
99
+ lora_args["down"] = param
100
+ lora_args["up"] = lora_state_dict[origin_key.replace(lora_a_suffix, lora_b_suffix)]
101
+ lora_args["rank"] = lora_args["up"].shape[1]
102
+ alpha_key = origin_key.replace(lora_a_suffix, "alpha")
103
+
104
+ if alpha_key in lora_state_dict:
105
+ alpha = lora_state_dict[alpha_key]
106
+ else:
107
+ alpha = lora_args["rank"]
108
+ lora_args["alpha"] = alpha
109
+
110
+ key = key.replace(f".{lora_a_suffix}", "")
111
+ key = key.replace("diffusion_model.", "")
74
112
 
75
113
  if key.startswith("transformer") and "attn.to_out.0" in key:
76
114
  key = key.replace("attn.to_out.0", "attn.to_out")
@@ -78,7 +116,11 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
78
116
  return {"dit": dit_dict}
79
117
 
80
118
  def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
81
- return self._from_diffsynth(lora_state_dict)
119
+ key = list(lora_state_dict.keys())[0]
120
+ if key.startswith("diffusion_model."):
121
+ return self._from_diffusers(lora_state_dict)
122
+ else:
123
+ return self._from_diffsynth(lora_state_dict)
82
124
 
83
125
 
84
126
  class QwenImagePipeline(BasePipeline):
@@ -101,11 +143,25 @@ class QwenImagePipeline(BasePipeline):
101
143
  dtype=config.model_dtype,
102
144
  )
103
145
  self.config = config
146
+ # qwen image
104
147
  self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
105
148
  self.prompt_template_encode_start_idx = 34
149
+ # qwen image edit
150
+ self.edit_system_prompt = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
151
+ self.edit_prompt_template_encode = (
152
+ "<|im_start|>system\n"
153
+ + self.edit_system_prompt
154
+ + "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
155
+ )
156
+ # qwen image edit plus
157
+ self.edit_plus_prompt_template_encode = (
158
+ "<|im_start|>system\n"
159
+ + self.edit_system_prompt
160
+ + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
161
+ )
106
162
 
107
- self.edit_prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
108
163
  self.edit_prompt_template_encode_start_idx = 64
164
+
109
165
  # sampler
110
166
  self.noise_scheduler = RecifitedFlowScheduler(shift=3.0, use_dynamic_shifting=True)
111
167
  self.sampler = FlowMatchEulerSampler()
@@ -138,6 +194,7 @@ class QwenImagePipeline(BasePipeline):
138
194
  logger.info(f"loading state dict from {config.vae_path} ...")
139
195
  vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
140
196
 
197
+ encoder_state_dict = None
141
198
  if config.encoder_path is None:
142
199
  config.encoder_path = fetch_model(
143
200
  "MusePublic/Qwen-image",
@@ -149,8 +206,11 @@ class QwenImagePipeline(BasePipeline):
149
206
  "text_encoder/model-00004-of-00004.safetensors",
150
207
  ],
151
208
  )
152
- logger.info(f"loading state dict from {config.encoder_path} ...")
153
- encoder_state_dict = cls.load_model_checkpoint(config.encoder_path, device="cpu", dtype=config.encoder_dtype)
209
+ if config.load_encoder:
210
+ logger.info(f"loading state dict from {config.encoder_path} ...")
211
+ encoder_state_dict = cls.load_model_checkpoint(
212
+ config.encoder_path, device="cpu", dtype=config.encoder_dtype
213
+ )
154
214
 
155
215
  state_dicts = QwenImageStateDicts(
156
216
  model=model_state_dict,
@@ -177,50 +237,44 @@ class QwenImagePipeline(BasePipeline):
177
237
  @classmethod
178
238
  def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipelineConfig) -> "QwenImagePipeline":
179
239
  init_device = "cpu" if config.offload_mode is not None else config.device
180
- tokenizer = Qwen2TokenizerFast.from_pretrained(QWEN_IMAGE_TOKENIZER_CONF_PATH)
181
- processor = Qwen2VLProcessor.from_pretrained(
182
- tokenizer_config_path=QWEN_IMAGE_TOKENIZER_CONF_PATH,
183
- image_processor_config_path=QWEN_IMAGE_PROCESSOR_CONFIG_FILE,
184
- )
185
- with open(QWEN_IMAGE_VISION_CONFIG_FILE, "r") as f:
186
- vision_config = Qwen2_5_VLVisionConfig(**json.load(f))
187
- with open(QWEN_IMAGE_CONFIG_FILE, "r") as f:
188
- text_config = Qwen2_5_VLConfig(**json.load(f))
189
- encoder = Qwen2_5_VLForConditionalGeneration.from_state_dict(
190
- state_dicts.encoder,
191
- vision_config=vision_config,
192
- config=text_config,
193
- device=init_device,
194
- dtype=config.encoder_dtype,
195
- )
196
- with open(QWEN_IMAGE_VAE_CONFIG_FILE, "r") as f:
240
+ tokenizer, processor, encoder = None, None, None
241
+ if config.load_encoder:
242
+ tokenizer = Qwen2TokenizerFast.from_pretrained(QWEN_IMAGE_TOKENIZER_CONF_PATH)
243
+ processor = Qwen2VLProcessor.from_pretrained(
244
+ tokenizer_config_path=QWEN_IMAGE_TOKENIZER_CONF_PATH,
245
+ image_processor_config_path=QWEN_IMAGE_PROCESSOR_CONFIG_FILE,
246
+ )
247
+ with open(QWEN_IMAGE_VISION_CONFIG_FILE, "r", encoding="utf-8") as f:
248
+ vision_config = Qwen2_5_VLVisionConfig(**json.load(f))
249
+ with open(QWEN_IMAGE_CONFIG_FILE, "r", encoding="utf-8") as f:
250
+ text_config = Qwen2_5_VLConfig(**json.load(f))
251
+ encoder = Qwen2_5_VLForConditionalGeneration.from_state_dict(
252
+ state_dicts.encoder,
253
+ vision_config=vision_config,
254
+ config=text_config,
255
+ device=("cpu" if config.use_fsdp else init_device),
256
+ dtype=config.encoder_dtype,
257
+ )
258
+
259
+ with open(QWEN_IMAGE_VAE_CONFIG_FILE, "r", encoding="utf-8") as f:
197
260
  vae_config = json.load(f)
198
261
  vae = QwenImageVAE.from_state_dict(
199
262
  state_dicts.vae, config=vae_config, device=init_device, dtype=config.vae_dtype
200
263
  )
201
264
 
202
265
  with LoRAContext():
203
- attn_kwargs = {
204
- "attn_impl": config.dit_attn_impl,
205
- "sparge_smooth_k": config.sparge_smooth_k,
206
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
207
- "sparge_simthreshd1": config.sparge_simthreshd1,
208
- "sparge_pvthreshd": config.sparge_pvthreshd,
209
- }
210
266
  if config.use_fbcache:
211
267
  dit = QwenImageDiTFBCache.from_state_dict(
212
268
  state_dicts.model,
213
- device=init_device,
269
+ device=("cpu" if config.use_fsdp else init_device),
214
270
  dtype=config.model_dtype,
215
- attn_kwargs=attn_kwargs,
216
271
  relative_l1_threshold=config.fbcache_relative_l1_threshold,
217
272
  )
218
273
  else:
219
274
  dit = QwenImageDiT.from_state_dict(
220
275
  state_dicts.model,
221
- device=init_device,
276
+ device=("cpu" if config.use_fsdp else init_device),
222
277
  dtype=config.model_dtype,
223
- attn_kwargs=attn_kwargs,
224
278
  )
225
279
  if config.use_fp8_linear:
226
280
  enable_fp8_linear(dit)
@@ -254,8 +308,13 @@ class QwenImagePipeline(BasePipeline):
254
308
  pipe.compile()
255
309
  return pipe
256
310
 
311
+ def update_weights(self, state_dicts: QwenImageStateDicts) -> None:
312
+ self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
313
+ self.update_component(self.encoder, state_dicts.encoder, self.config.device, self.config.encoder_dtype)
314
+ self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
315
+
257
316
  def compile(self):
258
- self.dit.compile_repeated_blocks(dynamic=True)
317
+ self.dit.compile_repeated_blocks()
259
318
 
260
319
  def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
261
320
  assert self.config.tp_degree is None or self.config.tp_degree == 1, (
@@ -270,6 +329,10 @@ class QwenImagePipeline(BasePipeline):
270
329
 
271
330
  def unload_loras(self):
272
331
  self.dit.unload_loras()
332
+ self.noise_scheduler.restore_config()
333
+
334
+ def apply_scheduler_config(self, scheduler_config: Dict):
335
+ self.noise_scheduler.update_config(scheduler_config)
273
336
 
274
337
  def prepare_latents(
275
338
  self,
@@ -307,32 +370,43 @@ class QwenImagePipeline(BasePipeline):
307
370
  input_ids, attention_mask = outputs["input_ids"].to(self.device), outputs["attention_mask"].to(self.device)
308
371
  outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
309
372
  hidden_states = outputs["hidden_states"]
310
- prompt_embeds = hidden_states[:, drop_idx:]
311
- prompt_embeds_mask = attention_mask[:, drop_idx:]
312
- seq_len = prompt_embeds.shape[1]
373
+ prompt_emb = hidden_states[:, drop_idx:]
374
+ prompt_emb_mask = attention_mask[:, drop_idx:]
375
+ seq_len = prompt_emb.shape[1]
313
376
 
314
377
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
315
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
316
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
378
+ prompt_emb = prompt_emb.repeat(1, num_images_per_prompt, 1)
379
+ prompt_emb = prompt_emb.view(batch_size * num_images_per_prompt, seq_len, -1)
317
380
 
318
- prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
319
- prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
381
+ prompt_emb_mask = prompt_emb_mask.repeat(1, num_images_per_prompt, 1)
382
+ prompt_emb_mask = prompt_emb_mask.view(batch_size * num_images_per_prompt, seq_len)
320
383
 
321
- return prompt_embeds, prompt_embeds_mask
384
+ return prompt_emb, prompt_emb_mask
322
385
 
323
386
  def encode_prompt_with_image(
324
387
  self,
325
388
  prompt: Union[str, List[str]],
326
- image: torch.Tensor,
389
+ vae_image: List[torch.Tensor],
390
+ condition_image: List[torch.Tensor], # edit plus
327
391
  num_images_per_prompt: int = 1,
328
392
  max_sequence_length: int = 1024,
393
+ is_edit_plus: bool = True,
329
394
  ):
330
395
  prompt = [prompt] if isinstance(prompt, str) else prompt
331
396
 
332
397
  batch_size = len(prompt)
333
398
  template = self.edit_prompt_template_encode
334
399
  drop_idx = self.edit_prompt_template_encode_start_idx
335
- texts = [template.format(txt) for txt in prompt]
400
+ if not is_edit_plus:
401
+ template = self.edit_prompt_template_encode
402
+ texts = [template.format(txt) for txt in prompt]
403
+ image = vae_image
404
+ else:
405
+ template = self.edit_plus_prompt_template_encode
406
+ img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
407
+ img_prompt = "".join([img_prompt_template.format(i + 1) for i in range(len(condition_image))])
408
+ texts = [template.format(img_prompt + e) for e in prompt]
409
+ image = condition_image
336
410
 
337
411
  model_inputs = self.processor(text=texts, images=image, max_length=max_sequence_length + drop_idx)
338
412
  input_ids, attention_mask, pixel_values, image_grid_thw = (
@@ -348,18 +422,18 @@ class QwenImagePipeline(BasePipeline):
348
422
  image_grid_thw=image_grid_thw,
349
423
  )
350
424
  hidden_states = outputs["hidden_states"]
351
- prompt_embeds = hidden_states[:, drop_idx:]
352
- prompt_embeds_mask = attention_mask[:, drop_idx:]
353
- seq_len = prompt_embeds.shape[1]
425
+ prompt_emb = hidden_states[:, drop_idx:]
426
+ prompt_emb_mask = attention_mask[:, drop_idx:]
427
+ seq_len = prompt_emb.shape[1]
354
428
 
355
429
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
356
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
357
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
430
+ prompt_emb = prompt_emb.repeat(1, num_images_per_prompt, 1)
431
+ prompt_emb = prompt_emb.view(batch_size * num_images_per_prompt, seq_len, -1)
358
432
 
359
- prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
360
- prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
433
+ prompt_emb_mask = prompt_emb_mask.repeat(1, num_images_per_prompt, 1)
434
+ prompt_emb_mask = prompt_emb_mask.view(batch_size * num_images_per_prompt, seq_len)
361
435
 
362
- return prompt_embeds, prompt_embeds_mask
436
+ return prompt_emb, prompt_emb_mask
363
437
 
364
438
  def predict_noise_with_cfg(
365
439
  self,
@@ -368,9 +442,17 @@ class QwenImagePipeline(BasePipeline):
368
442
  timestep: torch.Tensor,
369
443
  prompt_emb: torch.Tensor,
370
444
  negative_prompt_emb: torch.Tensor,
371
- prompt_embeds_mask: torch.Tensor,
372
- negative_prompt_embeds_mask: torch.Tensor,
373
- cfg_scale: float,
445
+ prompt_emb_mask: torch.Tensor,
446
+ negative_prompt_emb_mask: torch.Tensor,
447
+ # in_context
448
+ context_latents: torch.Tensor = None,
449
+ # eligen
450
+ entity_prompt_embs: Optional[List[torch.Tensor]] = None,
451
+ entity_prompt_emb_masks: Optional[List[torch.Tensor]] = None,
452
+ negative_entity_prompt_embs: Optional[List[torch.Tensor]] = None,
453
+ negative_entity_prompt_emb_masks: Optional[List[torch.Tensor]] = None,
454
+ entity_masks: Optional[List[torch.Tensor]] = None,
455
+ cfg_scale: float = 1.0,
374
456
  batch_cfg: bool = False,
375
457
  ):
376
458
  if cfg_scale <= 1.0 or negative_prompt_emb is None:
@@ -379,7 +461,11 @@ class QwenImagePipeline(BasePipeline):
379
461
  image_latents,
380
462
  timestep,
381
463
  prompt_emb,
382
- prompt_embeds_mask,
464
+ prompt_emb_mask,
465
+ context_latents=context_latents,
466
+ entity_prompt_embs=entity_prompt_embs,
467
+ entity_prompt_emb_masks=entity_prompt_emb_masks,
468
+ entity_masks=entity_masks,
383
469
  )
384
470
  if not batch_cfg:
385
471
  # cfg by predict noise one by one
@@ -389,14 +475,22 @@ class QwenImagePipeline(BasePipeline):
389
475
  image_latents,
390
476
  timestep,
391
477
  prompt_emb,
392
- prompt_embeds_mask,
478
+ prompt_emb_mask,
479
+ context_latents=context_latents,
480
+ entity_prompt_embs=entity_prompt_embs,
481
+ entity_prompt_emb_masks=entity_prompt_emb_masks,
482
+ entity_masks=entity_masks,
393
483
  )
394
484
  negative_noise_pred = self.predict_noise(
395
485
  latents,
396
486
  image_latents,
397
487
  timestep,
398
488
  negative_prompt_emb,
399
- negative_prompt_embeds_mask,
489
+ negative_prompt_emb_mask,
490
+ context_latents=context_latents,
491
+ entity_prompt_embs=negative_entity_prompt_embs,
492
+ entity_prompt_emb_masks=negative_entity_prompt_emb_masks,
493
+ entity_masks=entity_masks,
400
494
  )
401
495
  comb_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
402
496
  cond_norm = torch.norm(self.dit.patchify(positive_noise_pred), dim=-1, keepdim=True)
@@ -406,18 +500,32 @@ class QwenImagePipeline(BasePipeline):
406
500
  else:
407
501
  # cfg by predict noise in one batch
408
502
  bs, _, h, w = latents.shape
409
- prompt_emb = torch.cat([prompt_emb, negative_prompt_emb], dim=0)
410
- prompt_embeds_mask = torch.cat([prompt_embeds_mask, negative_prompt_embeds_mask], dim=0)
503
+ prompt_emb = pad_and_concat(prompt_emb, negative_prompt_emb)
504
+ prompt_emb_mask = pad_and_concat(prompt_emb_mask, negative_prompt_emb_mask)
505
+ if entity_prompt_embs is not None:
506
+ entity_prompt_embs = [
507
+ torch.cat([x, y], dim=0) for x, y in zip(entity_prompt_embs, negative_entity_prompt_embs)
508
+ ]
509
+ entity_prompt_emb_masks = [
510
+ torch.cat([x, y], dim=0) for x, y in zip(entity_prompt_emb_masks, negative_entity_prompt_emb_masks)
511
+ ]
512
+ entity_masks = [torch.cat([mask, mask], dim=0) for mask in entity_masks]
411
513
  latents = torch.cat([latents, latents], dim=0)
412
514
  if image_latents is not None:
413
- image_latents = torch.cat([image_latents, image_latents], dim=0)
515
+ image_latents = [torch.cat([image_latent, image_latent], dim=0) for image_latent in image_latents]
516
+ if context_latents is not None:
517
+ context_latents = torch.cat([context_latents, context_latents], dim=0)
414
518
  timestep = torch.cat([timestep, timestep], dim=0)
415
519
  noise_pred = self.predict_noise(
416
520
  latents,
417
521
  image_latents,
418
522
  timestep,
419
523
  prompt_emb,
420
- prompt_embeds_mask,
524
+ prompt_emb_mask,
525
+ context_latents=context_latents,
526
+ entity_prompt_embs=entity_prompt_embs,
527
+ entity_prompt_emb_masks=entity_prompt_emb_masks,
528
+ entity_masks=entity_masks,
421
529
  )
422
530
  positive_noise_pred, negative_noise_pred = noise_pred[:bs], noise_pred[bs:]
423
531
  comb_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred)
@@ -432,15 +540,27 @@ class QwenImagePipeline(BasePipeline):
432
540
  image_latents: torch.Tensor,
433
541
  timestep: torch.Tensor,
434
542
  prompt_emb: torch.Tensor,
435
- prompt_embeds_mask: torch.Tensor,
543
+ prompt_emb_mask: torch.Tensor,
544
+ # in_context
545
+ context_latents: torch.Tensor = None,
546
+ # eligen
547
+ entity_prompt_embs: Optional[List[torch.Tensor]] = None,
548
+ entity_prompt_emb_masks: Optional[List[torch.Tensor]] = None,
549
+ entity_masks: Optional[List[torch.Tensor]] = None,
436
550
  ):
437
551
  self.load_models_to_device(["dit"])
552
+ attn_kwargs = self.get_attn_kwargs(latents)
438
553
  noise_pred = self.dit(
439
554
  image=latents,
440
555
  edit=image_latents,
441
- text=prompt_emb,
442
556
  timestep=timestep,
443
- txt_seq_lens=prompt_embeds_mask.sum(dim=1),
557
+ text=prompt_emb,
558
+ text_seq_lens=prompt_emb_mask.sum(dim=1),
559
+ context_latents=context_latents,
560
+ entity_text=entity_prompt_embs,
561
+ entity_seq_lens=[mask.sum(dim=1) for mask in entity_prompt_emb_masks] if entity_prompt_emb_masks else None,
562
+ entity_masks=entity_masks,
563
+ attn_kwargs=attn_kwargs,
444
564
  )
445
565
  return noise_pred
446
566
 
@@ -457,6 +577,20 @@ class QwenImagePipeline(BasePipeline):
457
577
  image_latents = image_latents.squeeze(2).to(device=self.device)
458
578
  return image_latents
459
579
 
580
+ def prepare_eligen(self, entity_prompts, entity_masks, width, height):
581
+ entity_masks = [mask.resize((width // 8, height // 8), resample=Image.NEAREST) for mask in entity_masks]
582
+ entity_masks = [self.preprocess_image(mask).mean(dim=1, keepdim=True) > 0 for mask in entity_masks]
583
+ entity_masks = [mask.to(device=self.device, dtype=self.dtype) for mask in entity_masks]
584
+ prompt_embs, prompt_emb_masks = [], []
585
+ negative_prompt_embs, negative_prompt_emb_masks = [], []
586
+ for entity_prompt in entity_prompts:
587
+ prompt_emb, prompt_emb_mask = self.encode_prompt(entity_prompt, 1, 512)
588
+ prompt_embs.append(prompt_emb)
589
+ prompt_emb_masks.append(prompt_emb_mask)
590
+ negative_prompt_embs.append(torch.zeros_like(prompt_emb))
591
+ negative_prompt_emb_masks.append(torch.zeros_like(prompt_emb_mask))
592
+ return prompt_embs, prompt_emb_masks, negative_prompt_embs, negative_prompt_emb_masks, entity_masks
593
+
460
594
  def calculate_dimensions(self, target_area, ratio):
461
595
  width = math.sqrt(target_area * ratio)
462
596
  height = width / ratio
@@ -469,21 +603,51 @@ class QwenImagePipeline(BasePipeline):
469
603
  self,
470
604
  prompt: str,
471
605
  negative_prompt: str = "",
472
- input_image: Image.Image | None = None, # use for img2img
606
+ # single image for edit, list for edit plus(QwenImageEdit2509)
607
+ input_image: List[Image.Image] | Image.Image | None = None,
473
608
  cfg_scale: float = 4.0, # true cfg
474
- height: int = 1328,
475
- width: int = 1328,
609
+ height: Optional[int] = None,
610
+ width: Optional[int] = None,
476
611
  num_inference_steps: int = 50,
477
612
  seed: int | None = None,
613
+ controlnet_params: List[QwenImageControlNetParams] | QwenImageControlNetParams = [],
478
614
  progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
615
+ # eligen
616
+ entity_prompts: Optional[List[str]] = None,
617
+ entity_masks: Optional[List[Image.Image]] = None,
479
618
  ):
480
- if input_image is not None:
481
- width, height = input_image.size
482
- width, height = self.calculate_dimensions(1024 * 1024, width / height)
483
- input_image = input_image.resize((width, height), Image.LANCZOS)
619
+ assert (height is None) == (width is None), "height and width should be set together"
620
+ is_edit_plus = isinstance(input_image, list)
484
621
 
622
+ if input_image is not None:
623
+ if not isinstance(input_image, list):
624
+ input_image = [input_image]
625
+ condition_images = []
626
+ vae_images = []
627
+ for img in input_image:
628
+ img_width, img_height = img.size
629
+ condition_width, condition_height = self.calculate_dimensions(384 * 384, img_width / img_height)
630
+ vae_width, vae_height = self.calculate_dimensions(1024 * 1024, img_width / img_height)
631
+ condition_images.append(img.resize((condition_width, condition_height), Image.LANCZOS))
632
+ vae_images.append(img.resize((vae_width, vae_height), Image.LANCZOS))
633
+ if width is None and height is None:
634
+ width, height = vae_images[-1].size
635
+
636
+ if width is None and height is None:
637
+ width, height = 1328, 1328
485
638
  self.validate_image_size(height, width, minimum=64, multiple_of=16)
486
639
 
640
+ if not isinstance(controlnet_params, list):
641
+ controlnet_params = [controlnet_params]
642
+
643
+ context_latents = None
644
+ for param in controlnet_params:
645
+ self.load_lora(param.model, param.scale, fused=False, save_original_weight=False)
646
+ if param.control_type == QwenImageControlType.in_context:
647
+ width, height = param.image.size
648
+ self.validate_image_size(height, width, minimum=64, multiple_of=16)
649
+ context_latents = self.prepare_image_latents(param.image.resize((width, height), Image.LANCZOS))
650
+
487
651
  noise = self.generate_noise((1, 16, height // 8, width // 8), seed=seed, device="cpu", dtype=self.dtype).to(
488
652
  device=self.device
489
653
  )
@@ -496,39 +660,60 @@ class QwenImagePipeline(BasePipeline):
496
660
 
497
661
  self.load_models_to_device(["vae"])
498
662
  if input_image:
499
- image_latents = self.prepare_image_latents(input_image)
663
+ image_latents = [self.prepare_image_latents(img) for img in vae_images]
500
664
  else:
501
665
  image_latents = None
502
666
 
503
667
  self.load_models_to_device(["encoder"])
504
668
  if image_latents is not None:
505
- prompt_embeds, prompt_embeds_mask = self.encode_prompt_with_image(prompt, input_image, 1, 4096)
669
+ prompt_emb, prompt_emb_mask = self.encode_prompt_with_image(
670
+ prompt, vae_images, condition_images, 1, 4096, is_edit_plus
671
+ )
506
672
  if cfg_scale > 1.0 and negative_prompt != "":
507
- negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt_with_image(
508
- negative_prompt, input_image, 1, 4096
673
+ negative_prompt_emb, negative_prompt_emb_mask = self.encode_prompt_with_image(
674
+ negative_prompt, vae_images, condition_images, 1, 4096, is_edit_plus
509
675
  )
510
676
  else:
511
- negative_prompt_embeds, negative_prompt_embeds_mask = None, None
677
+ negative_prompt_emb, negative_prompt_emb_mask = None, None
512
678
  else:
513
- prompt_embeds, prompt_embeds_mask = self.encode_prompt(prompt, 1, 4096)
679
+ prompt_emb, prompt_emb_mask = self.encode_prompt(prompt, 1, 4096)
514
680
  if cfg_scale > 1.0 and negative_prompt != "":
515
- negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(negative_prompt, 1, 4096)
681
+ negative_prompt_emb, negative_prompt_emb_mask = self.encode_prompt(negative_prompt, 1, 4096)
516
682
  else:
517
- negative_prompt_embeds, negative_prompt_embeds_mask = None, None
683
+ negative_prompt_emb, negative_prompt_emb_mask = None, None
684
+
685
+ entity_prompt_embs, entity_prompt_emb_masks = None, None
686
+ negative_entity_prompt_embs, negative_entity_prompt_emb_masks = None, None
687
+ if entity_prompts is not None and entity_masks is not None:
688
+ assert len(entity_prompts) == len(entity_masks), "entity_prompts and entity_masks must have the same length"
689
+ (
690
+ entity_prompt_embs,
691
+ entity_prompt_emb_masks,
692
+ negative_entity_prompt_embs,
693
+ negative_entity_prompt_emb_masks,
694
+ entity_masks,
695
+ ) = self.prepare_eligen(entity_prompts, entity_masks, width, height)
696
+
518
697
  self.model_lifecycle_finish(["encoder"])
519
698
 
699
+ self.load_models_to_device(["dit"])
520
700
  hide_progress = dist.is_initialized() and dist.get_rank() != 0
521
-
522
701
  for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
523
702
  timestep = timestep.unsqueeze(0).to(dtype=self.dtype)
524
703
  noise_pred = self.predict_noise_with_cfg(
525
704
  latents=latents,
526
705
  image_latents=image_latents,
527
706
  timestep=timestep,
528
- prompt_emb=prompt_embeds,
529
- negative_prompt_emb=negative_prompt_embeds,
530
- prompt_embeds_mask=prompt_embeds_mask,
531
- negative_prompt_embeds_mask=negative_prompt_embeds_mask,
707
+ prompt_emb=prompt_emb,
708
+ negative_prompt_emb=negative_prompt_emb,
709
+ prompt_emb_mask=prompt_emb_mask,
710
+ negative_prompt_emb_mask=negative_prompt_emb_mask,
711
+ context_latents=context_latents,
712
+ entity_prompt_embs=entity_prompt_embs,
713
+ entity_prompt_emb_masks=entity_prompt_emb_masks,
714
+ negative_entity_prompt_embs=negative_entity_prompt_embs,
715
+ negative_entity_prompt_emb_masks=negative_entity_prompt_emb_masks,
716
+ entity_masks=entity_masks,
532
717
  cfg_scale=cfg_scale,
533
718
  batch_cfg=self.config.batch_cfg,
534
719
  )
@@ -181,7 +181,7 @@ class SDXLImagePipeline(BasePipeline):
181
181
 
182
182
  @classmethod
183
183
  def from_state_dict(cls, state_dicts: SDXLStateDicts, config: SDXLPipelineConfig) -> "SDXLImagePipeline":
184
- init_device = "cpu" if config.offload_mode else config.device
184
+ init_device = "cpu" if config.offload_mode is not None else config.device
185
185
  tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
186
186
  tokenizer_2 = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_2_CONF_PATH)
187
187
  with LoRAContext():
@@ -1,3 +1,7 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
1
5
  def accumulate(result, new_item):
2
6
  if result is None:
3
7
  return new_item
@@ -17,3 +21,51 @@ def calculate_shift(
17
21
  b = base_shift - m * base_seq_len
18
22
  mu = image_seq_len * m + b
19
23
  return mu
24
+
25
+
26
+ def pad_and_concat(
27
+ tensor1: torch.Tensor,
28
+ tensor2: torch.Tensor,
29
+ concat_dim: int = 0,
30
+ pad_dim: int = 1,
31
+ ) -> torch.Tensor:
32
+ """
33
+ Concatenate two tensors along a specified dimension after padding along another dimension.
34
+
35
+ Assumes input tensors have shape (b, s, d), where:
36
+ - b: batch dimension
37
+ - s: sequence dimension (may differ)
38
+ - d: feature dimension
39
+
40
+ Args:
41
+ tensor1: First tensor with shape (b1, s1, d)
42
+ tensor2: Second tensor with shape (b2, s2, d)
43
+ concat_dim: Dimension to concatenate along, default is 0 (batch dimension)
44
+ pad_dim: Dimension to pad along, default is 1 (sequence dimension)
45
+
46
+ Returns:
47
+ Concatenated tensor, shape depends on concat_dim and pad_dim choices
48
+ """
49
+ assert tensor1.dim() == tensor2.dim(), "Both tensors must have the same number of dimensions"
50
+ assert concat_dim != pad_dim, "concat_dim and pad_dim cannot be the same"
51
+
52
+ len1, len2 = tensor1.shape[pad_dim], tensor2.shape[pad_dim]
53
+ max_len = max(len1, len2)
54
+
55
+ # Calculate the position of pad_dim in the padding list
56
+ # Padding format: from the last dimension, each pair represents (dim_n_left, dim_n_right, ..., dim_0_left, dim_0_right)
57
+ ndim = tensor1.dim()
58
+ padding = [0] * (2 * ndim)
59
+ pad_right_idx = -2 * pad_dim - 1
60
+
61
+ if len1 < max_len:
62
+ pad_len = max_len - len1
63
+ padding[pad_right_idx] = pad_len
64
+ tensor1 = F.pad(tensor1, padding, mode="constant", value=0)
65
+ elif len2 < max_len:
66
+ pad_len = max_len - len2
67
+ padding[pad_right_idx] = pad_len
68
+ tensor2 = F.pad(tensor2, padding, mode="constant", value=0)
69
+
70
+ # Concatenate along the specified dimension
71
+ return torch.cat([tensor1, tensor2], dim=concat_dim)