diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +12 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
- diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
- diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/__init__.py +16 -1
- diffsynth_engine/configs/controlnet.py +13 -0
- diffsynth_engine/configs/pipeline.py +37 -11
- diffsynth_engine/models/base.py +1 -1
- diffsynth_engine/models/basic/attention.py +105 -43
- diffsynth_engine/models/basic/transformer_helper.py +36 -2
- diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
- diffsynth_engine/models/flux/flux_controlnet.py +16 -30
- diffsynth_engine/models/flux/flux_dit.py +49 -62
- diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
- diffsynth_engine/models/flux/flux_vae.py +20 -2
- diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
- diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
- diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
- diffsynth_engine/models/sd/sd_unet.py +1 -1
- diffsynth_engine/models/sd3/sd3_dit.py +1 -1
- diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
- diffsynth_engine/models/vae/vae.py +1 -1
- diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
- diffsynth_engine/models/wan/wan_dit.py +65 -28
- diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
- diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
- diffsynth_engine/models/wan/wan_vae.py +2 -2
- diffsynth_engine/pipelines/base.py +73 -7
- diffsynth_engine/pipelines/flux_image.py +139 -120
- diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
- diffsynth_engine/pipelines/qwen_image.py +272 -87
- diffsynth_engine/pipelines/sdxl_image.py +1 -1
- diffsynth_engine/pipelines/utils.py +52 -0
- diffsynth_engine/pipelines/wan_s2v.py +25 -14
- diffsynth_engine/pipelines/wan_video.py +43 -19
- diffsynth_engine/tokenizers/base.py +6 -0
- diffsynth_engine/tokenizers/qwen2.py +12 -4
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/download.py +4 -2
- diffsynth_engine/utils/env.py +2 -0
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/loader.py +25 -6
- diffsynth_engine/utils/parallel.py +62 -29
- diffsynth_engine/utils/video.py +3 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@ import json
|
|
|
2
2
|
import torch
|
|
3
3
|
import torch.nn as nn
|
|
4
4
|
import numpy as np
|
|
5
|
-
from typing import Any, Dict, Optional
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
6
|
from einops import rearrange
|
|
7
7
|
|
|
8
8
|
from diffsynth_engine.models.basic.transformer_helper import (
|
|
@@ -28,7 +28,7 @@ from diffsynth_engine.utils import logging
|
|
|
28
28
|
|
|
29
29
|
logger = logging.get_logger(__name__)
|
|
30
30
|
|
|
31
|
-
with open(FLUX_DIT_CONFIG_FILE, "r") as f:
|
|
31
|
+
with open(FLUX_DIT_CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
32
32
|
config = json.load(f)
|
|
33
33
|
|
|
34
34
|
|
|
@@ -176,7 +176,6 @@ class FluxDoubleAttention(nn.Module):
|
|
|
176
176
|
dim_b,
|
|
177
177
|
num_heads,
|
|
178
178
|
head_dim,
|
|
179
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
180
179
|
device: str = "cuda:0",
|
|
181
180
|
dtype: torch.dtype = torch.bfloat16,
|
|
182
181
|
):
|
|
@@ -194,19 +193,20 @@ class FluxDoubleAttention(nn.Module):
|
|
|
194
193
|
|
|
195
194
|
self.a_to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype)
|
|
196
195
|
self.b_to_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype)
|
|
197
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
198
196
|
|
|
199
197
|
def attention_callback(self, attn_out_a, attn_out_b, x_a, x_b, q_a, q_b, k_a, k_b, v_a, v_b, rope_emb, image_emb):
|
|
200
198
|
return attn_out_a, attn_out_b
|
|
201
199
|
|
|
202
|
-
def forward(self, image, text, rope_emb, image_emb):
|
|
200
|
+
def forward(self, image, text, rope_emb, image_emb, attn_kwargs=None):
|
|
203
201
|
q_a, k_a, v_a = rearrange(self.a_to_qkv(image), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
|
|
204
202
|
q_b, k_b, v_b = rearrange(self.b_to_qkv(text), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
|
|
205
203
|
q = torch.cat([self.norm_q_b(q_b), self.norm_q_a(q_a)], dim=1)
|
|
206
204
|
k = torch.cat([self.norm_k_b(k_b), self.norm_k_a(k_a)], dim=1)
|
|
207
205
|
v = torch.cat([v_b, v_a], dim=1)
|
|
208
206
|
q, k = apply_rope(q, k, rope_emb)
|
|
209
|
-
|
|
207
|
+
|
|
208
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
209
|
+
attn_out = attention_ops.attention(q, k, v, **attn_kwargs)
|
|
210
210
|
attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
|
|
211
211
|
text_out, image_out = attn_out[:, : text.shape[1]], attn_out[:, text.shape[1] :]
|
|
212
212
|
image_out, text_out = self.attention_callback(
|
|
@@ -231,19 +231,18 @@ class FluxDoubleTransformerBlock(nn.Module):
|
|
|
231
231
|
self,
|
|
232
232
|
dim,
|
|
233
233
|
num_heads,
|
|
234
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
235
234
|
device: str = "cuda:0",
|
|
236
235
|
dtype: torch.dtype = torch.bfloat16,
|
|
237
236
|
):
|
|
238
237
|
super().__init__()
|
|
239
|
-
self.attn = FluxDoubleAttention(
|
|
240
|
-
dim, dim, num_heads, dim // num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype
|
|
241
|
-
)
|
|
238
|
+
self.attn = FluxDoubleAttention(dim, dim, num_heads, dim // num_heads, device=device, dtype=dtype)
|
|
242
239
|
# Image
|
|
243
240
|
self.norm_msa_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
|
|
244
241
|
self.norm_mlp_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
|
|
245
242
|
self.ff_a = nn.Sequential(
|
|
246
|
-
nn.Linear(dim, dim * 4
|
|
243
|
+
nn.Linear(dim, dim * 4, device=device, dtype=dtype),
|
|
244
|
+
nn.GELU(approximate="tanh"),
|
|
245
|
+
nn.Linear(dim * 4, dim, device=device, dtype=dtype),
|
|
247
246
|
)
|
|
248
247
|
# Text
|
|
249
248
|
self.norm_msa_b = AdaLayerNormZero(dim, device=device, dtype=dtype)
|
|
@@ -254,11 +253,11 @@ class FluxDoubleTransformerBlock(nn.Module):
|
|
|
254
253
|
nn.Linear(dim * 4, dim, device=device, dtype=dtype),
|
|
255
254
|
)
|
|
256
255
|
|
|
257
|
-
def forward(self, image, text, t_emb, rope_emb, image_emb=None):
|
|
256
|
+
def forward(self, image, text, t_emb, rope_emb, image_emb=None, attn_kwargs=None):
|
|
258
257
|
# AdaLayerNorm-Zero for Image and Text MSA
|
|
259
258
|
image_in, gate_a = self.norm_msa_a(image, t_emb)
|
|
260
259
|
text_in, gate_b = self.norm_msa_b(text, t_emb)
|
|
261
|
-
image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb)
|
|
260
|
+
image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb, attn_kwargs)
|
|
262
261
|
image = image + gate_a * image_out
|
|
263
262
|
text = text + gate_b * text_out
|
|
264
263
|
|
|
@@ -277,7 +276,6 @@ class FluxSingleAttention(nn.Module):
|
|
|
277
276
|
self,
|
|
278
277
|
dim,
|
|
279
278
|
num_heads,
|
|
280
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
281
279
|
device: str = "cuda:0",
|
|
282
280
|
dtype: torch.dtype = torch.bfloat16,
|
|
283
281
|
):
|
|
@@ -286,15 +284,16 @@ class FluxSingleAttention(nn.Module):
|
|
|
286
284
|
self.to_qkv = nn.Linear(dim, dim * 3, device=device, dtype=dtype)
|
|
287
285
|
self.norm_q_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
|
|
288
286
|
self.norm_k_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype)
|
|
289
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
290
287
|
|
|
291
288
|
def attention_callback(self, attn_out, x, q, k, v, rope_emb, image_emb):
|
|
292
289
|
return attn_out
|
|
293
290
|
|
|
294
|
-
def forward(self, x, rope_emb, image_emb):
|
|
291
|
+
def forward(self, x, rope_emb, image_emb, attn_kwargs=None):
|
|
295
292
|
q, k, v = rearrange(self.to_qkv(x), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
|
|
296
293
|
q, k = apply_rope(self.norm_q_a(q), self.norm_k_a(k), rope_emb)
|
|
297
|
-
|
|
294
|
+
|
|
295
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
296
|
+
attn_out = attention_ops.attention(q, k, v, **attn_kwargs)
|
|
298
297
|
attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
|
|
299
298
|
return self.attention_callback(attn_out=attn_out, x=x, q=q, k=k, v=v, rope_emb=rope_emb, image_emb=image_emb)
|
|
300
299
|
|
|
@@ -304,23 +303,22 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
|
304
303
|
self,
|
|
305
304
|
dim,
|
|
306
305
|
num_heads,
|
|
307
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
308
306
|
device: str = "cuda:0",
|
|
309
307
|
dtype: torch.dtype = torch.bfloat16,
|
|
310
308
|
):
|
|
311
309
|
super().__init__()
|
|
312
310
|
self.dim = dim
|
|
313
311
|
self.norm = AdaLayerNormZero(dim, device=device, dtype=dtype)
|
|
314
|
-
self.attn = FluxSingleAttention(dim, num_heads,
|
|
312
|
+
self.attn = FluxSingleAttention(dim, num_heads, device=device, dtype=dtype)
|
|
315
313
|
self.mlp = nn.Sequential(
|
|
316
|
-
nn.Linear(dim, dim * 4),
|
|
314
|
+
nn.Linear(dim, dim * 4, device=device, dtype=dtype),
|
|
317
315
|
nn.GELU(approximate="tanh"),
|
|
318
316
|
)
|
|
319
|
-
self.proj_out = nn.Linear(dim * 5, dim)
|
|
317
|
+
self.proj_out = nn.Linear(dim * 5, dim, device=device, dtype=dtype)
|
|
320
318
|
|
|
321
|
-
def forward(self, x, t_emb, rope_emb, image_emb=None):
|
|
319
|
+
def forward(self, x, t_emb, rope_emb, image_emb=None, attn_kwargs=None):
|
|
322
320
|
h, gate = self.norm(x, emb=t_emb)
|
|
323
|
-
attn_output = self.attn(h, rope_emb, image_emb)
|
|
321
|
+
attn_output = self.attn(h, rope_emb, image_emb, attn_kwargs)
|
|
324
322
|
mlp_output = self.mlp(h)
|
|
325
323
|
return x + gate * self.proj_out(torch.cat([attn_output, mlp_output], dim=2))
|
|
326
324
|
|
|
@@ -332,7 +330,6 @@ class FluxDiT(PreTrainedModel):
|
|
|
332
330
|
def __init__(
|
|
333
331
|
self,
|
|
334
332
|
in_channel: int = 64,
|
|
335
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
336
333
|
device: str = "cuda:0",
|
|
337
334
|
dtype: torch.dtype = torch.bfloat16,
|
|
338
335
|
):
|
|
@@ -350,16 +347,10 @@ class FluxDiT(PreTrainedModel):
|
|
|
350
347
|
self.x_embedder = nn.Linear(in_channel, 3072, device=device, dtype=dtype)
|
|
351
348
|
|
|
352
349
|
self.blocks = nn.ModuleList(
|
|
353
|
-
[
|
|
354
|
-
FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
355
|
-
for _ in range(19)
|
|
356
|
-
]
|
|
350
|
+
[FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(19)]
|
|
357
351
|
)
|
|
358
352
|
self.single_blocks = nn.ModuleList(
|
|
359
|
-
[
|
|
360
|
-
FluxSingleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
361
|
-
for _ in range(38)
|
|
362
|
-
]
|
|
353
|
+
[FluxSingleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(38)]
|
|
363
354
|
)
|
|
364
355
|
self.final_norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
|
|
365
356
|
self.final_proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
|
|
@@ -393,21 +384,20 @@ class FluxDiT(PreTrainedModel):
|
|
|
393
384
|
|
|
394
385
|
def forward(
|
|
395
386
|
self,
|
|
396
|
-
hidden_states,
|
|
397
|
-
timestep,
|
|
398
|
-
prompt_emb,
|
|
399
|
-
pooled_prompt_emb,
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
387
|
+
hidden_states: torch.Tensor,
|
|
388
|
+
timestep: torch.Tensor,
|
|
389
|
+
prompt_emb: torch.Tensor,
|
|
390
|
+
pooled_prompt_emb: torch.Tensor,
|
|
391
|
+
image_ids: torch.Tensor,
|
|
392
|
+
text_ids: torch.Tensor,
|
|
393
|
+
guidance: torch.Tensor,
|
|
394
|
+
image_emb: torch.Tensor | None = None,
|
|
395
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
396
|
+
controlnet_double_block_output: List[torch.Tensor] | None = None,
|
|
397
|
+
controlnet_single_block_output: List[torch.Tensor] | None = None,
|
|
406
398
|
**kwargs,
|
|
407
399
|
):
|
|
408
|
-
|
|
409
|
-
if image_ids is None:
|
|
410
|
-
image_ids = self.prepare_image_ids(hidden_states)
|
|
400
|
+
image_seq_len = hidden_states.shape[1]
|
|
411
401
|
controlnet_double_block_output = (
|
|
412
402
|
controlnet_double_block_output if controlnet_double_block_output is not None else ()
|
|
413
403
|
)
|
|
@@ -426,10 +416,10 @@ class FluxDiT(PreTrainedModel):
|
|
|
426
416
|
timestep,
|
|
427
417
|
prompt_emb,
|
|
428
418
|
pooled_prompt_emb,
|
|
429
|
-
image_emb,
|
|
430
|
-
guidance,
|
|
431
|
-
text_ids,
|
|
432
419
|
image_ids,
|
|
420
|
+
text_ids,
|
|
421
|
+
guidance,
|
|
422
|
+
image_emb,
|
|
433
423
|
*controlnet_double_block_output,
|
|
434
424
|
*controlnet_single_block_output,
|
|
435
425
|
),
|
|
@@ -446,7 +436,6 @@ class FluxDiT(PreTrainedModel):
|
|
|
446
436
|
rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
|
447
437
|
text_rope_emb = rope_emb[:, :, : text_ids.size(1)]
|
|
448
438
|
image_rope_emb = rope_emb[:, :, text_ids.size(1) :]
|
|
449
|
-
hidden_states = self.patchify(hidden_states)
|
|
450
439
|
|
|
451
440
|
with sequence_parallel(
|
|
452
441
|
(
|
|
@@ -471,14 +460,16 @@ class FluxDiT(PreTrainedModel):
|
|
|
471
460
|
rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2)
|
|
472
461
|
|
|
473
462
|
for i, block in enumerate(self.blocks):
|
|
474
|
-
hidden_states, prompt_emb = block(
|
|
463
|
+
hidden_states, prompt_emb = block(
|
|
464
|
+
hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
|
|
465
|
+
)
|
|
475
466
|
if len(controlnet_double_block_output) > 0:
|
|
476
467
|
interval_control = len(self.blocks) / len(controlnet_double_block_output)
|
|
477
468
|
interval_control = int(np.ceil(interval_control))
|
|
478
469
|
hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
|
|
479
470
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
480
471
|
for i, block in enumerate(self.single_blocks):
|
|
481
|
-
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
|
|
472
|
+
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs)
|
|
482
473
|
if len(controlnet_single_block_output) > 0:
|
|
483
474
|
interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
|
|
484
475
|
interval_control = int(np.ceil(interval_control))
|
|
@@ -487,9 +478,8 @@ class FluxDiT(PreTrainedModel):
|
|
|
487
478
|
hidden_states = hidden_states[:, prompt_emb.shape[1] :]
|
|
488
479
|
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
|
489
480
|
hidden_states = self.final_proj_out(hidden_states)
|
|
490
|
-
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(
|
|
481
|
+
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(image_seq_len,))
|
|
491
482
|
|
|
492
|
-
hidden_states = self.unpatchify(hidden_states, h, w)
|
|
493
483
|
(hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg)
|
|
494
484
|
return hidden_states
|
|
495
485
|
|
|
@@ -500,14 +490,8 @@ class FluxDiT(PreTrainedModel):
|
|
|
500
490
|
device: str,
|
|
501
491
|
dtype: torch.dtype,
|
|
502
492
|
in_channel: int = 64,
|
|
503
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
504
493
|
):
|
|
505
|
-
model = cls(
|
|
506
|
-
device="meta",
|
|
507
|
-
dtype=dtype,
|
|
508
|
-
in_channel=in_channel,
|
|
509
|
-
attn_kwargs=attn_kwargs,
|
|
510
|
-
)
|
|
494
|
+
model = cls(device="meta", dtype=dtype, in_channel=in_channel)
|
|
511
495
|
model = model.requires_grad_(False)
|
|
512
496
|
model.load_state_dict(state_dict, assign=True)
|
|
513
497
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
@@ -517,5 +501,8 @@ class FluxDiT(PreTrainedModel):
|
|
|
517
501
|
for block in self.blocks:
|
|
518
502
|
block.compile(*args, **kwargs)
|
|
519
503
|
|
|
520
|
-
|
|
521
|
-
|
|
504
|
+
for block in self.single_blocks:
|
|
505
|
+
block.compile(*args, **kwargs)
|
|
506
|
+
|
|
507
|
+
def get_fsdp_module_cls(self):
|
|
508
|
+
return {FluxDoubleTransformerBlock, FluxSingleTransformerBlock}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import numpy as np
|
|
3
|
-
from typing import Any, Dict, Optional
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
4
|
|
|
5
5
|
from diffsynth_engine.utils.gguf import gguf_inference
|
|
6
6
|
from diffsynth_engine.utils.fp8_linear import fp8_inference
|
|
@@ -20,12 +20,11 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
20
20
|
def __init__(
|
|
21
21
|
self,
|
|
22
22
|
in_channel: int = 64,
|
|
23
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
24
23
|
device: str = "cuda:0",
|
|
25
24
|
dtype: torch.dtype = torch.bfloat16,
|
|
26
25
|
relative_l1_threshold: float = 0.05,
|
|
27
26
|
):
|
|
28
|
-
super().__init__(in_channel=in_channel,
|
|
27
|
+
super().__init__(in_channel=in_channel, device=device, dtype=dtype)
|
|
29
28
|
self.relative_l1_threshold = relative_l1_threshold
|
|
30
29
|
self.step_count = 0
|
|
31
30
|
self.num_inference_steps = 0
|
|
@@ -48,21 +47,20 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
48
47
|
|
|
49
48
|
def forward(
|
|
50
49
|
self,
|
|
51
|
-
hidden_states,
|
|
52
|
-
timestep,
|
|
53
|
-
prompt_emb,
|
|
54
|
-
pooled_prompt_emb,
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
50
|
+
hidden_states: torch.Tensor,
|
|
51
|
+
timestep: torch.Tensor,
|
|
52
|
+
prompt_emb: torch.Tensor,
|
|
53
|
+
pooled_prompt_emb: torch.Tensor,
|
|
54
|
+
image_ids: torch.Tensor,
|
|
55
|
+
text_ids: torch.Tensor,
|
|
56
|
+
guidance: torch.Tensor,
|
|
57
|
+
image_emb: torch.Tensor | None = None,
|
|
58
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
59
|
+
controlnet_double_block_output: List[torch.Tensor] | None = None,
|
|
60
|
+
controlnet_single_block_output: List[torch.Tensor] | None = None,
|
|
61
61
|
**kwargs,
|
|
62
62
|
):
|
|
63
|
-
|
|
64
|
-
if image_ids is None:
|
|
65
|
-
image_ids = self.prepare_image_ids(hidden_states)
|
|
63
|
+
image_seq_len = hidden_states.shape[1]
|
|
66
64
|
controlnet_double_block_output = (
|
|
67
65
|
controlnet_double_block_output if controlnet_double_block_output is not None else ()
|
|
68
66
|
)
|
|
@@ -81,10 +79,10 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
81
79
|
timestep,
|
|
82
80
|
prompt_emb,
|
|
83
81
|
pooled_prompt_emb,
|
|
84
|
-
image_emb,
|
|
85
|
-
guidance,
|
|
86
|
-
text_ids,
|
|
87
82
|
image_ids,
|
|
83
|
+
text_ids,
|
|
84
|
+
guidance,
|
|
85
|
+
image_emb,
|
|
88
86
|
*controlnet_double_block_output,
|
|
89
87
|
*controlnet_single_block_output,
|
|
90
88
|
),
|
|
@@ -101,7 +99,6 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
101
99
|
rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
|
102
100
|
text_rope_emb = rope_emb[:, :, : text_ids.size(1)]
|
|
103
101
|
image_rope_emb = rope_emb[:, :, text_ids.size(1) :]
|
|
104
|
-
hidden_states = self.patchify(hidden_states)
|
|
105
102
|
|
|
106
103
|
with sequence_parallel(
|
|
107
104
|
(
|
|
@@ -127,11 +124,13 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
127
124
|
|
|
128
125
|
# first block
|
|
129
126
|
original_hidden_states = hidden_states
|
|
130
|
-
hidden_states, prompt_emb = self.blocks[0](
|
|
127
|
+
hidden_states, prompt_emb = self.blocks[0](
|
|
128
|
+
hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
|
|
129
|
+
)
|
|
131
130
|
first_hidden_states_residual = hidden_states - original_hidden_states
|
|
132
131
|
|
|
133
132
|
(first_hidden_states_residual,) = sequence_parallel_unshard(
|
|
134
|
-
(first_hidden_states_residual,), seq_dims=(1,), seq_lens=(
|
|
133
|
+
(first_hidden_states_residual,), seq_dims=(1,), seq_lens=(image_seq_len,)
|
|
135
134
|
)
|
|
136
135
|
|
|
137
136
|
if self.step_count == 0 or self.step_count == (self.num_inference_steps - 1):
|
|
@@ -152,14 +151,16 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
152
151
|
|
|
153
152
|
first_hidden_states = hidden_states.clone()
|
|
154
153
|
for i, block in enumerate(self.blocks[1:]):
|
|
155
|
-
hidden_states, prompt_emb = block(
|
|
154
|
+
hidden_states, prompt_emb = block(
|
|
155
|
+
hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs
|
|
156
|
+
)
|
|
156
157
|
if len(controlnet_double_block_output) > 0:
|
|
157
158
|
interval_control = len(self.blocks) / len(controlnet_double_block_output)
|
|
158
159
|
interval_control = int(np.ceil(interval_control))
|
|
159
160
|
hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
|
|
160
161
|
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
|
161
162
|
for i, block in enumerate(self.single_blocks):
|
|
162
|
-
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
|
|
163
|
+
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs)
|
|
163
164
|
if len(controlnet_single_block_output) > 0:
|
|
164
165
|
interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
|
|
165
166
|
interval_control = int(np.ceil(interval_control))
|
|
@@ -172,9 +173,8 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
172
173
|
|
|
173
174
|
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
|
174
175
|
hidden_states = self.final_proj_out(hidden_states)
|
|
175
|
-
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(
|
|
176
|
+
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(image_seq_len,))
|
|
176
177
|
|
|
177
|
-
hidden_states = self.unpatchify(hidden_states, h, w)
|
|
178
178
|
(hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg)
|
|
179
179
|
|
|
180
180
|
return hidden_states
|
|
@@ -186,14 +186,12 @@ class FluxDiTFBCache(FluxDiT):
|
|
|
186
186
|
device: str,
|
|
187
187
|
dtype: torch.dtype,
|
|
188
188
|
in_channel: int = 64,
|
|
189
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
190
189
|
relative_l1_threshold: float = 0.05,
|
|
191
190
|
):
|
|
192
191
|
model = cls(
|
|
193
192
|
device="meta",
|
|
194
193
|
dtype=dtype,
|
|
195
194
|
in_channel=in_channel,
|
|
196
|
-
attn_kwargs=attn_kwargs,
|
|
197
195
|
relative_l1_threshold=relative_l1_threshold,
|
|
198
196
|
)
|
|
199
197
|
model = model.requires_grad_(False)
|
|
@@ -2,7 +2,7 @@ import torch
|
|
|
2
2
|
from einops import rearrange
|
|
3
3
|
from torch import nn
|
|
4
4
|
from PIL import Image
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Dict, List
|
|
6
6
|
from functools import partial
|
|
7
7
|
from diffsynth_engine.models.text_encoder.siglip import SiglipImageEncoder
|
|
8
8
|
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
|
|
@@ -18,7 +18,6 @@ class FluxIPAdapterAttention(nn.Module):
|
|
|
18
18
|
dim: int = 3072,
|
|
19
19
|
head_num: int = 24,
|
|
20
20
|
scale: float = 1.0,
|
|
21
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
22
21
|
device: str = "cuda:0",
|
|
23
22
|
dtype: torch.dtype = torch.bfloat16,
|
|
24
23
|
):
|
|
@@ -28,12 +27,13 @@ class FluxIPAdapterAttention(nn.Module):
|
|
|
28
27
|
self.to_v_ip = nn.Linear(image_emb_dim, dim, device=device, dtype=dtype, bias=False)
|
|
29
28
|
self.head_num = head_num
|
|
30
29
|
self.scale = scale
|
|
31
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
32
30
|
|
|
33
|
-
def forward(self, query: torch.Tensor, image_emb: torch.Tensor):
|
|
31
|
+
def forward(self, query: torch.Tensor, image_emb: torch.Tensor, attn_kwargs=None):
|
|
34
32
|
key = rearrange(self.norm_k(self.to_k_ip(image_emb)), "b s (h d) -> b s h d", h=self.head_num)
|
|
35
33
|
value = rearrange(self.to_v_ip(image_emb), "b s (h d) -> b s h d", h=self.head_num)
|
|
36
|
-
|
|
34
|
+
|
|
35
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
36
|
+
attn_out = attention(query, key, value, **attn_kwargs)
|
|
37
37
|
return self.scale * rearrange(attn_out, "b s h d -> b s (h d)")
|
|
38
38
|
|
|
39
39
|
@classmethod
|
|
@@ -8,7 +8,7 @@ from diffsynth_engine.utils import logging
|
|
|
8
8
|
|
|
9
9
|
logger = logging.get_logger(__name__)
|
|
10
10
|
|
|
11
|
-
with open(FLUX_VAE_CONFIG_FILE, "r") as f:
|
|
11
|
+
with open(FLUX_VAE_CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
12
12
|
config = json.load(f)
|
|
13
13
|
|
|
14
14
|
|
|
@@ -25,11 +25,29 @@ class FluxVAEStateDictConverter(VAEStateDictConverter):
|
|
|
25
25
|
new_state_dict[name_] = param
|
|
26
26
|
return new_state_dict
|
|
27
27
|
|
|
28
|
+
def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
29
|
+
rename_dict = config["diffusers"]["rename_dict"]
|
|
30
|
+
new_state_dict = {}
|
|
31
|
+
for name, param in state_dict.items():
|
|
32
|
+
if name not in rename_dict:
|
|
33
|
+
continue
|
|
34
|
+
name_ = rename_dict[name]
|
|
35
|
+
if "transformer_blocks" in name_:
|
|
36
|
+
param = param.squeeze()
|
|
37
|
+
new_state_dict[name_] = param
|
|
38
|
+
return new_state_dict
|
|
39
|
+
|
|
28
40
|
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
29
41
|
assert self.has_decoder or self.has_encoder, "Either decoder or encoder must be present"
|
|
30
|
-
if "decoder.
|
|
42
|
+
if "decoder.up.0.block.0.conv1.weight" in state_dict or "encoder.down.0.block.0.conv1.weight" in state_dict:
|
|
31
43
|
state_dict = self._from_civitai(state_dict)
|
|
32
44
|
logger.info("use civitai format state dict")
|
|
45
|
+
elif (
|
|
46
|
+
"decoder.up_blocks.0.resnets.0.conv1.weight" in state_dict
|
|
47
|
+
or "encoder.down_blocks.0.resnets.0.conv1.weight" in state_dict
|
|
48
|
+
):
|
|
49
|
+
state_dict = self._from_diffusers(state_dict)
|
|
50
|
+
logger.info("use diffusers format state dict")
|
|
33
51
|
else:
|
|
34
52
|
logger.info("use diffsynth format state dict")
|
|
35
53
|
return self._filter(state_dict)
|
|
@@ -2,7 +2,7 @@ import torch.nn as nn
|
|
|
2
2
|
import torchvision.transforms as transforms
|
|
3
3
|
import collections.abc
|
|
4
4
|
import math
|
|
5
|
-
from typing import Optional,
|
|
5
|
+
from typing import Optional, Dict
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
|
|
@@ -112,7 +112,9 @@ class Dinov2SelfAttention(nn.Module):
|
|
|
112
112
|
def __init__(self, hidden_size: int, num_attention_heads: int, qkv_bias: bool) -> None:
|
|
113
113
|
super().__init__()
|
|
114
114
|
if hidden_size % num_attention_heads != 0:
|
|
115
|
-
raise ValueError(
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"hidden_size {hidden_size} is not a multiple of num_attention_heads {num_attention_heads}."
|
|
117
|
+
)
|
|
116
118
|
|
|
117
119
|
self.num_attention_heads = num_attention_heads
|
|
118
120
|
self.attention_head_size = int(hidden_size / num_attention_heads)
|
|
@@ -942,6 +942,8 @@ class Qwen2_5_VLModel(nn.Module):
|
|
|
942
942
|
|
|
943
943
|
|
|
944
944
|
class Qwen2_5_VLForConditionalGeneration(PreTrainedModel):
|
|
945
|
+
_supports_parallelization = True
|
|
946
|
+
|
|
945
947
|
def __init__(
|
|
946
948
|
self,
|
|
947
949
|
vision_config: Qwen2_5_VLVisionConfig,
|
|
@@ -1173,6 +1175,9 @@ class Qwen2_5_VLForConditionalGeneration(PreTrainedModel):
|
|
|
1173
1175
|
|
|
1174
1176
|
return position_ids, mrope_position_deltas
|
|
1175
1177
|
|
|
1178
|
+
def get_fsdp_module_cls(self):
|
|
1179
|
+
return {Qwen2_5_VisionBlock, Qwen2_5_VLDecoderLayer}
|
|
1180
|
+
|
|
1176
1181
|
def forward(
|
|
1177
1182
|
self,
|
|
1178
1183
|
input_ids: Optional[torch.LongTensor] = None,
|