diffsynth-engine 0.6.1.dev33__py3-none-any.whl → 0.6.1.dev35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. diffsynth_engine/__init__.py +4 -0
  2. diffsynth_engine/conf/models/z_image/qwen3_config.json +30 -0
  3. diffsynth_engine/conf/tokenizers/z_image/tokenizer/merges.txt +151388 -0
  4. diffsynth_engine/conf/tokenizers/z_image/tokenizer/tokenizer.json +757480 -0
  5. diffsynth_engine/conf/tokenizers/z_image/tokenizer/tokenizer_config.json +239 -0
  6. diffsynth_engine/conf/tokenizers/z_image/tokenizer/vocab.json +1 -0
  7. diffsynth_engine/configs/__init__.py +4 -0
  8. diffsynth_engine/configs/pipeline.py +44 -1
  9. diffsynth_engine/models/basic/attention.py +2 -2
  10. diffsynth_engine/models/qwen_image/qwen_image_dit.py +5 -5
  11. diffsynth_engine/models/qwen_image/qwen_image_vae.py +0 -1
  12. diffsynth_engine/models/z_image/__init__.py +11 -0
  13. diffsynth_engine/models/z_image/qwen3.py +124 -0
  14. diffsynth_engine/models/z_image/z_image_dit.py +602 -0
  15. diffsynth_engine/pipelines/__init__.py +2 -0
  16. diffsynth_engine/pipelines/qwen_image.py +4 -3
  17. diffsynth_engine/pipelines/z_image.py +377 -0
  18. diffsynth_engine/utils/constants.py +3 -0
  19. diffsynth_engine/utils/process_group.py +1 -1
  20. {diffsynth_engine-0.6.1.dev33.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/METADATA +1 -1
  21. {diffsynth_engine-0.6.1.dev33.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/RECORD +24 -15
  22. {diffsynth_engine-0.6.1.dev33.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/WHEEL +0 -0
  23. {diffsynth_engine-0.6.1.dev33.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/licenses/LICENSE +0 -0
  24. {diffsynth_engine-0.6.1.dev33.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ from .pipeline import (
10
10
  WanSpeech2VideoPipelineConfig,
11
11
  QwenImagePipelineConfig,
12
12
  HunyuanPipelineConfig,
13
+ ZImagePipelineConfig,
13
14
  BaseStateDicts,
14
15
  SDStateDicts,
15
16
  SDXLStateDicts,
@@ -17,6 +18,7 @@ from .pipeline import (
17
18
  WanStateDicts,
18
19
  WanS2VStateDicts,
19
20
  QwenImageStateDicts,
21
+ ZImageStateDicts,
20
22
  AttnImpl,
21
23
  SpargeAttentionParams,
22
24
  VideoSparseAttentionParams,
@@ -41,6 +43,7 @@ __all__ = [
41
43
  "WanSpeech2VideoPipelineConfig",
42
44
  "QwenImagePipelineConfig",
43
45
  "HunyuanPipelineConfig",
46
+ "ZImagePipelineConfig",
44
47
  "BaseStateDicts",
45
48
  "SDStateDicts",
46
49
  "SDXLStateDicts",
@@ -48,6 +51,7 @@ __all__ = [
48
51
  "WanStateDicts",
49
52
  "WanS2VStateDicts",
50
53
  "QwenImageStateDicts",
54
+ "ZImageStateDicts",
51
55
  "AttnImpl",
52
56
  "SpargeAttentionParams",
53
57
  "VideoSparseAttentionParams",
@@ -298,6 +298,42 @@ class HunyuanPipelineConfig(BaseConfig):
298
298
  image_encoder_dtype: torch.dtype = torch.float16
299
299
 
300
300
 
301
+ @dataclass
302
+ class ZImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
303
+ model_path: str | os.PathLike | List[str | os.PathLike]
304
+ model_dtype: torch.dtype = torch.float16
305
+ vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
306
+ vae_dtype: torch.dtype = torch.float16
307
+ encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
308
+ encoder_dtype: torch.dtype = torch.float16
309
+
310
+ @classmethod
311
+ def basic_config(
312
+ cls,
313
+ model_path: str | os.PathLike | List[str | os.PathLike],
314
+ encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
315
+ vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
316
+ device: str = "cuda",
317
+ parallelism: int = 1,
318
+ offload_mode: Optional[str] = None,
319
+ offload_to_disk: bool = False,
320
+ ) -> "ZImagePipelineConfig":
321
+ return cls(
322
+ model_path=model_path,
323
+ device=device,
324
+ encoder_path=encoder_path,
325
+ vae_path=vae_path,
326
+ parallelism=parallelism,
327
+ use_cfg_parallel=True if parallelism > 1 else False,
328
+ use_fsdp=True if parallelism > 1 else False,
329
+ offload_mode=offload_mode,
330
+ offload_to_disk=offload_to_disk,
331
+ )
332
+
333
+ def __post_init__(self):
334
+ init_parallel_config(self)
335
+
336
+
301
337
  @dataclass
302
338
  class BaseStateDicts:
303
339
  pass
@@ -349,7 +385,14 @@ class QwenImageStateDicts:
349
385
  vae: Dict[str, torch.Tensor]
350
386
 
351
387
 
352
- def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig):
388
+ @dataclass
389
+ class ZImageStateDicts:
390
+ model: Dict[str, torch.Tensor]
391
+ encoder: Dict[str, torch.Tensor]
392
+ vae: Dict[str, torch.Tensor]
393
+
394
+
395
+ def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig):
353
396
  assert config.parallelism in (1, 2, 4, 8), "parallelism must be 1, 2, 4 or 8"
354
397
  config.batch_cfg = True if config.parallelism > 1 and config.use_cfg_parallel else config.batch_cfg
355
398
 
@@ -343,7 +343,7 @@ def long_context_attention(
343
343
  f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
344
344
  )
345
345
  if SDPA_AVAILABLE:
346
- return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
346
+ return LongContextAttention(attn_type=AttnType.TORCH_EFFICIENT)(q, k, v, softmax_scale=scale)
347
347
  if FLASH_ATTN_2_AVAILABLE:
348
348
  return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
349
349
  raise ValueError("No available long context attention implementation")
@@ -379,7 +379,7 @@ def long_context_attention(
379
379
  if attn_impl == "fa2":
380
380
  return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
381
381
  if attn_impl == "sdpa":
382
- return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
382
+ return LongContextAttention(attn_type=AttnType.TORCH_EFFICIENT)(q, k, v, softmax_scale=scale)
383
383
  if attn_impl == "sage":
384
384
  return LongContextAttention(attn_type=AttnType.SAGE_AUTO)(q, k, v, softmax_scale=scale)
385
385
  if attn_impl == "sparge":
@@ -286,16 +286,15 @@ class QwenImageTransformerBlock(nn.Module):
286
286
  shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:]
287
287
  scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
288
288
  gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
289
- index_expanded = index.unsqueeze(-1)
290
289
  shift_0_exp = shift_0.unsqueeze(1)
291
290
  shift_1_exp = shift_1.unsqueeze(1)
292
291
  scale_0_exp = scale_0.unsqueeze(1)
293
292
  scale_1_exp = scale_1.unsqueeze(1)
294
293
  gate_0_exp = gate_0.unsqueeze(1)
295
294
  gate_1_exp = gate_1.unsqueeze(1)
296
- shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
297
- scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
298
- gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
295
+ shift_result = torch.where(index == 0, shift_0_exp, shift_1_exp)
296
+ scale_result = torch.where(index == 0, scale_0_exp, scale_1_exp)
297
+ gate_result = torch.where(index == 0, gate_0_exp, gate_1_exp)
299
298
  else:
300
299
  shift_result = shift.unsqueeze(1)
301
300
  scale_result = scale.unsqueeze(1)
@@ -514,6 +513,7 @@ class QwenImageDiT(PreTrainedModel):
514
513
  device=timestep.device,
515
514
  dtype=torch.int,
516
515
  )
516
+ modulate_index = modulate_index.unsqueeze(-1)
517
517
  rotary_emb = self.pos_embed(video_fhw, text_seq_len, image.device)
518
518
 
519
519
  image = self.img_in(image)
@@ -535,7 +535,7 @@ class QwenImageDiT(PreTrainedModel):
535
535
 
536
536
  # warning: Eligen does not work with sequence parallel because long context attention does not support attention masks
537
537
  img_freqs, txt_freqs = rotary_emb
538
- with sequence_parallel((image, text, img_freqs, txt_freqs), seq_dims=(1, 1, 0, 0)):
538
+ with sequence_parallel((image, text, img_freqs, txt_freqs, modulate_index), seq_dims=(1, 1, 0, 0, 1)):
539
539
  rotary_emb = (img_freqs, txt_freqs)
540
540
  for block in self.transformer_blocks:
541
541
  text, image = block(
@@ -685,7 +685,6 @@ class VideoVAE(nn.Module):
685
685
  x = patchify(x, patch_size=2 if self.in_channels == 12 else 1)
686
686
  t = x.shape[2]
687
687
  iter_ = 1 + (t - 1) // 4
688
-
689
688
  for i in range(iter_):
690
689
  if i == 0:
691
690
  out = self.encoder(x[:, :, :1, :, :], feat_cache=feat_cache)
@@ -0,0 +1,11 @@
1
+ from .qwen3 import (
2
+ Qwen3Model,
3
+ Qwen3Config,
4
+ )
5
+ from .z_image_dit import ZImageDiT
6
+
7
+ __all__ = [
8
+ "Qwen3Model",
9
+ "Qwen3Config",
10
+ "ZImageDiT",
11
+ ]
@@ -0,0 +1,124 @@
1
+ # modified from transformers.models.qwen3.modeling_qwen3
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Dict, Tuple, Optional
5
+
6
+ from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
7
+ from diffsynth_engine.utils.cache import Cache, DynamicCache
8
+ from diffsynth_engine.utils import logging
9
+
10
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer, Qwen3RMSNorm, Qwen3RotaryEmbedding
11
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
12
+ from transformers.masking_utils import create_causal_mask
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+
17
+ class Qwen3ModelStateDictConverter(StateDictConverter):
18
+ def __init__(self):
19
+ super().__init__()
20
+
21
+ def _from_diffusers(self, state_dict):
22
+ new_state_dict = {}
23
+ for key, param in state_dict.items():
24
+ if key.startswith("model."):
25
+ key = key[len("model.") :]
26
+ new_state_dict[key] = param
27
+ return new_state_dict
28
+
29
+ def convert(self, state_dict):
30
+ return self._from_diffusers(state_dict)
31
+
32
+
33
+ class Qwen3Model(PreTrainedModel):
34
+ converter = Qwen3ModelStateDictConverter()
35
+
36
+ def __init__(self, config: Qwen3Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
37
+ super().__init__()
38
+ # for causal_mask
39
+ config._attn_implementation = "sdpa"
40
+ self.config = config
41
+
42
+ self.embed_tokens = nn.Embedding(
43
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id, device=device, dtype=dtype
44
+ )
45
+ self.layers = nn.ModuleList(
46
+ [
47
+ Qwen3DecoderLayer(layer_idx=layer_idx, config=config).to(device=device, dtype=dtype)
48
+ for layer_idx in range(config.num_hidden_layers)
49
+ ]
50
+ )
51
+ self.norm = Qwen3RMSNorm(config.hidden_size, config.rms_norm_eps).to(device=device, dtype=dtype)
52
+ self.rotary_emb = Qwen3RotaryEmbedding(config=config)
53
+
54
+ @classmethod
55
+ def from_state_dict(
56
+ cls,
57
+ state_dict: Dict[str, torch.Tensor],
58
+ config: Qwen3Config,
59
+ device: str = "cuda:0",
60
+ dtype: torch.dtype = torch.bfloat16,
61
+ ):
62
+ model = cls(config=config, device="meta", dtype=dtype)
63
+ model.requires_grad_(False)
64
+ model.load_state_dict(state_dict, assign=True)
65
+ model.to(device=device, dtype=dtype, non_blocking=True)
66
+ return model
67
+
68
+ def forward(
69
+ self,
70
+ input_ids: Optional[torch.LongTensor] = None,
71
+ inputs_embeds: Optional[torch.FloatTensor] = None,
72
+ position_ids: Optional[torch.LongTensor] = None,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ use_cache: Optional[bool] = None,
75
+ past_key_values: Optional[Cache] = None,
76
+ cache_position: Optional[torch.LongTensor] = None,
77
+ **kwargs,
78
+ ) -> Tuple[torch.Tensor, Optional[Cache]]:
79
+ all_hidden_states = []
80
+ if (input_ids is None) ^ (inputs_embeds is not None):
81
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
82
+
83
+ if inputs_embeds is None:
84
+ inputs_embeds = self.embed_tokens(input_ids)
85
+
86
+ if use_cache and past_key_values is None:
87
+ past_key_values = DynamicCache()
88
+
89
+ if cache_position is None:
90
+ seq_len = inputs_embeds.size(1)
91
+ cache_position = torch.arange(seq_len, device=inputs_embeds.device)
92
+
93
+ if position_ids is None:
94
+ position_ids = cache_position.unsqueeze(0)
95
+
96
+ causal_mask = create_causal_mask(
97
+ config=self.config,
98
+ input_embeds=inputs_embeds,
99
+ attention_mask=attention_mask,
100
+ cache_position=cache_position,
101
+ past_key_values=None,
102
+ position_ids=position_ids,
103
+ )
104
+ hidden_states = inputs_embeds
105
+
106
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
107
+ all_hidden_states.append(hidden_states)
108
+ for decoder_layer in self.layers:
109
+ hidden_states = decoder_layer(
110
+ hidden_states,
111
+ position_embeddings=position_embeddings,
112
+ position_ids=position_ids,
113
+ attention_mask=causal_mask,
114
+ past_key_values=past_key_values,
115
+ cache_position=cache_position,
116
+ )
117
+ all_hidden_states.append(hidden_states)
118
+
119
+ hidden_states = self.norm(hidden_states)
120
+ return {
121
+ "last_hidden_state": hidden_states,
122
+ "past_key_values": past_key_values,
123
+ "hidden_states": all_hidden_states,
124
+ }