diffsynth-engine 0.6.1.dev41__py3-none-any.whl → 0.6.1.dev42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/configs/pipeline.py +5 -0
- diffsynth_engine/models/z_image/__init__.py +4 -0
- diffsynth_engine/models/z_image/siglip.py +72 -0
- diffsynth_engine/models/z_image/z_image_dit_omni_base.py +1132 -0
- diffsynth_engine/pipelines/__init__.py +2 -0
- diffsynth_engine/pipelines/z_image_omni_base.py +503 -0
- {diffsynth_engine-0.6.1.dev41.dist-info → diffsynth_engine-0.6.1.dev42.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev41.dist-info → diffsynth_engine-0.6.1.dev42.dist-info}/RECORD +11 -8
- {diffsynth_engine-0.6.1.dev41.dist-info → diffsynth_engine-0.6.1.dev42.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev41.dist-info → diffsynth_engine-0.6.1.dev42.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev41.dist-info → diffsynth_engine-0.6.1.dev42.dist-info}/top_level.txt +0 -0
|
@@ -8,6 +8,7 @@ from .wan_dmd import WanDMDPipeline
|
|
|
8
8
|
from .qwen_image import QwenImagePipeline
|
|
9
9
|
from .hunyuan3d_shape import Hunyuan3DShapePipeline
|
|
10
10
|
from .z_image import ZImagePipeline
|
|
11
|
+
from .z_image_omni_base import ZImageOmniBasePipeline
|
|
11
12
|
|
|
12
13
|
__all__ = [
|
|
13
14
|
"BasePipeline",
|
|
@@ -21,4 +22,5 @@ __all__ = [
|
|
|
21
22
|
"QwenImagePipeline",
|
|
22
23
|
"Hunyuan3DShapePipeline",
|
|
23
24
|
"ZImagePipeline",
|
|
25
|
+
"ZImageOmniBasePipeline",
|
|
24
26
|
]
|
|
@@ -0,0 +1,503 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.distributed as dist
|
|
3
|
+
import math
|
|
4
|
+
import json
|
|
5
|
+
from typing import Callable, List, Dict, Tuple, Optional, Union
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from diffsynth_engine.configs import (
|
|
10
|
+
ZImagePipelineConfig,
|
|
11
|
+
ZImageStateDicts,
|
|
12
|
+
)
|
|
13
|
+
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
14
|
+
|
|
15
|
+
from diffsynth_engine.models.z_image import (
|
|
16
|
+
ZImageOmniBaseDiT,
|
|
17
|
+
Qwen3Model,
|
|
18
|
+
Qwen3Config,
|
|
19
|
+
Siglip2ImageEncoder,
|
|
20
|
+
)
|
|
21
|
+
from transformers import Qwen2Tokenizer
|
|
22
|
+
from diffsynth_engine.utils.constants import (
|
|
23
|
+
Z_IMAGE_TEXT_ENCODER_CONFIG_FILE,
|
|
24
|
+
Z_IMAGE_TOKENIZER_CONF_PATH,
|
|
25
|
+
)
|
|
26
|
+
from diffsynth_engine.models.flux import FluxVAEDecoder, FluxVAEEncoder
|
|
27
|
+
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
28
|
+
from diffsynth_engine.pipelines.utils import calculate_shift
|
|
29
|
+
from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
|
|
30
|
+
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
|
|
31
|
+
from diffsynth_engine.utils.parallel import ParallelWrapper
|
|
32
|
+
from diffsynth_engine.utils import logging
|
|
33
|
+
from diffsynth_engine.utils.fp8_linear import enable_fp8_linear
|
|
34
|
+
from diffsynth_engine.utils.download import fetch_model
|
|
35
|
+
|
|
36
|
+
logger = logging.get_logger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ZImageOmniBaseLoRAConverter(LoRAStateDictConverter):
|
|
40
|
+
def _from_diffusers(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
41
|
+
dit_dict = {}
|
|
42
|
+
for key, param in lora_state_dict.items():
|
|
43
|
+
if "lora_A.weight" in key:
|
|
44
|
+
lora_b_key = key.replace("lora_A.weight", "lora_B.weight")
|
|
45
|
+
target_key = key.replace(".lora_A.weight", "").replace("diffusion_model.", "")
|
|
46
|
+
|
|
47
|
+
# if "attention.to_out.0" in target_key:
|
|
48
|
+
# target_key = target_key.replace("attention.to_out.0", "attention.to_out")
|
|
49
|
+
# if "adaLN_modulation.0" in target_key:
|
|
50
|
+
# target_key = target_key.replace("adaLN_modulation.0", "adaLN_modulation")
|
|
51
|
+
|
|
52
|
+
up = lora_state_dict[lora_b_key]
|
|
53
|
+
rank = up.shape[1]
|
|
54
|
+
|
|
55
|
+
dit_dict[target_key] = {
|
|
56
|
+
"down": param,
|
|
57
|
+
"up": up,
|
|
58
|
+
"rank": rank,
|
|
59
|
+
"alpha": lora_state_dict.get(key.replace("lora_A.weight", "alpha"), rank),
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
return {"dit": dit_dict}
|
|
63
|
+
|
|
64
|
+
def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
65
|
+
dit_dict = {}
|
|
66
|
+
for key, param in lora_state_dict.items():
|
|
67
|
+
if "lora_A.default.weight" in key:
|
|
68
|
+
lora_b_key = key.replace("lora_A.default.weight", "lora_B.default.weight")
|
|
69
|
+
target_key = key.replace(".lora_A.default.weight", "")
|
|
70
|
+
|
|
71
|
+
# if "attention.to_out.0" in target_key:
|
|
72
|
+
# target_key = target_key.replace("attention.to_out.0", "attention.to_out")
|
|
73
|
+
|
|
74
|
+
up = lora_state_dict[lora_b_key]
|
|
75
|
+
rank = up.shape[1]
|
|
76
|
+
|
|
77
|
+
dit_dict[target_key] = {
|
|
78
|
+
"down": param,
|
|
79
|
+
"up": up,
|
|
80
|
+
"rank": rank,
|
|
81
|
+
"alpha": lora_state_dict.get(key.replace("lora_A.default.weight", "alpha"), rank),
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
return {"dit": dit_dict}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
88
|
+
key = list(lora_state_dict.keys())[0]
|
|
89
|
+
if key.startswith("diffusion_model."):
|
|
90
|
+
return self._from_diffusers(lora_state_dict)
|
|
91
|
+
else:
|
|
92
|
+
return self._from_diffsynth(lora_state_dict)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ZImageOmniBasePipeline(BasePipeline):
|
|
96
|
+
lora_converter = ZImageOmniBaseLoRAConverter()
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
config: ZImagePipelineConfig,
|
|
101
|
+
tokenizer: Qwen2Tokenizer,
|
|
102
|
+
text_encoder: Qwen3Model,
|
|
103
|
+
image_encoder: Siglip2ImageEncoder,
|
|
104
|
+
dit: ZImageOmniBaseDiT,
|
|
105
|
+
vae_decoder: FluxVAEDecoder,
|
|
106
|
+
vae_encoder: FluxVAEEncoder,
|
|
107
|
+
):
|
|
108
|
+
super().__init__(
|
|
109
|
+
vae_tiled=config.vae_tiled,
|
|
110
|
+
vae_tile_size=config.vae_tile_size,
|
|
111
|
+
vae_tile_stride=config.vae_tile_stride,
|
|
112
|
+
device=config.device,
|
|
113
|
+
dtype=config.model_dtype,
|
|
114
|
+
)
|
|
115
|
+
self.config = config
|
|
116
|
+
|
|
117
|
+
# Scheduler
|
|
118
|
+
self.noise_scheduler = RecifitedFlowScheduler(shift=6.0, use_dynamic_shifting=False)
|
|
119
|
+
self.sampler = FlowMatchEulerSampler()
|
|
120
|
+
self.tokenizer = tokenizer
|
|
121
|
+
# Models
|
|
122
|
+
self.text_encoder = text_encoder
|
|
123
|
+
self.image_encoder = image_encoder
|
|
124
|
+
self.dit = dit
|
|
125
|
+
self.vae_decoder = vae_decoder
|
|
126
|
+
self.vae_encoder = vae_encoder
|
|
127
|
+
|
|
128
|
+
self.model_names = ["text_encoder", "dit", "vae_decoder"]
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def from_pretrained(cls, model_path_or_config: str | ZImagePipelineConfig) -> "ZImageOmniBasePipeline":
|
|
132
|
+
if isinstance(model_path_or_config, str):
|
|
133
|
+
config = ZImagePipelineConfig(model_path=model_path_or_config)
|
|
134
|
+
else:
|
|
135
|
+
config = model_path_or_config
|
|
136
|
+
|
|
137
|
+
logger.info(f"Loading state dict from {config.model_path} ...")
|
|
138
|
+
|
|
139
|
+
model_state_dict = cls.load_model_checkpoint(
|
|
140
|
+
config.model_path, device="cpu", dtype=config.model_dtype, convert_dtype=False
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if config.vae_path is None:
|
|
144
|
+
config.vae_path = fetch_model(config.model_path, path="vae/diffusion_pytorch_model.safetensors")
|
|
145
|
+
logger.info(f"Loading VAE from {config.vae_path} ...")
|
|
146
|
+
vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
|
|
147
|
+
|
|
148
|
+
if config.encoder_path is None:
|
|
149
|
+
config.encoder_path = fetch_model(config.model_path, path="text_encoder")
|
|
150
|
+
logger.info(f"Loading Text Encoder from {config.encoder_path} ...")
|
|
151
|
+
text_encoder_state_dict = cls.load_model_checkpoint(
|
|
152
|
+
config.encoder_path, device="cpu", dtype=config.encoder_dtype
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
if config.image_encoder_path is None:
|
|
156
|
+
config.image_encoder_path = fetch_model(config.model_path, path="siglip/model.safetensors")
|
|
157
|
+
logger.info(f"Loading Image Encoder from {config.image_encoder_path} ...")
|
|
158
|
+
image_encoder_state_dict = cls.load_model_checkpoint(
|
|
159
|
+
config.image_encoder_path, device="cpu", dtype=config.encoder_dtype
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
state_dicts = ZImageStateDicts(
|
|
163
|
+
model=model_state_dict,
|
|
164
|
+
vae=vae_state_dict,
|
|
165
|
+
encoder=text_encoder_state_dict,
|
|
166
|
+
image_encoder=image_encoder_state_dict,
|
|
167
|
+
)
|
|
168
|
+
return cls.from_state_dict(state_dicts, config)
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def from_state_dict(cls, state_dicts: ZImageStateDicts, config: ZImagePipelineConfig) -> "ZImageOmniBasePipeline":
|
|
172
|
+
assert config.parallelism <= 1, "Z-Image-Omni-Base doesn't support parallelism > 1"
|
|
173
|
+
pipe = cls._from_state_dict(state_dicts, config)
|
|
174
|
+
return pipe
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def _from_state_dict(cls, state_dicts: ZImageStateDicts, config: ZImagePipelineConfig) -> "ZImageOmniBasePipeline":
|
|
178
|
+
init_device = "cpu" if config.offload_mode is not None else config.device
|
|
179
|
+
with open(Z_IMAGE_TEXT_ENCODER_CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
180
|
+
qwen3_config = Qwen3Config(**json.load(f))
|
|
181
|
+
text_encoder = Qwen3Model.from_state_dict(
|
|
182
|
+
state_dicts.encoder, config=qwen3_config, device=init_device, dtype=config.encoder_dtype
|
|
183
|
+
)
|
|
184
|
+
tokenizer = Qwen2Tokenizer.from_pretrained(Z_IMAGE_TOKENIZER_CONF_PATH)
|
|
185
|
+
vae_decoder = FluxVAEDecoder.from_state_dict(state_dicts.vae, device=init_device, dtype=config.vae_dtype)
|
|
186
|
+
vae_encoder = FluxVAEEncoder.from_state_dict(state_dicts.vae, device=init_device, dtype=config.vae_dtype)
|
|
187
|
+
image_encoder = Siglip2ImageEncoder.from_state_dict(state_dicts.image_encoder, device=init_device, dtype=config.image_encoder_dtype)
|
|
188
|
+
|
|
189
|
+
with LoRAContext():
|
|
190
|
+
dit = ZImageOmniBaseDiT.from_state_dict(
|
|
191
|
+
state_dicts.model,
|
|
192
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
193
|
+
dtype=config.model_dtype,
|
|
194
|
+
)
|
|
195
|
+
if config.use_fp8_linear:
|
|
196
|
+
enable_fp8_linear(dit)
|
|
197
|
+
|
|
198
|
+
pipe = cls(
|
|
199
|
+
config=config,
|
|
200
|
+
tokenizer=tokenizer,
|
|
201
|
+
text_encoder=text_encoder,
|
|
202
|
+
dit=dit,
|
|
203
|
+
vae_decoder=vae_decoder,
|
|
204
|
+
vae_encoder=vae_encoder,
|
|
205
|
+
image_encoder=image_encoder,
|
|
206
|
+
)
|
|
207
|
+
pipe.eval()
|
|
208
|
+
|
|
209
|
+
if config.offload_mode is not None:
|
|
210
|
+
pipe.enable_cpu_offload(config.offload_mode, config.offload_to_disk)
|
|
211
|
+
|
|
212
|
+
if config.model_dtype == torch.float8_e4m3fn:
|
|
213
|
+
pipe.dtype = torch.bfloat16
|
|
214
|
+
pipe.enable_fp8_autocast(
|
|
215
|
+
model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if config.use_torch_compile:
|
|
219
|
+
pipe.compile()
|
|
220
|
+
|
|
221
|
+
return pipe
|
|
222
|
+
|
|
223
|
+
def update_weights(self, state_dicts: ZImageStateDicts) -> None:
|
|
224
|
+
self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
|
|
225
|
+
self.update_component(
|
|
226
|
+
self.text_encoder, state_dicts.encoder, self.config.device, self.config.encoder_dtype
|
|
227
|
+
)
|
|
228
|
+
self.update_component(self.vae_decoder, state_dicts.vae, self.config.device, self.config.vae_dtype)
|
|
229
|
+
|
|
230
|
+
def compile(self):
|
|
231
|
+
if hasattr(self.dit, "compile_repeated_blocks"):
|
|
232
|
+
self.dit.compile_repeated_blocks()
|
|
233
|
+
|
|
234
|
+
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
|
|
235
|
+
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
|
|
236
|
+
"load LoRA is not allowed when tensor parallel is enabled; "
|
|
237
|
+
"set tp_degree=None or tp_degree=1 during pipeline initialization"
|
|
238
|
+
)
|
|
239
|
+
assert not (self.config.use_fsdp and fused), (
|
|
240
|
+
"load fused LoRA is not allowed when fully sharded data parallel is enabled; "
|
|
241
|
+
"either load LoRA with fused=False or set use_fsdp=False during pipeline initialization"
|
|
242
|
+
)
|
|
243
|
+
super().load_loras(lora_list, fused, save_original_weight)
|
|
244
|
+
|
|
245
|
+
def unload_loras(self):
|
|
246
|
+
if hasattr(self.dit, "unload_loras"):
|
|
247
|
+
self.dit.unload_loras()
|
|
248
|
+
self.noise_scheduler.restore_config()
|
|
249
|
+
|
|
250
|
+
def apply_scheduler_config(self, scheduler_config: Dict):
|
|
251
|
+
self.noise_scheduler.update_config(scheduler_config)
|
|
252
|
+
|
|
253
|
+
def prepare_latents(
|
|
254
|
+
self,
|
|
255
|
+
latents: torch.Tensor,
|
|
256
|
+
num_inference_steps: int,
|
|
257
|
+
):
|
|
258
|
+
sigmas, timesteps = self.noise_scheduler.schedule(num_inference_steps, sigma_min=0, sigma_max=1.0)
|
|
259
|
+
|
|
260
|
+
sigmas = sigmas.to(device=self.device, dtype=self.dtype)
|
|
261
|
+
timesteps = timesteps.to(device=self.device, dtype=self.dtype)
|
|
262
|
+
latents = latents.to(device=self.device, dtype=self.dtype)
|
|
263
|
+
|
|
264
|
+
return latents, sigmas, timesteps
|
|
265
|
+
|
|
266
|
+
def encode_prompt(
|
|
267
|
+
self,
|
|
268
|
+
prompt: str,
|
|
269
|
+
edit_image = None,
|
|
270
|
+
max_sequence_length: int = 512,
|
|
271
|
+
):
|
|
272
|
+
if isinstance(prompt, str):
|
|
273
|
+
prompt = [prompt]
|
|
274
|
+
|
|
275
|
+
if edit_image is None:
|
|
276
|
+
num_condition_images = 0
|
|
277
|
+
elif isinstance(edit_image, list):
|
|
278
|
+
num_condition_images = len(edit_image)
|
|
279
|
+
else:
|
|
280
|
+
num_condition_images = 1
|
|
281
|
+
|
|
282
|
+
for i, prompt_item in enumerate(prompt):
|
|
283
|
+
if num_condition_images == 0:
|
|
284
|
+
prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"]
|
|
285
|
+
elif num_condition_images > 0:
|
|
286
|
+
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
|
|
287
|
+
prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1)
|
|
288
|
+
prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"]
|
|
289
|
+
prompt_list += ["<|vision_end|><|im_end|>"]
|
|
290
|
+
prompt[i] = prompt_list
|
|
291
|
+
|
|
292
|
+
flattened_prompt = []
|
|
293
|
+
prompt_list_lengths = []
|
|
294
|
+
|
|
295
|
+
for i in range(len(prompt)):
|
|
296
|
+
prompt_list_lengths.append(len(prompt[i]))
|
|
297
|
+
flattened_prompt.extend(prompt[i])
|
|
298
|
+
|
|
299
|
+
text_inputs = self.tokenizer(
|
|
300
|
+
flattened_prompt,
|
|
301
|
+
padding="max_length",
|
|
302
|
+
max_length=max_sequence_length,
|
|
303
|
+
truncation=True,
|
|
304
|
+
return_tensors="pt",
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
text_input_ids = text_inputs["input_ids"].to(self.device)
|
|
308
|
+
prompt_masks = text_inputs["attention_mask"].to(self.device).bool()
|
|
309
|
+
|
|
310
|
+
prompt_embeds = self.text_encoder(
|
|
311
|
+
input_ids=text_input_ids,
|
|
312
|
+
attention_mask=prompt_masks,
|
|
313
|
+
output_hidden_states=True,
|
|
314
|
+
)["hidden_states"][-2]
|
|
315
|
+
|
|
316
|
+
embeddings_list = []
|
|
317
|
+
start_idx = 0
|
|
318
|
+
for i in range(len(prompt_list_lengths)):
|
|
319
|
+
batch_embeddings = []
|
|
320
|
+
end_idx = start_idx + prompt_list_lengths[i]
|
|
321
|
+
for j in range(start_idx, end_idx):
|
|
322
|
+
batch_embeddings.append(prompt_embeds[j][prompt_masks[j]])
|
|
323
|
+
embeddings_list.append(batch_embeddings)
|
|
324
|
+
start_idx = end_idx
|
|
325
|
+
|
|
326
|
+
return embeddings_list
|
|
327
|
+
|
|
328
|
+
def calculate_dimensions(self, target_area, ratio):
|
|
329
|
+
width = math.sqrt(target_area * ratio)
|
|
330
|
+
height = width / ratio
|
|
331
|
+
width = round(width / 32) * 32
|
|
332
|
+
height = round(height / 32) * 32
|
|
333
|
+
return width, height
|
|
334
|
+
|
|
335
|
+
def auto_resize_image(self, image):
|
|
336
|
+
image = image.convert("RGB")
|
|
337
|
+
width, height = image.size
|
|
338
|
+
target_width, target_height = self.calculate_dimensions(1024*1024, width / height)
|
|
339
|
+
image = image.resize((target_width, target_height))
|
|
340
|
+
return image
|
|
341
|
+
|
|
342
|
+
def encode_image(self, edit_image):
|
|
343
|
+
if edit_image is None:
|
|
344
|
+
return None, None
|
|
345
|
+
if not isinstance(edit_image, list):
|
|
346
|
+
edit_image = [edit_image]
|
|
347
|
+
edit_image = [self.auto_resize_image(i) for i in edit_image]
|
|
348
|
+
|
|
349
|
+
image_emb = []
|
|
350
|
+
for i in edit_image:
|
|
351
|
+
image_emb.append(self.image_encoder(i, device=self.device))
|
|
352
|
+
|
|
353
|
+
image_latents = []
|
|
354
|
+
for image_ in edit_image:
|
|
355
|
+
image_ = self.preprocess_image(image_).to(dtype=self.dtype, device=self.device)
|
|
356
|
+
image_latents.append(self.vae_encoder(image_).transpose(0, 1))
|
|
357
|
+
|
|
358
|
+
return image_emb, image_latents
|
|
359
|
+
|
|
360
|
+
def predict_noise_with_cfg(
|
|
361
|
+
self,
|
|
362
|
+
latents: torch.Tensor,
|
|
363
|
+
timestep: torch.Tensor,
|
|
364
|
+
prompt_emb: List[torch.Tensor],
|
|
365
|
+
image_emb: List[torch.Tensor],
|
|
366
|
+
image_latents: List[torch.Tensor],
|
|
367
|
+
negative_prompt_emb: List[torch.Tensor],
|
|
368
|
+
cfg_scale: float = 5.0,
|
|
369
|
+
cfg_truncation: float = 1.0,
|
|
370
|
+
cfg_normalization: float = 0.0, # 0.0 means disabled
|
|
371
|
+
batch_cfg: bool = False,
|
|
372
|
+
):
|
|
373
|
+
t = timestep.expand(latents.shape[0])
|
|
374
|
+
t = (1000 - t) / 1000
|
|
375
|
+
progress = t[0].item()
|
|
376
|
+
|
|
377
|
+
current_cfg_scale = cfg_scale
|
|
378
|
+
if cfg_truncation <= 1.0 and progress > cfg_truncation:
|
|
379
|
+
current_cfg_scale = 0.0
|
|
380
|
+
|
|
381
|
+
do_cfg = current_cfg_scale > 0 and negative_prompt_emb is not None
|
|
382
|
+
|
|
383
|
+
if not do_cfg:
|
|
384
|
+
latents_input = [[latents.transpose(0, 1)]]
|
|
385
|
+
image_emb = [image_emb]
|
|
386
|
+
image_latents = [image_latents] if image_latents is not None else None
|
|
387
|
+
comb_pred = self.predict_noise(latents_input, t, prompt_emb, image_emb, image_latents)[0]
|
|
388
|
+
else:
|
|
389
|
+
if not batch_cfg:
|
|
390
|
+
latents_input = [[latents.transpose(0, 1)]]
|
|
391
|
+
image_emb = [image_emb]
|
|
392
|
+
image_latents = [image_latents] if image_latents is not None else None
|
|
393
|
+
|
|
394
|
+
positive_noise_pred = self.predict_noise(latents_input, t, prompt_emb, image_emb, image_latents)[0]
|
|
395
|
+
negative_noise_pred = self.predict_noise(latents_input, t, negative_prompt_emb, image_emb, image_latents)[0]
|
|
396
|
+
else:
|
|
397
|
+
latents_input = [[latents.transpose(0, 1)], [latents.transpose(0, 1)]]
|
|
398
|
+
t = torch.concat([t, t], dim=0)
|
|
399
|
+
prompt_input = prompt_emb + negative_prompt_emb
|
|
400
|
+
image_emb = [image_emb, image_emb]
|
|
401
|
+
image_latents = [image_latents, image_latents] if image_latents is not None else None
|
|
402
|
+
|
|
403
|
+
noise_pred = self.predict_noise(latents_input, t, prompt_input, image_emb, image_latents)
|
|
404
|
+
|
|
405
|
+
positive_noise_pred, negative_noise_pred = noise_pred[0], noise_pred[1]
|
|
406
|
+
|
|
407
|
+
comb_pred = positive_noise_pred + current_cfg_scale * (positive_noise_pred - negative_noise_pred)
|
|
408
|
+
|
|
409
|
+
if cfg_normalization is not None and cfg_normalization > 0:
|
|
410
|
+
cond_norm = torch.linalg.vector_norm(positive_noise_pred)
|
|
411
|
+
new_norm = torch.linalg.vector_norm(comb_pred)
|
|
412
|
+
max_allowed_norm = cond_norm * cfg_normalization
|
|
413
|
+
new_norm = torch.where(new_norm < 1e-6, torch.ones_like(new_norm), new_norm)
|
|
414
|
+
scale_factor = max_allowed_norm / new_norm
|
|
415
|
+
scale_factor = torch.clamp(scale_factor, max=1.0)
|
|
416
|
+
comb_pred = comb_pred * scale_factor
|
|
417
|
+
|
|
418
|
+
comb_pred = -comb_pred.squeeze(1).unsqueeze(0)
|
|
419
|
+
return comb_pred
|
|
420
|
+
|
|
421
|
+
def predict_noise(
|
|
422
|
+
self,
|
|
423
|
+
latents: torch.Tensor,
|
|
424
|
+
timestep: torch.Tensor,
|
|
425
|
+
prompt_emb: List[torch.Tensor],
|
|
426
|
+
image_emb: List[torch.Tensor] = None,
|
|
427
|
+
image_latents: List[torch.Tensor] = None,
|
|
428
|
+
):
|
|
429
|
+
self.load_models_to_device(["dit"])
|
|
430
|
+
if image_latents is not None:
|
|
431
|
+
latents = [i + j for i, j in zip(image_latents, latents)]
|
|
432
|
+
image_noise_mask = [[0] * len(i) + [1] for i in image_latents]
|
|
433
|
+
else:
|
|
434
|
+
image_noise_mask = [[1], [1]]
|
|
435
|
+
|
|
436
|
+
noise_pred = self.dit(
|
|
437
|
+
x=latents,
|
|
438
|
+
t=timestep,
|
|
439
|
+
cap_feats=prompt_emb,
|
|
440
|
+
siglip_feats=image_emb,
|
|
441
|
+
image_noise_mask=image_noise_mask,
|
|
442
|
+
)
|
|
443
|
+
return noise_pred
|
|
444
|
+
|
|
445
|
+
@torch.no_grad()
|
|
446
|
+
def __call__(
|
|
447
|
+
self,
|
|
448
|
+
prompt: Union[str, List[str]],
|
|
449
|
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
450
|
+
edit_image: Optional[List[Image.Image]] = None,
|
|
451
|
+
height: int = 1024,
|
|
452
|
+
width: int = 1024,
|
|
453
|
+
num_inference_steps: int = 50,
|
|
454
|
+
cfg_scale: float = 5.0,
|
|
455
|
+
cfg_normalization: bool = False,
|
|
456
|
+
cfg_truncation: float = 1.0,
|
|
457
|
+
seed: Optional[int] = None,
|
|
458
|
+
progress_callback: Optional[Callable] = None,
|
|
459
|
+
):
|
|
460
|
+
self.validate_image_size(height, width, multiple_of=16)
|
|
461
|
+
|
|
462
|
+
self.load_models_to_device(["text_encoder"])
|
|
463
|
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, edit_image=edit_image), self.encode_prompt(negative_prompt, edit_image=edit_image)
|
|
464
|
+
self.model_lifecycle_finish(["text_encoder"])
|
|
465
|
+
|
|
466
|
+
image_emb, image_latents = self.encode_image(edit_image)
|
|
467
|
+
|
|
468
|
+
noise = self.generate_noise((1, 16, height // 8, width // 8), seed=seed, device="cpu", dtype=self.dtype).to(
|
|
469
|
+
device=self.device
|
|
470
|
+
)
|
|
471
|
+
latents, sigmas, timesteps = self.prepare_latents(noise, num_inference_steps)
|
|
472
|
+
|
|
473
|
+
self.sampler.initialize(sigmas=sigmas)
|
|
474
|
+
|
|
475
|
+
self.load_models_to_device(["dit"])
|
|
476
|
+
hide_progress = dist.is_initialized() and dist.get_rank() != 0
|
|
477
|
+
|
|
478
|
+
for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
|
|
479
|
+
timestep = timestep.unsqueeze(0).to(dtype=self.dtype)
|
|
480
|
+
noise_pred = self.predict_noise_with_cfg(
|
|
481
|
+
latents=latents,
|
|
482
|
+
timestep=timestep,
|
|
483
|
+
prompt_emb=prompt_embeds,
|
|
484
|
+
negative_prompt_emb=negative_prompt_embeds,
|
|
485
|
+
image_emb=image_emb,
|
|
486
|
+
image_latents=image_latents,
|
|
487
|
+
batch_cfg=self.config.batch_cfg,
|
|
488
|
+
cfg_scale=cfg_scale,
|
|
489
|
+
cfg_truncation=cfg_truncation,
|
|
490
|
+
cfg_normalization=cfg_normalization,
|
|
491
|
+
)
|
|
492
|
+
latents = self.sampler.step(latents, noise_pred, i)
|
|
493
|
+
if progress_callback is not None:
|
|
494
|
+
progress_callback(i, len(timesteps), "DENOISING")
|
|
495
|
+
|
|
496
|
+
self.model_lifecycle_finish(["dit"])
|
|
497
|
+
|
|
498
|
+
self.load_models_to_device(["vae_decoder"])
|
|
499
|
+
vae_output = self.decode_image(latents)
|
|
500
|
+
image = self.vae_output_to_image(vae_output)
|
|
501
|
+
# Offload all models
|
|
502
|
+
self.load_models_to_device([])
|
|
503
|
+
return image
|
|
@@ -86,7 +86,7 @@ diffsynth_engine/conf/tokenizers/z_image/tokenizer/tokenizer_config.json,sha256=
|
|
|
86
86
|
diffsynth_engine/conf/tokenizers/z_image/tokenizer/vocab.json,sha256=yhDX6fs-0YV13R4neiV5wW0QjjLydDloSvoOELFECRA,2776833
|
|
87
87
|
diffsynth_engine/configs/__init__.py,sha256=biluGSEw78PPwO7XFlms16iuWXDiM0Eg_qsOMMTY0NQ,1409
|
|
88
88
|
diffsynth_engine/configs/controlnet.py,sha256=f3vclyP3lcAjxDGD9C1vevhqqQ7W2LL_c6Wye0uxk3Q,1180
|
|
89
|
-
diffsynth_engine/configs/pipeline.py,sha256=
|
|
89
|
+
diffsynth_engine/configs/pipeline.py,sha256=K0lH6Vg33H97oLAxtD8Hi1PpopRNY70DioaVZUP4uDM,15687
|
|
90
90
|
diffsynth_engine/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
91
91
|
diffsynth_engine/models/__init__.py,sha256=8Ze7cSE8InetgXWTNb0neVA2Q44K7WlE-h7O-02m2sY,119
|
|
92
92
|
diffsynth_engine/models/base.py,sha256=svao__9WH8VNcyXz5o5dzywYXDcGV0YV9IfkLzDKews,2558
|
|
@@ -147,10 +147,12 @@ diffsynth_engine/models/wan/wan_image_encoder.py,sha256=Vdd39lv_QvOsmPxihZWZZbpP
|
|
|
147
147
|
diffsynth_engine/models/wan/wan_s2v_dit.py,sha256=j63ulcWLY4XGITOKUMGX292LtSEtP-n8BTvqb98YExU,23615
|
|
148
148
|
diffsynth_engine/models/wan/wan_text_encoder.py,sha256=ePeOifbTI_o650mckzugyWPuHn5vhM-uFMcDVCijxPM,11394
|
|
149
149
|
diffsynth_engine/models/wan/wan_vae.py,sha256=dC7MoUFeXRL7SIY0LG1OOUiZW-pp9IbXCghutMxpXr4,38889
|
|
150
|
-
diffsynth_engine/models/z_image/__init__.py,sha256=
|
|
150
|
+
diffsynth_engine/models/z_image/__init__.py,sha256=7sQvTYf984sK6ke3Wr-_Pt3Qkqw_s540wPswn4nThkY,305
|
|
151
151
|
diffsynth_engine/models/z_image/qwen3.py,sha256=PmT6m46Fc7KZXNzG7ig23Mzj6QfHnMmrpX_MM0UuuYg,4580
|
|
152
|
+
diffsynth_engine/models/z_image/siglip.py,sha256=PjB6ECXXJKgEpU9gF5Fyyt8twjKNA5_jCAG_8qQkoc8,2661
|
|
152
153
|
diffsynth_engine/models/z_image/z_image_dit.py,sha256=kGtYzmfzk_FDe7KWfXpJagN7k7ROXl5J01IhRRs-Bsk,23806
|
|
153
|
-
diffsynth_engine/
|
|
154
|
+
diffsynth_engine/models/z_image/z_image_dit_omni_base.py,sha256=cfdUFTwGFYRiyBhB_4ptn0lAvYuLAulF6zf0ABqlAzs,44854
|
|
155
|
+
diffsynth_engine/pipelines/__init__.py,sha256=cjKBhZabdKPB9j8R_JbeW3Fu6rKDPwVdvOyp5nOdUMI,804
|
|
154
156
|
diffsynth_engine/pipelines/base.py,sha256=h6xOqT1LMFGrJYoTD68_VoHcfRX04je8KUE_y3BUZfM,17279
|
|
155
157
|
diffsynth_engine/pipelines/flux_image.py,sha256=L0ggxpthLD8a5-zdPHu9z668uWBei9YzPb4PFVypDNU,50707
|
|
156
158
|
diffsynth_engine/pipelines/hunyuan3d_shape.py,sha256=TNV0Wr09Dj2bzzlpua9WioCClOj3YiLfE6utI9aWL8A,8164
|
|
@@ -162,6 +164,7 @@ diffsynth_engine/pipelines/wan_dmd.py,sha256=T_i4xp_tASFSaKZxg50FEAk5TOn89JSNv-4
|
|
|
162
164
|
diffsynth_engine/pipelines/wan_s2v.py,sha256=QHlCLMqlmnp55iYm2mzg4qCq4jceRAP3Zt5Mubz3mAM,29384
|
|
163
165
|
diffsynth_engine/pipelines/wan_video.py,sha256=9nUV6h2zBbGu3gvVSM_oqdoruCdBWoa7t6vrJYJt8QY,32391
|
|
164
166
|
diffsynth_engine/pipelines/z_image.py,sha256=VvqjxsKRsmP2tfWg9nDlcQu5oEzIRFa2wtuArzjQAlk,16151
|
|
167
|
+
diffsynth_engine/pipelines/z_image_omni_base.py,sha256=KwLLz1o50SK8XvVBT9KE4b1QCbzZsb2OJ0UZ90anGTc,20414
|
|
165
168
|
diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
166
169
|
diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
|
|
167
170
|
diffsynth_engine/processor/depth_processor.py,sha256=dQvs3JsnyMbz4dyI9QoR8oO-mMFBFAgNvgqeCoaU5jk,1532
|
|
@@ -200,8 +203,8 @@ diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CD
|
|
|
200
203
|
diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
201
204
|
diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
|
|
202
205
|
diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
|
|
203
|
-
diffsynth_engine-0.6.1.
|
|
204
|
-
diffsynth_engine-0.6.1.
|
|
205
|
-
diffsynth_engine-0.6.1.
|
|
206
|
-
diffsynth_engine-0.6.1.
|
|
207
|
-
diffsynth_engine-0.6.1.
|
|
206
|
+
diffsynth_engine-0.6.1.dev42.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
|
|
207
|
+
diffsynth_engine-0.6.1.dev42.dist-info/METADATA,sha256=6WMcsZ3FoKJUqNWK-oPkM7QtFOgKFaSv6D7OlcN4EYw,1164
|
|
208
|
+
diffsynth_engine-0.6.1.dev42.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
209
|
+
diffsynth_engine-0.6.1.dev42.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
|
|
210
|
+
diffsynth_engine-0.6.1.dev42.dist-info/RECORD,,
|
|
File without changes
|
{diffsynth_engine-0.6.1.dev41.dist-info → diffsynth_engine-0.6.1.dev42.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{diffsynth_engine-0.6.1.dev41.dist-info → diffsynth_engine-0.6.1.dev42.dist-info}/top_level.txt
RENAMED
|
File without changes
|