diffsynth-engine 0.6.1.dev34__py3-none-any.whl → 0.6.1.dev35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +4 -0
- diffsynth_engine/conf/models/z_image/qwen3_config.json +30 -0
- diffsynth_engine/conf/tokenizers/z_image/tokenizer/merges.txt +151388 -0
- diffsynth_engine/conf/tokenizers/z_image/tokenizer/tokenizer.json +757480 -0
- diffsynth_engine/conf/tokenizers/z_image/tokenizer/tokenizer_config.json +239 -0
- diffsynth_engine/conf/tokenizers/z_image/tokenizer/vocab.json +1 -0
- diffsynth_engine/configs/__init__.py +4 -0
- diffsynth_engine/configs/pipeline.py +44 -1
- diffsynth_engine/models/z_image/__init__.py +11 -0
- diffsynth_engine/models/z_image/qwen3.py +124 -0
- diffsynth_engine/models/z_image/z_image_dit.py +602 -0
- diffsynth_engine/pipelines/__init__.py +2 -0
- diffsynth_engine/pipelines/z_image.py +377 -0
- diffsynth_engine/utils/constants.py +3 -0
- diffsynth_engine/utils/process_group.py +1 -1
- {diffsynth_engine-0.6.1.dev34.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev34.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/RECORD +20 -11
- {diffsynth_engine-0.6.1.dev34.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev34.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev34.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/top_level.txt +0 -0
|
@@ -10,6 +10,7 @@ from .pipeline import (
|
|
|
10
10
|
WanSpeech2VideoPipelineConfig,
|
|
11
11
|
QwenImagePipelineConfig,
|
|
12
12
|
HunyuanPipelineConfig,
|
|
13
|
+
ZImagePipelineConfig,
|
|
13
14
|
BaseStateDicts,
|
|
14
15
|
SDStateDicts,
|
|
15
16
|
SDXLStateDicts,
|
|
@@ -17,6 +18,7 @@ from .pipeline import (
|
|
|
17
18
|
WanStateDicts,
|
|
18
19
|
WanS2VStateDicts,
|
|
19
20
|
QwenImageStateDicts,
|
|
21
|
+
ZImageStateDicts,
|
|
20
22
|
AttnImpl,
|
|
21
23
|
SpargeAttentionParams,
|
|
22
24
|
VideoSparseAttentionParams,
|
|
@@ -41,6 +43,7 @@ __all__ = [
|
|
|
41
43
|
"WanSpeech2VideoPipelineConfig",
|
|
42
44
|
"QwenImagePipelineConfig",
|
|
43
45
|
"HunyuanPipelineConfig",
|
|
46
|
+
"ZImagePipelineConfig",
|
|
44
47
|
"BaseStateDicts",
|
|
45
48
|
"SDStateDicts",
|
|
46
49
|
"SDXLStateDicts",
|
|
@@ -48,6 +51,7 @@ __all__ = [
|
|
|
48
51
|
"WanStateDicts",
|
|
49
52
|
"WanS2VStateDicts",
|
|
50
53
|
"QwenImageStateDicts",
|
|
54
|
+
"ZImageStateDicts",
|
|
51
55
|
"AttnImpl",
|
|
52
56
|
"SpargeAttentionParams",
|
|
53
57
|
"VideoSparseAttentionParams",
|
|
@@ -298,6 +298,42 @@ class HunyuanPipelineConfig(BaseConfig):
|
|
|
298
298
|
image_encoder_dtype: torch.dtype = torch.float16
|
|
299
299
|
|
|
300
300
|
|
|
301
|
+
@dataclass
|
|
302
|
+
class ZImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
|
|
303
|
+
model_path: str | os.PathLike | List[str | os.PathLike]
|
|
304
|
+
model_dtype: torch.dtype = torch.float16
|
|
305
|
+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
306
|
+
vae_dtype: torch.dtype = torch.float16
|
|
307
|
+
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
308
|
+
encoder_dtype: torch.dtype = torch.float16
|
|
309
|
+
|
|
310
|
+
@classmethod
|
|
311
|
+
def basic_config(
|
|
312
|
+
cls,
|
|
313
|
+
model_path: str | os.PathLike | List[str | os.PathLike],
|
|
314
|
+
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
|
|
315
|
+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
|
|
316
|
+
device: str = "cuda",
|
|
317
|
+
parallelism: int = 1,
|
|
318
|
+
offload_mode: Optional[str] = None,
|
|
319
|
+
offload_to_disk: bool = False,
|
|
320
|
+
) -> "ZImagePipelineConfig":
|
|
321
|
+
return cls(
|
|
322
|
+
model_path=model_path,
|
|
323
|
+
device=device,
|
|
324
|
+
encoder_path=encoder_path,
|
|
325
|
+
vae_path=vae_path,
|
|
326
|
+
parallelism=parallelism,
|
|
327
|
+
use_cfg_parallel=True if parallelism > 1 else False,
|
|
328
|
+
use_fsdp=True if parallelism > 1 else False,
|
|
329
|
+
offload_mode=offload_mode,
|
|
330
|
+
offload_to_disk=offload_to_disk,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
def __post_init__(self):
|
|
334
|
+
init_parallel_config(self)
|
|
335
|
+
|
|
336
|
+
|
|
301
337
|
@dataclass
|
|
302
338
|
class BaseStateDicts:
|
|
303
339
|
pass
|
|
@@ -349,7 +385,14 @@ class QwenImageStateDicts:
|
|
|
349
385
|
vae: Dict[str, torch.Tensor]
|
|
350
386
|
|
|
351
387
|
|
|
352
|
-
|
|
388
|
+
@dataclass
|
|
389
|
+
class ZImageStateDicts:
|
|
390
|
+
model: Dict[str, torch.Tensor]
|
|
391
|
+
encoder: Dict[str, torch.Tensor]
|
|
392
|
+
vae: Dict[str, torch.Tensor]
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig):
|
|
353
396
|
assert config.parallelism in (1, 2, 4, 8), "parallelism must be 1, 2, 4 or 8"
|
|
354
397
|
config.batch_cfg = True if config.parallelism > 1 and config.use_cfg_parallel else config.batch_cfg
|
|
355
398
|
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# modified from transformers.models.qwen3.modeling_qwen3
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from typing import Dict, Tuple, Optional
|
|
5
|
+
|
|
6
|
+
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
|
|
7
|
+
from diffsynth_engine.utils.cache import Cache, DynamicCache
|
|
8
|
+
from diffsynth_engine.utils import logging
|
|
9
|
+
|
|
10
|
+
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer, Qwen3RMSNorm, Qwen3RotaryEmbedding
|
|
11
|
+
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
|
|
12
|
+
from transformers.masking_utils import create_causal_mask
|
|
13
|
+
|
|
14
|
+
logger = logging.get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Qwen3ModelStateDictConverter(StateDictConverter):
|
|
18
|
+
def __init__(self):
|
|
19
|
+
super().__init__()
|
|
20
|
+
|
|
21
|
+
def _from_diffusers(self, state_dict):
|
|
22
|
+
new_state_dict = {}
|
|
23
|
+
for key, param in state_dict.items():
|
|
24
|
+
if key.startswith("model."):
|
|
25
|
+
key = key[len("model.") :]
|
|
26
|
+
new_state_dict[key] = param
|
|
27
|
+
return new_state_dict
|
|
28
|
+
|
|
29
|
+
def convert(self, state_dict):
|
|
30
|
+
return self._from_diffusers(state_dict)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Qwen3Model(PreTrainedModel):
|
|
34
|
+
converter = Qwen3ModelStateDictConverter()
|
|
35
|
+
|
|
36
|
+
def __init__(self, config: Qwen3Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
|
|
37
|
+
super().__init__()
|
|
38
|
+
# for causal_mask
|
|
39
|
+
config._attn_implementation = "sdpa"
|
|
40
|
+
self.config = config
|
|
41
|
+
|
|
42
|
+
self.embed_tokens = nn.Embedding(
|
|
43
|
+
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id, device=device, dtype=dtype
|
|
44
|
+
)
|
|
45
|
+
self.layers = nn.ModuleList(
|
|
46
|
+
[
|
|
47
|
+
Qwen3DecoderLayer(layer_idx=layer_idx, config=config).to(device=device, dtype=dtype)
|
|
48
|
+
for layer_idx in range(config.num_hidden_layers)
|
|
49
|
+
]
|
|
50
|
+
)
|
|
51
|
+
self.norm = Qwen3RMSNorm(config.hidden_size, config.rms_norm_eps).to(device=device, dtype=dtype)
|
|
52
|
+
self.rotary_emb = Qwen3RotaryEmbedding(config=config)
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def from_state_dict(
|
|
56
|
+
cls,
|
|
57
|
+
state_dict: Dict[str, torch.Tensor],
|
|
58
|
+
config: Qwen3Config,
|
|
59
|
+
device: str = "cuda:0",
|
|
60
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
61
|
+
):
|
|
62
|
+
model = cls(config=config, device="meta", dtype=dtype)
|
|
63
|
+
model.requires_grad_(False)
|
|
64
|
+
model.load_state_dict(state_dict, assign=True)
|
|
65
|
+
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
66
|
+
return model
|
|
67
|
+
|
|
68
|
+
def forward(
|
|
69
|
+
self,
|
|
70
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
71
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
72
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
73
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
74
|
+
use_cache: Optional[bool] = None,
|
|
75
|
+
past_key_values: Optional[Cache] = None,
|
|
76
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
77
|
+
**kwargs,
|
|
78
|
+
) -> Tuple[torch.Tensor, Optional[Cache]]:
|
|
79
|
+
all_hidden_states = []
|
|
80
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
81
|
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
82
|
+
|
|
83
|
+
if inputs_embeds is None:
|
|
84
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
|
85
|
+
|
|
86
|
+
if use_cache and past_key_values is None:
|
|
87
|
+
past_key_values = DynamicCache()
|
|
88
|
+
|
|
89
|
+
if cache_position is None:
|
|
90
|
+
seq_len = inputs_embeds.size(1)
|
|
91
|
+
cache_position = torch.arange(seq_len, device=inputs_embeds.device)
|
|
92
|
+
|
|
93
|
+
if position_ids is None:
|
|
94
|
+
position_ids = cache_position.unsqueeze(0)
|
|
95
|
+
|
|
96
|
+
causal_mask = create_causal_mask(
|
|
97
|
+
config=self.config,
|
|
98
|
+
input_embeds=inputs_embeds,
|
|
99
|
+
attention_mask=attention_mask,
|
|
100
|
+
cache_position=cache_position,
|
|
101
|
+
past_key_values=None,
|
|
102
|
+
position_ids=position_ids,
|
|
103
|
+
)
|
|
104
|
+
hidden_states = inputs_embeds
|
|
105
|
+
|
|
106
|
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
107
|
+
all_hidden_states.append(hidden_states)
|
|
108
|
+
for decoder_layer in self.layers:
|
|
109
|
+
hidden_states = decoder_layer(
|
|
110
|
+
hidden_states,
|
|
111
|
+
position_embeddings=position_embeddings,
|
|
112
|
+
position_ids=position_ids,
|
|
113
|
+
attention_mask=causal_mask,
|
|
114
|
+
past_key_values=past_key_values,
|
|
115
|
+
cache_position=cache_position,
|
|
116
|
+
)
|
|
117
|
+
all_hidden_states.append(hidden_states)
|
|
118
|
+
|
|
119
|
+
hidden_states = self.norm(hidden_states)
|
|
120
|
+
return {
|
|
121
|
+
"last_hidden_state": hidden_states,
|
|
122
|
+
"past_key_values": past_key_values,
|
|
123
|
+
"hidden_states": all_hidden_states,
|
|
124
|
+
}
|