diffsynth-engine 0.7.0__py3-none-any.whl → 0.7.1.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +6 -0
- diffsynth_engine/conf/models/flux2/qwen3_8B_config.json +68 -0
- diffsynth_engine/configs/__init__.py +4 -0
- diffsynth_engine/configs/pipeline.py +50 -1
- diffsynth_engine/models/flux2/__init__.py +7 -0
- diffsynth_engine/models/flux2/flux2_dit.py +1065 -0
- diffsynth_engine/models/flux2/flux2_vae.py +1992 -0
- diffsynth_engine/pipelines/__init__.py +2 -0
- diffsynth_engine/pipelines/flux2_klein_image.py +634 -0
- diffsynth_engine/utils/constants.py +1 -0
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/RECORD +15 -10
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/WHEEL +1 -1
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.7.0.dist-info → diffsynth_engine-0.7.1.dev1.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from .base import BasePipeline, LoRAStateDictConverter
|
|
2
2
|
from .flux_image import FluxImagePipeline
|
|
3
|
+
from .flux2_klein_image import Flux2KleinPipeline
|
|
3
4
|
from .sdxl_image import SDXLImagePipeline
|
|
4
5
|
from .sd_image import SDImagePipeline
|
|
5
6
|
from .wan_video import WanVideoPipeline
|
|
@@ -14,6 +15,7 @@ __all__ = [
|
|
|
14
15
|
"BasePipeline",
|
|
15
16
|
"LoRAStateDictConverter",
|
|
16
17
|
"FluxImagePipeline",
|
|
18
|
+
"Flux2KleinPipeline",
|
|
17
19
|
"SDXLImagePipeline",
|
|
18
20
|
"SDImagePipeline",
|
|
19
21
|
"WanVideoPipeline",
|
|
@@ -0,0 +1,634 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import math
|
|
3
|
+
import json
|
|
4
|
+
import torchvision
|
|
5
|
+
from typing import Callable, List, Dict, Tuple, Optional, Union
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
from PIL import Image
|
|
8
|
+
import numpy as np
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
|
|
11
|
+
from diffsynth_engine.configs import (
|
|
12
|
+
Flux2KleinPipelineConfig,
|
|
13
|
+
Flux2StateDicts,
|
|
14
|
+
)
|
|
15
|
+
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
16
|
+
|
|
17
|
+
from diffsynth_engine.models.flux2 import (
|
|
18
|
+
Flux2DiT,
|
|
19
|
+
Flux2VAE,
|
|
20
|
+
)
|
|
21
|
+
from diffsynth_engine.models.z_image import (
|
|
22
|
+
Qwen3Model,
|
|
23
|
+
Qwen3Config,
|
|
24
|
+
)
|
|
25
|
+
from transformers import AutoTokenizer
|
|
26
|
+
from diffsynth_engine.utils.constants import (
|
|
27
|
+
Z_IMAGE_TEXT_ENCODER_CONFIG_FILE,
|
|
28
|
+
Z_IMAGE_TOKENIZER_CONF_PATH,
|
|
29
|
+
FLUX2_TEXT_ENCODER_8B_CONF_PATH,
|
|
30
|
+
)
|
|
31
|
+
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
32
|
+
from diffsynth_engine.pipelines.utils import calculate_shift
|
|
33
|
+
from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
|
|
34
|
+
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
|
|
35
|
+
from diffsynth_engine.utils.parallel import ParallelWrapper
|
|
36
|
+
from diffsynth_engine.utils import logging
|
|
37
|
+
from diffsynth_engine.utils.fp8_linear import enable_fp8_linear
|
|
38
|
+
from diffsynth_engine.utils.download import fetch_model
|
|
39
|
+
|
|
40
|
+
logger = logging.get_logger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Flux2LoRAConverter(LoRAStateDictConverter):
|
|
44
|
+
def _from_diffusers(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
45
|
+
dit_dict = {}
|
|
46
|
+
for key, param in lora_state_dict.items():
|
|
47
|
+
if "lora_A.weight" in key:
|
|
48
|
+
lora_b_key = key.replace("lora_A.weight", "lora_B.weight")
|
|
49
|
+
target_key = key.replace(".lora_A.weight", "").replace("diffusion_model.", "")
|
|
50
|
+
|
|
51
|
+
up = lora_state_dict[lora_b_key]
|
|
52
|
+
rank = up.shape[1]
|
|
53
|
+
|
|
54
|
+
dit_dict[target_key] = {
|
|
55
|
+
"down": param,
|
|
56
|
+
"up": up,
|
|
57
|
+
"rank": rank,
|
|
58
|
+
"alpha": lora_state_dict.get(key.replace("lora_A.weight", "alpha"), rank),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
return {"dit": dit_dict}
|
|
62
|
+
|
|
63
|
+
def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
64
|
+
dit_dict = {}
|
|
65
|
+
for key, param in lora_state_dict.items():
|
|
66
|
+
if "lora_A.default.weight" in key:
|
|
67
|
+
lora_b_key = key.replace("lora_A.default.weight", "lora_B.default.weight")
|
|
68
|
+
target_key = key.replace(".lora_A.default.weight", "")
|
|
69
|
+
|
|
70
|
+
up = lora_state_dict[lora_b_key]
|
|
71
|
+
rank = up.shape[1]
|
|
72
|
+
|
|
73
|
+
dit_dict[target_key] = {
|
|
74
|
+
"down": param,
|
|
75
|
+
"up": up,
|
|
76
|
+
"rank": rank,
|
|
77
|
+
"alpha": lora_state_dict.get(key.replace("lora_A.default.weight", "alpha"), rank),
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
return {"dit": dit_dict}
|
|
81
|
+
|
|
82
|
+
def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
83
|
+
key = list(lora_state_dict.keys())[0]
|
|
84
|
+
if key.startswith("diffusion_model."):
|
|
85
|
+
return self._from_diffusers(lora_state_dict)
|
|
86
|
+
else:
|
|
87
|
+
return self._from_diffsynth(lora_state_dict)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def model_fn_flux2(
|
|
91
|
+
dit: Flux2DiT,
|
|
92
|
+
latents=None,
|
|
93
|
+
timestep=None,
|
|
94
|
+
embedded_guidance=None,
|
|
95
|
+
prompt_embeds=None,
|
|
96
|
+
text_ids=None,
|
|
97
|
+
image_ids=None,
|
|
98
|
+
edit_latents=None,
|
|
99
|
+
edit_image_ids=None,
|
|
100
|
+
use_gradient_checkpointing=False,
|
|
101
|
+
use_gradient_checkpointing_offload=False,
|
|
102
|
+
**kwargs,
|
|
103
|
+
):
|
|
104
|
+
image_seq_len = latents.shape[1]
|
|
105
|
+
if edit_latents is not None:
|
|
106
|
+
latents = torch.concat([latents, edit_latents], dim=1)
|
|
107
|
+
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
|
|
108
|
+
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
|
|
109
|
+
model_output = dit(
|
|
110
|
+
hidden_states=latents,
|
|
111
|
+
timestep=timestep / 1000,
|
|
112
|
+
guidance=embedded_guidance,
|
|
113
|
+
encoder_hidden_states=prompt_embeds,
|
|
114
|
+
txt_ids=text_ids,
|
|
115
|
+
img_ids=image_ids,
|
|
116
|
+
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
117
|
+
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
118
|
+
)
|
|
119
|
+
model_output = model_output[:, :image_seq_len]
|
|
120
|
+
return model_output
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class Flux2KleinPipeline(BasePipeline):
|
|
124
|
+
lora_converter = Flux2LoRAConverter()
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
config: Flux2KleinPipelineConfig,
|
|
129
|
+
tokenizer: AutoTokenizer,
|
|
130
|
+
text_encoder: Qwen3Model,
|
|
131
|
+
dit: Flux2DiT,
|
|
132
|
+
vae: Flux2VAE,
|
|
133
|
+
):
|
|
134
|
+
super().__init__(
|
|
135
|
+
vae_tiled=config.vae_tiled,
|
|
136
|
+
vae_tile_size=config.vae_tile_size,
|
|
137
|
+
vae_tile_stride=config.vae_tile_stride,
|
|
138
|
+
device=config.device,
|
|
139
|
+
dtype=config.model_dtype,
|
|
140
|
+
)
|
|
141
|
+
self.config = config
|
|
142
|
+
|
|
143
|
+
# Scheduler
|
|
144
|
+
self.noise_scheduler = RecifitedFlowScheduler(shift=1.0, use_dynamic_shifting=True, exponential_shift_mu=None)
|
|
145
|
+
self.sampler = FlowMatchEulerSampler()
|
|
146
|
+
self.tokenizer = tokenizer
|
|
147
|
+
# Models
|
|
148
|
+
self.text_encoder = text_encoder
|
|
149
|
+
self.dit = dit
|
|
150
|
+
self.vae = vae
|
|
151
|
+
|
|
152
|
+
self.model_names = ["text_encoder", "dit", "vae"]
|
|
153
|
+
|
|
154
|
+
@classmethod
|
|
155
|
+
def from_pretrained(cls, model_path_or_config: str | Flux2KleinPipelineConfig) -> "Flux2KleinPipeline":
|
|
156
|
+
if isinstance(model_path_or_config, str):
|
|
157
|
+
config = Flux2KleinPipelineConfig(model_path=model_path_or_config)
|
|
158
|
+
else:
|
|
159
|
+
config = model_path_or_config
|
|
160
|
+
|
|
161
|
+
logger.info(f"Loading state dict from {config.model_path} ...")
|
|
162
|
+
|
|
163
|
+
model_state_dict = cls.load_model_checkpoint(
|
|
164
|
+
config.model_path, device="cpu", dtype=config.model_dtype, convert_dtype=False
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if config.vae_path is None:
|
|
168
|
+
config.vae_path = fetch_model("black-forest-labs/FLUX.2-klein-4B", path="vae/*.safetensors")
|
|
169
|
+
logger.info(f"Loading VAE from {config.vae_path} ...")
|
|
170
|
+
vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
|
|
171
|
+
|
|
172
|
+
if config.encoder_path is None:
|
|
173
|
+
if config.model_size == "4B":
|
|
174
|
+
config.encoder_path = fetch_model("black-forest-labs/FLUX.2-klein-4B", path="text_encoder/*.safetensors")
|
|
175
|
+
else:
|
|
176
|
+
config.encoder_path = fetch_model("black-forest-labs/FLUX.2-klein-9B", path="text_encoder/*.safetensors")
|
|
177
|
+
logger.info(f"Loading Text Encoder from {config.encoder_path} ...")
|
|
178
|
+
text_encoder_state_dict = cls.load_model_checkpoint(
|
|
179
|
+
config.encoder_path, device="cpu", dtype=config.encoder_dtype
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
state_dicts = Flux2StateDicts(
|
|
183
|
+
model=model_state_dict,
|
|
184
|
+
vae=vae_state_dict,
|
|
185
|
+
encoder=text_encoder_state_dict,
|
|
186
|
+
)
|
|
187
|
+
return cls.from_state_dict(state_dicts, config)
|
|
188
|
+
|
|
189
|
+
@classmethod
|
|
190
|
+
def from_state_dict(cls, state_dicts: Flux2StateDicts, config: Flux2KleinPipelineConfig) -> "Flux2KleinPipeline":
|
|
191
|
+
assert config.parallelism <= 1, "Flux2 doesn't support parallelism > 1"
|
|
192
|
+
pipe = cls._from_state_dict(state_dicts, config)
|
|
193
|
+
return pipe
|
|
194
|
+
|
|
195
|
+
@classmethod
|
|
196
|
+
def _from_state_dict(cls, state_dicts: Flux2StateDicts, config: Flux2KleinPipelineConfig) -> "Flux2KleinPipeline":
|
|
197
|
+
init_device = "cpu" if config.offload_mode is not None else config.device
|
|
198
|
+
if config.model_size == "4B":
|
|
199
|
+
with open(Z_IMAGE_TEXT_ENCODER_CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
200
|
+
qwen3_config = Qwen3Config(**json.load(f))
|
|
201
|
+
dit_config = {}
|
|
202
|
+
else:
|
|
203
|
+
with open(FLUX2_TEXT_ENCODER_8B_CONF_PATH, "r", encoding="utf-8") as f:
|
|
204
|
+
qwen3_config = Qwen3Config(**json.load(f))
|
|
205
|
+
state_dicts.encoder.pop("lm_head.weight")
|
|
206
|
+
dit_config = {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
|
|
207
|
+
text_encoder = Qwen3Model.from_state_dict(
|
|
208
|
+
state_dicts.encoder, config=qwen3_config, device=init_device, dtype=config.encoder_dtype
|
|
209
|
+
)
|
|
210
|
+
tokenizer = AutoTokenizer.from_pretrained(Z_IMAGE_TOKENIZER_CONF_PATH, local_files_only=True)
|
|
211
|
+
vae = Flux2VAE.from_state_dict(state_dicts.vae, device=init_device, dtype=config.vae_dtype)
|
|
212
|
+
|
|
213
|
+
with LoRAContext():
|
|
214
|
+
dit = Flux2DiT.from_state_dict(
|
|
215
|
+
state_dicts.model,
|
|
216
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
217
|
+
dtype=config.model_dtype,
|
|
218
|
+
**dit_config,
|
|
219
|
+
)
|
|
220
|
+
if config.use_fp8_linear:
|
|
221
|
+
enable_fp8_linear(dit)
|
|
222
|
+
|
|
223
|
+
pipe = cls(
|
|
224
|
+
config=config,
|
|
225
|
+
tokenizer=tokenizer,
|
|
226
|
+
text_encoder=text_encoder,
|
|
227
|
+
dit=dit,
|
|
228
|
+
vae=vae,
|
|
229
|
+
)
|
|
230
|
+
pipe.eval()
|
|
231
|
+
|
|
232
|
+
if config.offload_mode is not None:
|
|
233
|
+
pipe.enable_cpu_offload(config.offload_mode, config.offload_to_disk)
|
|
234
|
+
|
|
235
|
+
if config.model_dtype == torch.float8_e4m3fn:
|
|
236
|
+
pipe.dtype = torch.bfloat16
|
|
237
|
+
pipe.enable_fp8_autocast(
|
|
238
|
+
model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
if config.use_torch_compile:
|
|
242
|
+
pipe.compile()
|
|
243
|
+
|
|
244
|
+
return pipe
|
|
245
|
+
|
|
246
|
+
def update_weights(self, state_dicts: Flux2StateDicts) -> None:
|
|
247
|
+
self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
|
|
248
|
+
self.update_component(
|
|
249
|
+
self.text_encoder, state_dicts.encoder, self.config.device, self.config.encoder_dtype
|
|
250
|
+
)
|
|
251
|
+
self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
|
|
252
|
+
|
|
253
|
+
def compile(self):
|
|
254
|
+
if hasattr(self.dit, "compile_repeated_blocks"):
|
|
255
|
+
self.dit.compile_repeated_blocks()
|
|
256
|
+
|
|
257
|
+
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
|
|
258
|
+
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
|
|
259
|
+
"load LoRA is not allowed when tensor parallel is enabled; "
|
|
260
|
+
"set tp_degree=None or tp_degree=1 during pipeline initialization"
|
|
261
|
+
)
|
|
262
|
+
assert not (self.config.use_fsdp and fused), (
|
|
263
|
+
"load fused LoRA is not allowed when fully sharded data parallel is enabled; "
|
|
264
|
+
"either load LoRA with fused=False or set use_fsdp=False during pipeline initialization"
|
|
265
|
+
)
|
|
266
|
+
super().load_loras(lora_list, fused, save_original_weight)
|
|
267
|
+
|
|
268
|
+
def unload_loras(self):
|
|
269
|
+
if hasattr(self.dit, "unload_loras"):
|
|
270
|
+
self.dit.unload_loras()
|
|
271
|
+
self.noise_scheduler.restore_config()
|
|
272
|
+
|
|
273
|
+
def apply_scheduler_config(self, scheduler_config: Dict):
|
|
274
|
+
self.noise_scheduler.update_config(scheduler_config)
|
|
275
|
+
|
|
276
|
+
def prepare_latents(
|
|
277
|
+
self,
|
|
278
|
+
latents: torch.Tensor,
|
|
279
|
+
num_inference_steps: int,
|
|
280
|
+
denoising_strength: float = 1.0,
|
|
281
|
+
height: int = 1024,
|
|
282
|
+
width: int = 1024,
|
|
283
|
+
):
|
|
284
|
+
# Compute dynamic shift length for FLUX.2 scheduler
|
|
285
|
+
dynamic_shift_len = (height // 16) * (width // 16)
|
|
286
|
+
|
|
287
|
+
# Match original FLUX.2 scheduler parameters
|
|
288
|
+
sigma_min = 1.0 / num_inference_steps
|
|
289
|
+
sigma_max = 1.0
|
|
290
|
+
|
|
291
|
+
sigmas, timesteps = self.noise_scheduler.schedule(
|
|
292
|
+
num_inference_steps,
|
|
293
|
+
sigma_min=sigma_min,
|
|
294
|
+
sigma_max=sigma_max,
|
|
295
|
+
mu=self._compute_empirical_mu(dynamic_shift_len, num_inference_steps)
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Apply denoising strength by truncating the schedule
|
|
299
|
+
if denoising_strength < 1.0:
|
|
300
|
+
num_actual_steps = max(1, int(num_inference_steps * denoising_strength))
|
|
301
|
+
sigmas = sigmas[:num_actual_steps + 1]
|
|
302
|
+
timesteps = timesteps[:num_actual_steps]
|
|
303
|
+
|
|
304
|
+
sigmas = sigmas.to(device=self.device, dtype=self.dtype)
|
|
305
|
+
timesteps = timesteps.to(device=self.device, dtype=self.dtype)
|
|
306
|
+
latents = latents.to(device=self.device, dtype=self.dtype)
|
|
307
|
+
|
|
308
|
+
return latents, sigmas, timesteps
|
|
309
|
+
|
|
310
|
+
def _compute_empirical_mu(self, image_seq_len: int, num_steps: int) -> float:
|
|
311
|
+
"""Compute empirical mu for FLUX.2 scheduler (matching original implementation)"""
|
|
312
|
+
a1, b1 = 8.73809524e-05, 1.89833333
|
|
313
|
+
a2, b2 = 0.00016927, 0.45666666
|
|
314
|
+
|
|
315
|
+
if image_seq_len > 4300:
|
|
316
|
+
mu = a2 * image_seq_len + b2
|
|
317
|
+
return float(mu)
|
|
318
|
+
|
|
319
|
+
m_200 = a2 * image_seq_len + b2
|
|
320
|
+
m_10 = a1 * image_seq_len + b1
|
|
321
|
+
|
|
322
|
+
a = (m_200 - m_10) / 190.0
|
|
323
|
+
b = m_200 - 200.0 * a
|
|
324
|
+
mu = a * num_steps + b
|
|
325
|
+
|
|
326
|
+
return float(mu)
|
|
327
|
+
|
|
328
|
+
def encode_prompt(
|
|
329
|
+
self,
|
|
330
|
+
prompt: Union[str, List[str]],
|
|
331
|
+
max_sequence_length: int = 512,
|
|
332
|
+
):
|
|
333
|
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
334
|
+
|
|
335
|
+
all_input_ids = []
|
|
336
|
+
all_attention_masks = []
|
|
337
|
+
|
|
338
|
+
for single_prompt in prompt:
|
|
339
|
+
messages = [{"role": "user", "content": single_prompt}]
|
|
340
|
+
text = self.tokenizer.apply_chat_template(
|
|
341
|
+
messages,
|
|
342
|
+
tokenize=False,
|
|
343
|
+
add_generation_prompt=True,
|
|
344
|
+
enable_thinking=False,
|
|
345
|
+
)
|
|
346
|
+
inputs = self.tokenizer(
|
|
347
|
+
text,
|
|
348
|
+
return_tensors="pt",
|
|
349
|
+
padding="max_length",
|
|
350
|
+
truncation=True,
|
|
351
|
+
max_length=max_sequence_length,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
all_input_ids.append(inputs["input_ids"])
|
|
355
|
+
all_attention_masks.append(inputs["attention_mask"])
|
|
356
|
+
|
|
357
|
+
input_ids = torch.cat(all_input_ids, dim=0).to(self.device)
|
|
358
|
+
attention_mask = torch.cat(all_attention_masks, dim=0).to(self.device)
|
|
359
|
+
|
|
360
|
+
# Forward pass through the model
|
|
361
|
+
with torch.inference_mode():
|
|
362
|
+
output = self.text_encoder(
|
|
363
|
+
input_ids=input_ids,
|
|
364
|
+
attention_mask=attention_mask,
|
|
365
|
+
output_hidden_states=True,
|
|
366
|
+
use_cache=False,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# Use outputs from intermediate layers (9, 18, 27) for Qwen3 (matching original behavior)
|
|
370
|
+
hidden_states = output["hidden_states"] if isinstance(output, dict) else output.hidden_states
|
|
371
|
+
out = torch.stack([hidden_states[k] for k in (9, 18, 27)], dim=1)
|
|
372
|
+
out = out.to(dtype=self.dtype, device=self.device)
|
|
373
|
+
|
|
374
|
+
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
|
375
|
+
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
|
376
|
+
|
|
377
|
+
# Prepare text IDs
|
|
378
|
+
text_ids = self.prepare_text_ids(prompt_embeds)
|
|
379
|
+
text_ids = text_ids.to(self.device)
|
|
380
|
+
|
|
381
|
+
return prompt_embeds, text_ids
|
|
382
|
+
|
|
383
|
+
def prepare_text_ids(
|
|
384
|
+
self,
|
|
385
|
+
x: torch.Tensor, # (B, L, D) or (L, D)
|
|
386
|
+
t_coord: Optional[torch.Tensor] = None,
|
|
387
|
+
):
|
|
388
|
+
B, L, _ = x.shape
|
|
389
|
+
out_ids = []
|
|
390
|
+
|
|
391
|
+
for i in range(B):
|
|
392
|
+
t = torch.arange(1) if t_coord is None else t_coord[i]
|
|
393
|
+
h = torch.arange(1)
|
|
394
|
+
w = torch.arange(1)
|
|
395
|
+
l = torch.arange(L)
|
|
396
|
+
|
|
397
|
+
coords = torch.cartesian_prod(t, h, w, l)
|
|
398
|
+
out_ids.append(coords)
|
|
399
|
+
|
|
400
|
+
return torch.stack(out_ids)
|
|
401
|
+
|
|
402
|
+
def calculate_dimensions(self, target_area, ratio):
|
|
403
|
+
width = math.sqrt(target_area * ratio)
|
|
404
|
+
height = width / ratio
|
|
405
|
+
width = round(width / 32) * 32
|
|
406
|
+
height = round(height / 32) * 32
|
|
407
|
+
return width, height
|
|
408
|
+
|
|
409
|
+
def prepare_image_ids(self, height, width):
|
|
410
|
+
t = torch.arange(1) # [0] - time dimension
|
|
411
|
+
h = torch.arange(height)
|
|
412
|
+
w = torch.arange(width)
|
|
413
|
+
l = torch.arange(1) # [0] - layer dimension
|
|
414
|
+
|
|
415
|
+
# Create position IDs: (H*W, 4)
|
|
416
|
+
image_ids = torch.cartesian_prod(t, h, w, l)
|
|
417
|
+
|
|
418
|
+
# Expand to batch: (B, H*W, 4)
|
|
419
|
+
image_ids = image_ids.unsqueeze(0).expand(1, -1, -1)
|
|
420
|
+
|
|
421
|
+
return image_ids
|
|
422
|
+
|
|
423
|
+
def predict_noise(
|
|
424
|
+
self,
|
|
425
|
+
latents: torch.Tensor,
|
|
426
|
+
timestep: torch.Tensor,
|
|
427
|
+
prompt_embeds: torch.Tensor,
|
|
428
|
+
text_ids: torch.Tensor,
|
|
429
|
+
image_ids: torch.Tensor,
|
|
430
|
+
embedded_guidance: float = 4.0,
|
|
431
|
+
edit_latents: torch.Tensor = None,
|
|
432
|
+
edit_image_ids: torch.Tensor = None,
|
|
433
|
+
):
|
|
434
|
+
self.load_models_to_device(["dit"])
|
|
435
|
+
|
|
436
|
+
# Handle edit images by concatenating latents and image IDs
|
|
437
|
+
if edit_latents is not None and edit_image_ids is not None:
|
|
438
|
+
latents = torch.concat([latents, edit_latents], dim=1)
|
|
439
|
+
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
|
|
440
|
+
|
|
441
|
+
embedded_guidance_tensor = torch.tensor([embedded_guidance], device=latents.device)
|
|
442
|
+
|
|
443
|
+
noise_pred = self.dit(
|
|
444
|
+
hidden_states=latents,
|
|
445
|
+
timestep=timestep / 1000,
|
|
446
|
+
guidance=embedded_guidance_tensor,
|
|
447
|
+
encoder_hidden_states=prompt_embeds,
|
|
448
|
+
txt_ids=text_ids,
|
|
449
|
+
img_ids=image_ids,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# Return only the original image sequence length if edit images were used
|
|
453
|
+
if edit_latents is not None:
|
|
454
|
+
noise_pred = noise_pred[:, :image_ids.shape[1] - edit_image_ids.shape[1]]
|
|
455
|
+
|
|
456
|
+
return noise_pred
|
|
457
|
+
|
|
458
|
+
def encode_edit_image(
|
|
459
|
+
self,
|
|
460
|
+
edit_image: Union[Image.Image, List[Image.Image]],
|
|
461
|
+
edit_image_auto_resize: bool = True,
|
|
462
|
+
):
|
|
463
|
+
"""Encode edit image(s) to latents for FLUX.2 pipeline"""
|
|
464
|
+
if edit_image is None:
|
|
465
|
+
return None, None
|
|
466
|
+
|
|
467
|
+
self.load_models_to_device(["vae"])
|
|
468
|
+
|
|
469
|
+
if isinstance(edit_image, Image.Image):
|
|
470
|
+
edit_image = [edit_image]
|
|
471
|
+
|
|
472
|
+
resized_edit_image, edit_latents = [], []
|
|
473
|
+
for image in edit_image:
|
|
474
|
+
# Preprocess
|
|
475
|
+
if edit_image_auto_resize:
|
|
476
|
+
image = self.edit_image_auto_resize(image)
|
|
477
|
+
resized_edit_image.append(image)
|
|
478
|
+
# Encode
|
|
479
|
+
image_tensor = self.preprocess_image(image).to(dtype=self.dtype, device=self.device)
|
|
480
|
+
latents = self.vae.encode(image_tensor)
|
|
481
|
+
edit_latents.append(latents)
|
|
482
|
+
|
|
483
|
+
edit_image_ids = self.process_edit_image_ids(edit_latents)
|
|
484
|
+
edit_latents = torch.concat([rearrange(latents, "B C H W -> B (H W) C") for latents in edit_latents], dim=1)
|
|
485
|
+
|
|
486
|
+
return edit_latents, edit_image_ids
|
|
487
|
+
|
|
488
|
+
def edit_image_auto_resize(self, edit_image):
|
|
489
|
+
"""Auto resize edit image to optimal dimensions"""
|
|
490
|
+
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
|
|
491
|
+
return self.crop_and_resize(edit_image, calculated_height, calculated_width)
|
|
492
|
+
|
|
493
|
+
def crop_and_resize(self, image, target_height, target_width):
|
|
494
|
+
"""Crop and resize image to target dimensions"""
|
|
495
|
+
width, height = image.size
|
|
496
|
+
scale = max(target_width / width, target_height / height)
|
|
497
|
+
image = torchvision.transforms.functional.resize(
|
|
498
|
+
image,
|
|
499
|
+
(round(height*scale), round(width*scale)),
|
|
500
|
+
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
|
501
|
+
)
|
|
502
|
+
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
|
503
|
+
return image
|
|
504
|
+
|
|
505
|
+
def process_edit_image_ids(self, image_latents, scale=10):
|
|
506
|
+
"""Process image IDs for edit images"""
|
|
507
|
+
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
|
508
|
+
t_coords = [t.view(-1) for t in t_coords]
|
|
509
|
+
|
|
510
|
+
image_latent_ids = []
|
|
511
|
+
for x, t in zip(image_latents, t_coords):
|
|
512
|
+
x = x.squeeze(0)
|
|
513
|
+
_, height, width = x.shape
|
|
514
|
+
|
|
515
|
+
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
|
516
|
+
image_latent_ids.append(x_ids)
|
|
517
|
+
|
|
518
|
+
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
|
519
|
+
image_latent_ids = image_latent_ids.unsqueeze(0)
|
|
520
|
+
|
|
521
|
+
return image_latent_ids
|
|
522
|
+
|
|
523
|
+
@torch.no_grad()
|
|
524
|
+
def __call__(
|
|
525
|
+
self,
|
|
526
|
+
prompt: Union[str, List[str]],
|
|
527
|
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
528
|
+
height: int = 1024,
|
|
529
|
+
width: int = 1024,
|
|
530
|
+
num_inference_steps: int = 30,
|
|
531
|
+
cfg_scale: float = 1.0,
|
|
532
|
+
embedded_guidance: float = 4.0,
|
|
533
|
+
denoising_strength: float = 1.0,
|
|
534
|
+
seed: Optional[int] = None,
|
|
535
|
+
progress_callback: Optional[Callable] = None,
|
|
536
|
+
# Edit image parameters
|
|
537
|
+
edit_image: Union[Image.Image, List[Image.Image]] = None,
|
|
538
|
+
edit_image_auto_resize: bool = True,
|
|
539
|
+
):
|
|
540
|
+
self.validate_image_size(height, width, multiple_of=16)
|
|
541
|
+
|
|
542
|
+
# Encode prompts
|
|
543
|
+
self.load_models_to_device(["text_encoder"])
|
|
544
|
+
prompt_embeds, text_ids = self.encode_prompt(prompt)
|
|
545
|
+
if negative_prompt is not None:
|
|
546
|
+
negative_prompt_embeds, negative_text_ids = self.encode_prompt(negative_prompt)
|
|
547
|
+
else:
|
|
548
|
+
negative_prompt_embeds, negative_text_ids = None, None
|
|
549
|
+
self.model_lifecycle_finish(["text_encoder"])
|
|
550
|
+
|
|
551
|
+
# Encode edit images if provided
|
|
552
|
+
edit_latents, edit_image_ids = None, None
|
|
553
|
+
if edit_image is not None:
|
|
554
|
+
edit_latents, edit_image_ids = self.encode_edit_image(edit_image, edit_image_auto_resize)
|
|
555
|
+
if edit_latents is not None:
|
|
556
|
+
edit_latents = edit_latents.to(device=self.device, dtype=self.dtype)
|
|
557
|
+
edit_image_ids = edit_image_ids.to(device=self.device, dtype=self.dtype)
|
|
558
|
+
|
|
559
|
+
# Generate initial noise
|
|
560
|
+
noise = self.generate_noise((1, 128, height // 16, width // 16), seed=seed, device="cpu", dtype=self.dtype).to(
|
|
561
|
+
device=self.device
|
|
562
|
+
)
|
|
563
|
+
noise = noise.reshape(1, 128, height//16 * width//16).permute(0, 2, 1)
|
|
564
|
+
|
|
565
|
+
# Prepare latents with noise scheduling
|
|
566
|
+
latents, sigmas, timesteps = self.prepare_latents(noise, num_inference_steps, denoising_strength, height, width)
|
|
567
|
+
|
|
568
|
+
self.sampler.initialize(sigmas=sigmas)
|
|
569
|
+
|
|
570
|
+
# Prepare image IDs
|
|
571
|
+
image_ids = self.prepare_image_ids(height // 16, width // 16).to(self.device)
|
|
572
|
+
|
|
573
|
+
# Denoising loop
|
|
574
|
+
self.load_models_to_device(["dit"])
|
|
575
|
+
for i, timestep in enumerate(tqdm(timesteps)):
|
|
576
|
+
timestep = timestep.unsqueeze(0).to(dtype=self.dtype)
|
|
577
|
+
|
|
578
|
+
if cfg_scale > 1.0 and negative_prompt_embeds is not None:
|
|
579
|
+
# CFG prediction
|
|
580
|
+
latents_input = torch.cat([latents] * 2, dim=0)
|
|
581
|
+
timestep_input = torch.cat([timestep] * 2, dim=0)
|
|
582
|
+
prompt_embeds_input = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
|
|
583
|
+
text_ids_input = torch.cat([text_ids, negative_text_ids], dim=0)
|
|
584
|
+
image_ids_input = torch.cat([image_ids] * 2, dim=0)
|
|
585
|
+
|
|
586
|
+
# Handle edit images for CFG
|
|
587
|
+
edit_latents_input = None
|
|
588
|
+
edit_image_ids_input = None
|
|
589
|
+
if edit_latents is not None:
|
|
590
|
+
edit_latents_input = torch.cat([edit_latents] * 2, dim=0)
|
|
591
|
+
edit_image_ids_input = torch.cat([edit_image_ids] * 2, dim=0)
|
|
592
|
+
|
|
593
|
+
noise_pred = self.predict_noise(
|
|
594
|
+
latents=latents_input,
|
|
595
|
+
timestep=timestep_input,
|
|
596
|
+
prompt_embeds=prompt_embeds_input,
|
|
597
|
+
text_ids=text_ids_input,
|
|
598
|
+
image_ids=image_ids_input,
|
|
599
|
+
embedded_guidance=embedded_guidance,
|
|
600
|
+
edit_latents=edit_latents_input,
|
|
601
|
+
edit_image_ids=edit_image_ids_input,
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
# Split predictions and apply CFG
|
|
605
|
+
noise_pred_positive, noise_pred_negative = noise_pred.chunk(2)
|
|
606
|
+
noise_pred = noise_pred_negative + cfg_scale * (noise_pred_positive - noise_pred_negative)
|
|
607
|
+
else:
|
|
608
|
+
# Non-CFG prediction
|
|
609
|
+
noise_pred = self.predict_noise(
|
|
610
|
+
latents=latents,
|
|
611
|
+
timestep=timestep,
|
|
612
|
+
prompt_embeds=prompt_embeds,
|
|
613
|
+
text_ids=text_ids,
|
|
614
|
+
image_ids=image_ids,
|
|
615
|
+
embedded_guidance=embedded_guidance,
|
|
616
|
+
edit_latents=edit_latents,
|
|
617
|
+
edit_image_ids=edit_image_ids,
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
latents = self.sampler.step(latents, noise_pred, i)
|
|
621
|
+
if progress_callback is not None:
|
|
622
|
+
progress_callback(i, len(timesteps), "DENOISING")
|
|
623
|
+
|
|
624
|
+
self.model_lifecycle_finish(["dit"])
|
|
625
|
+
|
|
626
|
+
# Decode final latents
|
|
627
|
+
self.load_models_to_device(["vae"])
|
|
628
|
+
latents = rearrange(latents, "B (H W) C -> B C H W", H=height//16, W=width//16)
|
|
629
|
+
vae_output = self.vae.decode(latents)
|
|
630
|
+
image = self.vae_output_to_image(vae_output)
|
|
631
|
+
|
|
632
|
+
# Offload all models
|
|
633
|
+
self.load_models_to_device([])
|
|
634
|
+
return image
|
|
@@ -21,6 +21,7 @@ VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "components", "vae.json")
|
|
|
21
21
|
FLUX_DIT_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_dit.json")
|
|
22
22
|
FLUX_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_text_encoder.json")
|
|
23
23
|
FLUX_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_vae.json")
|
|
24
|
+
FLUX2_TEXT_ENCODER_8B_CONF_PATH = os.path.join(CONF_PATH, "models", "flux2", "qwen3_8B_config.json")
|
|
24
25
|
SD_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd", "sd_text_encoder.json")
|
|
25
26
|
SD_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd", "sd_unet.json")
|
|
26
27
|
SD3_DIT_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sd3", "sd3_dit.json")
|