diffsynth-engine 0.6.1.dev14__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. diffsynth_engine/__init__.py +6 -2
  2. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  3. diffsynth_engine/configs/__init__.py +10 -6
  4. diffsynth_engine/configs/pipeline.py +17 -10
  5. diffsynth_engine/models/base.py +1 -1
  6. diffsynth_engine/models/basic/attention.py +59 -20
  7. diffsynth_engine/models/basic/transformer_helper.py +36 -2
  8. diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
  9. diffsynth_engine/models/flux/flux_controlnet.py +7 -19
  10. diffsynth_engine/models/flux/flux_dit.py +27 -38
  11. diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
  12. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  13. diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
  14. diffsynth_engine/models/qwen_image/qwen_image_dit.py +28 -34
  15. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  16. diffsynth_engine/models/wan/wan_audio_encoder.py +0 -1
  17. diffsynth_engine/models/wan/wan_dit.py +64 -27
  18. diffsynth_engine/pipelines/base.py +36 -4
  19. diffsynth_engine/pipelines/flux_image.py +19 -17
  20. diffsynth_engine/pipelines/qwen_image.py +45 -36
  21. diffsynth_engine/pipelines/sdxl_image.py +1 -1
  22. diffsynth_engine/pipelines/utils.py +52 -0
  23. diffsynth_engine/pipelines/wan_s2v.py +4 -9
  24. diffsynth_engine/pipelines/wan_video.py +43 -19
  25. diffsynth_engine/tokenizers/base.py +6 -0
  26. diffsynth_engine/tokenizers/qwen2.py +12 -4
  27. diffsynth_engine/utils/constants.py +13 -12
  28. diffsynth_engine/utils/flag.py +6 -0
  29. diffsynth_engine/utils/parallel.py +62 -29
  30. {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
  31. {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +45 -43
  32. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  33. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  34. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  35. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  36. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  37. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  38. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  39. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  40. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  41. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  42. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  43. {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
  44. {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
  45. {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from diffsynth_engine.utils.constants import (
17
17
  WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
18
18
  WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
19
19
  WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
20
+ WAN_DIT_KEYMAP_FILE,
20
21
  )
21
22
  from diffsynth_engine.utils.gguf import gguf_inference
22
23
  from diffsynth_engine.utils.fp8_linear import fp8_inference
@@ -30,6 +31,9 @@ from diffsynth_engine.utils.parallel import (
30
31
  T5_TOKEN_NUM = 512
31
32
  FLF_TOKEN_NUM = 257 * 2
32
33
 
34
+ with open(WAN_DIT_KEYMAP_FILE, "r", encoding="utf-8") as f:
35
+ config = json.load(f)
36
+
33
37
 
34
38
  def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
35
39
  return x * (1 + scale) + shift
@@ -73,7 +77,7 @@ class SelfAttention(nn.Module):
73
77
  dim: int,
74
78
  num_heads: int,
75
79
  eps: float = 1e-6,
76
- attn_kwargs: Optional[Dict[str, Any]] = None,
80
+ use_vsa: bool = False,
77
81
  device: str = "cuda:0",
78
82
  dtype: torch.dtype = torch.bfloat16,
79
83
  ):
@@ -86,19 +90,25 @@ class SelfAttention(nn.Module):
86
90
  self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
87
91
  self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
88
92
  self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
89
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
93
+ self.gate_compress = nn.Linear(dim, dim, device=device, dtype=dtype) if use_vsa else None
90
94
 
91
- def forward(self, x, freqs):
95
+ def forward(self, x, freqs, attn_kwargs=None):
92
96
  q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x)
97
+ g = self.gate_compress(x) if self.gate_compress is not None else None
98
+
93
99
  num_heads = q.shape[2] // self.head_dim
94
100
  q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
95
101
  k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
96
102
  v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
103
+ g = rearrange(g, "b s (n d) -> b s n d", n=num_heads) if g is not None else None
104
+
105
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
97
106
  x = attention_ops.attention(
98
107
  q=rope_apply(q, freqs),
99
108
  k=rope_apply(k, freqs),
100
109
  v=v,
101
- **self.attn_kwargs,
110
+ g=g,
111
+ **attn_kwargs,
102
112
  )
103
113
  x = x.flatten(2)
104
114
  return self.o(x)
@@ -111,7 +121,6 @@ class CrossAttention(nn.Module):
111
121
  num_heads: int,
112
122
  eps: float = 1e-6,
113
123
  has_image_input: bool = False,
114
- attn_kwargs: Optional[Dict[str, Any]] = None,
115
124
  device: str = "cuda:0",
116
125
  dtype: torch.dtype = torch.bfloat16,
117
126
  ):
@@ -130,9 +139,8 @@ class CrossAttention(nn.Module):
130
139
  self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
131
140
  self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
132
141
  self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
133
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
134
142
 
135
- def forward(self, x: torch.Tensor, y: torch.Tensor):
143
+ def forward(self, x: torch.Tensor, y: torch.Tensor, attn_kwargs=None):
136
144
  if self.has_image_input:
137
145
  img = y[:, :-T5_TOKEN_NUM]
138
146
  ctx = y[:, -T5_TOKEN_NUM:]
@@ -144,12 +152,16 @@ class CrossAttention(nn.Module):
144
152
  k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
145
153
  v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
146
154
 
147
- x = attention(q, k, v, **self.attn_kwargs).flatten(2)
155
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
156
+ if attn_kwargs.get("attn_impl", None) == "vsa":
157
+ attn_kwargs = attn_kwargs.copy()
158
+ attn_kwargs["attn_impl"] = "sdpa"
159
+ x = attention(q, k, v, **attn_kwargs).flatten(2)
148
160
  if self.has_image_input:
149
161
  k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
150
162
  k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
151
163
  v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
152
- y = attention(q, k_img, v_img, **self.attn_kwargs).flatten(2)
164
+ y = attention(q, k_img, v_img, **attn_kwargs).flatten(2)
153
165
  x = x + y
154
166
  return self.o(x)
155
167
 
@@ -162,7 +174,7 @@ class DiTBlock(nn.Module):
162
174
  num_heads: int,
163
175
  ffn_dim: int,
164
176
  eps: float = 1e-6,
165
- attn_kwargs: Optional[Dict[str, Any]] = None,
177
+ use_vsa: bool = False,
166
178
  device: str = "cuda:0",
167
179
  dtype: torch.dtype = torch.bfloat16,
168
180
  ):
@@ -170,9 +182,9 @@ class DiTBlock(nn.Module):
170
182
  self.dim = dim
171
183
  self.num_heads = num_heads
172
184
  self.ffn_dim = ffn_dim
173
- self.self_attn = SelfAttention(dim, num_heads, eps, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
185
+ self.self_attn = SelfAttention(dim, num_heads, eps, use_vsa=use_vsa, device=device, dtype=dtype)
174
186
  self.cross_attn = CrossAttention(
175
- dim, num_heads, eps, has_image_input=has_image_input, attn_kwargs=attn_kwargs, device=device, dtype=dtype
187
+ dim, num_heads, eps, has_image_input=has_image_input, device=device, dtype=dtype
176
188
  )
177
189
  self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
178
190
  self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
@@ -184,14 +196,14 @@ class DiTBlock(nn.Module):
184
196
  )
185
197
  self.modulation = nn.Parameter(torch.randn(1, 6, dim, device=device, dtype=dtype) / dim**0.5)
186
198
 
187
- def forward(self, x, context, t_mod, freqs):
199
+ def forward(self, x, context, t_mod, freqs, attn_kwargs=None):
188
200
  # msa: multi-head self-attention mlp: multi-layer perceptron
189
201
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
190
202
  t.squeeze(1) for t in (self.modulation + t_mod).chunk(6, dim=1)
191
203
  ]
192
204
  input_x = modulate(self.norm1(x), shift_msa, scale_msa)
193
- x = x + gate_msa * self.self_attn(input_x, freqs)
194
- x = x + self.cross_attn(self.norm3(x), context)
205
+ x = x + gate_msa * self.self_attn(input_x, freqs, attn_kwargs)
206
+ x = x + self.cross_attn(self.norm3(x), context, attn_kwargs)
195
207
  input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
196
208
  x = x + gate_mlp * self.ffn(input_x)
197
209
  return x
@@ -249,7 +261,26 @@ class Head(nn.Module):
249
261
 
250
262
 
251
263
  class WanDiTStateDictConverter(StateDictConverter):
264
+ def _from_diffusers(self, state_dict):
265
+ global_rename_dict = config["diffusers"]["global_rename_dict"]
266
+ rename_dict = config["diffusers"]["rename_dict"]
267
+ state_dict_ = {}
268
+ for name, param in state_dict.items():
269
+ suffix = ""
270
+ suffix = ".weight" if name.endswith(".weight") else suffix
271
+ suffix = ".bias" if name.endswith(".bias") else suffix
272
+ prefix = name[: -len(suffix)] if suffix else name
273
+ if prefix in global_rename_dict:
274
+ state_dict_[f"{global_rename_dict[prefix]}{suffix}"] = param
275
+ if prefix.startswith("blocks."):
276
+ _, idx, middle = prefix.split(".", 2)
277
+ if middle in rename_dict:
278
+ state_dict_[f"blocks.{idx}.{rename_dict[middle]}{suffix}"] = param
279
+ return state_dict_
280
+
252
281
  def convert(self, state_dict):
282
+ if "condition_embedder.time_proj.weight" in state_dict:
283
+ return self._from_diffusers(state_dict)
253
284
  return state_dict
254
285
 
255
286
 
@@ -273,7 +304,7 @@ class WanDiT(PreTrainedModel):
273
304
  has_vae_feature: bool = False,
274
305
  fuse_image_latents: bool = False,
275
306
  flf_pos_emb: bool = False,
276
- attn_kwargs: Optional[Dict[str, Any]] = None,
307
+ use_vsa: bool = False,
277
308
  device: str = "cuda:0",
278
309
  dtype: torch.dtype = torch.bfloat16,
279
310
  ):
@@ -307,7 +338,16 @@ class WanDiT(PreTrainedModel):
307
338
  )
308
339
  self.blocks = nn.ModuleList(
309
340
  [
310
- DiTBlock(has_clip_feature, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
341
+ DiTBlock(
342
+ has_clip_feature,
343
+ dim,
344
+ num_heads,
345
+ ffn_dim,
346
+ eps,
347
+ use_vsa,
348
+ device=device,
349
+ dtype=dtype,
350
+ )
311
351
  for _ in range(num_layers)
312
352
  ]
313
353
  )
@@ -344,6 +384,7 @@ class WanDiT(PreTrainedModel):
344
384
  timestep: torch.Tensor,
345
385
  clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
346
386
  y: Optional[torch.Tensor] = None, # vae_encoder(img)
387
+ attn_kwargs: Optional[Dict[str, Any]] = None,
347
388
  ):
348
389
  fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
349
390
  use_cfg = x.shape[0] > 1
@@ -376,7 +417,7 @@ class WanDiT(PreTrainedModel):
376
417
 
377
418
  with sequence_parallel((x, t, t_mod, freqs), seq_dims=(1, 0, 0, 0)):
378
419
  for block in self.blocks:
379
- x = block(x, context, t_mod, freqs)
420
+ x = block(x, context, t_mod, freqs, attn_kwargs)
380
421
  x = self.head(x, t)
381
422
  (x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
382
423
  x = self.unpatchify(x, (f, h, w))
@@ -409,12 +450,11 @@ class WanDiT(PreTrainedModel):
409
450
  config: Dict[str, Any],
410
451
  device: str = "cuda:0",
411
452
  dtype: torch.dtype = torch.bfloat16,
412
- attn_kwargs: Optional[Dict[str, Any]] = None,
413
- assign: bool = True,
453
+ use_vsa: bool = False,
414
454
  ):
415
- model = cls(**config, device="meta", dtype=dtype, attn_kwargs=attn_kwargs)
455
+ model = cls(**config, device="meta", dtype=dtype, use_vsa=use_vsa)
416
456
  model = model.requires_grad_(False)
417
- model.load_state_dict(state_dict, assign=assign)
457
+ model.load_state_dict(state_dict, assign=True)
418
458
  model.to(device=device, dtype=dtype, non_blocking=True)
419
459
  return model
420
460
 
@@ -499,8 +539,5 @@ class WanDiT(PreTrainedModel):
499
539
  for block in self.blocks:
500
540
  block.compile(*args, **kwargs)
501
541
 
502
- for block in self.single_blocks:
503
- block.compile(*args, **kwargs)
504
-
505
- def get_fsdp_modules(self):
506
- return ["blocks"]
542
+ def get_fsdp_module_cls(self):
543
+ return {DiTBlock}
@@ -2,10 +2,18 @@ import os
2
2
  import torch
3
3
  import numpy as np
4
4
  from einops import rearrange
5
- from typing import Dict, List, Tuple, Union
5
+ from typing import Dict, List, Tuple, Union, Optional
6
6
  from PIL import Image
7
7
 
8
- from diffsynth_engine.configs import BaseConfig, BaseStateDicts, LoraConfig
8
+ from diffsynth_engine.configs import (
9
+ BaseConfig,
10
+ BaseStateDicts,
11
+ LoraConfig,
12
+ AttnImpl,
13
+ SpargeAttentionParams,
14
+ VideoSparseAttentionParams,
15
+ )
16
+ from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs
9
17
  from diffsynth_engine.utils.offload import enable_sequential_cpu_offload, offload_model_to_dict, restore_model_from_dict
10
18
  from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
11
19
  from diffsynth_engine.utils.gguf import load_gguf_checkpoint
@@ -33,6 +41,7 @@ class BasePipeline:
33
41
  dtype=torch.float16,
34
42
  ):
35
43
  super().__init__()
44
+ self.config = None
36
45
  self.vae_tiled = vae_tiled
37
46
  self.vae_tile_size = vae_tile_size
38
47
  self.vae_tile_stride = vae_tile_stride
@@ -48,7 +57,7 @@ class BasePipeline:
48
57
  raise NotImplementedError()
49
58
 
50
59
  @classmethod
51
- def from_state_dict(cls, state_dicts: BaseStateDicts, pipeline_config: BaseConfig) -> "BasePipeline":
60
+ def from_state_dict(cls, state_dicts: BaseStateDicts, config: BaseConfig) -> "BasePipeline":
52
61
  raise NotImplementedError()
53
62
 
54
63
  def update_weights(self, state_dicts: BaseStateDicts) -> None:
@@ -70,7 +79,11 @@ class BasePipeline:
70
79
  lora_list: List[Tuple[str, Union[float, LoraConfig]]],
71
80
  fused: bool = True,
72
81
  save_original_weight: bool = False,
82
+ lora_converter: Optional[LoRAStateDictConverter] = None,
73
83
  ):
84
+ if not lora_converter:
85
+ lora_converter = self.lora_converter
86
+
74
87
  for lora_path, lora_item in lora_list:
75
88
  if isinstance(lora_item, float):
76
89
  lora_scale = lora_item
@@ -86,7 +99,7 @@ class BasePipeline:
86
99
  self.apply_scheduler_config(scheduler_config)
87
100
  logger.info(f"Applied scheduler args from LoraConfig: {scheduler_config}")
88
101
 
89
- lora_state_dict = self.lora_converter.convert(state_dict)
102
+ lora_state_dict = lora_converter.convert(state_dict)
90
103
  for model_name, state_dict in lora_state_dict.items():
91
104
  model = getattr(self, model_name)
92
105
  lora_args = []
@@ -256,6 +269,25 @@ class BasePipeline:
256
269
  )
257
270
  return init_latents, latents, sigmas, timesteps
258
271
 
272
+ def get_attn_kwargs(self, latents: torch.Tensor) -> Dict:
273
+ attn_kwargs = {"attn_impl": self.config.dit_attn_impl.value}
274
+ if isinstance(self.config.attn_params, SpargeAttentionParams):
275
+ assert self.config.dit_attn_impl == AttnImpl.SPARGE
276
+ attn_kwargs.update(
277
+ {
278
+ "smooth_k": self.config.attn_params.smooth_k,
279
+ "simthreshd1": self.config.attn_params.simthreshd1,
280
+ "cdfthreshd": self.config.attn_params.cdfthreshd,
281
+ "pvthreshd": self.config.attn_params.pvthreshd,
282
+ }
283
+ )
284
+ elif isinstance(self.config.attn_params, VideoSparseAttentionParams):
285
+ assert self.config.dit_attn_impl == AttnImpl.VSA
286
+ attn_kwargs.update(
287
+ get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.config.attn_params.sparsity, device=self.device)
288
+ )
289
+ return attn_kwargs
290
+
259
291
  def eval(self):
260
292
  for model_name in self.model_names:
261
293
  model = getattr(self, model_name)
@@ -17,7 +17,12 @@ from diffsynth_engine.models.flux import (
17
17
  flux_dit_config,
18
18
  flux_text_encoder_config,
19
19
  )
20
- from diffsynth_engine.configs import FluxPipelineConfig, FluxStateDicts, ControlType, ControlNetParams
20
+ from diffsynth_engine.configs import (
21
+ FluxPipelineConfig,
22
+ FluxStateDicts,
23
+ ControlType,
24
+ ControlNetParams,
25
+ )
21
26
  from diffsynth_engine.models.basic.lora import LoRAContext
22
27
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
23
28
  from diffsynth_engine.pipelines.utils import accumulate, calculate_shift
@@ -143,7 +148,7 @@ class FluxLoRAConverter(LoRAStateDictConverter):
143
148
  layer_id, layer_type = name.split("_", 1)
144
149
  layer_type = layer_type.replace("self_attn_", "self_attn.").replace("mlp_", "mlp.")
145
150
  rename = ".".join(["encoders", layer_id, clip_attn_rename_dict[layer_type]])
146
-
151
+
147
152
  lora_args = {}
148
153
  lora_args["alpha"] = param
149
154
  lora_args["up"] = lora_state_dict[origin_key.replace(".alpha", ".lora_up.weight")]
@@ -507,29 +512,20 @@ class FluxImagePipeline(BasePipeline):
507
512
  vae_encoder = FluxVAEEncoder.from_state_dict(state_dicts.vae, device=init_device, dtype=config.vae_dtype)
508
513
 
509
514
  with LoRAContext():
510
- attn_kwargs = {
511
- "attn_impl": config.dit_attn_impl.value,
512
- "sparge_smooth_k": config.sparge_smooth_k,
513
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
514
- "sparge_simthreshd1": config.sparge_simthreshd1,
515
- "sparge_pvthreshd": config.sparge_pvthreshd,
516
- }
517
515
  if config.use_fbcache:
518
516
  dit = FluxDiTFBCache.from_state_dict(
519
517
  state_dicts.model,
520
- device=init_device,
518
+ device=("cpu" if config.use_fsdp else init_device),
521
519
  dtype=config.model_dtype,
522
520
  in_channel=config.control_type.get_in_channel(),
523
- attn_kwargs=attn_kwargs,
524
521
  relative_l1_threshold=config.fbcache_relative_l1_threshold,
525
522
  )
526
523
  else:
527
524
  dit = FluxDiT.from_state_dict(
528
525
  state_dicts.model,
529
- device=init_device,
526
+ device=("cpu" if config.use_fsdp else init_device),
530
527
  dtype=config.model_dtype,
531
528
  in_channel=config.control_type.get_in_channel(),
532
- attn_kwargs=attn_kwargs,
533
529
  )
534
530
  if config.use_fp8_linear:
535
531
  enable_fp8_linear(dit)
@@ -573,7 +569,7 @@ class FluxImagePipeline(BasePipeline):
573
569
  self.update_component(self.vae_encoder, state_dicts.vae, self.config.device, self.config.vae_dtype)
574
570
 
575
571
  def compile(self):
576
- self.dit.compile_repeated_blocks(dynamic=True)
572
+ self.dit.compile_repeated_blocks()
577
573
 
578
574
  def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
579
575
  assert self.config.tp_degree is None or self.config.tp_degree == 1, (
@@ -755,6 +751,7 @@ class FluxImagePipeline(BasePipeline):
755
751
  latents = latents.to(self.dtype)
756
752
  self.load_models_to_device(["dit"])
757
753
 
754
+ attn_kwargs = self.get_attn_kwargs(latents)
758
755
  noise_pred = self.dit(
759
756
  hidden_states=latents,
760
757
  timestep=timestep,
@@ -766,6 +763,7 @@ class FluxImagePipeline(BasePipeline):
766
763
  image_ids=image_ids,
767
764
  controlnet_double_block_output=double_block_output,
768
765
  controlnet_single_block_output=single_block_output,
766
+ attn_kwargs=attn_kwargs,
769
767
  )
770
768
  noise_pred = noise_pred[:, :image_seq_len]
771
769
  noise_pred = self.dit.unpatchify(noise_pred, height, width)
@@ -830,7 +828,7 @@ class FluxImagePipeline(BasePipeline):
830
828
  masked_image = image.clone()
831
829
  masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1
832
830
  latent = self.encode_image(masked_image)
833
- mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3]))
831
+ mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3])).to(latent.dtype)
834
832
  mask = 1 - mask
835
833
  latent = torch.cat([latent, mask], dim=1)
836
834
  elif self.config.control_type == ControlType.bfl_fill:
@@ -887,6 +885,8 @@ class FluxImagePipeline(BasePipeline):
887
885
  if self.offload_mode is not None:
888
886
  empty_cache()
889
887
  param.model.to(self.device)
888
+
889
+ attn_kwargs = self.get_attn_kwargs(latents)
890
890
  double_block_output, single_block_output = param.model(
891
891
  hidden_states=latents,
892
892
  control_condition=control_condition,
@@ -897,6 +897,7 @@ class FluxImagePipeline(BasePipeline):
897
897
  image_ids=image_ids,
898
898
  text_ids=text_ids,
899
899
  guidance=guidance,
900
+ attn_kwargs=attn_kwargs,
900
901
  )
901
902
  if self.offload_mode is not None:
902
903
  param.model.to("cpu")
@@ -983,8 +984,9 @@ class FluxImagePipeline(BasePipeline):
983
984
  elif self.ip_adapter is not None:
984
985
  image_emb = self.ip_adapter.encode_image(ref_image)
985
986
  elif self.redux is not None:
986
- image_prompt_embeds = self.redux(ref_image)
987
- positive_prompt_emb = torch.cat([positive_prompt_emb, image_prompt_embeds], dim=1)
987
+ ref_prompt_embeds = self.redux(ref_image)
988
+ flattened_ref_emb = ref_prompt_embeds.view(1, -1, ref_prompt_embeds.size(-1))
989
+ positive_prompt_emb = torch.cat([positive_prompt_emb, flattened_ref_emb], dim=1)
988
990
 
989
991
  # Extra input
990
992
  image_ids, text_ids, guidance = self.prepare_extra_input(
@@ -24,7 +24,7 @@ from diffsynth_engine.models.qwen_image import (
24
24
  from diffsynth_engine.models.qwen_image import QwenImageVAE
25
25
  from diffsynth_engine.tokenizers import Qwen2TokenizerFast, Qwen2VLProcessor
26
26
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
27
- from diffsynth_engine.pipelines.utils import calculate_shift
27
+ from diffsynth_engine.pipelines.utils import calculate_shift, pad_and_concat
28
28
  from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
29
29
  from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
30
30
  from diffsynth_engine.utils.constants import (
@@ -91,7 +91,7 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
91
91
  if "lora_A.weight" in key:
92
92
  lora_a_suffix = "lora_A.weight"
93
93
  lora_b_suffix = "lora_B.weight"
94
-
94
+
95
95
  if lora_a_suffix is None:
96
96
  continue
97
97
 
@@ -147,9 +147,18 @@ class QwenImagePipeline(BasePipeline):
147
147
  self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
148
148
  self.prompt_template_encode_start_idx = 34
149
149
  # qwen image edit
150
- self.edit_prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
150
+ self.edit_system_prompt = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
151
+ self.edit_prompt_template_encode = (
152
+ "<|im_start|>system\n"
153
+ + self.edit_system_prompt
154
+ + "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
155
+ )
151
156
  # qwen image edit plus
152
- self.edit_plus_prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
157
+ self.edit_plus_prompt_template_encode = (
158
+ "<|im_start|>system\n"
159
+ + self.edit_system_prompt
160
+ + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
161
+ )
153
162
 
154
163
  self.edit_prompt_template_encode_start_idx = 64
155
164
 
@@ -185,6 +194,7 @@ class QwenImagePipeline(BasePipeline):
185
194
  logger.info(f"loading state dict from {config.vae_path} ...")
186
195
  vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
187
196
 
197
+ encoder_state_dict = None
188
198
  if config.encoder_path is None:
189
199
  config.encoder_path = fetch_model(
190
200
  "MusePublic/Qwen-image",
@@ -196,8 +206,11 @@ class QwenImagePipeline(BasePipeline):
196
206
  "text_encoder/model-00004-of-00004.safetensors",
197
207
  ],
198
208
  )
199
- logger.info(f"loading state dict from {config.encoder_path} ...")
200
- encoder_state_dict = cls.load_model_checkpoint(config.encoder_path, device="cpu", dtype=config.encoder_dtype)
209
+ if config.load_encoder:
210
+ logger.info(f"loading state dict from {config.encoder_path} ...")
211
+ encoder_state_dict = cls.load_model_checkpoint(
212
+ config.encoder_path, device="cpu", dtype=config.encoder_dtype
213
+ )
201
214
 
202
215
  state_dicts = QwenImageStateDicts(
203
216
  model=model_state_dict,
@@ -224,22 +237,25 @@ class QwenImagePipeline(BasePipeline):
224
237
  @classmethod
225
238
  def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipelineConfig) -> "QwenImagePipeline":
226
239
  init_device = "cpu" if config.offload_mode is not None else config.device
227
- tokenizer = Qwen2TokenizerFast.from_pretrained(QWEN_IMAGE_TOKENIZER_CONF_PATH)
228
- processor = Qwen2VLProcessor.from_pretrained(
229
- tokenizer_config_path=QWEN_IMAGE_TOKENIZER_CONF_PATH,
230
- image_processor_config_path=QWEN_IMAGE_PROCESSOR_CONFIG_FILE,
231
- )
232
- with open(QWEN_IMAGE_VISION_CONFIG_FILE, "r", encoding="utf-8") as f:
233
- vision_config = Qwen2_5_VLVisionConfig(**json.load(f))
234
- with open(QWEN_IMAGE_CONFIG_FILE, "r", encoding="utf-8") as f:
235
- text_config = Qwen2_5_VLConfig(**json.load(f))
236
- encoder = Qwen2_5_VLForConditionalGeneration.from_state_dict(
237
- state_dicts.encoder,
238
- vision_config=vision_config,
239
- config=text_config,
240
- device=init_device,
241
- dtype=config.encoder_dtype,
242
- )
240
+ tokenizer, processor, encoder = None, None, None
241
+ if config.load_encoder:
242
+ tokenizer = Qwen2TokenizerFast.from_pretrained(QWEN_IMAGE_TOKENIZER_CONF_PATH)
243
+ processor = Qwen2VLProcessor.from_pretrained(
244
+ tokenizer_config_path=QWEN_IMAGE_TOKENIZER_CONF_PATH,
245
+ image_processor_config_path=QWEN_IMAGE_PROCESSOR_CONFIG_FILE,
246
+ )
247
+ with open(QWEN_IMAGE_VISION_CONFIG_FILE, "r", encoding="utf-8") as f:
248
+ vision_config = Qwen2_5_VLVisionConfig(**json.load(f))
249
+ with open(QWEN_IMAGE_CONFIG_FILE, "r", encoding="utf-8") as f:
250
+ text_config = Qwen2_5_VLConfig(**json.load(f))
251
+ encoder = Qwen2_5_VLForConditionalGeneration.from_state_dict(
252
+ state_dicts.encoder,
253
+ vision_config=vision_config,
254
+ config=text_config,
255
+ device=("cpu" if config.use_fsdp else init_device),
256
+ dtype=config.encoder_dtype,
257
+ )
258
+
243
259
  with open(QWEN_IMAGE_VAE_CONFIG_FILE, "r", encoding="utf-8") as f:
244
260
  vae_config = json.load(f)
245
261
  vae = QwenImageVAE.from_state_dict(
@@ -247,27 +263,18 @@ class QwenImagePipeline(BasePipeline):
247
263
  )
248
264
 
249
265
  with LoRAContext():
250
- attn_kwargs = {
251
- "attn_impl": config.dit_attn_impl.value,
252
- "sparge_smooth_k": config.sparge_smooth_k,
253
- "sparge_cdfthreshd": config.sparge_cdfthreshd,
254
- "sparge_simthreshd1": config.sparge_simthreshd1,
255
- "sparge_pvthreshd": config.sparge_pvthreshd,
256
- }
257
266
  if config.use_fbcache:
258
267
  dit = QwenImageDiTFBCache.from_state_dict(
259
268
  state_dicts.model,
260
- device=init_device,
269
+ device=("cpu" if config.use_fsdp else init_device),
261
270
  dtype=config.model_dtype,
262
- attn_kwargs=attn_kwargs,
263
271
  relative_l1_threshold=config.fbcache_relative_l1_threshold,
264
272
  )
265
273
  else:
266
274
  dit = QwenImageDiT.from_state_dict(
267
275
  state_dicts.model,
268
- device=init_device,
276
+ device=("cpu" if config.use_fsdp else init_device),
269
277
  dtype=config.model_dtype,
270
- attn_kwargs=attn_kwargs,
271
278
  )
272
279
  if config.use_fp8_linear:
273
280
  enable_fp8_linear(dit)
@@ -307,7 +314,7 @@ class QwenImagePipeline(BasePipeline):
307
314
  self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
308
315
 
309
316
  def compile(self):
310
- self.dit.compile_repeated_blocks(dynamic=True)
317
+ self.dit.compile_repeated_blocks()
311
318
 
312
319
  def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
313
320
  assert self.config.tp_degree is None or self.config.tp_degree == 1, (
@@ -493,8 +500,8 @@ class QwenImagePipeline(BasePipeline):
493
500
  else:
494
501
  # cfg by predict noise in one batch
495
502
  bs, _, h, w = latents.shape
496
- prompt_emb = torch.cat([prompt_emb, negative_prompt_emb], dim=0)
497
- prompt_emb_mask = torch.cat([prompt_emb_mask, negative_prompt_emb_mask], dim=0)
503
+ prompt_emb = pad_and_concat(prompt_emb, negative_prompt_emb)
504
+ prompt_emb_mask = pad_and_concat(prompt_emb_mask, negative_prompt_emb_mask)
498
505
  if entity_prompt_embs is not None:
499
506
  entity_prompt_embs = [
500
507
  torch.cat([x, y], dim=0) for x, y in zip(entity_prompt_embs, negative_entity_prompt_embs)
@@ -542,6 +549,7 @@ class QwenImagePipeline(BasePipeline):
542
549
  entity_masks: Optional[List[torch.Tensor]] = None,
543
550
  ):
544
551
  self.load_models_to_device(["dit"])
552
+ attn_kwargs = self.get_attn_kwargs(latents)
545
553
  noise_pred = self.dit(
546
554
  image=latents,
547
555
  edit=image_latents,
@@ -552,6 +560,7 @@ class QwenImagePipeline(BasePipeline):
552
560
  entity_text=entity_prompt_embs,
553
561
  entity_seq_lens=[mask.sum(dim=1) for mask in entity_prompt_emb_masks] if entity_prompt_emb_masks else None,
554
562
  entity_masks=entity_masks,
563
+ attn_kwargs=attn_kwargs,
555
564
  )
556
565
  return noise_pred
557
566
 
@@ -181,7 +181,7 @@ class SDXLImagePipeline(BasePipeline):
181
181
 
182
182
  @classmethod
183
183
  def from_state_dict(cls, state_dicts: SDXLStateDicts, config: SDXLPipelineConfig) -> "SDXLImagePipeline":
184
- init_device = "cpu" if config.offload_mode else config.device
184
+ init_device = "cpu" if config.offload_mode is not None else config.device
185
185
  tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
186
186
  tokenizer_2 = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_2_CONF_PATH)
187
187
  with LoRAContext():
@@ -1,3 +1,7 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
1
5
  def accumulate(result, new_item):
2
6
  if result is None:
3
7
  return new_item
@@ -17,3 +21,51 @@ def calculate_shift(
17
21
  b = base_shift - m * base_seq_len
18
22
  mu = image_seq_len * m + b
19
23
  return mu
24
+
25
+
26
+ def pad_and_concat(
27
+ tensor1: torch.Tensor,
28
+ tensor2: torch.Tensor,
29
+ concat_dim: int = 0,
30
+ pad_dim: int = 1,
31
+ ) -> torch.Tensor:
32
+ """
33
+ Concatenate two tensors along a specified dimension after padding along another dimension.
34
+
35
+ Assumes input tensors have shape (b, s, d), where:
36
+ - b: batch dimension
37
+ - s: sequence dimension (may differ)
38
+ - d: feature dimension
39
+
40
+ Args:
41
+ tensor1: First tensor with shape (b1, s1, d)
42
+ tensor2: Second tensor with shape (b2, s2, d)
43
+ concat_dim: Dimension to concatenate along, default is 0 (batch dimension)
44
+ pad_dim: Dimension to pad along, default is 1 (sequence dimension)
45
+
46
+ Returns:
47
+ Concatenated tensor, shape depends on concat_dim and pad_dim choices
48
+ """
49
+ assert tensor1.dim() == tensor2.dim(), "Both tensors must have the same number of dimensions"
50
+ assert concat_dim != pad_dim, "concat_dim and pad_dim cannot be the same"
51
+
52
+ len1, len2 = tensor1.shape[pad_dim], tensor2.shape[pad_dim]
53
+ max_len = max(len1, len2)
54
+
55
+ # Calculate the position of pad_dim in the padding list
56
+ # Padding format: from the last dimension, each pair represents (dim_n_left, dim_n_right, ..., dim_0_left, dim_0_right)
57
+ ndim = tensor1.dim()
58
+ padding = [0] * (2 * ndim)
59
+ pad_right_idx = -2 * pad_dim - 1
60
+
61
+ if len1 < max_len:
62
+ pad_len = max_len - len1
63
+ padding[pad_right_idx] = pad_len
64
+ tensor1 = F.pad(tensor1, padding, mode="constant", value=0)
65
+ elif len2 < max_len:
66
+ pad_len = max_len - len2
67
+ padding[pad_right_idx] = pad_len
68
+ tensor2 = F.pad(tensor2, padding, mode="constant", value=0)
69
+
70
+ # Concatenate along the specified dimension
71
+ return torch.cat([tensor1, tensor2], dim=concat_dim)