diffsynth-engine 0.6.1.dev34__py3-none-any.whl → 0.6.1.dev36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,6 +10,7 @@ from .pipeline import (
10
10
  WanSpeech2VideoPipelineConfig,
11
11
  QwenImagePipelineConfig,
12
12
  HunyuanPipelineConfig,
13
+ ZImagePipelineConfig,
13
14
  BaseStateDicts,
14
15
  SDStateDicts,
15
16
  SDXLStateDicts,
@@ -17,6 +18,7 @@ from .pipeline import (
17
18
  WanStateDicts,
18
19
  WanS2VStateDicts,
19
20
  QwenImageStateDicts,
21
+ ZImageStateDicts,
20
22
  AttnImpl,
21
23
  SpargeAttentionParams,
22
24
  VideoSparseAttentionParams,
@@ -41,6 +43,7 @@ __all__ = [
41
43
  "WanSpeech2VideoPipelineConfig",
42
44
  "QwenImagePipelineConfig",
43
45
  "HunyuanPipelineConfig",
46
+ "ZImagePipelineConfig",
44
47
  "BaseStateDicts",
45
48
  "SDStateDicts",
46
49
  "SDXLStateDicts",
@@ -48,6 +51,7 @@ __all__ = [
48
51
  "WanStateDicts",
49
52
  "WanS2VStateDicts",
50
53
  "QwenImageStateDicts",
54
+ "ZImageStateDicts",
51
55
  "AttnImpl",
52
56
  "SpargeAttentionParams",
53
57
  "VideoSparseAttentionParams",
@@ -298,6 +298,42 @@ class HunyuanPipelineConfig(BaseConfig):
298
298
  image_encoder_dtype: torch.dtype = torch.float16
299
299
 
300
300
 
301
+ @dataclass
302
+ class ZImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
303
+ model_path: str | os.PathLike | List[str | os.PathLike]
304
+ model_dtype: torch.dtype = torch.float16
305
+ vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
306
+ vae_dtype: torch.dtype = torch.float16
307
+ encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
308
+ encoder_dtype: torch.dtype = torch.float16
309
+
310
+ @classmethod
311
+ def basic_config(
312
+ cls,
313
+ model_path: str | os.PathLike | List[str | os.PathLike],
314
+ encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
315
+ vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
316
+ device: str = "cuda",
317
+ parallelism: int = 1,
318
+ offload_mode: Optional[str] = None,
319
+ offload_to_disk: bool = False,
320
+ ) -> "ZImagePipelineConfig":
321
+ return cls(
322
+ model_path=model_path,
323
+ device=device,
324
+ encoder_path=encoder_path,
325
+ vae_path=vae_path,
326
+ parallelism=parallelism,
327
+ use_cfg_parallel=True if parallelism > 1 else False,
328
+ use_fsdp=True if parallelism > 1 else False,
329
+ offload_mode=offload_mode,
330
+ offload_to_disk=offload_to_disk,
331
+ )
332
+
333
+ def __post_init__(self):
334
+ init_parallel_config(self)
335
+
336
+
301
337
  @dataclass
302
338
  class BaseStateDicts:
303
339
  pass
@@ -349,7 +385,14 @@ class QwenImageStateDicts:
349
385
  vae: Dict[str, torch.Tensor]
350
386
 
351
387
 
352
- def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig):
388
+ @dataclass
389
+ class ZImageStateDicts:
390
+ model: Dict[str, torch.Tensor]
391
+ encoder: Dict[str, torch.Tensor]
392
+ vae: Dict[str, torch.Tensor]
393
+
394
+
395
+ def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig):
353
396
  assert config.parallelism in (1, 2, 4, 8), "parallelism must be 1, 2, 4 or 8"
354
397
  config.batch_cfg = True if config.parallelism > 1 and config.use_cfg_parallel else config.batch_cfg
355
398
 
@@ -0,0 +1,11 @@
1
+ from .qwen3 import (
2
+ Qwen3Model,
3
+ Qwen3Config,
4
+ )
5
+ from .z_image_dit import ZImageDiT
6
+
7
+ __all__ = [
8
+ "Qwen3Model",
9
+ "Qwen3Config",
10
+ "ZImageDiT",
11
+ ]
@@ -0,0 +1,124 @@
1
+ # modified from transformers.models.qwen3.modeling_qwen3
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Dict, Tuple, Optional
5
+
6
+ from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
7
+ from diffsynth_engine.utils.cache import Cache, DynamicCache
8
+ from diffsynth_engine.utils import logging
9
+
10
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer, Qwen3RMSNorm, Qwen3RotaryEmbedding
11
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
12
+ from transformers.masking_utils import create_causal_mask
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+
17
+ class Qwen3ModelStateDictConverter(StateDictConverter):
18
+ def __init__(self):
19
+ super().__init__()
20
+
21
+ def _from_diffusers(self, state_dict):
22
+ new_state_dict = {}
23
+ for key, param in state_dict.items():
24
+ if key.startswith("model."):
25
+ key = key[len("model.") :]
26
+ new_state_dict[key] = param
27
+ return new_state_dict
28
+
29
+ def convert(self, state_dict):
30
+ return self._from_diffusers(state_dict)
31
+
32
+
33
+ class Qwen3Model(PreTrainedModel):
34
+ converter = Qwen3ModelStateDictConverter()
35
+
36
+ def __init__(self, config: Qwen3Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
37
+ super().__init__()
38
+ # for causal_mask
39
+ config._attn_implementation = "sdpa"
40
+ self.config = config
41
+
42
+ self.embed_tokens = nn.Embedding(
43
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id, device=device, dtype=dtype
44
+ )
45
+ self.layers = nn.ModuleList(
46
+ [
47
+ Qwen3DecoderLayer(layer_idx=layer_idx, config=config).to(device=device, dtype=dtype)
48
+ for layer_idx in range(config.num_hidden_layers)
49
+ ]
50
+ )
51
+ self.norm = Qwen3RMSNorm(config.hidden_size, config.rms_norm_eps).to(device=device, dtype=dtype)
52
+ self.rotary_emb = Qwen3RotaryEmbedding(config=config)
53
+
54
+ @classmethod
55
+ def from_state_dict(
56
+ cls,
57
+ state_dict: Dict[str, torch.Tensor],
58
+ config: Qwen3Config,
59
+ device: str = "cuda:0",
60
+ dtype: torch.dtype = torch.bfloat16,
61
+ ):
62
+ model = cls(config=config, device="meta", dtype=dtype)
63
+ model.requires_grad_(False)
64
+ model.load_state_dict(state_dict, assign=True)
65
+ model.to(device=device, dtype=dtype, non_blocking=True)
66
+ return model
67
+
68
+ def forward(
69
+ self,
70
+ input_ids: Optional[torch.LongTensor] = None,
71
+ inputs_embeds: Optional[torch.FloatTensor] = None,
72
+ position_ids: Optional[torch.LongTensor] = None,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ use_cache: Optional[bool] = None,
75
+ past_key_values: Optional[Cache] = None,
76
+ cache_position: Optional[torch.LongTensor] = None,
77
+ **kwargs,
78
+ ) -> Tuple[torch.Tensor, Optional[Cache]]:
79
+ all_hidden_states = []
80
+ if (input_ids is None) ^ (inputs_embeds is not None):
81
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
82
+
83
+ if inputs_embeds is None:
84
+ inputs_embeds = self.embed_tokens(input_ids)
85
+
86
+ if use_cache and past_key_values is None:
87
+ past_key_values = DynamicCache()
88
+
89
+ if cache_position is None:
90
+ seq_len = inputs_embeds.size(1)
91
+ cache_position = torch.arange(seq_len, device=inputs_embeds.device)
92
+
93
+ if position_ids is None:
94
+ position_ids = cache_position.unsqueeze(0)
95
+
96
+ causal_mask = create_causal_mask(
97
+ config=self.config,
98
+ input_embeds=inputs_embeds,
99
+ attention_mask=attention_mask,
100
+ cache_position=cache_position,
101
+ past_key_values=None,
102
+ position_ids=position_ids,
103
+ )
104
+ hidden_states = inputs_embeds
105
+
106
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
107
+ all_hidden_states.append(hidden_states)
108
+ for decoder_layer in self.layers:
109
+ hidden_states = decoder_layer(
110
+ hidden_states,
111
+ position_embeddings=position_embeddings,
112
+ position_ids=position_ids,
113
+ attention_mask=causal_mask,
114
+ past_key_values=past_key_values,
115
+ cache_position=cache_position,
116
+ )
117
+ all_hidden_states.append(hidden_states)
118
+
119
+ hidden_states = self.norm(hidden_states)
120
+ return {
121
+ "last_hidden_state": hidden_states,
122
+ "past_key_values": past_key_values,
123
+ "hidden_states": all_hidden_states,
124
+ }