diffsynth-engine 0.6.1.dev41__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -307,6 +307,8 @@ class ZImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig,
307
307
  vae_dtype: torch.dtype = torch.bfloat16
308
308
  encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
309
309
  encoder_dtype: torch.dtype = torch.bfloat16
310
+ image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
311
+ image_encoder_dtype: torch.dtype = torch.bfloat16
310
312
 
311
313
  @classmethod
312
314
  def basic_config(
@@ -314,6 +316,7 @@ class ZImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig,
314
316
  model_path: str | os.PathLike | List[str | os.PathLike],
315
317
  encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
316
318
  vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
319
+ image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
317
320
  device: str = "cuda",
318
321
  parallelism: int = 1,
319
322
  offload_mode: Optional[str] = None,
@@ -324,6 +327,7 @@ class ZImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig,
324
327
  device=device,
325
328
  encoder_path=encoder_path,
326
329
  vae_path=vae_path,
330
+ image_encoder_path=image_encoder_path,
327
331
  parallelism=parallelism,
328
332
  use_cfg_parallel=True if parallelism > 1 else False,
329
333
  use_fsdp=True if parallelism > 1 else False,
@@ -391,6 +395,7 @@ class ZImageStateDicts:
391
395
  model: Dict[str, torch.Tensor]
392
396
  encoder: Dict[str, torch.Tensor]
393
397
  vae: Dict[str, torch.Tensor]
398
+ image_encoder: Optional[Dict[str, torch.Tensor]] = None
394
399
 
395
400
 
396
401
  def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig):
@@ -3,9 +3,13 @@ from .qwen3 import (
3
3
  Qwen3Config,
4
4
  )
5
5
  from .z_image_dit import ZImageDiT
6
+ from .z_image_dit_omni_base import ZImageOmniBaseDiT
7
+ from .siglip import Siglip2ImageEncoder
6
8
 
7
9
  __all__ = [
8
10
  "Qwen3Model",
9
11
  "Qwen3Config",
10
12
  "ZImageDiT",
13
+ "ZImageOmniBaseDiT",
14
+ "Siglip2ImageEncoder",
11
15
  ]
@@ -0,0 +1,72 @@
1
+ from transformers import Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
2
+ import torch
3
+
4
+
5
+ class Siglip2ImageEncoder(Siglip2VisionModel):
6
+ def __init__(self, **kwargs):
7
+ config = Siglip2VisionConfig(
8
+ attention_dropout = 0.0,
9
+ dtype = "bfloat16",
10
+ hidden_act = "gelu_pytorch_tanh",
11
+ hidden_size = 1152,
12
+ intermediate_size = 4304,
13
+ layer_norm_eps = 1e-06,
14
+ model_type = "siglip2_vision_model",
15
+ num_attention_heads = 16,
16
+ num_channels = 3,
17
+ num_hidden_layers = 27,
18
+ num_patches = 256,
19
+ patch_size = 16,
20
+ transformers_version = "4.57.1"
21
+ )
22
+ super().__init__(config)
23
+ self.processor = Siglip2ImageProcessorFast(
24
+ **{
25
+ "data_format": "channels_first",
26
+ "default_to_square": True,
27
+ "device": None,
28
+ "disable_grouping": None,
29
+ "do_convert_rgb": None,
30
+ "do_normalize": True,
31
+ "do_pad": None,
32
+ "do_rescale": True,
33
+ "do_resize": True,
34
+ "image_mean": [
35
+ 0.5,
36
+ 0.5,
37
+ 0.5
38
+ ],
39
+ "image_processor_type": "Siglip2ImageProcessorFast",
40
+ "image_std": [
41
+ 0.5,
42
+ 0.5,
43
+ 0.5
44
+ ],
45
+ "input_data_format": None,
46
+ "max_num_patches": 256,
47
+ "pad_size": None,
48
+ "patch_size": 16,
49
+ "processor_class": "Siglip2Processor",
50
+ "resample": 2,
51
+ "rescale_factor": 0.00392156862745098,
52
+ "return_tensors": None,
53
+ }
54
+ )
55
+
56
+ def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
57
+ siglip_inputs = self.processor(images=[image], return_tensors="pt").to(device)
58
+ shape = siglip_inputs.spatial_shapes[0]
59
+ hidden_state = super().forward(**siglip_inputs).last_hidden_state
60
+ B, N, C = hidden_state.shape
61
+ hidden_state = hidden_state[:, : shape[0] * shape[1]]
62
+ hidden_state = hidden_state.view(shape[0], shape[1], C)
63
+ hidden_state = hidden_state.to(torch_dtype)
64
+ return hidden_state
65
+
66
+ @classmethod
67
+ def from_state_dict(cls, state_dict, device: str, dtype: torch.dtype):
68
+ model = cls()
69
+ model.requires_grad_(False)
70
+ model.load_state_dict(state_dict, assign=True)
71
+ model.to(device=device, dtype=dtype, non_blocking=True)
72
+ return model