diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. diffsynth_engine/__init__.py +12 -0
  2. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
  3. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
  4. diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
  5. diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
  6. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  7. diffsynth_engine/configs/__init__.py +16 -1
  8. diffsynth_engine/configs/controlnet.py +13 -0
  9. diffsynth_engine/configs/pipeline.py +37 -11
  10. diffsynth_engine/models/base.py +1 -1
  11. diffsynth_engine/models/basic/attention.py +105 -43
  12. diffsynth_engine/models/basic/transformer_helper.py +36 -2
  13. diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
  14. diffsynth_engine/models/flux/flux_controlnet.py +16 -30
  15. diffsynth_engine/models/flux/flux_dit.py +49 -62
  16. diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
  17. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  18. diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
  19. diffsynth_engine/models/flux/flux_vae.py +20 -2
  20. diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
  21. diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
  22. diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
  23. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  24. diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
  25. diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
  26. diffsynth_engine/models/sd/sd_unet.py +1 -1
  27. diffsynth_engine/models/sd3/sd3_dit.py +1 -1
  28. diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
  29. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
  30. diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
  31. diffsynth_engine/models/vae/vae.py +1 -1
  32. diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
  33. diffsynth_engine/models/wan/wan_dit.py +65 -28
  34. diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
  35. diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
  36. diffsynth_engine/models/wan/wan_vae.py +2 -2
  37. diffsynth_engine/pipelines/base.py +73 -7
  38. diffsynth_engine/pipelines/flux_image.py +139 -120
  39. diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
  40. diffsynth_engine/pipelines/qwen_image.py +272 -87
  41. diffsynth_engine/pipelines/sdxl_image.py +1 -1
  42. diffsynth_engine/pipelines/utils.py +52 -0
  43. diffsynth_engine/pipelines/wan_s2v.py +25 -14
  44. diffsynth_engine/pipelines/wan_video.py +43 -19
  45. diffsynth_engine/tokenizers/base.py +6 -0
  46. diffsynth_engine/tokenizers/qwen2.py +12 -4
  47. diffsynth_engine/utils/constants.py +13 -12
  48. diffsynth_engine/utils/download.py +4 -2
  49. diffsynth_engine/utils/env.py +2 -0
  50. diffsynth_engine/utils/flag.py +6 -0
  51. diffsynth_engine/utils/loader.py +25 -6
  52. diffsynth_engine/utils/parallel.py +62 -29
  53. diffsynth_engine/utils/video.py +3 -1
  54. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
  55. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
  56. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  57. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  58. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  59. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  60. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  61. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  62. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  63. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  64. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  65. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  66. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  67. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
  68. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
  69. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import json
2
2
  import torch
3
3
  import torch.nn as nn
4
4
  import numpy as np
5
- from typing import Any, Dict, Optional
5
+ from typing import Any, Dict, List, Optional
6
6
  from einops import rearrange
7
7
 
8
8
  from diffsynth_engine.models.basic.transformer_helper import (
@@ -28,7 +28,7 @@ from diffsynth_engine.utils import logging
28
28
 
29
29
  logger = logging.get_logger(__name__)
30
30
 
31
- with open(FLUX_DIT_CONFIG_FILE, "r") as f:
31
+ with open(FLUX_DIT_CONFIG_FILE, "r", encoding="utf-8") as f:
32
32
  config = json.load(f)
33
33
 
34
34
 
@@ -176,7 +176,6 @@ class FluxDoubleAttention(nn.Module):
176
176
  dim_b,
177
177
  num_heads,
178
178
  head_dim,
179
- attn_kwargs: Optional[Dict[str, Any]] = None,
180
179
  device: str = "cuda:0",
181
180
  dtype: torch.dtype = torch.bfloat16,
182
181
  ):
@@ -194,19 +193,20 @@ class FluxDoubleAttention(nn.Module):
194
193
 
195
194
  self.a_to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
196
195
  self.b_to_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
197
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
198
196
 
199
197
  def attention_callback(self, attn_out_a, attn_out_b, x_a, x_b, q_a, q_b, k_a, k_b, v_a, v_b, rope_emb, image_emb):
200
198
  return attn_out_a, attn_out_b
201
199
 
202
- def forward(self, image, text, rope_emb, image_emb):
200
+ def forward(self, image, text, rope_emb, image_emb, attn_kwargs=None):
203
201
  q_a, k_a, v_a = rearrange(self.a_to_qkv(image), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
204
202
  q_b, k_b, v_b = rearrange(self.b_to_qkv(text), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
205
203
  q = torch.cat([self.norm_q_b(q_b), self.norm_q_a(q_a)], dim=1)
206
204
  k = torch.cat([self.norm_k_b(k_b), self.norm_k_a(k_a)], dim=1)
207
205
  v = torch.cat([v_b, v_a], dim=1)
208
206
  q, k = apply_rope(q, k, rope_emb)
209
- attn_out = attention_ops.attention(q, k, v, **self.attn_kwargs)
207
+
208
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
209
+ attn_out = attention_ops.attention(q, k, v, **attn_kwargs)
210
210
  attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
211
211
  text_out, image_out = attn_out[:, : text.shape[1]], attn_out[:, text.shape[1] :]
212
212
  image_out, text_out = self.attention_callback(
@@ -231,19 +231,18 @@ class FluxDoubleTransformerBlock(nn.Module):
231
231
  self,
232
232
  dim,
233
233
  num_heads,
234
- attn_kwargs: Optional[Dict[str, Any]] = None,
235
234
  device: str = "cuda:0",
236
235
  dtype: torch.dtype = torch.bfloat16,
237
236
  ):
238
237
  super().__init__()
239
- self.attn = FluxDoubleAttention(
240
- dim, dim, num_heads, dim // num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype
241
- )
238
+ self.attn = FluxDoubleAttention(dim, dim, num_heads, dim // num_heads, device=device, dtype=dtype)
242
239
  # Image
243
240
  self.norm_msa_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
244
241
  self.norm_mlp_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
245
242
  self.ff_a = nn.Sequential(
246
- nn.Linear(dim, dim * 4), nn.GELU(approximate="tanh"), nn.Linear(dim * 4, dim, device=device, dtype=dtype)
243
+ nn.Linear(dim, dim * 4, device=device, dtype=dtype),
244
+ nn.GELU(approximate="tanh"),
245
+ nn.Linear(dim * 4, dim, device=device, dtype=dtype),
247
246
  )
248
247
  # Text
249
248
  self.norm_msa_b = AdaLayerNormZero(dim, device=device, dtype=dtype)
@@ -254,11 +253,11 @@ class FluxDoubleTransformerBlock(nn.Module):
254
253
  nn.Linear(dim * 4, dim, device=device, dtype=dtype),
255
254
  )
256
255
 
257
- def forward(self, image, text, t_emb, rope_emb, image_emb=None):
256
+ def forward(self, image, text, t_emb, rope_emb, image_emb=None, attn_kwargs=None):
258
257
  # AdaLayerNorm-Zero for Image and Text MSA
259
258
  image_in, gate_a = self.norm_msa_a(image, t_emb)
260
259
  text_in, gate_b = self.norm_msa_b(text, t_emb)
261
- image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb)
260
+ image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb, attn_kwargs)
262
261
  image = image + gate_a * image_out
263
262
  text = text + gate_b * text_out
264
263
 
@@ -277,7 +276,6 @@ class FluxSingleAttention(nn.Module):
277
276
  self,
278
277
  dim,
279
278
  num_heads,
280
- attn_kwargs: Optional[Dict[str, Any]] = None,
281
279
  device: str = "cuda:0",
282
280
  dtype: torch.dtype = torch.bfloat16,
283
281
  ):
@@ -286,15 +284,16 @@ class FluxSingleAttention(nn.Module):
286
284
  self.to_qkv = nn.Linear(dim, dim * 3, device=device, dtype=dtype)
287
285
  self.norm_q_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
288
286
  self.norm_k_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
289
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
290
287
 
291
288
  def attention_callback(self, attn_out, x, q, k, v, rope_emb, image_emb):
292
289
  return attn_out
293
290
 
294
- def forward(self, x, rope_emb, image_emb):
291
+ def forward(self, x, rope_emb, image_emb, attn_kwargs=None):
295
292
  q, k, v = rearrange(self.to_qkv(x), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
296
293
  q, k = apply_rope(self.norm_q_a(q), self.norm_k_a(k), rope_emb)
297
- attn_out = attention_ops.attention(q, k, v, **self.attn_kwargs)
294
+
295
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
296
+ attn_out = attention_ops.attention(q, k, v, **attn_kwargs)
298
297
  attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
299
298
  return self.attention_callback(attn_out=attn_out, x=x, q=q, k=k, v=v, rope_emb=rope_emb, image_emb=image_emb)
300
299
 
@@ -304,23 +303,22 @@ class FluxSingleTransformerBlock(nn.Module):
304
303
  self,
305
304
  dim,
306
305
  num_heads,
307
- attn_kwargs: Optional[Dict[str, Any]] = None,
308
306
  device: str = "cuda:0",
309
307
  dtype: torch.dtype = torch.bfloat16,
310
308
  ):
311
309
  super().__init__()
312
310
  self.dim = dim
313
311
  self.norm = AdaLayerNormZero(dim, device=device, dtype=dtype)
314
- self.attn = FluxSingleAttention(dim, num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
312
+ self.attn = FluxSingleAttention(dim, num_heads, device=device, dtype=dtype)
315
313
  self.mlp = nn.Sequential(
316
- nn.Linear(dim, dim * 4),
314
+ nn.Linear(dim, dim * 4, device=device, dtype=dtype),
317
315
  nn.GELU(approximate="tanh"),
318
316
  )
319
- self.proj_out = nn.Linear(dim * 5, dim)
317
+ self.proj_out = nn.Linear(dim * 5, dim, device=device, dtype=dtype)
320
318
 
321
- def forward(self, x, t_emb, rope_emb, image_emb=None):
319
+ def forward(self, x, t_emb, rope_emb, image_emb=None, attn_kwargs=None):
322
320
  h, gate = self.norm(x, emb=t_emb)
323
- attn_output = self.attn(h, rope_emb, image_emb)
321
+ attn_output = self.attn(h, rope_emb, image_emb, attn_kwargs)
324
322
  mlp_output = self.mlp(h)
325
323
  return x + gate * self.proj_out(torch.cat([attn_output, mlp_output], dim=2))
326
324
 
@@ -332,7 +330,6 @@ class FluxDiT(PreTrainedModel):
332
330
  def __init__(
333
331
  self,
334
332
  in_channel: int = 64,
335
- attn_kwargs: Optional[Dict[str, Any]] = None,
336
333
  device: str = "cuda:0",
337
334
  dtype: torch.dtype = torch.bfloat16,
338
335
  ):
@@ -350,16 +347,10 @@ class FluxDiT(PreTrainedModel):
350
347
  self.x_embedder = nn.Linear(in_channel, 3072, device=device, dtype=dtype)
351
348
 
352
349
  self.blocks = nn.ModuleList(
353
- [
354
- FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
355
- for _ in range(19)
356
- ]
350
+ [FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(19)]
357
351
  )
358
352
  self.single_blocks = nn.ModuleList(
359
- [
360
- FluxSingleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
361
- for _ in range(38)
362
- ]
353
+ [FluxSingleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(38)]
363
354
  )
364
355
  self.final_norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
365
356
  self.final_proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
@@ -393,21 +384,20 @@ class FluxDiT(PreTrainedModel):
393
384
 
394
385
  def forward(
395
386
  self,
396
- hidden_states,
397
- timestep,
398
- prompt_emb,
399
- pooled_prompt_emb,
400
- image_emb,
401
- guidance,
402
- text_ids,
403
- image_ids=None,
404
- controlnet_double_block_output=None,
405
- controlnet_single_block_output=None,
387
+ hidden_states: torch.Tensor,
388
+ timestep: torch.Tensor,
389
+ prompt_emb: torch.Tensor,
390
+ pooled_prompt_emb: torch.Tensor,
391
+ image_ids: torch.Tensor,
392
+ text_ids: torch.Tensor,
393
+ guidance: torch.Tensor,
394
+ image_emb: torch.Tensor | None = None,
395
+ attn_kwargs: Optional[Dict[str, Any]] = None,
396
+ controlnet_double_block_output: List[torch.Tensor] | None = None,
397
+ controlnet_single_block_output: List[torch.Tensor] | None = None,
406
398
  **kwargs,
407
399
  ):
408
- h, w = hidden_states.shape[-2:]
409
- if image_ids is None:
410
- image_ids = self.prepare_image_ids(hidden_states)
400
+ image_seq_len = hidden_states.shape[1]
411
401
  controlnet_double_block_output = (
412
402
  controlnet_double_block_output if controlnet_double_block_output is not None else ()
413
403
  )
@@ -426,10 +416,10 @@ class FluxDiT(PreTrainedModel):
426
416
  timestep,
427
417
  prompt_emb,
428
418
  pooled_prompt_emb,
429
- image_emb,
430
- guidance,
431
- text_ids,
432
419
  image_ids,
420
+ text_ids,
421
+ guidance,
422
+ image_emb,
433
423
  *controlnet_double_block_output,
434
424
  *controlnet_single_block_output,
435
425
  ),
@@ -446,7 +436,6 @@ class FluxDiT(PreTrainedModel):
446
436
  rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
447
437
  text_rope_emb = rope_emb[:, :, : text_ids.size(1)]
448
438
  image_rope_emb = rope_emb[:, :, text_ids.size(1) :]
449
- hidden_states = self.patchify(hidden_states)
450
439
 
451
440
  with sequence_parallel(
452
441
  (
@@ -471,14 +460,16 @@ class FluxDiT(PreTrainedModel):
471
460
  rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2)
472
461
 
473
462
  for i, block in enumerate(self.blocks):
474
- hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
463
+ hidden_states, prompt_emb = block(
464
+ hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
465
+ )
475
466
  if len(controlnet_double_block_output) > 0:
476
467
  interval_control = len(self.blocks) / len(controlnet_double_block_output)
477
468
  interval_control = int(np.ceil(interval_control))
478
469
  hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
479
470
  hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
480
471
  for i, block in enumerate(self.single_blocks):
481
- hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
472
+ hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs)
482
473
  if len(controlnet_single_block_output) > 0:
483
474
  interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
484
475
  interval_control = int(np.ceil(interval_control))
@@ -487,9 +478,8 @@ class FluxDiT(PreTrainedModel):
487
478
  hidden_states = hidden_states[:, prompt_emb.shape[1] :]
488
479
  hidden_states = self.final_norm_out(hidden_states, conditioning)
489
480
  hidden_states = self.final_proj_out(hidden_states)
490
- (hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,))
481
+ (hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(image_seq_len,))
491
482
 
492
- hidden_states = self.unpatchify(hidden_states, h, w)
493
483
  (hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg)
494
484
  return hidden_states
495
485
 
@@ -500,14 +490,8 @@ class FluxDiT(PreTrainedModel):
500
490
  device: str,
501
491
  dtype: torch.dtype,
502
492
  in_channel: int = 64,
503
- attn_kwargs: Optional[Dict[str, Any]] = None,
504
493
  ):
505
- model = cls(
506
- device="meta",
507
- dtype=dtype,
508
- in_channel=in_channel,
509
- attn_kwargs=attn_kwargs,
510
- )
494
+ model = cls(device="meta", dtype=dtype, in_channel=in_channel)
511
495
  model = model.requires_grad_(False)
512
496
  model.load_state_dict(state_dict, assign=True)
513
497
  model.to(device=device, dtype=dtype, non_blocking=True)
@@ -517,5 +501,8 @@ class FluxDiT(PreTrainedModel):
517
501
  for block in self.blocks:
518
502
  block.compile(*args, **kwargs)
519
503
 
520
- def get_fsdp_modules(self):
521
- return ["blocks", "single_blocks"]
504
+ for block in self.single_blocks:
505
+ block.compile(*args, **kwargs)
506
+
507
+ def get_fsdp_module_cls(self):
508
+ return {FluxDoubleTransformerBlock, FluxSingleTransformerBlock}
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
  import numpy as np
3
- from typing import Any, Dict, Optional
3
+ from typing import Any, Dict, List, Optional
4
4
 
5
5
  from diffsynth_engine.utils.gguf import gguf_inference
6
6
  from diffsynth_engine.utils.fp8_linear import fp8_inference
@@ -20,12 +20,11 @@ class FluxDiTFBCache(FluxDiT):
20
20
  def __init__(
21
21
  self,
22
22
  in_channel: int = 64,
23
- attn_kwargs: Optional[Dict[str, Any]] = None,
24
23
  device: str = "cuda:0",
25
24
  dtype: torch.dtype = torch.bfloat16,
26
25
  relative_l1_threshold: float = 0.05,
27
26
  ):
28
- super().__init__(in_channel=in_channel, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
27
+ super().__init__(in_channel=in_channel, device=device, dtype=dtype)
29
28
  self.relative_l1_threshold = relative_l1_threshold
30
29
  self.step_count = 0
31
30
  self.num_inference_steps = 0
@@ -48,21 +47,20 @@ class FluxDiTFBCache(FluxDiT):
48
47
 
49
48
  def forward(
50
49
  self,
51
- hidden_states,
52
- timestep,
53
- prompt_emb,
54
- pooled_prompt_emb,
55
- image_emb,
56
- guidance,
57
- text_ids,
58
- image_ids=None,
59
- controlnet_double_block_output=None,
60
- controlnet_single_block_output=None,
50
+ hidden_states: torch.Tensor,
51
+ timestep: torch.Tensor,
52
+ prompt_emb: torch.Tensor,
53
+ pooled_prompt_emb: torch.Tensor,
54
+ image_ids: torch.Tensor,
55
+ text_ids: torch.Tensor,
56
+ guidance: torch.Tensor,
57
+ image_emb: torch.Tensor | None = None,
58
+ attn_kwargs: Optional[Dict[str, Any]] = None,
59
+ controlnet_double_block_output: List[torch.Tensor] | None = None,
60
+ controlnet_single_block_output: List[torch.Tensor] | None = None,
61
61
  **kwargs,
62
62
  ):
63
- h, w = hidden_states.shape[-2:]
64
- if image_ids is None:
65
- image_ids = self.prepare_image_ids(hidden_states)
63
+ image_seq_len = hidden_states.shape[1]
66
64
  controlnet_double_block_output = (
67
65
  controlnet_double_block_output if controlnet_double_block_output is not None else ()
68
66
  )
@@ -81,10 +79,10 @@ class FluxDiTFBCache(FluxDiT):
81
79
  timestep,
82
80
  prompt_emb,
83
81
  pooled_prompt_emb,
84
- image_emb,
85
- guidance,
86
- text_ids,
87
82
  image_ids,
83
+ text_ids,
84
+ guidance,
85
+ image_emb,
88
86
  *controlnet_double_block_output,
89
87
  *controlnet_single_block_output,
90
88
  ),
@@ -101,7 +99,6 @@ class FluxDiTFBCache(FluxDiT):
101
99
  rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
102
100
  text_rope_emb = rope_emb[:, :, : text_ids.size(1)]
103
101
  image_rope_emb = rope_emb[:, :, text_ids.size(1) :]
104
- hidden_states = self.patchify(hidden_states)
105
102
 
106
103
  with sequence_parallel(
107
104
  (
@@ -127,11 +124,13 @@ class FluxDiTFBCache(FluxDiT):
127
124
 
128
125
  # first block
129
126
  original_hidden_states = hidden_states
130
- hidden_states, prompt_emb = self.blocks[0](hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
127
+ hidden_states, prompt_emb = self.blocks[0](
128
+ hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
129
+ )
131
130
  first_hidden_states_residual = hidden_states - original_hidden_states
132
131
 
133
132
  (first_hidden_states_residual,) = sequence_parallel_unshard(
134
- (first_hidden_states_residual,), seq_dims=(1,), seq_lens=(h * w // 4,)
133
+ (first_hidden_states_residual,), seq_dims=(1,), seq_lens=(image_seq_len,)
135
134
  )
136
135
 
137
136
  if self.step_count == 0 or self.step_count == (self.num_inference_steps - 1):
@@ -152,14 +151,16 @@ class FluxDiTFBCache(FluxDiT):
152
151
 
153
152
  first_hidden_states = hidden_states.clone()
154
153
  for i, block in enumerate(self.blocks[1:]):
155
- hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
154
+ hidden_states, prompt_emb = block(
155
+ hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
156
+ )
156
157
  if len(controlnet_double_block_output) > 0:
157
158
  interval_control = len(self.blocks) / len(controlnet_double_block_output)
158
159
  interval_control = int(np.ceil(interval_control))
159
160
  hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
160
161
  hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
161
162
  for i, block in enumerate(self.single_blocks):
162
- hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
163
+ hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs)
163
164
  if len(controlnet_single_block_output) > 0:
164
165
  interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
165
166
  interval_control = int(np.ceil(interval_control))
@@ -172,9 +173,8 @@ class FluxDiTFBCache(FluxDiT):
172
173
 
173
174
  hidden_states = self.final_norm_out(hidden_states, conditioning)
174
175
  hidden_states = self.final_proj_out(hidden_states)
175
- (hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,))
176
+ (hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(image_seq_len,))
176
177
 
177
- hidden_states = self.unpatchify(hidden_states, h, w)
178
178
  (hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg)
179
179
 
180
180
  return hidden_states
@@ -186,14 +186,12 @@ class FluxDiTFBCache(FluxDiT):
186
186
  device: str,
187
187
  dtype: torch.dtype,
188
188
  in_channel: int = 64,
189
- attn_kwargs: Optional[Dict[str, Any]] = None,
190
189
  relative_l1_threshold: float = 0.05,
191
190
  ):
192
191
  model = cls(
193
192
  device="meta",
194
193
  dtype=dtype,
195
194
  in_channel=in_channel,
196
- attn_kwargs=attn_kwargs,
197
195
  relative_l1_threshold=relative_l1_threshold,
198
196
  )
199
197
  model = model.requires_grad_(False)
@@ -2,7 +2,7 @@ import torch
2
2
  from einops import rearrange
3
3
  from torch import nn
4
4
  from PIL import Image
5
- from typing import Any, Dict, List, Optional
5
+ from typing import Dict, List
6
6
  from functools import partial
7
7
  from diffsynth_engine.models.text_encoder.siglip import SiglipImageEncoder
8
8
  from diffsynth_engine.models.basic.transformer_helper import RMSNorm
@@ -18,7 +18,6 @@ class FluxIPAdapterAttention(nn.Module):
18
18
  dim: int = 3072,
19
19
  head_num: int = 24,
20
20
  scale: float = 1.0,
21
- attn_kwargs: Optional[Dict[str, Any]] = None,
22
21
  device: str = "cuda:0",
23
22
  dtype: torch.dtype = torch.bfloat16,
24
23
  ):
@@ -28,12 +27,13 @@ class FluxIPAdapterAttention(nn.Module):
28
27
  self.to_v_ip = nn.Linear(image_emb_dim, dim, device=device, dtype=dtype, bias=False)
29
28
  self.head_num = head_num
30
29
  self.scale = scale
31
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
32
30
 
33
- def forward(self, query: torch.Tensor, image_emb: torch.Tensor):
31
+ def forward(self, query: torch.Tensor, image_emb: torch.Tensor, attn_kwargs=None):
34
32
  key = rearrange(self.norm_k(self.to_k_ip(image_emb)), "b s (h d) -> b s h d", h=self.head_num)
35
33
  value = rearrange(self.to_v_ip(image_emb), "b s (h d) -> b s h d", h=self.head_num)
36
- attn_out = attention(query, key, value, **self.attn_kwargs)
34
+
35
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
36
+ attn_out = attention(query, key, value, **attn_kwargs)
37
37
  return self.scale * rearrange(attn_out, "b s h d -> b s (h d)")
38
38
 
39
39
  @classmethod
@@ -10,7 +10,7 @@ from diffsynth_engine.utils import logging
10
10
 
11
11
  logger = logging.get_logger(__name__)
12
12
 
13
- with open(FLUX_TEXT_ENCODER_CONFIG_FILE, "r") as f:
13
+ with open(FLUX_TEXT_ENCODER_CONFIG_FILE, "r", encoding="utf-8") as f:
14
14
  config = json.load(f)
15
15
 
16
16
 
@@ -8,7 +8,7 @@ from diffsynth_engine.utils import logging
8
8
 
9
9
  logger = logging.get_logger(__name__)
10
10
 
11
- with open(FLUX_VAE_CONFIG_FILE, "r") as f:
11
+ with open(FLUX_VAE_CONFIG_FILE, "r", encoding="utf-8") as f:
12
12
  config = json.load(f)
13
13
 
14
14
 
@@ -25,11 +25,29 @@ class FluxVAEStateDictConverter(VAEStateDictConverter):
25
25
  new_state_dict[name_] = param
26
26
  return new_state_dict
27
27
 
28
+ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
29
+ rename_dict = config["diffusers"]["rename_dict"]
30
+ new_state_dict = {}
31
+ for name, param in state_dict.items():
32
+ if name not in rename_dict:
33
+ continue
34
+ name_ = rename_dict[name]
35
+ if "transformer_blocks" in name_:
36
+ param = param.squeeze()
37
+ new_state_dict[name_] = param
38
+ return new_state_dict
39
+
28
40
  def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
29
41
  assert self.has_decoder or self.has_encoder, "Either decoder or encoder must be present"
30
- if "decoder.conv_in.weight" in state_dict or "encoder.conv_in.weight" in state_dict:
42
+ if "decoder.up.0.block.0.conv1.weight" in state_dict or "encoder.down.0.block.0.conv1.weight" in state_dict:
31
43
  state_dict = self._from_civitai(state_dict)
32
44
  logger.info("use civitai format state dict")
45
+ elif (
46
+ "decoder.up_blocks.0.resnets.0.conv1.weight" in state_dict
47
+ or "encoder.down_blocks.0.resnets.0.conv1.weight" in state_dict
48
+ ):
49
+ state_dict = self._from_diffusers(state_dict)
50
+ logger.info("use diffusers format state dict")
33
51
  else:
34
52
  logger.info("use diffsynth format state dict")
35
53
  return self._filter(state_dict)
@@ -2,7 +2,7 @@ import torch.nn as nn
2
2
  import torchvision.transforms as transforms
3
3
  import collections.abc
4
4
  import math
5
- from typing import Optional, Tuple, Dict
5
+ from typing import Optional, Dict
6
6
 
7
7
  import torch
8
8
  from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
@@ -112,7 +112,9 @@ class Dinov2SelfAttention(nn.Module):
112
112
  def __init__(self, hidden_size: int, num_attention_heads: int, qkv_bias: bool) -> None:
113
113
  super().__init__()
114
114
  if hidden_size % num_attention_heads != 0:
115
- raise ValueError(f"hidden_size {hidden_size} is not a multiple of num_attention_heads {num_attention_heads}.")
115
+ raise ValueError(
116
+ f"hidden_size {hidden_size} is not a multiple of num_attention_heads {num_attention_heads}."
117
+ )
116
118
 
117
119
  self.num_attention_heads = num_attention_heads
118
120
  self.attention_head_size = int(hidden_size / num_attention_heads)
@@ -942,6 +942,8 @@ class Qwen2_5_VLModel(nn.Module):
942
942
 
943
943
 
944
944
  class Qwen2_5_VLForConditionalGeneration(PreTrainedModel):
945
+ _supports_parallelization = True
946
+
945
947
  def __init__(
946
948
  self,
947
949
  vision_config: Qwen2_5_VLVisionConfig,
@@ -1173,6 +1175,9 @@ class Qwen2_5_VLForConditionalGeneration(PreTrainedModel):
1173
1175
 
1174
1176
  return position_ids, mrope_position_deltas
1175
1177
 
1178
+ def get_fsdp_module_cls(self):
1179
+ return {Qwen2_5_VisionBlock, Qwen2_5_VLDecoderLayer}
1180
+
1176
1181
  def forward(
1177
1182
  self,
1178
1183
  input_ids: Optional[torch.LongTensor] = None,