diffsynth-engine 0.7.0__py3-none-any.whl → 0.7.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -59,8 +59,11 @@ class Qwen3Model(PreTrainedModel):
59
59
  device: str = "cuda:0",
60
60
  dtype: torch.dtype = torch.bfloat16,
61
61
  ):
62
- model = cls(config=config, device="meta", dtype=dtype)
62
+ with torch.device("meta"):
63
+ model = cls(config=config, device="meta", dtype=dtype)
63
64
  model.requires_grad_(False)
65
+
66
+ model.rotary_emb = Qwen3RotaryEmbedding(config=config, device=device)
64
67
  model.load_state_dict(state_dict, assign=True)
65
68
  model.to(device=device, dtype=dtype, non_blocking=True)
66
69
  return model
@@ -584,7 +584,8 @@ class ZImageDiT(PreTrainedModel):
584
584
  dtype: torch.dtype,
585
585
  **kwargs,
586
586
  ):
587
- model = cls(device="meta", dtype=dtype, **kwargs)
587
+ with torch.device("meta"):
588
+ model = cls(device="meta", dtype=dtype, **kwargs)
588
589
  model = model.requires_grad_(False)
589
590
  model.load_state_dict(state_dict, assign=True)
590
591
  model.to(device=device, dtype=dtype, non_blocking=True)
@@ -1,5 +1,6 @@
1
1
  from .base import BasePipeline, LoRAStateDictConverter
2
2
  from .flux_image import FluxImagePipeline
3
+ from .flux2_klein_image import Flux2KleinPipeline
3
4
  from .sdxl_image import SDXLImagePipeline
4
5
  from .sd_image import SDImagePipeline
5
6
  from .wan_video import WanVideoPipeline
@@ -14,6 +15,7 @@ __all__ = [
14
15
  "BasePipeline",
15
16
  "LoRAStateDictConverter",
16
17
  "FluxImagePipeline",
18
+ "Flux2KleinPipeline",
17
19
  "SDXLImagePipeline",
18
20
  "SDImagePipeline",
19
21
  "WanVideoPipeline",
@@ -0,0 +1,634 @@
1
+ import torch
2
+ import math
3
+ import json
4
+ import torchvision
5
+ from typing import Callable, List, Dict, Tuple, Optional, Union
6
+ from tqdm import tqdm
7
+ from PIL import Image
8
+ import numpy as np
9
+ from einops import rearrange
10
+
11
+ from diffsynth_engine.configs import (
12
+ Flux2KleinPipelineConfig,
13
+ Flux2StateDicts,
14
+ )
15
+ from diffsynth_engine.models.basic.lora import LoRAContext
16
+
17
+ from diffsynth_engine.models.flux2 import (
18
+ Flux2DiT,
19
+ Flux2VAE,
20
+ )
21
+ from diffsynth_engine.models.z_image import (
22
+ Qwen3Model,
23
+ Qwen3Config,
24
+ )
25
+ from transformers import AutoTokenizer
26
+ from diffsynth_engine.utils.constants import (
27
+ Z_IMAGE_TEXT_ENCODER_CONFIG_FILE,
28
+ Z_IMAGE_TOKENIZER_CONF_PATH,
29
+ FLUX2_TEXT_ENCODER_8B_CONF_PATH,
30
+ )
31
+ from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
32
+ from diffsynth_engine.pipelines.utils import calculate_shift
33
+ from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
34
+ from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
35
+ from diffsynth_engine.utils.parallel import ParallelWrapper
36
+ from diffsynth_engine.utils import logging
37
+ from diffsynth_engine.utils.fp8_linear import enable_fp8_linear
38
+ from diffsynth_engine.utils.download import fetch_model
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ class Flux2LoRAConverter(LoRAStateDictConverter):
44
+ def _from_diffusers(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
45
+ dit_dict = {}
46
+ for key, param in lora_state_dict.items():
47
+ if "lora_A.weight" in key:
48
+ lora_b_key = key.replace("lora_A.weight", "lora_B.weight")
49
+ target_key = key.replace(".lora_A.weight", "").replace("diffusion_model.", "")
50
+
51
+ up = lora_state_dict[lora_b_key]
52
+ rank = up.shape[1]
53
+
54
+ dit_dict[target_key] = {
55
+ "down": param,
56
+ "up": up,
57
+ "rank": rank,
58
+ "alpha": lora_state_dict.get(key.replace("lora_A.weight", "alpha"), rank),
59
+ }
60
+
61
+ return {"dit": dit_dict}
62
+
63
+ def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
64
+ dit_dict = {}
65
+ for key, param in lora_state_dict.items():
66
+ if "lora_A.default.weight" in key:
67
+ lora_b_key = key.replace("lora_A.default.weight", "lora_B.default.weight")
68
+ target_key = key.replace(".lora_A.default.weight", "")
69
+
70
+ up = lora_state_dict[lora_b_key]
71
+ rank = up.shape[1]
72
+
73
+ dit_dict[target_key] = {
74
+ "down": param,
75
+ "up": up,
76
+ "rank": rank,
77
+ "alpha": lora_state_dict.get(key.replace("lora_A.default.weight", "alpha"), rank),
78
+ }
79
+
80
+ return {"dit": dit_dict}
81
+
82
+ def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
83
+ key = list(lora_state_dict.keys())[0]
84
+ if key.startswith("diffusion_model."):
85
+ return self._from_diffusers(lora_state_dict)
86
+ else:
87
+ return self._from_diffsynth(lora_state_dict)
88
+
89
+
90
+ def model_fn_flux2(
91
+ dit: Flux2DiT,
92
+ latents=None,
93
+ timestep=None,
94
+ embedded_guidance=None,
95
+ prompt_embeds=None,
96
+ text_ids=None,
97
+ image_ids=None,
98
+ edit_latents=None,
99
+ edit_image_ids=None,
100
+ use_gradient_checkpointing=False,
101
+ use_gradient_checkpointing_offload=False,
102
+ **kwargs,
103
+ ):
104
+ image_seq_len = latents.shape[1]
105
+ if edit_latents is not None:
106
+ latents = torch.concat([latents, edit_latents], dim=1)
107
+ image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
108
+ embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
109
+ model_output = dit(
110
+ hidden_states=latents,
111
+ timestep=timestep / 1000,
112
+ guidance=embedded_guidance,
113
+ encoder_hidden_states=prompt_embeds,
114
+ txt_ids=text_ids,
115
+ img_ids=image_ids,
116
+ use_gradient_checkpointing=use_gradient_checkpointing,
117
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
118
+ )
119
+ model_output = model_output[:, :image_seq_len]
120
+ return model_output
121
+
122
+
123
+ class Flux2KleinPipeline(BasePipeline):
124
+ lora_converter = Flux2LoRAConverter()
125
+
126
+ def __init__(
127
+ self,
128
+ config: Flux2KleinPipelineConfig,
129
+ tokenizer: AutoTokenizer,
130
+ text_encoder: Qwen3Model,
131
+ dit: Flux2DiT,
132
+ vae: Flux2VAE,
133
+ ):
134
+ super().__init__(
135
+ vae_tiled=config.vae_tiled,
136
+ vae_tile_size=config.vae_tile_size,
137
+ vae_tile_stride=config.vae_tile_stride,
138
+ device=config.device,
139
+ dtype=config.model_dtype,
140
+ )
141
+ self.config = config
142
+
143
+ # Scheduler
144
+ self.noise_scheduler = RecifitedFlowScheduler(shift=1.0, use_dynamic_shifting=True, exponential_shift_mu=None)
145
+ self.sampler = FlowMatchEulerSampler()
146
+ self.tokenizer = tokenizer
147
+ # Models
148
+ self.text_encoder = text_encoder
149
+ self.dit = dit
150
+ self.vae = vae
151
+
152
+ self.model_names = ["text_encoder", "dit", "vae"]
153
+
154
+ @classmethod
155
+ def from_pretrained(cls, model_path_or_config: str | Flux2KleinPipelineConfig) -> "Flux2KleinPipeline":
156
+ if isinstance(model_path_or_config, str):
157
+ config = Flux2KleinPipelineConfig(model_path=model_path_or_config)
158
+ else:
159
+ config = model_path_or_config
160
+
161
+ logger.info(f"Loading state dict from {config.model_path} ...")
162
+
163
+ model_state_dict = cls.load_model_checkpoint(
164
+ config.model_path, device="cpu", dtype=config.model_dtype, convert_dtype=False
165
+ )
166
+
167
+ if config.vae_path is None:
168
+ config.vae_path = fetch_model("black-forest-labs/FLUX.2-klein-4B", path="vae/*.safetensors")
169
+ logger.info(f"Loading VAE from {config.vae_path} ...")
170
+ vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
171
+
172
+ if config.encoder_path is None:
173
+ if config.model_size == "4B":
174
+ config.encoder_path = fetch_model("black-forest-labs/FLUX.2-klein-4B", path="text_encoder/*.safetensors")
175
+ else:
176
+ config.encoder_path = fetch_model("black-forest-labs/FLUX.2-klein-9B", path="text_encoder/*.safetensors")
177
+ logger.info(f"Loading Text Encoder from {config.encoder_path} ...")
178
+ text_encoder_state_dict = cls.load_model_checkpoint(
179
+ config.encoder_path, device="cpu", dtype=config.encoder_dtype
180
+ )
181
+
182
+ state_dicts = Flux2StateDicts(
183
+ model=model_state_dict,
184
+ vae=vae_state_dict,
185
+ encoder=text_encoder_state_dict,
186
+ )
187
+ return cls.from_state_dict(state_dicts, config)
188
+
189
+ @classmethod
190
+ def from_state_dict(cls, state_dicts: Flux2StateDicts, config: Flux2KleinPipelineConfig) -> "Flux2KleinPipeline":
191
+ assert config.parallelism <= 1, "Flux2 doesn't support parallelism > 1"
192
+ pipe = cls._from_state_dict(state_dicts, config)
193
+ return pipe
194
+
195
+ @classmethod
196
+ def _from_state_dict(cls, state_dicts: Flux2StateDicts, config: Flux2KleinPipelineConfig) -> "Flux2KleinPipeline":
197
+ init_device = "cpu" if config.offload_mode is not None else config.device
198
+ if config.model_size == "4B":
199
+ with open(Z_IMAGE_TEXT_ENCODER_CONFIG_FILE, "r", encoding="utf-8") as f:
200
+ qwen3_config = Qwen3Config(**json.load(f))
201
+ dit_config = {}
202
+ else:
203
+ with open(FLUX2_TEXT_ENCODER_8B_CONF_PATH, "r", encoding="utf-8") as f:
204
+ qwen3_config = Qwen3Config(**json.load(f))
205
+ state_dicts.encoder.pop("lm_head.weight")
206
+ dit_config = {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
207
+ text_encoder = Qwen3Model.from_state_dict(
208
+ state_dicts.encoder, config=qwen3_config, device=init_device, dtype=config.encoder_dtype
209
+ )
210
+ tokenizer = AutoTokenizer.from_pretrained(Z_IMAGE_TOKENIZER_CONF_PATH, local_files_only=True)
211
+ vae = Flux2VAE.from_state_dict(state_dicts.vae, device=init_device, dtype=config.vae_dtype)
212
+
213
+ with LoRAContext():
214
+ dit = Flux2DiT.from_state_dict(
215
+ state_dicts.model,
216
+ device=("cpu" if config.use_fsdp else init_device),
217
+ dtype=config.model_dtype,
218
+ **dit_config,
219
+ )
220
+ if config.use_fp8_linear:
221
+ enable_fp8_linear(dit)
222
+
223
+ pipe = cls(
224
+ config=config,
225
+ tokenizer=tokenizer,
226
+ text_encoder=text_encoder,
227
+ dit=dit,
228
+ vae=vae,
229
+ )
230
+ pipe.eval()
231
+
232
+ if config.offload_mode is not None:
233
+ pipe.enable_cpu_offload(config.offload_mode, config.offload_to_disk)
234
+
235
+ if config.model_dtype == torch.float8_e4m3fn:
236
+ pipe.dtype = torch.bfloat16
237
+ pipe.enable_fp8_autocast(
238
+ model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
239
+ )
240
+
241
+ if config.use_torch_compile:
242
+ pipe.compile()
243
+
244
+ return pipe
245
+
246
+ def update_weights(self, state_dicts: Flux2StateDicts) -> None:
247
+ self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
248
+ self.update_component(
249
+ self.text_encoder, state_dicts.encoder, self.config.device, self.config.encoder_dtype
250
+ )
251
+ self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
252
+
253
+ def compile(self):
254
+ if hasattr(self.dit, "compile_repeated_blocks"):
255
+ self.dit.compile_repeated_blocks()
256
+
257
+ def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
258
+ assert self.config.tp_degree is None or self.config.tp_degree == 1, (
259
+ "load LoRA is not allowed when tensor parallel is enabled; "
260
+ "set tp_degree=None or tp_degree=1 during pipeline initialization"
261
+ )
262
+ assert not (self.config.use_fsdp and fused), (
263
+ "load fused LoRA is not allowed when fully sharded data parallel is enabled; "
264
+ "either load LoRA with fused=False or set use_fsdp=False during pipeline initialization"
265
+ )
266
+ super().load_loras(lora_list, fused, save_original_weight)
267
+
268
+ def unload_loras(self):
269
+ if hasattr(self.dit, "unload_loras"):
270
+ self.dit.unload_loras()
271
+ self.noise_scheduler.restore_config()
272
+
273
+ def apply_scheduler_config(self, scheduler_config: Dict):
274
+ self.noise_scheduler.update_config(scheduler_config)
275
+
276
+ def prepare_latents(
277
+ self,
278
+ latents: torch.Tensor,
279
+ num_inference_steps: int,
280
+ denoising_strength: float = 1.0,
281
+ height: int = 1024,
282
+ width: int = 1024,
283
+ ):
284
+ # Compute dynamic shift length for FLUX.2 scheduler
285
+ dynamic_shift_len = (height // 16) * (width // 16)
286
+
287
+ # Match original FLUX.2 scheduler parameters
288
+ sigma_min = 1.0 / num_inference_steps
289
+ sigma_max = 1.0
290
+
291
+ sigmas, timesteps = self.noise_scheduler.schedule(
292
+ num_inference_steps,
293
+ sigma_min=sigma_min,
294
+ sigma_max=sigma_max,
295
+ mu=self._compute_empirical_mu(dynamic_shift_len, num_inference_steps)
296
+ )
297
+
298
+ # Apply denoising strength by truncating the schedule
299
+ if denoising_strength < 1.0:
300
+ num_actual_steps = max(1, int(num_inference_steps * denoising_strength))
301
+ sigmas = sigmas[:num_actual_steps + 1]
302
+ timesteps = timesteps[:num_actual_steps]
303
+
304
+ sigmas = sigmas.to(device=self.device, dtype=self.dtype)
305
+ timesteps = timesteps.to(device=self.device, dtype=self.dtype)
306
+ latents = latents.to(device=self.device, dtype=self.dtype)
307
+
308
+ return latents, sigmas, timesteps
309
+
310
+ def _compute_empirical_mu(self, image_seq_len: int, num_steps: int) -> float:
311
+ """Compute empirical mu for FLUX.2 scheduler (matching original implementation)"""
312
+ a1, b1 = 8.73809524e-05, 1.89833333
313
+ a2, b2 = 0.00016927, 0.45666666
314
+
315
+ if image_seq_len > 4300:
316
+ mu = a2 * image_seq_len + b2
317
+ return float(mu)
318
+
319
+ m_200 = a2 * image_seq_len + b2
320
+ m_10 = a1 * image_seq_len + b1
321
+
322
+ a = (m_200 - m_10) / 190.0
323
+ b = m_200 - 200.0 * a
324
+ mu = a * num_steps + b
325
+
326
+ return float(mu)
327
+
328
+ def encode_prompt(
329
+ self,
330
+ prompt: Union[str, List[str]],
331
+ max_sequence_length: int = 512,
332
+ ):
333
+ prompt = [prompt] if isinstance(prompt, str) else prompt
334
+
335
+ all_input_ids = []
336
+ all_attention_masks = []
337
+
338
+ for single_prompt in prompt:
339
+ messages = [{"role": "user", "content": single_prompt}]
340
+ text = self.tokenizer.apply_chat_template(
341
+ messages,
342
+ tokenize=False,
343
+ add_generation_prompt=True,
344
+ enable_thinking=False,
345
+ )
346
+ inputs = self.tokenizer(
347
+ text,
348
+ return_tensors="pt",
349
+ padding="max_length",
350
+ truncation=True,
351
+ max_length=max_sequence_length,
352
+ )
353
+
354
+ all_input_ids.append(inputs["input_ids"])
355
+ all_attention_masks.append(inputs["attention_mask"])
356
+
357
+ input_ids = torch.cat(all_input_ids, dim=0).to(self.device)
358
+ attention_mask = torch.cat(all_attention_masks, dim=0).to(self.device)
359
+
360
+ # Forward pass through the model
361
+ with torch.inference_mode():
362
+ output = self.text_encoder(
363
+ input_ids=input_ids,
364
+ attention_mask=attention_mask,
365
+ output_hidden_states=True,
366
+ use_cache=False,
367
+ )
368
+
369
+ # Use outputs from intermediate layers (9, 18, 27) for Qwen3 (matching original behavior)
370
+ hidden_states = output["hidden_states"] if isinstance(output, dict) else output.hidden_states
371
+ out = torch.stack([hidden_states[k] for k in (9, 18, 27)], dim=1)
372
+ out = out.to(dtype=self.dtype, device=self.device)
373
+
374
+ batch_size, num_channels, seq_len, hidden_dim = out.shape
375
+ prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
376
+
377
+ # Prepare text IDs
378
+ text_ids = self.prepare_text_ids(prompt_embeds)
379
+ text_ids = text_ids.to(self.device)
380
+
381
+ return prompt_embeds, text_ids
382
+
383
+ def prepare_text_ids(
384
+ self,
385
+ x: torch.Tensor, # (B, L, D) or (L, D)
386
+ t_coord: Optional[torch.Tensor] = None,
387
+ ):
388
+ B, L, _ = x.shape
389
+ out_ids = []
390
+
391
+ for i in range(B):
392
+ t = torch.arange(1) if t_coord is None else t_coord[i]
393
+ h = torch.arange(1)
394
+ w = torch.arange(1)
395
+ l = torch.arange(L)
396
+
397
+ coords = torch.cartesian_prod(t, h, w, l)
398
+ out_ids.append(coords)
399
+
400
+ return torch.stack(out_ids)
401
+
402
+ def calculate_dimensions(self, target_area, ratio):
403
+ width = math.sqrt(target_area * ratio)
404
+ height = width / ratio
405
+ width = round(width / 32) * 32
406
+ height = round(height / 32) * 32
407
+ return width, height
408
+
409
+ def prepare_image_ids(self, height, width):
410
+ t = torch.arange(1) # [0] - time dimension
411
+ h = torch.arange(height)
412
+ w = torch.arange(width)
413
+ l = torch.arange(1) # [0] - layer dimension
414
+
415
+ # Create position IDs: (H*W, 4)
416
+ image_ids = torch.cartesian_prod(t, h, w, l)
417
+
418
+ # Expand to batch: (B, H*W, 4)
419
+ image_ids = image_ids.unsqueeze(0).expand(1, -1, -1)
420
+
421
+ return image_ids
422
+
423
+ def predict_noise(
424
+ self,
425
+ latents: torch.Tensor,
426
+ timestep: torch.Tensor,
427
+ prompt_embeds: torch.Tensor,
428
+ text_ids: torch.Tensor,
429
+ image_ids: torch.Tensor,
430
+ embedded_guidance: float = 4.0,
431
+ edit_latents: torch.Tensor = None,
432
+ edit_image_ids: torch.Tensor = None,
433
+ ):
434
+ self.load_models_to_device(["dit"])
435
+
436
+ # Handle edit images by concatenating latents and image IDs
437
+ if edit_latents is not None and edit_image_ids is not None:
438
+ latents = torch.concat([latents, edit_latents], dim=1)
439
+ image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
440
+
441
+ embedded_guidance_tensor = torch.tensor([embedded_guidance], device=latents.device)
442
+
443
+ noise_pred = self.dit(
444
+ hidden_states=latents,
445
+ timestep=timestep / 1000,
446
+ guidance=embedded_guidance_tensor,
447
+ encoder_hidden_states=prompt_embeds,
448
+ txt_ids=text_ids,
449
+ img_ids=image_ids,
450
+ )
451
+
452
+ # Return only the original image sequence length if edit images were used
453
+ if edit_latents is not None:
454
+ noise_pred = noise_pred[:, :image_ids.shape[1] - edit_image_ids.shape[1]]
455
+
456
+ return noise_pred
457
+
458
+ def encode_edit_image(
459
+ self,
460
+ edit_image: Union[Image.Image, List[Image.Image]],
461
+ edit_image_auto_resize: bool = True,
462
+ ):
463
+ """Encode edit image(s) to latents for FLUX.2 pipeline"""
464
+ if edit_image is None:
465
+ return None, None
466
+
467
+ self.load_models_to_device(["vae"])
468
+
469
+ if isinstance(edit_image, Image.Image):
470
+ edit_image = [edit_image]
471
+
472
+ resized_edit_image, edit_latents = [], []
473
+ for image in edit_image:
474
+ # Preprocess
475
+ if edit_image_auto_resize:
476
+ image = self.edit_image_auto_resize(image)
477
+ resized_edit_image.append(image)
478
+ # Encode
479
+ image_tensor = self.preprocess_image(image).to(dtype=self.dtype, device=self.device)
480
+ latents = self.vae.encode(image_tensor)
481
+ edit_latents.append(latents)
482
+
483
+ edit_image_ids = self.process_edit_image_ids(edit_latents)
484
+ edit_latents = torch.concat([rearrange(latents, "B C H W -> B (H W) C") for latents in edit_latents], dim=1)
485
+
486
+ return edit_latents, edit_image_ids
487
+
488
+ def edit_image_auto_resize(self, edit_image):
489
+ """Auto resize edit image to optimal dimensions"""
490
+ calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
491
+ return self.crop_and_resize(edit_image, calculated_height, calculated_width)
492
+
493
+ def crop_and_resize(self, image, target_height, target_width):
494
+ """Crop and resize image to target dimensions"""
495
+ width, height = image.size
496
+ scale = max(target_width / width, target_height / height)
497
+ image = torchvision.transforms.functional.resize(
498
+ image,
499
+ (round(height*scale), round(width*scale)),
500
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR
501
+ )
502
+ image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
503
+ return image
504
+
505
+ def process_edit_image_ids(self, image_latents, scale=10):
506
+ """Process image IDs for edit images"""
507
+ t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
508
+ t_coords = [t.view(-1) for t in t_coords]
509
+
510
+ image_latent_ids = []
511
+ for x, t in zip(image_latents, t_coords):
512
+ x = x.squeeze(0)
513
+ _, height, width = x.shape
514
+
515
+ x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
516
+ image_latent_ids.append(x_ids)
517
+
518
+ image_latent_ids = torch.cat(image_latent_ids, dim=0)
519
+ image_latent_ids = image_latent_ids.unsqueeze(0)
520
+
521
+ return image_latent_ids
522
+
523
+ @torch.no_grad()
524
+ def __call__(
525
+ self,
526
+ prompt: Union[str, List[str]],
527
+ negative_prompt: Optional[Union[str, List[str]]] = None,
528
+ height: int = 1024,
529
+ width: int = 1024,
530
+ num_inference_steps: int = 30,
531
+ cfg_scale: float = 1.0,
532
+ embedded_guidance: float = 4.0,
533
+ denoising_strength: float = 1.0,
534
+ seed: Optional[int] = None,
535
+ progress_callback: Optional[Callable] = None,
536
+ # Edit image parameters
537
+ edit_image: Union[Image.Image, List[Image.Image]] = None,
538
+ edit_image_auto_resize: bool = True,
539
+ ):
540
+ self.validate_image_size(height, width, multiple_of=16)
541
+
542
+ # Encode prompts
543
+ self.load_models_to_device(["text_encoder"])
544
+ prompt_embeds, text_ids = self.encode_prompt(prompt)
545
+ if negative_prompt is not None:
546
+ negative_prompt_embeds, negative_text_ids = self.encode_prompt(negative_prompt)
547
+ else:
548
+ negative_prompt_embeds, negative_text_ids = None, None
549
+ self.model_lifecycle_finish(["text_encoder"])
550
+
551
+ # Encode edit images if provided
552
+ edit_latents, edit_image_ids = None, None
553
+ if edit_image is not None:
554
+ edit_latents, edit_image_ids = self.encode_edit_image(edit_image, edit_image_auto_resize)
555
+ if edit_latents is not None:
556
+ edit_latents = edit_latents.to(device=self.device, dtype=self.dtype)
557
+ edit_image_ids = edit_image_ids.to(device=self.device, dtype=self.dtype)
558
+
559
+ # Generate initial noise
560
+ noise = self.generate_noise((1, 128, height // 16, width // 16), seed=seed, device="cpu", dtype=self.dtype).to(
561
+ device=self.device
562
+ )
563
+ noise = noise.reshape(1, 128, height//16 * width//16).permute(0, 2, 1)
564
+
565
+ # Prepare latents with noise scheduling
566
+ latents, sigmas, timesteps = self.prepare_latents(noise, num_inference_steps, denoising_strength, height, width)
567
+
568
+ self.sampler.initialize(sigmas=sigmas)
569
+
570
+ # Prepare image IDs
571
+ image_ids = self.prepare_image_ids(height // 16, width // 16).to(self.device)
572
+
573
+ # Denoising loop
574
+ self.load_models_to_device(["dit"])
575
+ for i, timestep in enumerate(tqdm(timesteps)):
576
+ timestep = timestep.unsqueeze(0).to(dtype=self.dtype)
577
+
578
+ if cfg_scale > 1.0 and negative_prompt_embeds is not None:
579
+ # CFG prediction
580
+ latents_input = torch.cat([latents] * 2, dim=0)
581
+ timestep_input = torch.cat([timestep] * 2, dim=0)
582
+ prompt_embeds_input = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
583
+ text_ids_input = torch.cat([text_ids, negative_text_ids], dim=0)
584
+ image_ids_input = torch.cat([image_ids] * 2, dim=0)
585
+
586
+ # Handle edit images for CFG
587
+ edit_latents_input = None
588
+ edit_image_ids_input = None
589
+ if edit_latents is not None:
590
+ edit_latents_input = torch.cat([edit_latents] * 2, dim=0)
591
+ edit_image_ids_input = torch.cat([edit_image_ids] * 2, dim=0)
592
+
593
+ noise_pred = self.predict_noise(
594
+ latents=latents_input,
595
+ timestep=timestep_input,
596
+ prompt_embeds=prompt_embeds_input,
597
+ text_ids=text_ids_input,
598
+ image_ids=image_ids_input,
599
+ embedded_guidance=embedded_guidance,
600
+ edit_latents=edit_latents_input,
601
+ edit_image_ids=edit_image_ids_input,
602
+ )
603
+
604
+ # Split predictions and apply CFG
605
+ noise_pred_positive, noise_pred_negative = noise_pred.chunk(2)
606
+ noise_pred = noise_pred_negative + cfg_scale * (noise_pred_positive - noise_pred_negative)
607
+ else:
608
+ # Non-CFG prediction
609
+ noise_pred = self.predict_noise(
610
+ latents=latents,
611
+ timestep=timestep,
612
+ prompt_embeds=prompt_embeds,
613
+ text_ids=text_ids,
614
+ image_ids=image_ids,
615
+ embedded_guidance=embedded_guidance,
616
+ edit_latents=edit_latents,
617
+ edit_image_ids=edit_image_ids,
618
+ )
619
+
620
+ latents = self.sampler.step(latents, noise_pred, i)
621
+ if progress_callback is not None:
622
+ progress_callback(i, len(timesteps), "DENOISING")
623
+
624
+ self.model_lifecycle_finish(["dit"])
625
+
626
+ # Decode final latents
627
+ self.load_models_to_device(["vae"])
628
+ latents = rearrange(latents, "B (H W) C -> B C H W", H=height//16, W=width//16)
629
+ vae_output = self.vae.decode(latents)
630
+ image = self.vae_output_to_image(vae_output)
631
+
632
+ # Offload all models
633
+ self.load_models_to_device([])
634
+ return image
@@ -21,6 +21,7 @@ VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "components", "vae.json")
21
21
  FLUX_DIT_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_dit.json")
22
22
  FLUX_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_text_encoder.json")
23
23
  FLUX_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_vae.json")
24
+ FLUX2_TEXT_ENCODER_8B_CONF_PATH = os.path.join(CONF_PATH, "models", "flux2", "qwen3_8B_config.json")
24
25
  SD_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd", "sd_text_encoder.json")
25
26
  SD_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd", "sd_unet.json")
26
27
  SD3_DIT_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd3", "sd3_dit.json")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.7.0
3
+ Version: 0.7.1.dev2
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent