diffsynth-engine 0.6.1.dev33__py3-none-any.whl → 0.6.1.dev35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +4 -0
- diffsynth_engine/conf/models/z_image/qwen3_config.json +30 -0
- diffsynth_engine/conf/tokenizers/z_image/tokenizer/merges.txt +151388 -0
- diffsynth_engine/conf/tokenizers/z_image/tokenizer/tokenizer.json +757480 -0
- diffsynth_engine/conf/tokenizers/z_image/tokenizer/tokenizer_config.json +239 -0
- diffsynth_engine/conf/tokenizers/z_image/tokenizer/vocab.json +1 -0
- diffsynth_engine/configs/__init__.py +4 -0
- diffsynth_engine/configs/pipeline.py +44 -1
- diffsynth_engine/models/basic/attention.py +2 -2
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +5 -5
- diffsynth_engine/models/qwen_image/qwen_image_vae.py +0 -1
- diffsynth_engine/models/z_image/__init__.py +11 -0
- diffsynth_engine/models/z_image/qwen3.py +124 -0
- diffsynth_engine/models/z_image/z_image_dit.py +602 -0
- diffsynth_engine/pipelines/__init__.py +2 -0
- diffsynth_engine/pipelines/qwen_image.py +4 -3
- diffsynth_engine/pipelines/z_image.py +377 -0
- diffsynth_engine/utils/constants.py +3 -0
- diffsynth_engine/utils/process_group.py +1 -1
- {diffsynth_engine-0.6.1.dev33.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev33.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/RECORD +24 -15
- {diffsynth_engine-0.6.1.dev33.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev33.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev33.dist-info → diffsynth_engine-0.6.1.dev35.dist-info}/top_level.txt +0 -0
|
@@ -10,6 +10,7 @@ from .pipeline import (
|
|
|
10
10
|
WanSpeech2VideoPipelineConfig,
|
|
11
11
|
QwenImagePipelineConfig,
|
|
12
12
|
HunyuanPipelineConfig,
|
|
13
|
+
ZImagePipelineConfig,
|
|
13
14
|
BaseStateDicts,
|
|
14
15
|
SDStateDicts,
|
|
15
16
|
SDXLStateDicts,
|
|
@@ -17,6 +18,7 @@ from .pipeline import (
|
|
|
17
18
|
WanStateDicts,
|
|
18
19
|
WanS2VStateDicts,
|
|
19
20
|
QwenImageStateDicts,
|
|
21
|
+
ZImageStateDicts,
|
|
20
22
|
AttnImpl,
|
|
21
23
|
SpargeAttentionParams,
|
|
22
24
|
VideoSparseAttentionParams,
|
|
@@ -41,6 +43,7 @@ __all__ = [
|
|
|
41
43
|
"WanSpeech2VideoPipelineConfig",
|
|
42
44
|
"QwenImagePipelineConfig",
|
|
43
45
|
"HunyuanPipelineConfig",
|
|
46
|
+
"ZImagePipelineConfig",
|
|
44
47
|
"BaseStateDicts",
|
|
45
48
|
"SDStateDicts",
|
|
46
49
|
"SDXLStateDicts",
|
|
@@ -48,6 +51,7 @@ __all__ = [
|
|
|
48
51
|
"WanStateDicts",
|
|
49
52
|
"WanS2VStateDicts",
|
|
50
53
|
"QwenImageStateDicts",
|
|
54
|
+
"ZImageStateDicts",
|
|
51
55
|
"AttnImpl",
|
|
52
56
|
"SpargeAttentionParams",
|
|
53
57
|
"VideoSparseAttentionParams",
|
|
@@ -298,6 +298,42 @@ class HunyuanPipelineConfig(BaseConfig):
|
|
|
298
298
|
image_encoder_dtype: torch.dtype = torch.float16
|
|
299
299
|
|
|
300
300
|
|
|
301
|
+
@dataclass
|
|
302
|
+
class ZImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
|
|
303
|
+
model_path: str | os.PathLike | List[str | os.PathLike]
|
|
304
|
+
model_dtype: torch.dtype = torch.float16
|
|
305
|
+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
306
|
+
vae_dtype: torch.dtype = torch.float16
|
|
307
|
+
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
308
|
+
encoder_dtype: torch.dtype = torch.float16
|
|
309
|
+
|
|
310
|
+
@classmethod
|
|
311
|
+
def basic_config(
|
|
312
|
+
cls,
|
|
313
|
+
model_path: str | os.PathLike | List[str | os.PathLike],
|
|
314
|
+
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
|
|
315
|
+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
|
|
316
|
+
device: str = "cuda",
|
|
317
|
+
parallelism: int = 1,
|
|
318
|
+
offload_mode: Optional[str] = None,
|
|
319
|
+
offload_to_disk: bool = False,
|
|
320
|
+
) -> "ZImagePipelineConfig":
|
|
321
|
+
return cls(
|
|
322
|
+
model_path=model_path,
|
|
323
|
+
device=device,
|
|
324
|
+
encoder_path=encoder_path,
|
|
325
|
+
vae_path=vae_path,
|
|
326
|
+
parallelism=parallelism,
|
|
327
|
+
use_cfg_parallel=True if parallelism > 1 else False,
|
|
328
|
+
use_fsdp=True if parallelism > 1 else False,
|
|
329
|
+
offload_mode=offload_mode,
|
|
330
|
+
offload_to_disk=offload_to_disk,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
def __post_init__(self):
|
|
334
|
+
init_parallel_config(self)
|
|
335
|
+
|
|
336
|
+
|
|
301
337
|
@dataclass
|
|
302
338
|
class BaseStateDicts:
|
|
303
339
|
pass
|
|
@@ -349,7 +385,14 @@ class QwenImageStateDicts:
|
|
|
349
385
|
vae: Dict[str, torch.Tensor]
|
|
350
386
|
|
|
351
387
|
|
|
352
|
-
|
|
388
|
+
@dataclass
|
|
389
|
+
class ZImageStateDicts:
|
|
390
|
+
model: Dict[str, torch.Tensor]
|
|
391
|
+
encoder: Dict[str, torch.Tensor]
|
|
392
|
+
vae: Dict[str, torch.Tensor]
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig):
|
|
353
396
|
assert config.parallelism in (1, 2, 4, 8), "parallelism must be 1, 2, 4 or 8"
|
|
354
397
|
config.batch_cfg = True if config.parallelism > 1 and config.use_cfg_parallel else config.batch_cfg
|
|
355
398
|
|
|
@@ -343,7 +343,7 @@ def long_context_attention(
|
|
|
343
343
|
f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
|
|
344
344
|
)
|
|
345
345
|
if SDPA_AVAILABLE:
|
|
346
|
-
return LongContextAttention(attn_type=AttnType.
|
|
346
|
+
return LongContextAttention(attn_type=AttnType.TORCH_EFFICIENT)(q, k, v, softmax_scale=scale)
|
|
347
347
|
if FLASH_ATTN_2_AVAILABLE:
|
|
348
348
|
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
|
|
349
349
|
raise ValueError("No available long context attention implementation")
|
|
@@ -379,7 +379,7 @@ def long_context_attention(
|
|
|
379
379
|
if attn_impl == "fa2":
|
|
380
380
|
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
|
|
381
381
|
if attn_impl == "sdpa":
|
|
382
|
-
return LongContextAttention(attn_type=AttnType.
|
|
382
|
+
return LongContextAttention(attn_type=AttnType.TORCH_EFFICIENT)(q, k, v, softmax_scale=scale)
|
|
383
383
|
if attn_impl == "sage":
|
|
384
384
|
return LongContextAttention(attn_type=AttnType.SAGE_AUTO)(q, k, v, softmax_scale=scale)
|
|
385
385
|
if attn_impl == "sparge":
|
|
@@ -286,16 +286,15 @@ class QwenImageTransformerBlock(nn.Module):
|
|
|
286
286
|
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:]
|
|
287
287
|
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
|
|
288
288
|
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
|
|
289
|
-
index_expanded = index.unsqueeze(-1)
|
|
290
289
|
shift_0_exp = shift_0.unsqueeze(1)
|
|
291
290
|
shift_1_exp = shift_1.unsqueeze(1)
|
|
292
291
|
scale_0_exp = scale_0.unsqueeze(1)
|
|
293
292
|
scale_1_exp = scale_1.unsqueeze(1)
|
|
294
293
|
gate_0_exp = gate_0.unsqueeze(1)
|
|
295
294
|
gate_1_exp = gate_1.unsqueeze(1)
|
|
296
|
-
shift_result = torch.where(
|
|
297
|
-
scale_result = torch.where(
|
|
298
|
-
gate_result = torch.where(
|
|
295
|
+
shift_result = torch.where(index == 0, shift_0_exp, shift_1_exp)
|
|
296
|
+
scale_result = torch.where(index == 0, scale_0_exp, scale_1_exp)
|
|
297
|
+
gate_result = torch.where(index == 0, gate_0_exp, gate_1_exp)
|
|
299
298
|
else:
|
|
300
299
|
shift_result = shift.unsqueeze(1)
|
|
301
300
|
scale_result = scale.unsqueeze(1)
|
|
@@ -514,6 +513,7 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
514
513
|
device=timestep.device,
|
|
515
514
|
dtype=torch.int,
|
|
516
515
|
)
|
|
516
|
+
modulate_index = modulate_index.unsqueeze(-1)
|
|
517
517
|
rotary_emb = self.pos_embed(video_fhw, text_seq_len, image.device)
|
|
518
518
|
|
|
519
519
|
image = self.img_in(image)
|
|
@@ -535,7 +535,7 @@ class QwenImageDiT(PreTrainedModel):
|
|
|
535
535
|
|
|
536
536
|
# warning: Eligen does not work with sequence parallel because long context attention does not support attention masks
|
|
537
537
|
img_freqs, txt_freqs = rotary_emb
|
|
538
|
-
with sequence_parallel((image, text, img_freqs, txt_freqs), seq_dims=(1, 1, 0, 0)):
|
|
538
|
+
with sequence_parallel((image, text, img_freqs, txt_freqs, modulate_index), seq_dims=(1, 1, 0, 0, 1)):
|
|
539
539
|
rotary_emb = (img_freqs, txt_freqs)
|
|
540
540
|
for block in self.transformer_blocks:
|
|
541
541
|
text, image = block(
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# modified from transformers.models.qwen3.modeling_qwen3
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from typing import Dict, Tuple, Optional
|
|
5
|
+
|
|
6
|
+
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
|
|
7
|
+
from diffsynth_engine.utils.cache import Cache, DynamicCache
|
|
8
|
+
from diffsynth_engine.utils import logging
|
|
9
|
+
|
|
10
|
+
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer, Qwen3RMSNorm, Qwen3RotaryEmbedding
|
|
11
|
+
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
|
|
12
|
+
from transformers.masking_utils import create_causal_mask
|
|
13
|
+
|
|
14
|
+
logger = logging.get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Qwen3ModelStateDictConverter(StateDictConverter):
|
|
18
|
+
def __init__(self):
|
|
19
|
+
super().__init__()
|
|
20
|
+
|
|
21
|
+
def _from_diffusers(self, state_dict):
|
|
22
|
+
new_state_dict = {}
|
|
23
|
+
for key, param in state_dict.items():
|
|
24
|
+
if key.startswith("model."):
|
|
25
|
+
key = key[len("model.") :]
|
|
26
|
+
new_state_dict[key] = param
|
|
27
|
+
return new_state_dict
|
|
28
|
+
|
|
29
|
+
def convert(self, state_dict):
|
|
30
|
+
return self._from_diffusers(state_dict)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Qwen3Model(PreTrainedModel):
|
|
34
|
+
converter = Qwen3ModelStateDictConverter()
|
|
35
|
+
|
|
36
|
+
def __init__(self, config: Qwen3Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
|
|
37
|
+
super().__init__()
|
|
38
|
+
# for causal_mask
|
|
39
|
+
config._attn_implementation = "sdpa"
|
|
40
|
+
self.config = config
|
|
41
|
+
|
|
42
|
+
self.embed_tokens = nn.Embedding(
|
|
43
|
+
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id, device=device, dtype=dtype
|
|
44
|
+
)
|
|
45
|
+
self.layers = nn.ModuleList(
|
|
46
|
+
[
|
|
47
|
+
Qwen3DecoderLayer(layer_idx=layer_idx, config=config).to(device=device, dtype=dtype)
|
|
48
|
+
for layer_idx in range(config.num_hidden_layers)
|
|
49
|
+
]
|
|
50
|
+
)
|
|
51
|
+
self.norm = Qwen3RMSNorm(config.hidden_size, config.rms_norm_eps).to(device=device, dtype=dtype)
|
|
52
|
+
self.rotary_emb = Qwen3RotaryEmbedding(config=config)
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def from_state_dict(
|
|
56
|
+
cls,
|
|
57
|
+
state_dict: Dict[str, torch.Tensor],
|
|
58
|
+
config: Qwen3Config,
|
|
59
|
+
device: str = "cuda:0",
|
|
60
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
61
|
+
):
|
|
62
|
+
model = cls(config=config, device="meta", dtype=dtype)
|
|
63
|
+
model.requires_grad_(False)
|
|
64
|
+
model.load_state_dict(state_dict, assign=True)
|
|
65
|
+
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
66
|
+
return model
|
|
67
|
+
|
|
68
|
+
def forward(
|
|
69
|
+
self,
|
|
70
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
71
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
72
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
73
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
74
|
+
use_cache: Optional[bool] = None,
|
|
75
|
+
past_key_values: Optional[Cache] = None,
|
|
76
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
77
|
+
**kwargs,
|
|
78
|
+
) -> Tuple[torch.Tensor, Optional[Cache]]:
|
|
79
|
+
all_hidden_states = []
|
|
80
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
81
|
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
82
|
+
|
|
83
|
+
if inputs_embeds is None:
|
|
84
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
|
85
|
+
|
|
86
|
+
if use_cache and past_key_values is None:
|
|
87
|
+
past_key_values = DynamicCache()
|
|
88
|
+
|
|
89
|
+
if cache_position is None:
|
|
90
|
+
seq_len = inputs_embeds.size(1)
|
|
91
|
+
cache_position = torch.arange(seq_len, device=inputs_embeds.device)
|
|
92
|
+
|
|
93
|
+
if position_ids is None:
|
|
94
|
+
position_ids = cache_position.unsqueeze(0)
|
|
95
|
+
|
|
96
|
+
causal_mask = create_causal_mask(
|
|
97
|
+
config=self.config,
|
|
98
|
+
input_embeds=inputs_embeds,
|
|
99
|
+
attention_mask=attention_mask,
|
|
100
|
+
cache_position=cache_position,
|
|
101
|
+
past_key_values=None,
|
|
102
|
+
position_ids=position_ids,
|
|
103
|
+
)
|
|
104
|
+
hidden_states = inputs_embeds
|
|
105
|
+
|
|
106
|
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
107
|
+
all_hidden_states.append(hidden_states)
|
|
108
|
+
for decoder_layer in self.layers:
|
|
109
|
+
hidden_states = decoder_layer(
|
|
110
|
+
hidden_states,
|
|
111
|
+
position_embeddings=position_embeddings,
|
|
112
|
+
position_ids=position_ids,
|
|
113
|
+
attention_mask=causal_mask,
|
|
114
|
+
past_key_values=past_key_values,
|
|
115
|
+
cache_position=cache_position,
|
|
116
|
+
)
|
|
117
|
+
all_hidden_states.append(hidden_states)
|
|
118
|
+
|
|
119
|
+
hidden_states = self.norm(hidden_states)
|
|
120
|
+
return {
|
|
121
|
+
"last_hidden_state": hidden_states,
|
|
122
|
+
"past_key_values": past_key_values,
|
|
123
|
+
"hidden_states": all_hidden_states,
|
|
124
|
+
}
|