diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. diffsynth_engine/__init__.py +12 -0
  2. diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
  3. diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
  4. diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
  5. diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
  6. diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
  7. diffsynth_engine/configs/__init__.py +16 -1
  8. diffsynth_engine/configs/controlnet.py +13 -0
  9. diffsynth_engine/configs/pipeline.py +37 -11
  10. diffsynth_engine/models/base.py +1 -1
  11. diffsynth_engine/models/basic/attention.py +105 -43
  12. diffsynth_engine/models/basic/transformer_helper.py +36 -2
  13. diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
  14. diffsynth_engine/models/flux/flux_controlnet.py +16 -30
  15. diffsynth_engine/models/flux/flux_dit.py +49 -62
  16. diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
  17. diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
  18. diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
  19. diffsynth_engine/models/flux/flux_vae.py +20 -2
  20. diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
  21. diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
  22. diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
  23. diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
  24. diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
  25. diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
  26. diffsynth_engine/models/sd/sd_unet.py +1 -1
  27. diffsynth_engine/models/sd3/sd3_dit.py +1 -1
  28. diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
  29. diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
  30. diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
  31. diffsynth_engine/models/vae/vae.py +1 -1
  32. diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
  33. diffsynth_engine/models/wan/wan_dit.py +65 -28
  34. diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
  35. diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
  36. diffsynth_engine/models/wan/wan_vae.py +2 -2
  37. diffsynth_engine/pipelines/base.py +73 -7
  38. diffsynth_engine/pipelines/flux_image.py +139 -120
  39. diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
  40. diffsynth_engine/pipelines/qwen_image.py +272 -87
  41. diffsynth_engine/pipelines/sdxl_image.py +1 -1
  42. diffsynth_engine/pipelines/utils.py +52 -0
  43. diffsynth_engine/pipelines/wan_s2v.py +25 -14
  44. diffsynth_engine/pipelines/wan_video.py +43 -19
  45. diffsynth_engine/tokenizers/base.py +6 -0
  46. diffsynth_engine/tokenizers/qwen2.py +12 -4
  47. diffsynth_engine/utils/constants.py +13 -12
  48. diffsynth_engine/utils/download.py +4 -2
  49. diffsynth_engine/utils/env.py +2 -0
  50. diffsynth_engine/utils/flag.py +6 -0
  51. diffsynth_engine/utils/loader.py +25 -6
  52. diffsynth_engine/utils/parallel.py +62 -29
  53. diffsynth_engine/utils/video.py +3 -1
  54. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
  55. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
  56. /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
  57. /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
  58. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
  59. /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
  60. /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
  61. /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
  62. /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
  63. /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
  64. /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
  65. /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
  66. /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
  67. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
  68. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
  69. {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
@@ -11,8 +11,14 @@ from .configs import (
11
11
  FluxStateDicts,
12
12
  WanStateDicts,
13
13
  QwenImageStateDicts,
14
+ AttnImpl,
15
+ SpargeAttentionParams,
16
+ VideoSparseAttentionParams,
17
+ LoraConfig,
14
18
  ControlNetParams,
15
19
  ControlType,
20
+ QwenImageControlNetParams,
21
+ QwenImageControlType,
16
22
  )
17
23
  from .pipelines import (
18
24
  SDImagePipeline,
@@ -54,8 +60,14 @@ __all__ = [
54
60
  "FluxStateDicts",
55
61
  "WanStateDicts",
56
62
  "QwenImageStateDicts",
63
+ "AttnImpl",
64
+ "SpargeAttentionParams",
65
+ "VideoSparseAttentionParams",
66
+ "LoraConfig",
57
67
  "ControlNetParams",
58
68
  "ControlType",
69
+ "QwenImageControlNetParams",
70
+ "QwenImageControlType",
59
71
  "SDImagePipeline",
60
72
  "SDControlNet",
61
73
  "SDXLImagePipeline",
@@ -6,5 +6,24 @@ def append_zero(x):
6
6
 
7
7
 
8
8
  class BaseScheduler:
9
+ def __init__(self):
10
+ self._stored_config = {}
11
+
12
+ def store_config(self):
13
+ self._stored_config = {
14
+ config_name: config_value
15
+ for config_name, config_value in vars(self).items()
16
+ if not config_name.startswith("_")
17
+ }
18
+
19
+ def update_config(self, config_dict):
20
+ for config_name, new_value in config_dict.items():
21
+ if hasattr(self, config_name):
22
+ setattr(self, config_name, new_value)
23
+
24
+ def restore_config(self):
25
+ for config_name, config_value in self._stored_config.items():
26
+ setattr(self, config_name, config_value)
27
+
9
28
  def schedule(self, num_inference_steps: int):
10
29
  raise NotImplementedError()
@@ -12,16 +12,23 @@ class RecifitedFlowScheduler(BaseScheduler):
12
12
  def __init__(
13
13
  self,
14
14
  shift=1.0,
15
- sigma_min=0.001,
16
- sigma_max=1.0,
15
+ sigma_min=None,
16
+ sigma_max=None,
17
17
  num_train_timesteps=1000,
18
18
  use_dynamic_shifting=False,
19
+ shift_terminal=None,
20
+ exponential_shift_mu=None,
19
21
  ):
22
+ super().__init__()
20
23
  self.shift = shift
21
24
  self.sigma_min = sigma_min
22
25
  self.sigma_max = sigma_max
23
26
  self.num_train_timesteps = num_train_timesteps
24
27
  self.use_dynamic_shifting = use_dynamic_shifting
28
+ self.shift_terminal = shift_terminal
29
+ # static mu for distill model
30
+ self.exponential_shift_mu = exponential_shift_mu
31
+ self.store_config()
25
32
 
26
33
  def _sigma_to_t(self, sigma):
27
34
  return sigma * self.num_train_timesteps
@@ -35,21 +42,30 @@ class RecifitedFlowScheduler(BaseScheduler):
35
42
  def _shift_sigma(self, sigma: torch.Tensor, shift: float):
36
43
  return shift * sigma / (1 + (shift - 1) * sigma)
37
44
 
45
+ def _stretch_shift_to_terminal(self, sigma: torch.Tensor):
46
+ one_minus_z = 1 - sigma
47
+ scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)
48
+ return 1 - (one_minus_z / scale_factor)
49
+
38
50
  def schedule(
39
51
  self,
40
52
  num_inference_steps: int,
41
53
  mu: float | None = None,
42
- sigma_min: float | None = None,
43
- sigma_max: float | None = None,
54
+ sigma_min: float = 0.001,
55
+ sigma_max: float = 1.0,
44
56
  append_value: float = 0,
45
57
  ):
46
- sigma_min = self.sigma_min if sigma_min is None else sigma_min
47
- sigma_max = self.sigma_max if sigma_max is None else sigma_max
58
+ sigma_min = sigma_min if self.sigma_min is None else self.sigma_min
59
+ sigma_max = sigma_max if self.sigma_max is None else self.sigma_max
48
60
  sigmas = torch.linspace(sigma_max, sigma_min, num_inference_steps)
61
+ if self.exponential_shift_mu is not None:
62
+ mu = self.exponential_shift_mu
49
63
  if self.use_dynamic_shifting:
50
64
  sigmas = self._time_shift(mu, 1.0, sigmas) # FLUX
51
65
  else:
52
66
  sigmas = self._shift_sigma(sigmas, self.shift)
67
+ if self.shift_terminal is not None:
68
+ sigmas = self._stretch_shift_to_terminal(sigmas)
53
69
  timesteps = sigmas * self.num_train_timesteps
54
70
  sigmas = append(sigmas, append_value)
55
71
  return sigmas, timesteps
@@ -101,5 +101,24 @@
101
101
  "proj_mlp": "proj_in_besides_attn",
102
102
  "proj_out": "proj_out"
103
103
  }
104
- }
104
+ },
105
+ "preferred_kontext_resolutions": [
106
+ [672, 1568],
107
+ [688, 1504],
108
+ [720, 1456],
109
+ [752, 1392],
110
+ [800, 1328],
111
+ [832, 1248],
112
+ [880, 1184],
113
+ [944, 1104],
114
+ [1024, 1024],
115
+ [1104, 944],
116
+ [1184, 880],
117
+ [1248, 832],
118
+ [1328, 800],
119
+ [1392, 752],
120
+ [1456, 720],
121
+ [1504, 688],
122
+ [1568, 672]
123
+ ]
105
124
  }
@@ -5,6 +5,8 @@
5
5
  "decoder.conv_in.weight": "decoder.conv_in.weight",
6
6
  "decoder.conv_out.bias": "decoder.conv_out.bias",
7
7
  "decoder.conv_out.weight": "decoder.conv_out.weight",
8
+ "decoder.norm_out.bias": "decoder.conv_norm_out.bias",
9
+ "decoder.norm_out.weight": "decoder.conv_norm_out.weight",
8
10
  "decoder.mid.attn_1.k.bias": "decoder.blocks.1.transformer_blocks.0.to_k.bias",
9
11
  "decoder.mid.attn_1.k.weight": "decoder.blocks.1.transformer_blocks.0.to_k.weight",
10
12
  "decoder.mid.attn_1.norm.bias": "decoder.blocks.1.norm.bias",
@@ -31,8 +33,6 @@
31
33
  "decoder.mid.block_2.norm1.weight": "decoder.blocks.2.norm1.weight",
32
34
  "decoder.mid.block_2.norm2.bias": "decoder.blocks.2.norm2.bias",
33
35
  "decoder.mid.block_2.norm2.weight": "decoder.blocks.2.norm2.weight",
34
- "decoder.norm_out.bias": "decoder.conv_norm_out.bias",
35
- "decoder.norm_out.weight": "decoder.conv_norm_out.weight",
36
36
  "decoder.up.0.block.0.conv1.bias": "decoder.blocks.15.conv1.bias",
37
37
  "decoder.up.0.block.0.conv1.weight": "decoder.blocks.15.conv1.weight",
38
38
  "decoder.up.0.block.0.conv2.bias": "decoder.blocks.15.conv2.bias",
@@ -143,6 +143,8 @@
143
143
  "encoder.conv_in.weight": "encoder.conv_in.weight",
144
144
  "encoder.conv_out.bias": "encoder.conv_out.bias",
145
145
  "encoder.conv_out.weight": "encoder.conv_out.weight",
146
+ "encoder.norm_out.bias": "encoder.conv_norm_out.bias",
147
+ "encoder.norm_out.weight": "encoder.conv_norm_out.weight",
146
148
  "encoder.down.0.block.0.conv1.bias": "encoder.blocks.0.conv1.bias",
147
149
  "encoder.down.0.block.0.conv1.weight": "encoder.blocks.0.conv1.weight",
148
150
  "encoder.down.0.block.0.conv2.bias": "encoder.blocks.0.conv2.bias",
@@ -242,9 +244,255 @@
242
244
  "encoder.mid.block_2.norm1.bias": "encoder.blocks.13.norm1.bias",
243
245
  "encoder.mid.block_2.norm1.weight": "encoder.blocks.13.norm1.weight",
244
246
  "encoder.mid.block_2.norm2.bias": "encoder.blocks.13.norm2.bias",
245
- "encoder.mid.block_2.norm2.weight": "encoder.blocks.13.norm2.weight",
246
- "encoder.norm_out.bias": "encoder.conv_norm_out.bias",
247
- "encoder.norm_out.weight": "encoder.conv_norm_out.weight"
247
+ "encoder.mid.block_2.norm2.weight": "encoder.blocks.13.norm2.weight"
248
+ }
249
+ },
250
+ "diffusers": {
251
+ "rename_dict": {
252
+ "decoder.conv_in.bias": "decoder.conv_in.bias",
253
+ "decoder.conv_in.weight": "decoder.conv_in.weight",
254
+ "decoder.conv_out.bias": "decoder.conv_out.bias",
255
+ "decoder.conv_out.weight": "decoder.conv_out.weight",
256
+ "decoder.conv_norm_out.bias": "decoder.conv_norm_out.bias",
257
+ "decoder.conv_norm_out.weight": "decoder.conv_norm_out.weight",
258
+ "decoder.mid_block.attentions.0.to_k.bias": "decoder.blocks.1.transformer_blocks.0.to_k.bias",
259
+ "decoder.mid_block.attentions.0.to_k.weight": "decoder.blocks.1.transformer_blocks.0.to_k.weight",
260
+ "decoder.mid_block.attentions.0.group_norm.bias": "decoder.blocks.1.norm.bias",
261
+ "decoder.mid_block.attentions.0.group_norm.weight": "decoder.blocks.1.norm.weight",
262
+ "decoder.mid_block.attentions.0.to_out.0.bias": "decoder.blocks.1.transformer_blocks.0.to_out.bias",
263
+ "decoder.mid_block.attentions.0.to_out.0.weight": "decoder.blocks.1.transformer_blocks.0.to_out.weight",
264
+ "decoder.mid_block.attentions.0.to_q.bias": "decoder.blocks.1.transformer_blocks.0.to_q.bias",
265
+ "decoder.mid_block.attentions.0.to_q.weight": "decoder.blocks.1.transformer_blocks.0.to_q.weight",
266
+ "decoder.mid_block.attentions.0.to_v.bias": "decoder.blocks.1.transformer_blocks.0.to_v.bias",
267
+ "decoder.mid_block.attentions.0.to_v.weight": "decoder.blocks.1.transformer_blocks.0.to_v.weight",
268
+ "decoder.mid_block.resnets.0.conv1.bias": "decoder.blocks.0.conv1.bias",
269
+ "decoder.mid_block.resnets.0.conv1.weight": "decoder.blocks.0.conv1.weight",
270
+ "decoder.mid_block.resnets.0.conv2.bias": "decoder.blocks.0.conv2.bias",
271
+ "decoder.mid_block.resnets.0.conv2.weight": "decoder.blocks.0.conv2.weight",
272
+ "decoder.mid_block.resnets.0.norm1.bias": "decoder.blocks.0.norm1.bias",
273
+ "decoder.mid_block.resnets.0.norm1.weight": "decoder.blocks.0.norm1.weight",
274
+ "decoder.mid_block.resnets.0.norm2.bias": "decoder.blocks.0.norm2.bias",
275
+ "decoder.mid_block.resnets.0.norm2.weight": "decoder.blocks.0.norm2.weight",
276
+ "decoder.mid_block.resnets.1.conv1.bias": "decoder.blocks.2.conv1.bias",
277
+ "decoder.mid_block.resnets.1.conv1.weight": "decoder.blocks.2.conv1.weight",
278
+ "decoder.mid_block.resnets.1.conv2.bias": "decoder.blocks.2.conv2.bias",
279
+ "decoder.mid_block.resnets.1.conv2.weight": "decoder.blocks.2.conv2.weight",
280
+ "decoder.mid_block.resnets.1.norm1.bias": "decoder.blocks.2.norm1.bias",
281
+ "decoder.mid_block.resnets.1.norm1.weight": "decoder.blocks.2.norm1.weight",
282
+ "decoder.mid_block.resnets.1.norm2.bias": "decoder.blocks.2.norm2.bias",
283
+ "decoder.mid_block.resnets.1.norm2.weight": "decoder.blocks.2.norm2.weight",
284
+ "decoder.up_blocks.0.resnets.0.conv1.bias": "decoder.blocks.3.conv1.bias",
285
+ "decoder.up_blocks.0.resnets.0.conv1.weight": "decoder.blocks.3.conv1.weight",
286
+ "decoder.up_blocks.0.resnets.0.conv2.bias": "decoder.blocks.3.conv2.bias",
287
+ "decoder.up_blocks.0.resnets.0.conv2.weight": "decoder.blocks.3.conv2.weight",
288
+ "decoder.up_blocks.0.resnets.0.norm1.bias": "decoder.blocks.3.norm1.bias",
289
+ "decoder.up_blocks.0.resnets.0.norm1.weight": "decoder.blocks.3.norm1.weight",
290
+ "decoder.up_blocks.0.resnets.0.norm2.bias": "decoder.blocks.3.norm2.bias",
291
+ "decoder.up_blocks.0.resnets.0.norm2.weight": "decoder.blocks.3.norm2.weight",
292
+ "decoder.up_blocks.0.resnets.1.conv1.bias": "decoder.blocks.4.conv1.bias",
293
+ "decoder.up_blocks.0.resnets.1.conv1.weight": "decoder.blocks.4.conv1.weight",
294
+ "decoder.up_blocks.0.resnets.1.conv2.bias": "decoder.blocks.4.conv2.bias",
295
+ "decoder.up_blocks.0.resnets.1.conv2.weight": "decoder.blocks.4.conv2.weight",
296
+ "decoder.up_blocks.0.resnets.1.norm1.bias": "decoder.blocks.4.norm1.bias",
297
+ "decoder.up_blocks.0.resnets.1.norm1.weight": "decoder.blocks.4.norm1.weight",
298
+ "decoder.up_blocks.0.resnets.1.norm2.bias": "decoder.blocks.4.norm2.bias",
299
+ "decoder.up_blocks.0.resnets.1.norm2.weight": "decoder.blocks.4.norm2.weight",
300
+ "decoder.up_blocks.0.resnets.2.conv1.bias": "decoder.blocks.5.conv1.bias",
301
+ "decoder.up_blocks.0.resnets.2.conv1.weight": "decoder.blocks.5.conv1.weight",
302
+ "decoder.up_blocks.0.resnets.2.conv2.bias": "decoder.blocks.5.conv2.bias",
303
+ "decoder.up_blocks.0.resnets.2.conv2.weight": "decoder.blocks.5.conv2.weight",
304
+ "decoder.up_blocks.0.resnets.2.norm1.bias": "decoder.blocks.5.norm1.bias",
305
+ "decoder.up_blocks.0.resnets.2.norm1.weight": "decoder.blocks.5.norm1.weight",
306
+ "decoder.up_blocks.0.resnets.2.norm2.bias": "decoder.blocks.5.norm2.bias",
307
+ "decoder.up_blocks.0.resnets.2.norm2.weight": "decoder.blocks.5.norm2.weight",
308
+ "decoder.up_blocks.0.upsamplers.0.conv.bias": "decoder.blocks.6.conv.bias",
309
+ "decoder.up_blocks.0.upsamplers.0.conv.weight": "decoder.blocks.6.conv.weight",
310
+ "decoder.up_blocks.1.resnets.0.conv1.bias": "decoder.blocks.7.conv1.bias",
311
+ "decoder.up_blocks.1.resnets.0.conv1.weight": "decoder.blocks.7.conv1.weight",
312
+ "decoder.up_blocks.1.resnets.0.conv2.bias": "decoder.blocks.7.conv2.bias",
313
+ "decoder.up_blocks.1.resnets.0.conv2.weight": "decoder.blocks.7.conv2.weight",
314
+ "decoder.up_blocks.1.resnets.0.norm1.bias": "decoder.blocks.7.norm1.bias",
315
+ "decoder.up_blocks.1.resnets.0.norm1.weight": "decoder.blocks.7.norm1.weight",
316
+ "decoder.up_blocks.1.resnets.0.norm2.bias": "decoder.blocks.7.norm2.bias",
317
+ "decoder.up_blocks.1.resnets.0.norm2.weight": "decoder.blocks.7.norm2.weight",
318
+ "decoder.up_blocks.1.resnets.1.conv1.bias": "decoder.blocks.8.conv1.bias",
319
+ "decoder.up_blocks.1.resnets.1.conv1.weight": "decoder.blocks.8.conv1.weight",
320
+ "decoder.up_blocks.1.resnets.1.conv2.bias": "decoder.blocks.8.conv2.bias",
321
+ "decoder.up_blocks.1.resnets.1.conv2.weight": "decoder.blocks.8.conv2.weight",
322
+ "decoder.up_blocks.1.resnets.1.norm1.bias": "decoder.blocks.8.norm1.bias",
323
+ "decoder.up_blocks.1.resnets.1.norm1.weight": "decoder.blocks.8.norm1.weight",
324
+ "decoder.up_blocks.1.resnets.1.norm2.bias": "decoder.blocks.8.norm2.bias",
325
+ "decoder.up_blocks.1.resnets.1.norm2.weight": "decoder.blocks.8.norm2.weight",
326
+ "decoder.up_blocks.1.resnets.2.conv1.bias": "decoder.blocks.9.conv1.bias",
327
+ "decoder.up_blocks.1.resnets.2.conv1.weight": "decoder.blocks.9.conv1.weight",
328
+ "decoder.up_blocks.1.resnets.2.conv2.bias": "decoder.blocks.9.conv2.bias",
329
+ "decoder.up_blocks.1.resnets.2.conv2.weight": "decoder.blocks.9.conv2.weight",
330
+ "decoder.up_blocks.1.resnets.2.norm1.bias": "decoder.blocks.9.norm1.bias",
331
+ "decoder.up_blocks.1.resnets.2.norm1.weight": "decoder.blocks.9.norm1.weight",
332
+ "decoder.up_blocks.1.resnets.2.norm2.bias": "decoder.blocks.9.norm2.bias",
333
+ "decoder.up_blocks.1.resnets.2.norm2.weight": "decoder.blocks.9.norm2.weight",
334
+ "decoder.up_blocks.1.upsamplers.0.conv.bias": "decoder.blocks.10.conv.bias",
335
+ "decoder.up_blocks.1.upsamplers.0.conv.weight": "decoder.blocks.10.conv.weight",
336
+ "decoder.up_blocks.2.resnets.0.conv1.bias": "decoder.blocks.11.conv1.bias",
337
+ "decoder.up_blocks.2.resnets.0.conv1.weight": "decoder.blocks.11.conv1.weight",
338
+ "decoder.up_blocks.2.resnets.0.conv2.bias": "decoder.blocks.11.conv2.bias",
339
+ "decoder.up_blocks.2.resnets.0.conv2.weight": "decoder.blocks.11.conv2.weight",
340
+ "decoder.up_blocks.2.resnets.0.conv_shortcut.bias": "decoder.blocks.11.conv_shortcut.bias",
341
+ "decoder.up_blocks.2.resnets.0.conv_shortcut.weight": "decoder.blocks.11.conv_shortcut.weight",
342
+ "decoder.up_blocks.2.resnets.0.norm1.bias": "decoder.blocks.11.norm1.bias",
343
+ "decoder.up_blocks.2.resnets.0.norm1.weight": "decoder.blocks.11.norm1.weight",
344
+ "decoder.up_blocks.2.resnets.0.norm2.bias": "decoder.blocks.11.norm2.bias",
345
+ "decoder.up_blocks.2.resnets.0.norm2.weight": "decoder.blocks.11.norm2.weight",
346
+ "decoder.up_blocks.2.resnets.1.conv1.bias": "decoder.blocks.12.conv1.bias",
347
+ "decoder.up_blocks.2.resnets.1.conv1.weight": "decoder.blocks.12.conv1.weight",
348
+ "decoder.up_blocks.2.resnets.1.conv2.bias": "decoder.blocks.12.conv2.bias",
349
+ "decoder.up_blocks.2.resnets.1.conv2.weight": "decoder.blocks.12.conv2.weight",
350
+ "decoder.up_blocks.2.resnets.1.norm1.bias": "decoder.blocks.12.norm1.bias",
351
+ "decoder.up_blocks.2.resnets.1.norm1.weight": "decoder.blocks.12.norm1.weight",
352
+ "decoder.up_blocks.2.resnets.1.norm2.bias": "decoder.blocks.12.norm2.bias",
353
+ "decoder.up_blocks.2.resnets.1.norm2.weight": "decoder.blocks.12.norm2.weight",
354
+ "decoder.up_blocks.2.resnets.2.conv1.bias": "decoder.blocks.13.conv1.bias",
355
+ "decoder.up_blocks.2.resnets.2.conv1.weight": "decoder.blocks.13.conv1.weight",
356
+ "decoder.up_blocks.2.resnets.2.conv2.bias": "decoder.blocks.13.conv2.bias",
357
+ "decoder.up_blocks.2.resnets.2.conv2.weight": "decoder.blocks.13.conv2.weight",
358
+ "decoder.up_blocks.2.resnets.2.norm1.bias": "decoder.blocks.13.norm1.bias",
359
+ "decoder.up_blocks.2.resnets.2.norm1.weight": "decoder.blocks.13.norm1.weight",
360
+ "decoder.up_blocks.2.resnets.2.norm2.bias": "decoder.blocks.13.norm2.bias",
361
+ "decoder.up_blocks.2.resnets.2.norm2.weight": "decoder.blocks.13.norm2.weight",
362
+ "decoder.up_blocks.2.upsamplers.0.conv.bias": "decoder.blocks.14.conv.bias",
363
+ "decoder.up_blocks.2.upsamplers.0.conv.weight": "decoder.blocks.14.conv.weight",
364
+ "decoder.up_blocks.3.resnets.0.conv1.bias": "decoder.blocks.15.conv1.bias",
365
+ "decoder.up_blocks.3.resnets.0.conv1.weight": "decoder.blocks.15.conv1.weight",
366
+ "decoder.up_blocks.3.resnets.0.conv2.bias": "decoder.blocks.15.conv2.bias",
367
+ "decoder.up_blocks.3.resnets.0.conv2.weight": "decoder.blocks.15.conv2.weight",
368
+ "decoder.up_blocks.3.resnets.0.conv_shortcut.bias": "decoder.blocks.15.conv_shortcut.bias",
369
+ "decoder.up_blocks.3.resnets.0.conv_shortcut.weight": "decoder.blocks.15.conv_shortcut.weight",
370
+ "decoder.up_blocks.3.resnets.0.norm1.bias": "decoder.blocks.15.norm1.bias",
371
+ "decoder.up_blocks.3.resnets.0.norm1.weight": "decoder.blocks.15.norm1.weight",
372
+ "decoder.up_blocks.3.resnets.0.norm2.bias": "decoder.blocks.15.norm2.bias",
373
+ "decoder.up_blocks.3.resnets.0.norm2.weight": "decoder.blocks.15.norm2.weight",
374
+ "decoder.up_blocks.3.resnets.1.conv1.bias": "decoder.blocks.16.conv1.bias",
375
+ "decoder.up_blocks.3.resnets.1.conv1.weight": "decoder.blocks.16.conv1.weight",
376
+ "decoder.up_blocks.3.resnets.1.conv2.bias": "decoder.blocks.16.conv2.bias",
377
+ "decoder.up_blocks.3.resnets.1.conv2.weight": "decoder.blocks.16.conv2.weight",
378
+ "decoder.up_blocks.3.resnets.1.norm1.bias": "decoder.blocks.16.norm1.bias",
379
+ "decoder.up_blocks.3.resnets.1.norm1.weight": "decoder.blocks.16.norm1.weight",
380
+ "decoder.up_blocks.3.resnets.1.norm2.bias": "decoder.blocks.16.norm2.bias",
381
+ "decoder.up_blocks.3.resnets.1.norm2.weight": "decoder.blocks.16.norm2.weight",
382
+ "decoder.up_blocks.3.resnets.2.conv1.bias": "decoder.blocks.17.conv1.bias",
383
+ "decoder.up_blocks.3.resnets.2.conv1.weight": "decoder.blocks.17.conv1.weight",
384
+ "decoder.up_blocks.3.resnets.2.conv2.bias": "decoder.blocks.17.conv2.bias",
385
+ "decoder.up_blocks.3.resnets.2.conv2.weight": "decoder.blocks.17.conv2.weight",
386
+ "decoder.up_blocks.3.resnets.2.norm1.bias": "decoder.blocks.17.norm1.bias",
387
+ "decoder.up_blocks.3.resnets.2.norm1.weight": "decoder.blocks.17.norm1.weight",
388
+ "decoder.up_blocks.3.resnets.2.norm2.bias": "decoder.blocks.17.norm2.bias",
389
+ "decoder.up_blocks.3.resnets.2.norm2.weight": "decoder.blocks.17.norm2.weight",
390
+ "encoder.conv_in.bias": "encoder.conv_in.bias",
391
+ "encoder.conv_in.weight": "encoder.conv_in.weight",
392
+ "encoder.conv_out.bias": "encoder.conv_out.bias",
393
+ "encoder.conv_out.weight": "encoder.conv_out.weight",
394
+ "encoder.conv_norm_out.bias": "encoder.conv_norm_out.bias",
395
+ "encoder.conv_norm_out.weight": "encoder.conv_norm_out.weight",
396
+ "encoder.down_blocks.0.resnets.0.conv1.bias": "encoder.blocks.0.conv1.bias",
397
+ "encoder.down_blocks.0.resnets.0.conv1.weight": "encoder.blocks.0.conv1.weight",
398
+ "encoder.down_blocks.0.resnets.0.conv2.bias": "encoder.blocks.0.conv2.bias",
399
+ "encoder.down_blocks.0.resnets.0.conv2.weight": "encoder.blocks.0.conv2.weight",
400
+ "encoder.down_blocks.0.resnets.0.norm1.bias": "encoder.blocks.0.norm1.bias",
401
+ "encoder.down_blocks.0.resnets.0.norm1.weight": "encoder.blocks.0.norm1.weight",
402
+ "encoder.down_blocks.0.resnets.0.norm2.bias": "encoder.blocks.0.norm2.bias",
403
+ "encoder.down_blocks.0.resnets.0.norm2.weight": "encoder.blocks.0.norm2.weight",
404
+ "encoder.down_blocks.0.resnets.1.conv1.bias": "encoder.blocks.1.conv1.bias",
405
+ "encoder.down_blocks.0.resnets.1.conv1.weight": "encoder.blocks.1.conv1.weight",
406
+ "encoder.down_blocks.0.resnets.1.conv2.bias": "encoder.blocks.1.conv2.bias",
407
+ "encoder.down_blocks.0.resnets.1.conv2.weight": "encoder.blocks.1.conv2.weight",
408
+ "encoder.down_blocks.0.resnets.1.norm1.bias": "encoder.blocks.1.norm1.bias",
409
+ "encoder.down_blocks.0.resnets.1.norm1.weight": "encoder.blocks.1.norm1.weight",
410
+ "encoder.down_blocks.0.resnets.1.norm2.bias": "encoder.blocks.1.norm2.bias",
411
+ "encoder.down_blocks.0.resnets.1.norm2.weight": "encoder.blocks.1.norm2.weight",
412
+ "encoder.down_blocks.0.downsamplers.0.conv.bias": "encoder.blocks.2.conv.bias",
413
+ "encoder.down_blocks.0.downsamplers.0.conv.weight": "encoder.blocks.2.conv.weight",
414
+ "encoder.down_blocks.1.resnets.0.conv1.bias": "encoder.blocks.3.conv1.bias",
415
+ "encoder.down_blocks.1.resnets.0.conv1.weight": "encoder.blocks.3.conv1.weight",
416
+ "encoder.down_blocks.1.resnets.0.conv2.bias": "encoder.blocks.3.conv2.bias",
417
+ "encoder.down_blocks.1.resnets.0.conv2.weight": "encoder.blocks.3.conv2.weight",
418
+ "encoder.down_blocks.1.resnets.0.conv_shortcut.bias": "encoder.blocks.3.conv_shortcut.bias",
419
+ "encoder.down_blocks.1.resnets.0.conv_shortcut.weight": "encoder.blocks.3.conv_shortcut.weight",
420
+ "encoder.down_blocks.1.resnets.0.norm1.bias": "encoder.blocks.3.norm1.bias",
421
+ "encoder.down_blocks.1.resnets.0.norm1.weight": "encoder.blocks.3.norm1.weight",
422
+ "encoder.down_blocks.1.resnets.0.norm2.bias": "encoder.blocks.3.norm2.bias",
423
+ "encoder.down_blocks.1.resnets.0.norm2.weight": "encoder.blocks.3.norm2.weight",
424
+ "encoder.down_blocks.1.resnets.1.conv1.bias": "encoder.blocks.4.conv1.bias",
425
+ "encoder.down_blocks.1.resnets.1.conv1.weight": "encoder.blocks.4.conv1.weight",
426
+ "encoder.down_blocks.1.resnets.1.conv2.bias": "encoder.blocks.4.conv2.bias",
427
+ "encoder.down_blocks.1.resnets.1.conv2.weight": "encoder.blocks.4.conv2.weight",
428
+ "encoder.down_blocks.1.resnets.1.norm1.bias": "encoder.blocks.4.norm1.bias",
429
+ "encoder.down_blocks.1.resnets.1.norm1.weight": "encoder.blocks.4.norm1.weight",
430
+ "encoder.down_blocks.1.resnets.1.norm2.bias": "encoder.blocks.4.norm2.bias",
431
+ "encoder.down_blocks.1.resnets.1.norm2.weight": "encoder.blocks.4.norm2.weight",
432
+ "encoder.down_blocks.1.downsamplers.0.conv.bias": "encoder.blocks.5.conv.bias",
433
+ "encoder.down_blocks.1.downsamplers.0.conv.weight": "encoder.blocks.5.conv.weight",
434
+ "encoder.down_blocks.2.resnets.0.conv1.bias": "encoder.blocks.6.conv1.bias",
435
+ "encoder.down_blocks.2.resnets.0.conv1.weight": "encoder.blocks.6.conv1.weight",
436
+ "encoder.down_blocks.2.resnets.0.conv2.bias": "encoder.blocks.6.conv2.bias",
437
+ "encoder.down_blocks.2.resnets.0.conv2.weight": "encoder.blocks.6.conv2.weight",
438
+ "encoder.down_blocks.2.resnets.0.conv_shortcut.bias": "encoder.blocks.6.conv_shortcut.bias",
439
+ "encoder.down_blocks.2.resnets.0.conv_shortcut.weight": "encoder.blocks.6.conv_shortcut.weight",
440
+ "encoder.down_blocks.2.resnets.0.norm1.bias": "encoder.blocks.6.norm1.bias",
441
+ "encoder.down_blocks.2.resnets.0.norm1.weight": "encoder.blocks.6.norm1.weight",
442
+ "encoder.down_blocks.2.resnets.0.norm2.bias": "encoder.blocks.6.norm2.bias",
443
+ "encoder.down_blocks.2.resnets.0.norm2.weight": "encoder.blocks.6.norm2.weight",
444
+ "encoder.down_blocks.2.resnets.1.conv1.bias": "encoder.blocks.7.conv1.bias",
445
+ "encoder.down_blocks.2.resnets.1.conv1.weight": "encoder.blocks.7.conv1.weight",
446
+ "encoder.down_blocks.2.resnets.1.conv2.bias": "encoder.blocks.7.conv2.bias",
447
+ "encoder.down_blocks.2.resnets.1.conv2.weight": "encoder.blocks.7.conv2.weight",
448
+ "encoder.down_blocks.2.resnets.1.norm1.bias": "encoder.blocks.7.norm1.bias",
449
+ "encoder.down_blocks.2.resnets.1.norm1.weight": "encoder.blocks.7.norm1.weight",
450
+ "encoder.down_blocks.2.resnets.1.norm2.bias": "encoder.blocks.7.norm2.bias",
451
+ "encoder.down_blocks.2.resnets.1.norm2.weight": "encoder.blocks.7.norm2.weight",
452
+ "encoder.down_blocks.2.downsamplers.0.conv.bias": "encoder.blocks.8.conv.bias",
453
+ "encoder.down_blocks.2.downsamplers.0.conv.weight": "encoder.blocks.8.conv.weight",
454
+ "encoder.down_blocks.3.resnets.0.conv1.bias": "encoder.blocks.9.conv1.bias",
455
+ "encoder.down_blocks.3.resnets.0.conv1.weight": "encoder.blocks.9.conv1.weight",
456
+ "encoder.down_blocks.3.resnets.0.conv2.bias": "encoder.blocks.9.conv2.bias",
457
+ "encoder.down_blocks.3.resnets.0.conv2.weight": "encoder.blocks.9.conv2.weight",
458
+ "encoder.down_blocks.3.resnets.0.norm1.bias": "encoder.blocks.9.norm1.bias",
459
+ "encoder.down_blocks.3.resnets.0.norm1.weight": "encoder.blocks.9.norm1.weight",
460
+ "encoder.down_blocks.3.resnets.0.norm2.bias": "encoder.blocks.9.norm2.bias",
461
+ "encoder.down_blocks.3.resnets.0.norm2.weight": "encoder.blocks.9.norm2.weight",
462
+ "encoder.down_blocks.3.resnets.1.conv1.bias": "encoder.blocks.10.conv1.bias",
463
+ "encoder.down_blocks.3.resnets.1.conv1.weight": "encoder.blocks.10.conv1.weight",
464
+ "encoder.down_blocks.3.resnets.1.conv2.bias": "encoder.blocks.10.conv2.bias",
465
+ "encoder.down_blocks.3.resnets.1.conv2.weight": "encoder.blocks.10.conv2.weight",
466
+ "encoder.down_blocks.3.resnets.1.norm1.bias": "encoder.blocks.10.norm1.bias",
467
+ "encoder.down_blocks.3.resnets.1.norm1.weight": "encoder.blocks.10.norm1.weight",
468
+ "encoder.down_blocks.3.resnets.1.norm2.bias": "encoder.blocks.10.norm2.bias",
469
+ "encoder.down_blocks.3.resnets.1.norm2.weight": "encoder.blocks.10.norm2.weight",
470
+ "encoder.mid_block.attentions.0.to_k.bias": "encoder.blocks.12.transformer_blocks.0.to_k.bias",
471
+ "encoder.mid_block.attentions.0.to_k.weight": "encoder.blocks.12.transformer_blocks.0.to_k.weight",
472
+ "encoder.mid_block.attentions.0.group_norm.bias": "encoder.blocks.12.norm.bias",
473
+ "encoder.mid_block.attentions.0.group_norm.weight": "encoder.blocks.12.norm.weight",
474
+ "encoder.mid_block.attentions.0.to_out.0.bias": "encoder.blocks.12.transformer_blocks.0.to_out.bias",
475
+ "encoder.mid_block.attentions.0.to_out.0.weight": "encoder.blocks.12.transformer_blocks.0.to_out.weight",
476
+ "encoder.mid_block.attentions.0.to_q.bias": "encoder.blocks.12.transformer_blocks.0.to_q.bias",
477
+ "encoder.mid_block.attentions.0.to_q.weight": "encoder.blocks.12.transformer_blocks.0.to_q.weight",
478
+ "encoder.mid_block.attentions.0.to_v.bias": "encoder.blocks.12.transformer_blocks.0.to_v.bias",
479
+ "encoder.mid_block.attentions.0.to_v.weight": "encoder.blocks.12.transformer_blocks.0.to_v.weight",
480
+ "encoder.mid_block.resnets.0.conv1.bias": "encoder.blocks.11.conv1.bias",
481
+ "encoder.mid_block.resnets.0.conv1.weight": "encoder.blocks.11.conv1.weight",
482
+ "encoder.mid_block.resnets.0.conv2.bias": "encoder.blocks.11.conv2.bias",
483
+ "encoder.mid_block.resnets.0.conv2.weight": "encoder.blocks.11.conv2.weight",
484
+ "encoder.mid_block.resnets.0.norm1.bias": "encoder.blocks.11.norm1.bias",
485
+ "encoder.mid_block.resnets.0.norm1.weight": "encoder.blocks.11.norm1.weight",
486
+ "encoder.mid_block.resnets.0.norm2.bias": "encoder.blocks.11.norm2.bias",
487
+ "encoder.mid_block.resnets.0.norm2.weight": "encoder.blocks.11.norm2.weight",
488
+ "encoder.mid_block.resnets.1.conv1.bias": "encoder.blocks.13.conv1.bias",
489
+ "encoder.mid_block.resnets.1.conv1.weight": "encoder.blocks.13.conv1.weight",
490
+ "encoder.mid_block.resnets.1.conv2.bias": "encoder.blocks.13.conv2.bias",
491
+ "encoder.mid_block.resnets.1.conv2.weight": "encoder.blocks.13.conv2.weight",
492
+ "encoder.mid_block.resnets.1.norm1.bias": "encoder.blocks.13.norm1.bias",
493
+ "encoder.mid_block.resnets.1.norm1.weight": "encoder.blocks.13.norm1.weight",
494
+ "encoder.mid_block.resnets.1.norm2.bias": "encoder.blocks.13.norm2.bias",
495
+ "encoder.mid_block.resnets.1.norm2.weight": "encoder.blocks.13.norm2.weight"
248
496
  }
249
497
  }
250
498
  }
@@ -0,0 +1,41 @@
1
+ {
2
+ "diffusers": {
3
+ "global_rename_dict": {
4
+ "patch_embedding": "patch_embedding",
5
+ "condition_embedder.text_embedder.linear_1": "text_embedding.0",
6
+ "condition_embedder.text_embedder.linear_2": "text_embedding.2",
7
+ "condition_embedder.time_embedder.linear_1": "time_embedding.0",
8
+ "condition_embedder.time_embedder.linear_2": "time_embedding.2",
9
+ "condition_embedder.time_proj": "time_projection.1",
10
+ "condition_embedder.image_embedder.norm1": "img_emb.proj.0",
11
+ "condition_embedder.image_embedder.ff.net.0.proj": "img_emb.proj.1",
12
+ "condition_embedder.image_embedder.ff.net.2": "img_emb.proj.3",
13
+ "condition_embedder.image_embedder.norm2": "img_emb.proj.4",
14
+ "condition_embedder.image_embedder.pos_embed": "img_emb.emb_pos",
15
+ "proj_out": "head.head",
16
+ "scale_shift_table": "head.modulation"
17
+ },
18
+ "rename_dict": {
19
+ "attn1.to_q": "self_attn.q",
20
+ "attn1.to_k": "self_attn.k",
21
+ "attn1.to_v": "self_attn.v",
22
+ "attn1.to_out.0": "self_attn.o",
23
+ "attn1.norm_q": "self_attn.norm_q",
24
+ "attn1.norm_k": "self_attn.norm_k",
25
+ "to_gate_compress": "self_attn.gate_compress",
26
+ "attn2.to_q": "cross_attn.q",
27
+ "attn2.to_k": "cross_attn.k",
28
+ "attn2.to_v": "cross_attn.v",
29
+ "attn2.to_out.0": "cross_attn.o",
30
+ "attn2.norm_q": "cross_attn.norm_q",
31
+ "attn2.norm_k": "cross_attn.norm_k",
32
+ "attn2.add_k_proj": "cross_attn.k_img",
33
+ "attn2.add_v_proj": "cross_attn.v_img",
34
+ "attn2.norm_added_k": "cross_attn.norm_k_img",
35
+ "norm2": "norm3",
36
+ "ffn.net.0.proj": "ffn.0",
37
+ "ffn.net.2": "ffn.2",
38
+ "scale_shift_table": "modulation"
39
+ }
40
+ }
41
+ }
@@ -17,8 +17,17 @@ from .pipeline import (
17
17
  WanStateDicts,
18
18
  WanS2VStateDicts,
19
19
  QwenImageStateDicts,
20
+ AttnImpl,
21
+ SpargeAttentionParams,
22
+ VideoSparseAttentionParams,
23
+ LoraConfig,
24
+ )
25
+ from .controlnet import (
26
+ ControlType,
27
+ ControlNetParams,
28
+ QwenImageControlType,
29
+ QwenImageControlNetParams,
20
30
  )
21
- from .controlnet import ControlType, ControlNetParams
22
31
 
23
32
  __all__ = [
24
33
  "BaseConfig",
@@ -39,6 +48,12 @@ __all__ = [
39
48
  "WanStateDicts",
40
49
  "WanS2VStateDicts",
41
50
  "QwenImageStateDicts",
51
+ "AttnImpl",
52
+ "SpargeAttentionParams",
53
+ "VideoSparseAttentionParams",
54
+ "LoraConfig",
42
55
  "ControlType",
43
56
  "ControlNetParams",
57
+ "QwenImageControlType",
58
+ "QwenImageControlNetParams",
44
59
  ]
@@ -34,3 +34,16 @@ class ControlNetParams:
34
34
  control_start: float = 0
35
35
  control_end: float = 1
36
36
  processor_name: Optional[str] = None # only used for sdxl controlnet union now
37
+
38
+
39
+ class QwenImageControlType(Enum):
40
+ eligen = "eligen"
41
+ in_context = "in_context"
42
+
43
+
44
+ @dataclass
45
+ class QwenImageControlNetParams:
46
+ image: ImageType
47
+ model: str
48
+ control_type: QwenImageControlType
49
+ scale: float = 1.0
@@ -1,5 +1,6 @@
1
1
  import os
2
2
  import torch
3
+ from enum import Enum
3
4
  from dataclasses import dataclass, field
4
5
  from typing import List, Dict, Tuple, Optional
5
6
 
@@ -19,14 +20,36 @@ class BaseConfig:
19
20
  offload_to_disk: bool = False
20
21
 
21
22
 
23
+ class AttnImpl(Enum):
24
+ AUTO = "auto"
25
+ EAGER = "eager" # Native Attention
26
+ FA2 = "fa2" # Flash Attention 2
27
+ FA3 = "fa3" # Flash Attention 3
28
+ FA3_FP8 = "fa3_fp8" # Flash Attention 3 with FP8
29
+ XFORMERS = "xformers" # XFormers
30
+ SDPA = "sdpa" # Scaled Dot Product Attention
31
+ SAGE = "sage" # Sage Attention
32
+ SPARGE = "sparge" # Sparge Attention
33
+ VSA = "vsa" # Video Sparse Attention
34
+
35
+
36
+ @dataclass
37
+ class SpargeAttentionParams:
38
+ smooth_k: bool = True
39
+ cdfthreshd: float = 0.6
40
+ simthreshd1: float = 0.98
41
+ pvthreshd: float = 50.0
42
+
43
+
44
+ @dataclass
45
+ class VideoSparseAttentionParams:
46
+ sparsity: float = 0.9
47
+
48
+
22
49
  @dataclass
23
50
  class AttentionConfig:
24
- dit_attn_impl: str = "auto"
25
- # Sparge Attention
26
- sparge_smooth_k: bool = True
27
- sparge_cdfthreshd: float = 0.6
28
- sparge_simthreshd1: float = 0.98
29
- sparge_pvthreshd: float = 50.0
51
+ dit_attn_impl: AttnImpl = AttnImpl.AUTO
52
+ attn_params: Optional[SpargeAttentionParams | VideoSparseAttentionParams] = None
30
53
 
31
54
 
32
55
  @dataclass
@@ -221,14 +244,11 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
221
244
  encoder_dtype: torch.dtype = torch.bfloat16
222
245
  vae_dtype: torch.dtype = torch.float32
223
246
 
247
+ load_encoder: bool = True
248
+
224
249
  # override OptimizationConfig
225
250
  fbcache_relative_l1_threshold = 0.009
226
251
 
227
- # override BaseConfig
228
- vae_tiled: bool = True
229
- vae_tile_size: Tuple[int, int] = (34, 34)
230
- vae_tile_stride: Tuple[int, int] = (18, 16)
231
-
232
252
  @classmethod
233
253
  def basic_config(
234
254
  cls,
@@ -352,3 +372,9 @@ def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig |
352
372
  config.tp_degree = 1
353
373
  else:
354
374
  raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
375
+
376
+
377
+ @dataclass
378
+ class LoraConfig:
379
+ scale: float
380
+ scheduler_config: Optional[Dict] = None
@@ -57,7 +57,7 @@ class PreTrainedModel(nn.Module):
57
57
  def get_tp_plan(self):
58
58
  raise NotImplementedError(f"{self.__class__.__name__} does not support TP")
59
59
 
60
- def get_fsdp_modules(self):
60
+ def get_fsdp_module_cls(self):
61
61
  raise NotImplementedError(f"{self.__class__.__name__} does not support FSDP")
62
62
 
63
63