diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. diffsynth_engine/__init__.py +12 -0
  2. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
  3. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
  4. diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
  5. diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
  6. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  7. diffsynth_engine/configs/__init__.py +16 -1
  8. diffsynth_engine/configs/controlnet.py +13 -0
  9. diffsynth_engine/configs/pipeline.py +37 -11
  10. diffsynth_engine/models/base.py +1 -1
  11. diffsynth_engine/models/basic/attention.py +105 -43
  12. diffsynth_engine/models/basic/transformer_helper.py +36 -2
  13. diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
  14. diffsynth_engine/models/flux/flux_controlnet.py +16 -30
  15. diffsynth_engine/models/flux/flux_dit.py +49 -62
  16. diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
  17. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  18. diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
  19. diffsynth_engine/models/flux/flux_vae.py +20 -2
  20. diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
  21. diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
  22. diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
  23. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  24. diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
  25. diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
  26. diffsynth_engine/models/sd/sd_unet.py +1 -1
  27. diffsynth_engine/models/sd3/sd3_dit.py +1 -1
  28. diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
  29. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
  30. diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
  31. diffsynth_engine/models/vae/vae.py +1 -1
  32. diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
  33. diffsynth_engine/models/wan/wan_dit.py +65 -28
  34. diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
  35. diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
  36. diffsynth_engine/models/wan/wan_vae.py +2 -2
  37. diffsynth_engine/pipelines/base.py +73 -7
  38. diffsynth_engine/pipelines/flux_image.py +139 -120
  39. diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
  40. diffsynth_engine/pipelines/qwen_image.py +272 -87
  41. diffsynth_engine/pipelines/sdxl_image.py +1 -1
  42. diffsynth_engine/pipelines/utils.py +52 -0
  43. diffsynth_engine/pipelines/wan_s2v.py +25 -14
  44. diffsynth_engine/pipelines/wan_video.py +43 -19
  45. diffsynth_engine/tokenizers/base.py +6 -0
  46. diffsynth_engine/tokenizers/qwen2.py +12 -4
  47. diffsynth_engine/utils/constants.py +13 -12
  48. diffsynth_engine/utils/download.py +4 -2
  49. diffsynth_engine/utils/env.py +2 -0
  50. diffsynth_engine/utils/flag.py +6 -0
  51. diffsynth_engine/utils/loader.py +25 -6
  52. diffsynth_engine/utils/parallel.py +62 -29
  53. diffsynth_engine/utils/video.py +3 -1
  54. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
  55. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
  56. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  57. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  58. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  59. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  60. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  61. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  62. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  63. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  64. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  65. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  66. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  67. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
  68. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
  69. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from diffsynth_engine.utils.constants import (
17
17
  WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
18
18
  WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
19
19
  WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
20
+ WAN_DIT_KEYMAP_FILE,
20
21
  )
21
22
  from diffsynth_engine.utils.gguf import gguf_inference
22
23
  from diffsynth_engine.utils.fp8_linear import fp8_inference
@@ -30,6 +31,9 @@ from diffsynth_engine.utils.parallel import (
30
31
  T5_TOKEN_NUM = 512
31
32
  FLF_TOKEN_NUM = 257 * 2
32
33
 
34
+ with open(WAN_DIT_KEYMAP_FILE, "r", encoding="utf-8") as f:
35
+ config = json.load(f)
36
+
33
37
 
34
38
  def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
35
39
  return x * (1 + scale) + shift
@@ -73,7 +77,7 @@ class SelfAttention(nn.Module):
73
77
  dim: int,
74
78
  num_heads: int,
75
79
  eps: float = 1e-6,
76
- attn_kwargs: Optional[Dict[str, Any]] = None,
80
+ use_vsa: bool = False,
77
81
  device: str = "cuda:0",
78
82
  dtype: torch.dtype = torch.bfloat16,
79
83
  ):
@@ -86,19 +90,25 @@ class SelfAttention(nn.Module):
86
90
  self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
87
91
  self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
88
92
  self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
89
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
93
+ self.gate_compress = nn.Linear(dim, dim, device=device, dtype=dtype) if use_vsa else None
90
94
 
91
- def forward(self, x, freqs):
95
+ def forward(self, x, freqs, attn_kwargs=None):
92
96
  q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x)
97
+ g = self.gate_compress(x) if self.gate_compress is not None else None
98
+
93
99
  num_heads = q.shape[2] // self.head_dim
94
100
  q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
95
101
  k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
96
102
  v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
103
+ g = rearrange(g, "b s (n d) -> b s n d", n=num_heads) if g is not None else None
104
+
105
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
97
106
  x = attention_ops.attention(
98
107
  q=rope_apply(q, freqs),
99
108
  k=rope_apply(k, freqs),
100
109
  v=v,
101
- **self.attn_kwargs,
110
+ g=g,
111
+ **attn_kwargs,
102
112
  )
103
113
  x = x.flatten(2)
104
114
  return self.o(x)
@@ -111,7 +121,6 @@ class CrossAttention(nn.Module):
111
121
  num_heads: int,
112
122
  eps: float = 1e-6,
113
123
  has_image_input: bool = False,
114
- attn_kwargs: Optional[Dict[str, Any]] = None,
115
124
  device: str = "cuda:0",
116
125
  dtype: torch.dtype = torch.bfloat16,
117
126
  ):
@@ -130,9 +139,8 @@ class CrossAttention(nn.Module):
130
139
  self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
131
140
  self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
132
141
  self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
133
- self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
134
142
 
135
- def forward(self, x: torch.Tensor, y: torch.Tensor):
143
+ def forward(self, x: torch.Tensor, y: torch.Tensor, attn_kwargs=None):
136
144
  if self.has_image_input:
137
145
  img = y[:, :-T5_TOKEN_NUM]
138
146
  ctx = y[:, -T5_TOKEN_NUM:]
@@ -144,12 +152,16 @@ class CrossAttention(nn.Module):
144
152
  k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
145
153
  v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
146
154
 
147
- x = attention(q, k, v, **self.attn_kwargs).flatten(2)
155
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
156
+ if attn_kwargs.get("attn_impl", None) == "vsa":
157
+ attn_kwargs = attn_kwargs.copy()
158
+ attn_kwargs["attn_impl"] = "sdpa"
159
+ x = attention(q, k, v, **attn_kwargs).flatten(2)
148
160
  if self.has_image_input:
149
161
  k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
150
162
  k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
151
163
  v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
152
- y = attention(q, k_img, v_img, **self.attn_kwargs).flatten(2)
164
+ y = attention(q, k_img, v_img, **attn_kwargs).flatten(2)
153
165
  x = x + y
154
166
  return self.o(x)
155
167
 
@@ -162,7 +174,7 @@ class DiTBlock(nn.Module):
162
174
  num_heads: int,
163
175
  ffn_dim: int,
164
176
  eps: float = 1e-6,
165
- attn_kwargs: Optional[Dict[str, Any]] = None,
177
+ use_vsa: bool = False,
166
178
  device: str = "cuda:0",
167
179
  dtype: torch.dtype = torch.bfloat16,
168
180
  ):
@@ -170,9 +182,9 @@ class DiTBlock(nn.Module):
170
182
  self.dim = dim
171
183
  self.num_heads = num_heads
172
184
  self.ffn_dim = ffn_dim
173
- self.self_attn = SelfAttention(dim, num_heads, eps, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
185
+ self.self_attn = SelfAttention(dim, num_heads, eps, use_vsa=use_vsa, device=device, dtype=dtype)
174
186
  self.cross_attn = CrossAttention(
175
- dim, num_heads, eps, has_image_input=has_image_input, attn_kwargs=attn_kwargs, device=device, dtype=dtype
187
+ dim, num_heads, eps, has_image_input=has_image_input, device=device, dtype=dtype
176
188
  )
177
189
  self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
178
190
  self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
@@ -184,14 +196,14 @@ class DiTBlock(nn.Module):
184
196
  )
185
197
  self.modulation = nn.Parameter(torch.randn(1, 6, dim, device=device, dtype=dtype) / dim**0.5)
186
198
 
187
- def forward(self, x, context, t_mod, freqs):
199
+ def forward(self, x, context, t_mod, freqs, attn_kwargs=None):
188
200
  # msa: multi-head self-attention mlp: multi-layer perceptron
189
201
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
190
202
  t.squeeze(1) for t in (self.modulation + t_mod).chunk(6, dim=1)
191
203
  ]
192
204
  input_x = modulate(self.norm1(x), shift_msa, scale_msa)
193
- x = x + gate_msa * self.self_attn(input_x, freqs)
194
- x = x + self.cross_attn(self.norm3(x), context)
205
+ x = x + gate_msa * self.self_attn(input_x, freqs, attn_kwargs)
206
+ x = x + self.cross_attn(self.norm3(x), context, attn_kwargs)
195
207
  input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
196
208
  x = x + gate_mlp * self.ffn(input_x)
197
209
  return x
@@ -249,7 +261,26 @@ class Head(nn.Module):
249
261
 
250
262
 
251
263
  class WanDiTStateDictConverter(StateDictConverter):
264
+ def _from_diffusers(self, state_dict):
265
+ global_rename_dict = config["diffusers"]["global_rename_dict"]
266
+ rename_dict = config["diffusers"]["rename_dict"]
267
+ state_dict_ = {}
268
+ for name, param in state_dict.items():
269
+ suffix = ""
270
+ suffix = ".weight" if name.endswith(".weight") else suffix
271
+ suffix = ".bias" if name.endswith(".bias") else suffix
272
+ prefix = name[: -len(suffix)] if suffix else name
273
+ if prefix in global_rename_dict:
274
+ state_dict_[f"{global_rename_dict[prefix]}{suffix}"] = param
275
+ if prefix.startswith("blocks."):
276
+ _, idx, middle = prefix.split(".", 2)
277
+ if middle in rename_dict:
278
+ state_dict_[f"blocks.{idx}.{rename_dict[middle]}{suffix}"] = param
279
+ return state_dict_
280
+
252
281
  def convert(self, state_dict):
282
+ if "condition_embedder.time_proj.weight" in state_dict:
283
+ return self._from_diffusers(state_dict)
253
284
  return state_dict
254
285
 
255
286
 
@@ -273,7 +304,7 @@ class WanDiT(PreTrainedModel):
273
304
  has_vae_feature: bool = False,
274
305
  fuse_image_latents: bool = False,
275
306
  flf_pos_emb: bool = False,
276
- attn_kwargs: Optional[Dict[str, Any]] = None,
307
+ use_vsa: bool = False,
277
308
  device: str = "cuda:0",
278
309
  dtype: torch.dtype = torch.bfloat16,
279
310
  ):
@@ -307,7 +338,16 @@ class WanDiT(PreTrainedModel):
307
338
  )
308
339
  self.blocks = nn.ModuleList(
309
340
  [
310
- DiTBlock(has_clip_feature, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype)
341
+ DiTBlock(
342
+ has_clip_feature,
343
+ dim,
344
+ num_heads,
345
+ ffn_dim,
346
+ eps,
347
+ use_vsa,
348
+ device=device,
349
+ dtype=dtype,
350
+ )
311
351
  for _ in range(num_layers)
312
352
  ]
313
353
  )
@@ -344,6 +384,7 @@ class WanDiT(PreTrainedModel):
344
384
  timestep: torch.Tensor,
345
385
  clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
346
386
  y: Optional[torch.Tensor] = None, # vae_encoder(img)
387
+ attn_kwargs: Optional[Dict[str, Any]] = None,
347
388
  ):
348
389
  fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
349
390
  use_cfg = x.shape[0] > 1
@@ -376,7 +417,7 @@ class WanDiT(PreTrainedModel):
376
417
 
377
418
  with sequence_parallel((x, t, t_mod, freqs), seq_dims=(1, 0, 0, 0)):
378
419
  for block in self.blocks:
379
- x = block(x, context, t_mod, freqs)
420
+ x = block(x, context, t_mod, freqs, attn_kwargs)
380
421
  x = self.head(x, t)
381
422
  (x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
382
423
  x = self.unpatchify(x, (f, h, w))
@@ -398,7 +439,7 @@ class WanDiT(PreTrainedModel):
398
439
  raise ValueError(f"Unsupported model type: {model_type}")
399
440
 
400
441
  config_file = MODEL_CONFIG_FILES[model_type]
401
- with open(config_file, "r") as f:
442
+ with open(config_file, "r", encoding="utf-8") as f:
402
443
  config = json.load(f)
403
444
  return config
404
445
 
@@ -409,12 +450,11 @@ class WanDiT(PreTrainedModel):
409
450
  config: Dict[str, Any],
410
451
  device: str = "cuda:0",
411
452
  dtype: torch.dtype = torch.bfloat16,
412
- attn_kwargs: Optional[Dict[str, Any]] = None,
413
- assign: bool = True,
453
+ use_vsa: bool = False,
414
454
  ):
415
- model = cls(**config, device="meta", dtype=dtype, attn_kwargs=attn_kwargs)
455
+ model = cls(**config, device="meta", dtype=dtype, use_vsa=use_vsa)
416
456
  model = model.requires_grad_(False)
417
- model.load_state_dict(state_dict, assign=assign)
457
+ model.load_state_dict(state_dict, assign=True)
418
458
  model.to(device=device, dtype=dtype, non_blocking=True)
419
459
  return model
420
460
 
@@ -499,8 +539,5 @@ class WanDiT(PreTrainedModel):
499
539
  for block in self.blocks:
500
540
  block.compile(*args, **kwargs)
501
541
 
502
- for block in self.single_blocks:
503
- block.compile(*args, **kwargs)
504
-
505
- def get_fsdp_modules(self):
506
- return ["blocks"]
542
+ def get_fsdp_module_cls(self):
543
+ return {DiTBlock}
@@ -360,7 +360,7 @@ class WanS2VDiT(WanDiT):
360
360
  raise ValueError(f"Unsupported model type: {model_type}")
361
361
 
362
362
  config_file = MODEL_CONFIG_FILES[model_type]
363
- with open(config_file, "r") as f:
363
+ with open(config_file, "r", encoding="utf-8") as f:
364
364
  config = json.load(f)
365
365
  return config
366
366
 
@@ -198,22 +198,22 @@ class WanTextEncoderStateDictConverter(StateDictConverter):
198
198
 
199
199
  def _from_diffusers(self, state_dict):
200
200
  rename_dict = {
201
- "enc.output_norm.weight": "norm.weight",
202
- "token_embd.weight": "token_embedding.weight",
201
+ "shared.weight": "token_embedding.weight",
202
+ "encoder.final_layer_norm.weight": "norm.weight",
203
203
  }
204
204
  for i in range(self.num_encoder_layers):
205
205
  rename_dict.update(
206
206
  {
207
- f"enc.blk.{i}.attn_q.weight": f"blocks.{i}.attn.q.weight",
208
- f"enc.blk.{i}.attn_k.weight": f"blocks.{i}.attn.k.weight",
209
- f"enc.blk.{i}.attn_v.weight": f"blocks.{i}.attn.v.weight",
210
- f"enc.blk.{i}.attn_o.weight": f"blocks.{i}.attn.o.weight",
211
- f"enc.blk.{i}.ffn_up.weight": f"blocks.{i}.ffn.fc1.weight",
212
- f"enc.blk.{i}.ffn_down.weight": f"blocks.{i}.ffn.fc2.weight",
213
- f"enc.blk.{i}.ffn_gate.weight": f"blocks.{i}.ffn.gate.0.weight",
214
- f"enc.blk.{i}.attn_norm.weight": f"blocks.{i}.norm1.weight",
215
- f"enc.blk.{i}.ffn_norm.weight": f"blocks.{i}.norm2.weight",
216
- f"enc.blk.{i}.attn_rel_b.weight": f"blocks.{i}.pos_embedding.embedding.weight",
207
+ f"encoder.block.{i}.layer.0.SelfAttention.q.weight": f"blocks.{i}.attn.q.weight",
208
+ f"encoder.block.{i}.layer.0.SelfAttention.k.weight": f"blocks.{i}.attn.k.weight",
209
+ f"encoder.block.{i}.layer.0.SelfAttention.v.weight": f"blocks.{i}.attn.v.weight",
210
+ f"encoder.block.{i}.layer.0.SelfAttention.o.weight": f"blocks.{i}.attn.o.weight",
211
+ f"encoder.block.{i}.layer.0.SelfAttention.relative_attention_bias.weight": f"blocks.{i}.pos_embedding.embedding.weight",
212
+ f"encoder.block.{i}.layer.0.layer_norm.weight": f"blocks.{i}.norm1.weight",
213
+ f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight": f"blocks.{i}.ffn.gate.0.weight",
214
+ f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight": f"blocks.{i}.ffn.fc1.weight",
215
+ f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight": f"blocks.{i}.ffn.fc2.weight",
216
+ f"encoder.block.{i}.layer.1.layer_norm.weight": f"blocks.{i}.norm2.weight",
217
217
  }
218
218
  )
219
219
 
@@ -224,7 +224,7 @@ class WanTextEncoderStateDictConverter(StateDictConverter):
224
224
  return new_state_dict
225
225
 
226
226
  def convert(self, state_dict):
227
- if "enc.output_norm.weight" in state_dict:
227
+ if "encoder.final_layer_norm.weight" in state_dict:
228
228
  logger.info("use diffusers format state dict")
229
229
  return self._from_diffusers(state_dict)
230
230
  return state_dict
@@ -12,7 +12,7 @@ from diffsynth_engine.utils.constants import WAN2_1_VAE_CONFIG_FILE, WAN2_2_VAE_
12
12
 
13
13
  CACHE_T = 2
14
14
 
15
- with open(WAN_VAE_KEYMAP_FILE, "r") as f:
15
+ with open(WAN_VAE_KEYMAP_FILE, "r", encoding="utf-8") as f:
16
16
  config = json.load(f)
17
17
 
18
18
 
@@ -855,7 +855,7 @@ class WanVideoVAE(PreTrainedModel):
855
855
  raise ValueError(f"Unsupported model type: {model_type}")
856
856
 
857
857
  config_file = MODEL_CONFIG_FILES[model_type]
858
- with open(config_file, "r") as f:
858
+ with open(config_file, "r", encoding="utf-8") as f:
859
859
  config = json.load(f)
860
860
  return config
861
861
 
@@ -2,10 +2,18 @@ import os
2
2
  import torch
3
3
  import numpy as np
4
4
  from einops import rearrange
5
- from typing import Dict, List, Tuple
5
+ from typing import Dict, List, Tuple, Union, Optional
6
6
  from PIL import Image
7
7
 
8
- from diffsynth_engine.configs import BaseConfig, BaseStateDicts
8
+ from diffsynth_engine.configs import (
9
+ BaseConfig,
10
+ BaseStateDicts,
11
+ LoraConfig,
12
+ AttnImpl,
13
+ SpargeAttentionParams,
14
+ VideoSparseAttentionParams,
15
+ )
16
+ from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs
9
17
  from diffsynth_engine.utils.offload import enable_sequential_cpu_offload, offload_model_to_dict, restore_model_from_dict
10
18
  from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
11
19
  from diffsynth_engine.utils.gguf import load_gguf_checkpoint
@@ -33,6 +41,7 @@ class BasePipeline:
33
41
  dtype=torch.float16,
34
42
  ):
35
43
  super().__init__()
44
+ self.config = None
36
45
  self.vae_tiled = vae_tiled
37
46
  self.vae_tile_size = vae_tile_size
38
47
  self.vae_tile_stride = vae_tile_stride
@@ -48,14 +57,49 @@ class BasePipeline:
48
57
  raise NotImplementedError()
49
58
 
50
59
  @classmethod
51
- def from_state_dict(cls, state_dicts: BaseStateDicts, pipeline_config: BaseConfig) -> "BasePipeline":
60
+ def from_state_dict(cls, state_dicts: BaseStateDicts, config: BaseConfig) -> "BasePipeline":
52
61
  raise NotImplementedError()
53
62
 
54
- def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
55
- for lora_path, lora_scale in lora_list:
56
- logger.info(f"loading lora from {lora_path} with scale {lora_scale}")
63
+ def update_weights(self, state_dicts: BaseStateDicts) -> None:
64
+ raise NotImplementedError()
65
+
66
+ @staticmethod
67
+ def update_component(
68
+ component: torch.nn.Module,
69
+ state_dict: Dict[str, torch.Tensor],
70
+ device: str,
71
+ dtype: torch.dtype,
72
+ ) -> None:
73
+ if component and state_dict:
74
+ component.load_state_dict(state_dict, assign=True)
75
+ component.to(device=device, dtype=dtype, non_blocking=True)
76
+
77
+ def load_loras(
78
+ self,
79
+ lora_list: List[Tuple[str, Union[float, LoraConfig]]],
80
+ fused: bool = True,
81
+ save_original_weight: bool = False,
82
+ lora_converter: Optional[LoRAStateDictConverter] = None,
83
+ ):
84
+ if not lora_converter:
85
+ lora_converter = self.lora_converter
86
+
87
+ for lora_path, lora_item in lora_list:
88
+ if isinstance(lora_item, float):
89
+ lora_scale = lora_item
90
+ scheduler_config = None
91
+ if isinstance(lora_item, LoraConfig):
92
+ lora_scale = lora_item.scale
93
+ scheduler_config = lora_item.scheduler_config
94
+
95
+ logger.info(f"loading lora from {lora_path} with LoraConfig (scale={lora_scale})")
57
96
  state_dict = load_file(lora_path, device=self.device)
58
- lora_state_dict = self.lora_converter.convert(state_dict)
97
+
98
+ if scheduler_config is not None:
99
+ self.apply_scheduler_config(scheduler_config)
100
+ logger.info(f"Applied scheduler args from LoraConfig: {scheduler_config}")
101
+
102
+ lora_state_dict = lora_converter.convert(state_dict)
59
103
  for model_name, state_dict in lora_state_dict.items():
60
104
  model = getattr(self, model_name)
61
105
  lora_args = []
@@ -78,6 +122,9 @@ class BasePipeline:
78
122
  def load_lora(self, path: str, scale: float, fused: bool = True, save_original_weight: bool = False):
79
123
  self.load_loras([(path, scale)], fused, save_original_weight)
80
124
 
125
+ def apply_scheduler_config(self, scheduler_config: Dict):
126
+ pass
127
+
81
128
  def unload_loras(self):
82
129
  raise NotImplementedError()
83
130
 
@@ -222,6 +269,25 @@ class BasePipeline:
222
269
  )
223
270
  return init_latents, latents, sigmas, timesteps
224
271
 
272
+ def get_attn_kwargs(self, latents: torch.Tensor) -> Dict:
273
+ attn_kwargs = {"attn_impl": self.config.dit_attn_impl.value}
274
+ if isinstance(self.config.attn_params, SpargeAttentionParams):
275
+ assert self.config.dit_attn_impl == AttnImpl.SPARGE
276
+ attn_kwargs.update(
277
+ {
278
+ "smooth_k": self.config.attn_params.smooth_k,
279
+ "simthreshd1": self.config.attn_params.simthreshd1,
280
+ "cdfthreshd": self.config.attn_params.cdfthreshd,
281
+ "pvthreshd": self.config.attn_params.pvthreshd,
282
+ }
283
+ )
284
+ elif isinstance(self.config.attn_params, VideoSparseAttentionParams):
285
+ assert self.config.dit_attn_impl == AttnImpl.VSA
286
+ attn_kwargs.update(
287
+ get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.config.attn_params.sparsity, device=self.device)
288
+ )
289
+ return attn_kwargs
290
+
225
291
  def eval(self):
226
292
  for model_name in self.model_names:
227
293
  model = getattr(self, model_name)