diffsynth-engine 0.6.1.dev41__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/configs/pipeline.py +5 -0
- diffsynth_engine/models/z_image/__init__.py +4 -0
- diffsynth_engine/models/z_image/siglip.py +72 -0
- diffsynth_engine/models/z_image/z_image_dit_omni_base.py +1132 -0
- diffsynth_engine/pipelines/__init__.py +2 -0
- diffsynth_engine/pipelines/z_image_omni_base.py +503 -0
- {diffsynth_engine-0.6.1.dev41.dist-info → diffsynth_engine-0.7.0.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev41.dist-info → diffsynth_engine-0.7.0.dist-info}/RECORD +11 -8
- {diffsynth_engine-0.6.1.dev41.dist-info → diffsynth_engine-0.7.0.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev41.dist-info → diffsynth_engine-0.7.0.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev41.dist-info → diffsynth_engine-0.7.0.dist-info}/top_level.txt +0 -0
|
@@ -307,6 +307,8 @@ class ZImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig,
|
|
|
307
307
|
vae_dtype: torch.dtype = torch.bfloat16
|
|
308
308
|
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
309
309
|
encoder_dtype: torch.dtype = torch.bfloat16
|
|
310
|
+
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
311
|
+
image_encoder_dtype: torch.dtype = torch.bfloat16
|
|
310
312
|
|
|
311
313
|
@classmethod
|
|
312
314
|
def basic_config(
|
|
@@ -314,6 +316,7 @@ class ZImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig,
|
|
|
314
316
|
model_path: str | os.PathLike | List[str | os.PathLike],
|
|
315
317
|
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
|
|
316
318
|
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
|
|
319
|
+
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
|
|
317
320
|
device: str = "cuda",
|
|
318
321
|
parallelism: int = 1,
|
|
319
322
|
offload_mode: Optional[str] = None,
|
|
@@ -324,6 +327,7 @@ class ZImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig,
|
|
|
324
327
|
device=device,
|
|
325
328
|
encoder_path=encoder_path,
|
|
326
329
|
vae_path=vae_path,
|
|
330
|
+
image_encoder_path=image_encoder_path,
|
|
327
331
|
parallelism=parallelism,
|
|
328
332
|
use_cfg_parallel=True if parallelism > 1 else False,
|
|
329
333
|
use_fsdp=True if parallelism > 1 else False,
|
|
@@ -391,6 +395,7 @@ class ZImageStateDicts:
|
|
|
391
395
|
model: Dict[str, torch.Tensor]
|
|
392
396
|
encoder: Dict[str, torch.Tensor]
|
|
393
397
|
vae: Dict[str, torch.Tensor]
|
|
398
|
+
image_encoder: Optional[Dict[str, torch.Tensor]] = None
|
|
394
399
|
|
|
395
400
|
|
|
396
401
|
def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig):
|
|
@@ -3,9 +3,13 @@ from .qwen3 import (
|
|
|
3
3
|
Qwen3Config,
|
|
4
4
|
)
|
|
5
5
|
from .z_image_dit import ZImageDiT
|
|
6
|
+
from .z_image_dit_omni_base import ZImageOmniBaseDiT
|
|
7
|
+
from .siglip import Siglip2ImageEncoder
|
|
6
8
|
|
|
7
9
|
__all__ = [
|
|
8
10
|
"Qwen3Model",
|
|
9
11
|
"Qwen3Config",
|
|
10
12
|
"ZImageDiT",
|
|
13
|
+
"ZImageOmniBaseDiT",
|
|
14
|
+
"Siglip2ImageEncoder",
|
|
11
15
|
]
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from transformers import Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Siglip2ImageEncoder(Siglip2VisionModel):
|
|
6
|
+
def __init__(self, **kwargs):
|
|
7
|
+
config = Siglip2VisionConfig(
|
|
8
|
+
attention_dropout = 0.0,
|
|
9
|
+
dtype = "bfloat16",
|
|
10
|
+
hidden_act = "gelu_pytorch_tanh",
|
|
11
|
+
hidden_size = 1152,
|
|
12
|
+
intermediate_size = 4304,
|
|
13
|
+
layer_norm_eps = 1e-06,
|
|
14
|
+
model_type = "siglip2_vision_model",
|
|
15
|
+
num_attention_heads = 16,
|
|
16
|
+
num_channels = 3,
|
|
17
|
+
num_hidden_layers = 27,
|
|
18
|
+
num_patches = 256,
|
|
19
|
+
patch_size = 16,
|
|
20
|
+
transformers_version = "4.57.1"
|
|
21
|
+
)
|
|
22
|
+
super().__init__(config)
|
|
23
|
+
self.processor = Siglip2ImageProcessorFast(
|
|
24
|
+
**{
|
|
25
|
+
"data_format": "channels_first",
|
|
26
|
+
"default_to_square": True,
|
|
27
|
+
"device": None,
|
|
28
|
+
"disable_grouping": None,
|
|
29
|
+
"do_convert_rgb": None,
|
|
30
|
+
"do_normalize": True,
|
|
31
|
+
"do_pad": None,
|
|
32
|
+
"do_rescale": True,
|
|
33
|
+
"do_resize": True,
|
|
34
|
+
"image_mean": [
|
|
35
|
+
0.5,
|
|
36
|
+
0.5,
|
|
37
|
+
0.5
|
|
38
|
+
],
|
|
39
|
+
"image_processor_type": "Siglip2ImageProcessorFast",
|
|
40
|
+
"image_std": [
|
|
41
|
+
0.5,
|
|
42
|
+
0.5,
|
|
43
|
+
0.5
|
|
44
|
+
],
|
|
45
|
+
"input_data_format": None,
|
|
46
|
+
"max_num_patches": 256,
|
|
47
|
+
"pad_size": None,
|
|
48
|
+
"patch_size": 16,
|
|
49
|
+
"processor_class": "Siglip2Processor",
|
|
50
|
+
"resample": 2,
|
|
51
|
+
"rescale_factor": 0.00392156862745098,
|
|
52
|
+
"return_tensors": None,
|
|
53
|
+
}
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
|
57
|
+
siglip_inputs = self.processor(images=[image], return_tensors="pt").to(device)
|
|
58
|
+
shape = siglip_inputs.spatial_shapes[0]
|
|
59
|
+
hidden_state = super().forward(**siglip_inputs).last_hidden_state
|
|
60
|
+
B, N, C = hidden_state.shape
|
|
61
|
+
hidden_state = hidden_state[:, : shape[0] * shape[1]]
|
|
62
|
+
hidden_state = hidden_state.view(shape[0], shape[1], C)
|
|
63
|
+
hidden_state = hidden_state.to(torch_dtype)
|
|
64
|
+
return hidden_state
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def from_state_dict(cls, state_dict, device: str, dtype: torch.dtype):
|
|
68
|
+
model = cls()
|
|
69
|
+
model.requires_grad_(False)
|
|
70
|
+
model.load_state_dict(state_dict, assign=True)
|
|
71
|
+
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
72
|
+
return model
|