diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. diffsynth_engine/__init__.py +12 -0
  2. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
  3. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
  4. diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
  5. diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
  6. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  7. diffsynth_engine/configs/__init__.py +16 -1
  8. diffsynth_engine/configs/controlnet.py +13 -0
  9. diffsynth_engine/configs/pipeline.py +37 -11
  10. diffsynth_engine/models/base.py +1 -1
  11. diffsynth_engine/models/basic/attention.py +105 -43
  12. diffsynth_engine/models/basic/transformer_helper.py +36 -2
  13. diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
  14. diffsynth_engine/models/flux/flux_controlnet.py +16 -30
  15. diffsynth_engine/models/flux/flux_dit.py +49 -62
  16. diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
  17. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  18. diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
  19. diffsynth_engine/models/flux/flux_vae.py +20 -2
  20. diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
  21. diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
  22. diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
  23. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  24. diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
  25. diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
  26. diffsynth_engine/models/sd/sd_unet.py +1 -1
  27. diffsynth_engine/models/sd3/sd3_dit.py +1 -1
  28. diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
  29. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
  30. diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
  31. diffsynth_engine/models/vae/vae.py +1 -1
  32. diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
  33. diffsynth_engine/models/wan/wan_dit.py +65 -28
  34. diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
  35. diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
  36. diffsynth_engine/models/wan/wan_vae.py +2 -2
  37. diffsynth_engine/pipelines/base.py +73 -7
  38. diffsynth_engine/pipelines/flux_image.py +139 -120
  39. diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
  40. diffsynth_engine/pipelines/qwen_image.py +272 -87
  41. diffsynth_engine/pipelines/sdxl_image.py +1 -1
  42. diffsynth_engine/pipelines/utils.py +52 -0
  43. diffsynth_engine/pipelines/wan_s2v.py +25 -14
  44. diffsynth_engine/pipelines/wan_video.py +43 -19
  45. diffsynth_engine/tokenizers/base.py +6 -0
  46. diffsynth_engine/tokenizers/qwen2.py +12 -4
  47. diffsynth_engine/utils/constants.py +13 -12
  48. diffsynth_engine/utils/download.py +4 -2
  49. diffsynth_engine/utils/env.py +2 -0
  50. diffsynth_engine/utils/flag.py +6 -0
  51. diffsynth_engine/utils/loader.py +25 -6
  52. diffsynth_engine/utils/parallel.py +62 -29
  53. diffsynth_engine/utils/video.py +3 -1
  54. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
  55. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
  56. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  57. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  58. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  59. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  60. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  61. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  62. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  63. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  64. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  65. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  66. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  67. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
  68. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
  69. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,20 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
- from typing import Any, Dict, Tuple, Union, Optional
3
+ from typing import Any, Dict, List, Tuple, Union, Optional
4
4
  from einops import rearrange
5
5
 
6
6
  from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
7
7
  from diffsynth_engine.models.basic import attention as attention_ops
8
8
  from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
9
- from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, ApproximateGELU, RMSNorm
9
+ from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, GELU, RMSNorm
10
10
  from diffsynth_engine.utils.gguf import gguf_inference
11
11
  from diffsynth_engine.utils.fp8_linear import fp8_inference
12
- from diffsynth_engine.utils.parallel import cfg_parallel, cfg_parallel_unshard
12
+ from diffsynth_engine.utils.parallel import (
13
+ cfg_parallel,
14
+ cfg_parallel_unshard,
15
+ sequence_parallel,
16
+ sequence_parallel_unshard,
17
+ )
13
18
 
14
19
 
15
20
  class QwenImageDiTStateDictConverter(StateDictConverter):
@@ -139,7 +144,7 @@ class QwenFeedForward(nn.Module):
139
144
  super().__init__()
140
145
  inner_dim = int(dim * 4)
141
146
  self.net = nn.ModuleList([])
142
- self.net.append(ApproximateGELU(dim, inner_dim, device=device, dtype=dtype))
147
+ self.net.append(GELU(dim, inner_dim, approximate="tanh", device=device, dtype=dtype))
143
148
  self.net.append(nn.Dropout(dropout))
144
149
  self.net.append(nn.Linear(inner_dim, dim_out, device=device, dtype=dtype))
145
150
 
@@ -150,8 +155,8 @@ class QwenFeedForward(nn.Module):
150
155
 
151
156
 
152
157
  def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]):
153
- x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
154
- x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
158
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) # (b, s, h, d) -> (b, s, h, d/2, 2)
159
+ x_out = torch.view_as_real(x_rotated * freqs_cis.unsqueeze(1)).flatten(3) # (b, s, h, d/2, 2) -> (b, s, h, d)
155
160
  return x_out.type_as(x)
156
161
 
157
162
 
@@ -162,7 +167,6 @@ class QwenDoubleStreamAttention(nn.Module):
162
167
  dim_b,
163
168
  num_heads,
164
169
  head_dim,
165
- attn_kwargs: Optional[Dict[str, Any]] = None,
166
170
  device: str = "cuda:0",
167
171
  dtype: torch.dtype = torch.bfloat16,
168
172
  ):
@@ -184,44 +188,42 @@ class QwenDoubleStreamAttention(nn.Module):
184
188
 
185
189
  self.to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
186
190
  self.to_add_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
187
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
188
191
 
189
192
  def forward(
190
193
  self,
191
194
  image: torch.FloatTensor,
192
195
  text: torch.FloatTensor,
193
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
196
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
197
+ attn_mask: Optional[torch.Tensor] = None,
198
+ attn_kwargs: Optional[Dict[str, Any]] = None,
194
199
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
195
200
  img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
196
201
  txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
197
202
 
198
- img_q = rearrange(img_q, "b s (h d) -> b h s d", h=self.num_heads)
199
- img_k = rearrange(img_k, "b s (h d) -> b h s d", h=self.num_heads)
200
- img_v = rearrange(img_v, "b s (h d) -> b h s d", h=self.num_heads)
203
+ img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads)
204
+ img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads)
205
+ img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads)
201
206
 
202
- txt_q = rearrange(txt_q, "b s (h d) -> b h s d", h=self.num_heads)
203
- txt_k = rearrange(txt_k, "b s (h d) -> b h s d", h=self.num_heads)
204
- txt_v = rearrange(txt_v, "b s (h d) -> b h s d", h=self.num_heads)
207
+ txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads)
208
+ txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads)
209
+ txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads)
205
210
 
206
211
  img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
207
212
  txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
208
213
 
209
- if image_rotary_emb is not None:
210
- img_freqs, txt_freqs = image_rotary_emb
214
+ if rotary_emb is not None:
215
+ img_freqs, txt_freqs = rotary_emb
211
216
  img_q = apply_rotary_emb_qwen(img_q, img_freqs)
212
217
  img_k = apply_rotary_emb_qwen(img_k, img_freqs)
213
218
  txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
214
219
  txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
215
220
 
216
- joint_q = torch.cat([txt_q, img_q], dim=2)
217
- joint_k = torch.cat([txt_k, img_k], dim=2)
218
- joint_v = torch.cat([txt_v, img_v], dim=2)
221
+ joint_q = torch.cat([txt_q, img_q], dim=1)
222
+ joint_k = torch.cat([txt_k, img_k], dim=1)
223
+ joint_v = torch.cat([txt_v, img_v], dim=1)
219
224
 
220
- joint_q = joint_q.transpose(1, 2)
221
- joint_k = joint_k.transpose(1, 2)
222
- joint_v = joint_v.transpose(1, 2)
223
-
224
- joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, **self.attn_kwargs)
225
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
226
+ joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)
225
227
 
226
228
  joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
227
229
 
@@ -241,7 +243,6 @@ class QwenImageTransformerBlock(nn.Module):
241
243
  num_attention_heads: int,
242
244
  attention_head_dim: int,
243
245
  eps: float = 1e-6,
244
- attn_kwargs: Optional[Dict[str, Any]] = None,
245
246
  device: str = "cuda:0",
246
247
  dtype: torch.dtype = torch.bfloat16,
247
248
  ):
@@ -261,7 +262,6 @@ class QwenImageTransformerBlock(nn.Module):
261
262
  dim_b=dim,
262
263
  num_heads=num_attention_heads,
263
264
  head_dim=attention_head_dim,
264
- attn_kwargs=attn_kwargs,
265
265
  device=device,
266
266
  dtype=dtype,
267
267
  )
@@ -285,7 +285,9 @@ class QwenImageTransformerBlock(nn.Module):
285
285
  image: torch.Tensor,
286
286
  text: torch.Tensor,
287
287
  temb: torch.Tensor,
288
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
288
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
289
+ attn_mask: Optional[torch.Tensor] = None,
290
+ attn_kwargs: Optional[Dict[str, Any]] = None,
289
291
  ) -> Tuple[torch.Tensor, torch.Tensor]:
290
292
  img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
291
293
  txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
@@ -299,7 +301,9 @@ class QwenImageTransformerBlock(nn.Module):
299
301
  img_attn_out, txt_attn_out = self.attn(
300
302
  image=img_modulated,
301
303
  text=txt_modulated,
302
- image_rotary_emb=image_rotary_emb,
304
+ rotary_emb=rotary_emb,
305
+ attn_mask=attn_mask,
306
+ attn_kwargs=attn_kwargs,
303
307
  )
304
308
 
305
309
  image = image + img_gate * img_attn_out
@@ -327,7 +331,6 @@ class QwenImageDiT(PreTrainedModel):
327
331
  def __init__(
328
332
  self,
329
333
  num_layers: int = 60,
330
- attn_kwargs: Optional[Dict[str, Any]] = None,
331
334
  device: str = "cuda:0",
332
335
  dtype: torch.dtype = torch.bfloat16,
333
336
  ):
@@ -348,7 +351,6 @@ class QwenImageDiT(PreTrainedModel):
348
351
  dim=3072,
349
352
  num_attention_heads=24,
350
353
  attention_head_dim=128,
351
- attn_kwargs=attn_kwargs,
352
354
  device=device,
353
355
  dtype=dtype,
354
356
  )
@@ -368,13 +370,75 @@ class QwenImageDiT(PreTrainedModel):
368
370
  )
369
371
  return hidden_states
370
372
 
373
+ def process_entity_masks(
374
+ self,
375
+ text: torch.Tensor,
376
+ text_seq_lens: torch.LongTensor,
377
+ rotary_emb: Tuple[torch.Tensor, torch.Tensor],
378
+ video_fhw: List[Tuple[int, int, int]],
379
+ entity_text: List[torch.Tensor],
380
+ entity_seq_lens: List[torch.LongTensor],
381
+ entity_masks: List[torch.Tensor],
382
+ device: str,
383
+ dtype: torch.dtype,
384
+ ):
385
+ entity_seq_lens = [seq_lens.max().item() for seq_lens in entity_seq_lens]
386
+ text_seq_lens = entity_seq_lens + [text_seq_lens.max().item()]
387
+ entity_text = [
388
+ self.txt_in(self.txt_norm(text[:, :seq_len])) for text, seq_len in zip(entity_text, entity_seq_lens)
389
+ ]
390
+ text = torch.cat(entity_text + [text], dim=1)
391
+
392
+ entity_txt_freqs = [self.pos_embed(video_fhw, seq_len, device)[1] for seq_len in entity_seq_lens]
393
+ img_freqs, txt_freqs = rotary_emb
394
+ txt_freqs = torch.cat(entity_txt_freqs + [txt_freqs], dim=0)
395
+ rotary_emb = (img_freqs, txt_freqs)
396
+
397
+ global_mask = torch.ones_like(entity_masks[0], device=device, dtype=dtype)
398
+ patched_masks = [self.patchify(mask) for mask in entity_masks + [global_mask]]
399
+ batch_size, image_seq_len = patched_masks[0].shape[:2]
400
+ total_seq_len = sum(text_seq_lens) + image_seq_len
401
+ attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), device=device, dtype=torch.bool)
402
+
403
+ # text-image attention mask
404
+ img_start, img_end = sum(text_seq_lens), total_seq_len
405
+ cumsum = [0]
406
+ for seq_len in text_seq_lens:
407
+ cumsum.append(cumsum[-1] + seq_len)
408
+ for i, patched_mask in enumerate(patched_masks):
409
+ txt_start, txt_end = cumsum[i], cumsum[i + 1]
410
+ mask = torch.sum(patched_mask, dim=-1) > 0
411
+ mask = mask.unsqueeze(1).repeat(1, text_seq_lens[i], 1)
412
+ # text-to-image attention
413
+ attention_mask[:, txt_start:txt_end, img_start:img_end] = mask
414
+ # image-to-text attention
415
+ attention_mask[:, img_start:img_end, txt_start:txt_end] = mask.transpose(1, 2)
416
+ # entity text tokens should not attend to each other
417
+ for i in range(len(text_seq_lens)):
418
+ for j in range(len(text_seq_lens)):
419
+ if i == j:
420
+ continue
421
+ i_start, i_end = cumsum[i], cumsum[i + 1]
422
+ j_start, j_end = cumsum[j], cumsum[j + 1]
423
+ attention_mask[:, i_start:i_end, j_start:j_end] = False
424
+
425
+ attn_mask = torch.zeros_like(attention_mask, device=device, dtype=dtype)
426
+ attn_mask[~attention_mask] = -torch.inf
427
+ attn_mask = attn_mask.unsqueeze(1)
428
+ return text, rotary_emb, attn_mask
429
+
371
430
  def forward(
372
431
  self,
373
432
  image: torch.Tensor,
374
433
  edit: torch.Tensor = None,
375
- text: torch.Tensor = None,
376
434
  timestep: torch.LongTensor = None,
377
- txt_seq_lens: torch.LongTensor = None,
435
+ text: torch.Tensor = None,
436
+ text_seq_lens: torch.LongTensor = None,
437
+ context_latents: Optional[torch.Tensor] = None,
438
+ entity_text: Optional[List[torch.Tensor]] = None,
439
+ entity_seq_lens: Optional[List[torch.LongTensor]] = None,
440
+ entity_masks: Optional[List[torch.Tensor]] = None,
441
+ attn_kwargs: Optional[Dict[str, Any]] = None,
378
442
  ):
379
443
  h, w = image.shape[-2:]
380
444
  fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
@@ -385,37 +449,72 @@ class QwenImageDiT(PreTrainedModel):
385
449
  cfg_parallel(
386
450
  (
387
451
  image,
388
- edit,
389
- text,
452
+ *(edit if edit is not None else ()),
390
453
  timestep,
391
- txt_seq_lens,
454
+ text,
455
+ text_seq_lens,
456
+ *(entity_text if entity_text is not None else ()),
457
+ *(entity_seq_lens if entity_seq_lens is not None else ()),
458
+ *(entity_masks if entity_masks is not None else ()),
459
+ context_latents,
392
460
  ),
393
461
  use_cfg=use_cfg,
394
462
  ),
395
463
  ):
396
464
  conditioning = self.time_text_embed(timestep, image.dtype)
397
465
  video_fhw = [(1, h // 2, w // 2)] # frame, height, width
398
- max_length = txt_seq_lens.max().item()
466
+ text_seq_len = text_seq_lens.max().item()
399
467
  image = self.patchify(image)
400
468
  image_seq_len = image.shape[1]
469
+ if context_latents is not None:
470
+ context_latents = context_latents.to(dtype=image.dtype)
471
+ context_latents = self.patchify(context_latents)
472
+ image = torch.cat([image, context_latents], dim=1)
473
+ video_fhw += [(1, h // 2, w // 2)]
401
474
  if edit is not None:
402
- edit = edit.to(dtype=image.dtype)
403
- edit = self.patchify(edit)
404
- image = torch.cat([image, edit], dim=1)
405
- video_fhw += video_fhw
475
+ for img in edit:
476
+ img = img.to(dtype=image.dtype)
477
+ edit_h, edit_w = img.shape[-2:]
478
+ img = self.patchify(img)
479
+ image = torch.cat([image, img], dim=1)
480
+ video_fhw += [(1, edit_h // 2, edit_w // 2)]
406
481
 
407
- image_rotary_emb = self.pos_embed(video_fhw, max_length, image.device)
482
+ rotary_emb = self.pos_embed(video_fhw, text_seq_len, image.device)
408
483
 
409
484
  image = self.img_in(image)
410
- text = self.txt_in(self.txt_norm(text[:, :max_length]))
485
+ text = self.txt_in(self.txt_norm(text[:, :text_seq_len]))
411
486
 
412
- for block in self.transformer_blocks:
413
- text, image = block(image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb)
414
- image = self.norm_out(image, conditioning)
415
- image = self.proj_out(image)
416
- if edit is not None:
417
- image = image[:, :image_seq_len]
487
+ attn_mask = None
488
+ if entity_text is not None:
489
+ text, rotary_emb, attn_mask = self.process_entity_masks(
490
+ text,
491
+ text_seq_lens,
492
+ rotary_emb,
493
+ video_fhw,
494
+ entity_text,
495
+ entity_seq_lens,
496
+ entity_masks,
497
+ image.device,
498
+ image.dtype,
499
+ )
418
500
 
501
+ # warning: Eligen does not work with sequence parallel because long context attention does not support attention masks
502
+ img_freqs, txt_freqs = rotary_emb
503
+ with sequence_parallel((image, text, img_freqs, txt_freqs), seq_dims=(1, 1, 0, 0)):
504
+ rotary_emb = (img_freqs, txt_freqs)
505
+ for block in self.transformer_blocks:
506
+ text, image = block(
507
+ image=image,
508
+ text=text,
509
+ temb=conditioning,
510
+ rotary_emb=rotary_emb,
511
+ attn_mask=attn_mask,
512
+ attn_kwargs=attn_kwargs,
513
+ )
514
+ image = self.norm_out(image, conditioning)
515
+ image = self.proj_out(image)
516
+ (image,) = sequence_parallel_unshard((image,), seq_dims=(1,), seq_lens=(image_seq_len,))
517
+ image = image[:, :image_seq_len]
419
518
  image = self.unpatchify(image, h, w)
420
519
 
421
520
  (image,) = cfg_parallel_unshard((image,), use_cfg=use_cfg)
@@ -428,14 +527,8 @@ class QwenImageDiT(PreTrainedModel):
428
527
  device: str,
429
528
  dtype: torch.dtype,
430
529
  num_layers: int = 60,
431
- attn_kwargs: Optional[Dict[str, Any]] = None,
432
530
  ):
433
- model = cls(
434
- device="meta",
435
- dtype=dtype,
436
- num_layers=num_layers,
437
- attn_kwargs=attn_kwargs,
438
- )
531
+ model = cls(device="meta", dtype=dtype, num_layers=num_layers)
439
532
  model = model.requires_grad_(False)
440
533
  model.load_state_dict(state_dict, assign=True)
441
534
  model.to(device=device, dtype=dtype, non_blocking=True)
@@ -445,5 +538,5 @@ class QwenImageDiT(PreTrainedModel):
445
538
  for block in self.transformer_blocks:
446
539
  block.compile(*args, **kwargs)
447
540
 
448
- def get_fsdp_modules(self):
449
- return ["transformer_blocks"]
541
+ def get_fsdp_module_cls(self):
542
+ return {QwenImageTransformerBlock}
@@ -11,12 +11,11 @@ class QwenImageDiTFBCache(QwenImageDiT):
11
11
  def __init__(
12
12
  self,
13
13
  num_layers: int = 60,
14
- attn_kwargs: Optional[Dict[str, Any]] = None,
15
14
  device: str = "cuda:0",
16
15
  dtype: torch.dtype = torch.bfloat16,
17
16
  relative_l1_threshold: float = 0.05,
18
17
  ):
19
- super().__init__(num_layers=num_layers, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
18
+ super().__init__(num_layers=num_layers, device=device, dtype=dtype)
20
19
  self.relative_l1_threshold = relative_l1_threshold
21
20
  self.step_count = 0
22
21
  self.num_inference_steps = 0
@@ -43,6 +42,7 @@ class QwenImageDiTFBCache(QwenImageDiT):
43
42
  text: torch.Tensor = None,
44
43
  timestep: torch.LongTensor = None,
45
44
  txt_seq_lens: torch.LongTensor = None,
45
+ attn_kwargs: Optional[Dict[str, Any]] = None,
46
46
  ):
47
47
  h, w = image.shape[-2:]
48
48
  fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
@@ -72,7 +72,11 @@ class QwenImageDiTFBCache(QwenImageDiT):
72
72
  # first block
73
73
  original_hidden_states = image
74
74
  text, image = self.transformer_blocks[0](
75
- image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb
75
+ image=image,
76
+ text=text,
77
+ temb=conditioning,
78
+ image_rotary_emb=image_rotary_emb,
79
+ attn_kwargs=attn_kwargs,
76
80
  )
77
81
  first_hidden_states_residual = image - original_hidden_states
78
82
 
@@ -94,7 +98,13 @@ class QwenImageDiTFBCache(QwenImageDiT):
94
98
  first_hidden_states = image.clone()
95
99
 
96
100
  for block in self.transformer_blocks[1:]:
97
- text, image = block(image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb)
101
+ text, image = block(
102
+ image=image,
103
+ text=text,
104
+ temb=conditioning,
105
+ image_rotary_emb=image_rotary_emb,
106
+ attn_kwargs=attn_kwargs,
107
+ )
98
108
 
99
109
  previous_residual = image - first_hidden_states
100
110
  self.previous_residual = previous_residual
@@ -114,14 +124,12 @@ class QwenImageDiTFBCache(QwenImageDiT):
114
124
  device: str,
115
125
  dtype: torch.dtype,
116
126
  num_layers: int = 60,
117
- attn_kwargs: Optional[Dict[str, Any]] = None,
118
127
  relative_l1_threshold: float = 0.05,
119
128
  ):
120
129
  model = cls(
121
130
  device="meta",
122
131
  dtype=dtype,
123
132
  num_layers=num_layers,
124
- attn_kwargs=attn_kwargs,
125
133
  relative_l1_threshold=relative_l1_threshold,
126
134
  )
127
135
  model = model.requires_grad_(False)
@@ -12,7 +12,7 @@ from diffsynth_engine.utils.constants import QWEN_IMAGE_VAE_KEYMAP_FILE
12
12
 
13
13
  CACHE_T = 2
14
14
 
15
- with open(QWEN_IMAGE_VAE_KEYMAP_FILE, "r") as f:
15
+ with open(QWEN_IMAGE_VAE_KEYMAP_FILE, "r", encoding="utf-8") as f:
16
16
  config = json.load(f)
17
17
 
18
18
 
@@ -10,7 +10,7 @@ from diffsynth_engine.utils import logging
10
10
 
11
11
  logger = logging.get_logger(__name__)
12
12
 
13
- with open(SD_TEXT_ENCODER_CONFIG_FILE, "r") as f:
13
+ with open(SD_TEXT_ENCODER_CONFIG_FILE, "r", encoding="utf-8") as f:
14
14
  config = json.load(f)
15
15
 
16
16
 
@@ -18,7 +18,7 @@ from diffsynth_engine.utils import logging
18
18
 
19
19
  logger = logging.get_logger(__name__)
20
20
 
21
- with open(SD_UNET_CONFIG_FILE) as f:
21
+ with open(SD_UNET_CONFIG_FILE, encoding="utf-8") as f:
22
22
  config = json.load(f)
23
23
 
24
24
 
@@ -13,7 +13,7 @@ from diffsynth_engine.utils import logging
13
13
 
14
14
  logger = logging.get_logger(__name__)
15
15
 
16
- with open(SD3_DIT_CONFIG_FILE, "r") as f:
16
+ with open(SD3_DIT_CONFIG_FILE, "r", encoding="utf-8") as f:
17
17
  config = json.load(f)
18
18
 
19
19
 
@@ -11,7 +11,7 @@ from diffsynth_engine.utils import logging
11
11
 
12
12
  logger = logging.get_logger(__name__)
13
13
 
14
- with open(SD3_TEXT_ENCODER_CONFIG_FILE, "r") as f:
14
+ with open(SD3_TEXT_ENCODER_CONFIG_FILE, "r", encoding="utf-8") as f:
15
15
  config = json.load(f)
16
16
 
17
17
 
@@ -10,7 +10,7 @@ from diffsynth_engine.utils import logging
10
10
 
11
11
  logger = logging.get_logger(__name__)
12
12
 
13
- with open(SDXL_TEXT_ENCODER_CONFIG_FILE, "r") as f:
13
+ with open(SDXL_TEXT_ENCODER_CONFIG_FILE, "r", encoding="utf-8") as f:
14
14
  config = json.load(f)
15
15
 
16
16
 
@@ -18,7 +18,7 @@ from diffsynth_engine.utils import logging
18
18
 
19
19
  logger = logging.get_logger(__name__)
20
20
 
21
- with open(SDXL_UNET_CONFIG_FILE, "r") as f:
21
+ with open(SDXL_UNET_CONFIG_FILE, "r", encoding="utf-8") as f:
22
22
  config = json.load(f)
23
23
 
24
24
 
@@ -12,7 +12,7 @@ from diffsynth_engine.utils import logging
12
12
 
13
13
  logger = logging.get_logger(__name__)
14
14
 
15
- with open(VAE_CONFIG_FILE, "r") as f:
15
+ with open(VAE_CONFIG_FILE, "r", encoding="utf-8") as f:
16
16
  config = json.load(f)
17
17
 
18
18
 
@@ -223,7 +223,6 @@ class Wav2Vec2StateDictConverter:
223
223
 
224
224
  class Wav2Vec2Model(PreTrainedModel):
225
225
  converter = Wav2Vec2StateDictConverter()
226
- _supports_parallelization = False
227
226
 
228
227
  def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
229
228
  super().__init__()
@@ -267,9 +266,13 @@ def linear_interpolation(features: torch.Tensor, input_fps: int, output_fps: int
267
266
  return output_features.transpose(1, 2) # [1, output_len, 512]
268
267
 
269
268
 
270
- def extract_audio_feat(audio_input: torch.Tensor, model: Wav2Vec2Model, dtype=torch.float32, device="cuda:0") -> torch.Tensor:
269
+ def extract_audio_feat(
270
+ audio_input: torch.Tensor, model: Wav2Vec2Model, dtype=torch.float32, device="cuda:0"
271
+ ) -> torch.Tensor:
271
272
  video_rate = 30
272
- input_values = (audio_input - audio_input.mean(dim=1, keepdim=True)) / torch.sqrt(audio_input.var(dim=1, keepdim=True) + 1e-7)
273
+ input_values = (audio_input - audio_input.mean(dim=1, keepdim=True)) / torch.sqrt(
274
+ audio_input.var(dim=1, keepdim=True) + 1e-7
275
+ )
273
276
  feat = torch.cat(model(input_values.to(device)))
274
277
  feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
275
278
  return feat.to(dtype) # Encoding for the motion