diffsynth-engine 0.6.1.dev32__py3-none-any.whl → 0.6.1.dev33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -251,11 +251,14 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
251
251
  # override OptimizationConfig
252
252
  fbcache_relative_l1_threshold = 0.009
253
253
 
254
- # svd
255
- use_nunchaku: Optional[bool] = field(default=None, init=False)
256
- use_nunchaku_awq: Optional[bool] = field(default=None, init=False)
257
- use_nunchaku_attn: Optional[bool] = field(default=None, init=False)
258
-
254
+ # svd
255
+ use_nunchaku: Optional[bool] = field(default=None, init=False)
256
+ use_nunchaku_awq: Optional[bool] = field(default=None, init=False)
257
+ use_nunchaku_attn: Optional[bool] = field(default=None, init=False)
258
+
259
+ # for 2511
260
+ use_zero_cond_t: bool = False
261
+
259
262
  @classmethod
260
263
  def basic_config(
261
264
  cls,
@@ -266,6 +269,7 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
266
269
  parallelism: int = 1,
267
270
  offload_mode: Optional[str] = None,
268
271
  offload_to_disk: bool = False,
272
+ use_zero_cond_t: bool = False,
269
273
  ) -> "QwenImagePipelineConfig":
270
274
  return cls(
271
275
  model_path=model_path,
@@ -277,6 +281,7 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
277
281
  use_fsdp=True if parallelism > 1 else False,
278
282
  offload_mode=offload_mode,
279
283
  offload_to_disk=offload_to_disk,
284
+ use_zero_cond_t=use_zero_cond_t,
280
285
  )
281
286
 
282
287
  def __post_init__(self):
@@ -94,6 +94,7 @@ if SPARGE_ATTN_AVAILABLE:
94
94
  )
95
95
  return out.transpose(1, 2)
96
96
 
97
+
97
98
  if AITER_AVAILABLE:
98
99
  from aiter import flash_attn_func as aiter_flash_attn
99
100
  from aiter import flash_attn_fp8_pertensor_func as aiter_flash_attn_fp8
@@ -203,7 +204,7 @@ def attention(
203
204
  )
204
205
  if attn_mask is not None:
205
206
  raise RuntimeError("aiter_flash_attn does not support attention mask")
206
- if attn_impl == "aiter" :
207
+ if attn_impl == "aiter":
207
208
  return aiter_flash_attn(q, k, v, softmax_scale=scale)
208
209
  else:
209
210
  origin_dtype = q.dtype
@@ -211,7 +212,7 @@ def attention(
211
212
  k = k.to(dtype=DTYPE_FP8)
212
213
  v = v.to(dtype=DTYPE_FP8)
213
214
  out = aiter_flash_attn_fp8(q, k, v, softmax_scale=scale)
214
- return out.to(dtype=origin_dtype)
215
+ return out.to(dtype=origin_dtype)
215
216
  if attn_impl == "fa2":
216
217
  return flash_attn2(q, k, v, softmax_scale=scale)
217
218
  if attn_impl == "xformers":
@@ -2,6 +2,7 @@ import torch
2
2
  import torch.nn as nn
3
3
  from typing import Any, Dict, List, Tuple, Union, Optional
4
4
  from einops import rearrange
5
+ from math import prod
5
6
 
6
7
  from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
7
8
  from diffsynth_engine.models.basic import attention as attention_ops
@@ -243,6 +244,7 @@ class QwenImageTransformerBlock(nn.Module):
243
244
  num_attention_heads: int,
244
245
  attention_head_dim: int,
245
246
  eps: float = 1e-6,
247
+ zero_cond_t: bool = False,
246
248
  device: str = "cuda:0",
247
249
  dtype: torch.dtype = torch.bfloat16,
248
250
  ):
@@ -275,10 +277,30 @@ class QwenImageTransformerBlock(nn.Module):
275
277
  self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps, device=device, dtype=dtype)
276
278
  self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps, device=device, dtype=dtype)
277
279
  self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim, device=device, dtype=dtype)
280
+ self.zero_cond_t = zero_cond_t
278
281
 
279
- def _modulate(self, x, mod_params):
282
+ def _modulate(self, x, mod_params, index=None):
280
283
  shift, scale, gate = mod_params.chunk(3, dim=-1)
281
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
284
+ if index is not None:
285
+ actual_batch = shift.size(0) // 2
286
+ shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:]
287
+ scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
288
+ gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
289
+ index_expanded = index.unsqueeze(-1)
290
+ shift_0_exp = shift_0.unsqueeze(1)
291
+ shift_1_exp = shift_1.unsqueeze(1)
292
+ scale_0_exp = scale_0.unsqueeze(1)
293
+ scale_1_exp = scale_1.unsqueeze(1)
294
+ gate_0_exp = gate_0.unsqueeze(1)
295
+ gate_1_exp = gate_1.unsqueeze(1)
296
+ shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
297
+ scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
298
+ gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
299
+ else:
300
+ shift_result = shift.unsqueeze(1)
301
+ scale_result = scale.unsqueeze(1)
302
+ gate_result = gate.unsqueeze(1)
303
+ return x * (1 + scale_result) + shift_result, gate_result
282
304
 
283
305
  def forward(
284
306
  self,
@@ -288,12 +310,15 @@ class QwenImageTransformerBlock(nn.Module):
288
310
  rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
289
311
  attn_mask: Optional[torch.Tensor] = None,
290
312
  attn_kwargs: Optional[Dict[str, Any]] = None,
313
+ modulate_index: Optional[List[int]] = None,
291
314
  ) -> Tuple[torch.Tensor, torch.Tensor]:
292
315
  img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
316
+ if self.zero_cond_t:
317
+ temb = torch.chunk(temb, 2, dim=0)[0]
293
318
  txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
294
319
 
295
320
  img_normed = self.img_norm1(image)
296
- img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
321
+ img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, modulate_index)
297
322
 
298
323
  txt_normed = self.txt_norm1(text)
299
324
  txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
@@ -305,12 +330,11 @@ class QwenImageTransformerBlock(nn.Module):
305
330
  attn_mask=attn_mask,
306
331
  attn_kwargs=attn_kwargs,
307
332
  )
308
-
309
333
  image = image + img_gate * img_attn_out
310
334
  text = text + txt_gate * txt_attn_out
311
335
 
312
336
  img_normed_2 = self.img_norm2(image)
313
- img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
337
+ img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, modulate_index)
314
338
 
315
339
  txt_normed_2 = self.txt_norm2(text)
316
340
  txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
@@ -331,6 +355,7 @@ class QwenImageDiT(PreTrainedModel):
331
355
  def __init__(
332
356
  self,
333
357
  num_layers: int = 60,
358
+ zero_cond_t: bool = False,
334
359
  device: str = "cuda:0",
335
360
  dtype: torch.dtype = torch.bfloat16,
336
361
  ):
@@ -351,6 +376,7 @@ class QwenImageDiT(PreTrainedModel):
351
376
  dim=3072,
352
377
  num_attention_heads=24,
353
378
  attention_head_dim=128,
379
+ zero_cond_t=zero_cond_t,
354
380
  device=device,
355
381
  dtype=dtype,
356
382
  )
@@ -359,6 +385,7 @@ class QwenImageDiT(PreTrainedModel):
359
385
  )
360
386
  self.norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
361
387
  self.proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
388
+ self.zero_cond_t = zero_cond_t
362
389
 
363
390
  def patchify(self, hidden_states):
364
391
  hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
@@ -461,6 +488,9 @@ class QwenImageDiT(PreTrainedModel):
461
488
  use_cfg=use_cfg,
462
489
  ),
463
490
  ):
491
+ if self.zero_cond_t:
492
+ timestep = torch.cat([timestep, timestep * 0], dim=0)
493
+ modulate_index = None
464
494
  conditioning = self.time_text_embed(timestep, image.dtype)
465
495
  video_fhw = [(1, h // 2, w // 2)] # frame, height, width
466
496
  text_seq_len = text_seq_lens.max().item()
@@ -478,7 +508,12 @@ class QwenImageDiT(PreTrainedModel):
478
508
  img = self.patchify(img)
479
509
  image = torch.cat([image, img], dim=1)
480
510
  video_fhw += [(1, edit_h // 2, edit_w // 2)]
481
-
511
+ if self.zero_cond_t:
512
+ modulate_index = torch.tensor(
513
+ [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in [video_fhw]],
514
+ device=timestep.device,
515
+ dtype=torch.int,
516
+ )
482
517
  rotary_emb = self.pos_embed(video_fhw, text_seq_len, image.device)
483
518
 
484
519
  image = self.img_in(image)
@@ -510,7 +545,10 @@ class QwenImageDiT(PreTrainedModel):
510
545
  rotary_emb=rotary_emb,
511
546
  attn_mask=attn_mask,
512
547
  attn_kwargs=attn_kwargs,
548
+ modulate_index=modulate_index,
513
549
  )
550
+ if self.zero_cond_t:
551
+ conditioning = conditioning.chunk(2, dim=0)[0]
514
552
  image = self.norm_out(image, conditioning)
515
553
  image = self.proj_out(image)
516
554
  (image,) = sequence_parallel_unshard((image,), seq_dims=(1,), seq_lens=(image_seq_len,))
@@ -527,8 +565,9 @@ class QwenImageDiT(PreTrainedModel):
527
565
  device: str,
528
566
  dtype: torch.dtype,
529
567
  num_layers: int = 60,
568
+ use_zero_cond_t: bool = False,
530
569
  ):
531
- model = cls(device="meta", dtype=dtype, num_layers=num_layers)
570
+ model = cls(device="meta", dtype=dtype, num_layers=num_layers, zero_cond_t=use_zero_cond_t)
532
571
  model = model.requires_grad_(False)
533
572
  model.load_state_dict(state_dict, assign=True)
534
573
  model.to(device=device, dtype=dtype, non_blocking=True)
@@ -2,7 +2,6 @@ import json
2
2
  import torch
3
3
  import torch.distributed as dist
4
4
  import math
5
- import sys
6
5
  from typing import Callable, List, Dict, Tuple, Optional, Union
7
6
  from tqdm import tqdm
8
7
  from einops import rearrange
@@ -45,7 +44,6 @@ from diffsynth_engine.utils.flag import NUNCHAKU_AVAILABLE
45
44
  logger = logging.get_logger(__name__)
46
45
 
47
46
 
48
-
49
47
  class QwenImageLoRAConverter(LoRAStateDictConverter):
50
48
  def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
51
49
  dit_dict = {}
@@ -205,7 +203,7 @@ class QwenImagePipeline(BasePipeline):
205
203
  else:
206
204
  config.use_nunchaku_attn = False
207
205
  logger.info("Disable nunchaku attention quantization.")
208
-
206
+
209
207
  else:
210
208
  config.use_nunchaku = False
211
209
 
@@ -318,6 +316,7 @@ class QwenImagePipeline(BasePipeline):
318
316
  elif config.use_nunchaku:
319
317
  if not NUNCHAKU_AVAILABLE:
320
318
  from diffsynth_engine.utils.flag import NUNCHAKU_IMPORT_ERROR
319
+
321
320
  raise ImportError(NUNCHAKU_IMPORT_ERROR)
322
321
 
323
322
  from diffsynth_engine.models.qwen_image import QwenImageDiTNunchaku
@@ -337,6 +336,7 @@ class QwenImagePipeline(BasePipeline):
337
336
  state_dicts.model,
338
337
  device=("cpu" if config.use_fsdp else init_device),
339
338
  dtype=config.model_dtype,
339
+ use_zero_cond_t=config.use_zero_cond_t,
340
340
  )
341
341
  if config.use_fp8_linear and not config.use_nunchaku:
342
342
  enable_fp8_linear(dit)
@@ -704,7 +704,7 @@ class QwenImagePipeline(BasePipeline):
704
704
 
705
705
  context_latents = None
706
706
  for param in controlnet_params:
707
- self.load_lora(param.model, param.scale, fused=False, save_original_weight=False)
707
+ self.load_lora(param.model, param.scale, fused=True, save_original_weight=False)
708
708
  if param.control_type == QwenImageControlType.in_context:
709
709
  width, height = param.image.size
710
710
  self.validate_image_size(height, width, minimum=64, multiple_of=16)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev32
3
+ Version: 0.6.1.dev33
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -81,12 +81,12 @@ diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json,sha256=bhl7TT29cdoU
81
81
  diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json,sha256=7Zo6iw-qcacKMoR-BDX-A25uES1N9O23u0ipIeNE3AU,61728
82
82
  diffsynth_engine/configs/__init__.py,sha256=vSjJToEdq3JX7t81_z4nwNwIdD4bYnFjxnMZH7PXMKo,1309
83
83
  diffsynth_engine/configs/controlnet.py,sha256=f3vclyP3lcAjxDGD9C1vevhqqQ7W2LL_c6Wye0uxk3Q,1180
84
- diffsynth_engine/configs/pipeline.py,sha256=7duSdoD0LIROtepsLW9PxYsK59p7qSv34BVz0k29vu4,13633
84
+ diffsynth_engine/configs/pipeline.py,sha256=SLaxFd9mKuJgromrkXpJrsNGAGzMl51Twomc4Qo83Wc,13759
85
85
  diffsynth_engine/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
86
86
  diffsynth_engine/models/__init__.py,sha256=8Ze7cSE8InetgXWTNb0neVA2Q44K7WlE-h7O-02m2sY,119
87
87
  diffsynth_engine/models/base.py,sha256=svao__9WH8VNcyXz5o5dzywYXDcGV0YV9IfkLzDKews,2558
88
88
  diffsynth_engine/models/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
89
- diffsynth_engine/models/basic/attention.py,sha256=mvgk8LTqFwgtPdBeRv797IZNg9k7--X9wD92Hcr188c,15682
89
+ diffsynth_engine/models/basic/attention.py,sha256=62Ar8_ydnn28F1qH9ueXtvISgNszQK3q8k14gCIXGEs,15681
90
90
  diffsynth_engine/models/basic/lora.py,sha256=Y6cBgrBsuDAP9FZz_fgK8vBi_EMg23saFIUSAsPIG-M,10670
91
91
  diffsynth_engine/models/basic/lora_nunchaku.py,sha256=7qhzGCzUIfDrwtWG0nspwdyZ7YUkaM4vMqzxZby2Zds,7510
92
92
  diffsynth_engine/models/basic/relative_position_emb.py,sha256=rCXOweZMcayVnNUVvBcYXMdhHS257B_PC8PZSWxvhNQ,2540
@@ -111,7 +111,7 @@ diffsynth_engine/models/hunyuan3d/surface_extractor.py,sha256=b15mb1N4PYwAvDk1Gu
111
111
  diffsynth_engine/models/hunyuan3d/volume_decoder.py,sha256=sgflj1a8sIerqGSalBAVQOlyiIihkLOLXYysNbulCoQ,2355
112
112
  diffsynth_engine/models/qwen_image/__init__.py,sha256=_6f0LWaoLdDvD2CsjK2OzEIQryt9efge8DFS4_GUnHQ,582
113
113
  diffsynth_engine/models/qwen_image/qwen2_5_vl.py,sha256=Eu-r-c42t_q74Qpwz21ToCGHpvSi7VND4B1EI0e-ePA,57748
114
- diffsynth_engine/models/qwen_image/qwen_image_dit.py,sha256=iJ-FinDyXa982Uao1is37bxUttyPu0Eldyd7qPJO_XQ,22582
114
+ diffsynth_engine/models/qwen_image/qwen_image_dit.py,sha256=JEyK_yOa0A5xaqlmxI3nfD7NdCaHuvLDA10aWVbnac4,24635
115
115
  diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py,sha256=LIv9X_BohKk5rcEzyl3ATLwd8MSoFX43wjkArQ68nq8,4828
116
116
  diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py,sha256=1y1BkPRrX4_RioKjM09D9f9PK9neug1nSGJka0D9bvM,13516
117
117
  diffsynth_engine/models/qwen_image/qwen_image_vae.py,sha256=eO7f4YqiYXfw7NncBNFTu-xEvdJ5uKY-SnfP15QY0tE,38443
@@ -146,7 +146,7 @@ diffsynth_engine/pipelines/__init__.py,sha256=jh-4LSJ0vqlXiT8BgFgRIQxuAr2atEPyHr
146
146
  diffsynth_engine/pipelines/base.py,sha256=ShRiX5MY6bUkRKfuGrA1aalAqeHyeZxhzT87Mwc30b4,17231
147
147
  diffsynth_engine/pipelines/flux_image.py,sha256=L0ggxpthLD8a5-zdPHu9z668uWBei9YzPb4PFVypDNU,50707
148
148
  diffsynth_engine/pipelines/hunyuan3d_shape.py,sha256=TNV0Wr09Dj2bzzlpua9WioCClOj3YiLfE6utI9aWL8A,8164
149
- diffsynth_engine/pipelines/qwen_image.py,sha256=ktOirdU2ljgb6vHhXosC0tWgXI3gwvsoAtrYKYvMwzI,35719
149
+ diffsynth_engine/pipelines/qwen_image.py,sha256=lrqwF3fikgQouifb-8KwWCxQhNVZard_7buoJqxHD7s,35759
150
150
  diffsynth_engine/pipelines/sd_image.py,sha256=nr-Nhsnomq8CsUqhTM3i2l2zG01YjwXdfRXgr_bC3F0,17891
151
151
  diffsynth_engine/pipelines/sdxl_image.py,sha256=v7ZACGPb6EcBunL6e5E9jynSQjE7GQx8etEV-ZLP91g,21704
152
152
  diffsynth_engine/pipelines/utils.py,sha256=HZbJHErNJS1DhlwJKvZ9dY7Kh8Zdlsw3zE2e88TYGRY,2277
@@ -190,8 +190,8 @@ diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CD
190
190
  diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
191
191
  diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
192
192
  diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
193
- diffsynth_engine-0.6.1.dev32.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
194
- diffsynth_engine-0.6.1.dev32.dist-info/METADATA,sha256=ZEH2_1Zmgmk30J31qY1S0Ul9dD4rchav5AS3UclyCVg,1164
195
- diffsynth_engine-0.6.1.dev32.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
196
- diffsynth_engine-0.6.1.dev32.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
197
- diffsynth_engine-0.6.1.dev32.dist-info/RECORD,,
193
+ diffsynth_engine-0.6.1.dev33.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
194
+ diffsynth_engine-0.6.1.dev33.dist-info/METADATA,sha256=pgyNkuwU3lMQA66waiIU3BVtw-7zN3s8pEvinWC_LpI,1164
195
+ diffsynth_engine-0.6.1.dev33.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
196
+ diffsynth_engine-0.6.1.dev33.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
197
+ diffsynth_engine-0.6.1.dev33.dist-info/RECORD,,