diffsynth-engine 0.6.1.dev14__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +6 -2
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/__init__.py +10 -6
- diffsynth_engine/configs/pipeline.py +17 -10
- diffsynth_engine/models/base.py +1 -1
- diffsynth_engine/models/basic/attention.py +59 -20
- diffsynth_engine/models/basic/transformer_helper.py +36 -2
- diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
- diffsynth_engine/models/flux/flux_controlnet.py +7 -19
- diffsynth_engine/models/flux/flux_dit.py +27 -38
- diffsynth_engine/models/flux/flux_dit_fbcache.py +9 -7
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +28 -34
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/wan/wan_audio_encoder.py +0 -1
- diffsynth_engine/models/wan/wan_dit.py +64 -27
- diffsynth_engine/pipelines/base.py +36 -4
- diffsynth_engine/pipelines/flux_image.py +19 -17
- diffsynth_engine/pipelines/qwen_image.py +45 -36
- diffsynth_engine/pipelines/sdxl_image.py +1 -1
- diffsynth_engine/pipelines/utils.py +52 -0
- diffsynth_engine/pipelines/wan_s2v.py +4 -9
- diffsynth_engine/pipelines/wan_video.py +43 -19
- diffsynth_engine/tokenizers/base.py +6 -0
- diffsynth_engine/tokenizers/qwen2.py +12 -4
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/parallel.py +62 -29
- {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +45 -43
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev14.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
|
@@ -17,6 +17,7 @@ from diffsynth_engine.utils.constants import (
|
|
|
17
17
|
WAN2_2_DIT_TI2V_5B_CONFIG_FILE,
|
|
18
18
|
WAN2_2_DIT_I2V_A14B_CONFIG_FILE,
|
|
19
19
|
WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
|
|
20
|
+
WAN_DIT_KEYMAP_FILE,
|
|
20
21
|
)
|
|
21
22
|
from diffsynth_engine.utils.gguf import gguf_inference
|
|
22
23
|
from diffsynth_engine.utils.fp8_linear import fp8_inference
|
|
@@ -30,6 +31,9 @@ from diffsynth_engine.utils.parallel import (
|
|
|
30
31
|
T5_TOKEN_NUM = 512
|
|
31
32
|
FLF_TOKEN_NUM = 257 * 2
|
|
32
33
|
|
|
34
|
+
with open(WAN_DIT_KEYMAP_FILE, "r", encoding="utf-8") as f:
|
|
35
|
+
config = json.load(f)
|
|
36
|
+
|
|
33
37
|
|
|
34
38
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
|
35
39
|
return x * (1 + scale) + shift
|
|
@@ -73,7 +77,7 @@ class SelfAttention(nn.Module):
|
|
|
73
77
|
dim: int,
|
|
74
78
|
num_heads: int,
|
|
75
79
|
eps: float = 1e-6,
|
|
76
|
-
|
|
80
|
+
use_vsa: bool = False,
|
|
77
81
|
device: str = "cuda:0",
|
|
78
82
|
dtype: torch.dtype = torch.bfloat16,
|
|
79
83
|
):
|
|
@@ -86,19 +90,25 @@ class SelfAttention(nn.Module):
|
|
|
86
90
|
self.o = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
87
91
|
self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
88
92
|
self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
89
|
-
self.
|
|
93
|
+
self.gate_compress = nn.Linear(dim, dim, device=device, dtype=dtype) if use_vsa else None
|
|
90
94
|
|
|
91
|
-
def forward(self, x, freqs):
|
|
95
|
+
def forward(self, x, freqs, attn_kwargs=None):
|
|
92
96
|
q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x)
|
|
97
|
+
g = self.gate_compress(x) if self.gate_compress is not None else None
|
|
98
|
+
|
|
93
99
|
num_heads = q.shape[2] // self.head_dim
|
|
94
100
|
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
|
95
101
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
96
102
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
103
|
+
g = rearrange(g, "b s (n d) -> b s n d", n=num_heads) if g is not None else None
|
|
104
|
+
|
|
105
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
97
106
|
x = attention_ops.attention(
|
|
98
107
|
q=rope_apply(q, freqs),
|
|
99
108
|
k=rope_apply(k, freqs),
|
|
100
109
|
v=v,
|
|
101
|
-
|
|
110
|
+
g=g,
|
|
111
|
+
**attn_kwargs,
|
|
102
112
|
)
|
|
103
113
|
x = x.flatten(2)
|
|
104
114
|
return self.o(x)
|
|
@@ -111,7 +121,6 @@ class CrossAttention(nn.Module):
|
|
|
111
121
|
num_heads: int,
|
|
112
122
|
eps: float = 1e-6,
|
|
113
123
|
has_image_input: bool = False,
|
|
114
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
115
124
|
device: str = "cuda:0",
|
|
116
125
|
dtype: torch.dtype = torch.bfloat16,
|
|
117
126
|
):
|
|
@@ -130,9 +139,8 @@ class CrossAttention(nn.Module):
|
|
|
130
139
|
self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
131
140
|
self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype)
|
|
132
141
|
self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype)
|
|
133
|
-
self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
134
142
|
|
|
135
|
-
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
143
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor, attn_kwargs=None):
|
|
136
144
|
if self.has_image_input:
|
|
137
145
|
img = y[:, :-T5_TOKEN_NUM]
|
|
138
146
|
ctx = y[:, -T5_TOKEN_NUM:]
|
|
@@ -144,12 +152,16 @@ class CrossAttention(nn.Module):
|
|
|
144
152
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
145
153
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
146
154
|
|
|
147
|
-
|
|
155
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
156
|
+
if attn_kwargs.get("attn_impl", None) == "vsa":
|
|
157
|
+
attn_kwargs = attn_kwargs.copy()
|
|
158
|
+
attn_kwargs["attn_impl"] = "sdpa"
|
|
159
|
+
x = attention(q, k, v, **attn_kwargs).flatten(2)
|
|
148
160
|
if self.has_image_input:
|
|
149
161
|
k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
|
|
150
162
|
k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
|
|
151
163
|
v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
|
|
152
|
-
y = attention(q, k_img, v_img, **
|
|
164
|
+
y = attention(q, k_img, v_img, **attn_kwargs).flatten(2)
|
|
153
165
|
x = x + y
|
|
154
166
|
return self.o(x)
|
|
155
167
|
|
|
@@ -162,7 +174,7 @@ class DiTBlock(nn.Module):
|
|
|
162
174
|
num_heads: int,
|
|
163
175
|
ffn_dim: int,
|
|
164
176
|
eps: float = 1e-6,
|
|
165
|
-
|
|
177
|
+
use_vsa: bool = False,
|
|
166
178
|
device: str = "cuda:0",
|
|
167
179
|
dtype: torch.dtype = torch.bfloat16,
|
|
168
180
|
):
|
|
@@ -170,9 +182,9 @@ class DiTBlock(nn.Module):
|
|
|
170
182
|
self.dim = dim
|
|
171
183
|
self.num_heads = num_heads
|
|
172
184
|
self.ffn_dim = ffn_dim
|
|
173
|
-
self.self_attn = SelfAttention(dim, num_heads, eps,
|
|
185
|
+
self.self_attn = SelfAttention(dim, num_heads, eps, use_vsa=use_vsa, device=device, dtype=dtype)
|
|
174
186
|
self.cross_attn = CrossAttention(
|
|
175
|
-
dim, num_heads, eps, has_image_input=has_image_input,
|
|
187
|
+
dim, num_heads, eps, has_image_input=has_image_input, device=device, dtype=dtype
|
|
176
188
|
)
|
|
177
189
|
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
|
|
178
190
|
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype)
|
|
@@ -184,14 +196,14 @@ class DiTBlock(nn.Module):
|
|
|
184
196
|
)
|
|
185
197
|
self.modulation = nn.Parameter(torch.randn(1, 6, dim, device=device, dtype=dtype) / dim**0.5)
|
|
186
198
|
|
|
187
|
-
def forward(self, x, context, t_mod, freqs):
|
|
199
|
+
def forward(self, x, context, t_mod, freqs, attn_kwargs=None):
|
|
188
200
|
# msa: multi-head self-attention mlp: multi-layer perceptron
|
|
189
201
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
|
190
202
|
t.squeeze(1) for t in (self.modulation + t_mod).chunk(6, dim=1)
|
|
191
203
|
]
|
|
192
204
|
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
|
193
|
-
x = x + gate_msa * self.self_attn(input_x, freqs)
|
|
194
|
-
x = x + self.cross_attn(self.norm3(x), context)
|
|
205
|
+
x = x + gate_msa * self.self_attn(input_x, freqs, attn_kwargs)
|
|
206
|
+
x = x + self.cross_attn(self.norm3(x), context, attn_kwargs)
|
|
195
207
|
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
|
196
208
|
x = x + gate_mlp * self.ffn(input_x)
|
|
197
209
|
return x
|
|
@@ -249,7 +261,26 @@ class Head(nn.Module):
|
|
|
249
261
|
|
|
250
262
|
|
|
251
263
|
class WanDiTStateDictConverter(StateDictConverter):
|
|
264
|
+
def _from_diffusers(self, state_dict):
|
|
265
|
+
global_rename_dict = config["diffusers"]["global_rename_dict"]
|
|
266
|
+
rename_dict = config["diffusers"]["rename_dict"]
|
|
267
|
+
state_dict_ = {}
|
|
268
|
+
for name, param in state_dict.items():
|
|
269
|
+
suffix = ""
|
|
270
|
+
suffix = ".weight" if name.endswith(".weight") else suffix
|
|
271
|
+
suffix = ".bias" if name.endswith(".bias") else suffix
|
|
272
|
+
prefix = name[: -len(suffix)] if suffix else name
|
|
273
|
+
if prefix in global_rename_dict:
|
|
274
|
+
state_dict_[f"{global_rename_dict[prefix]}{suffix}"] = param
|
|
275
|
+
if prefix.startswith("blocks."):
|
|
276
|
+
_, idx, middle = prefix.split(".", 2)
|
|
277
|
+
if middle in rename_dict:
|
|
278
|
+
state_dict_[f"blocks.{idx}.{rename_dict[middle]}{suffix}"] = param
|
|
279
|
+
return state_dict_
|
|
280
|
+
|
|
252
281
|
def convert(self, state_dict):
|
|
282
|
+
if "condition_embedder.time_proj.weight" in state_dict:
|
|
283
|
+
return self._from_diffusers(state_dict)
|
|
253
284
|
return state_dict
|
|
254
285
|
|
|
255
286
|
|
|
@@ -273,7 +304,7 @@ class WanDiT(PreTrainedModel):
|
|
|
273
304
|
has_vae_feature: bool = False,
|
|
274
305
|
fuse_image_latents: bool = False,
|
|
275
306
|
flf_pos_emb: bool = False,
|
|
276
|
-
|
|
307
|
+
use_vsa: bool = False,
|
|
277
308
|
device: str = "cuda:0",
|
|
278
309
|
dtype: torch.dtype = torch.bfloat16,
|
|
279
310
|
):
|
|
@@ -307,7 +338,16 @@ class WanDiT(PreTrainedModel):
|
|
|
307
338
|
)
|
|
308
339
|
self.blocks = nn.ModuleList(
|
|
309
340
|
[
|
|
310
|
-
DiTBlock(
|
|
341
|
+
DiTBlock(
|
|
342
|
+
has_clip_feature,
|
|
343
|
+
dim,
|
|
344
|
+
num_heads,
|
|
345
|
+
ffn_dim,
|
|
346
|
+
eps,
|
|
347
|
+
use_vsa,
|
|
348
|
+
device=device,
|
|
349
|
+
dtype=dtype,
|
|
350
|
+
)
|
|
311
351
|
for _ in range(num_layers)
|
|
312
352
|
]
|
|
313
353
|
)
|
|
@@ -344,6 +384,7 @@ class WanDiT(PreTrainedModel):
|
|
|
344
384
|
timestep: torch.Tensor,
|
|
345
385
|
clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
|
|
346
386
|
y: Optional[torch.Tensor] = None, # vae_encoder(img)
|
|
387
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
347
388
|
):
|
|
348
389
|
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
349
390
|
use_cfg = x.shape[0] > 1
|
|
@@ -376,7 +417,7 @@ class WanDiT(PreTrainedModel):
|
|
|
376
417
|
|
|
377
418
|
with sequence_parallel((x, t, t_mod, freqs), seq_dims=(1, 0, 0, 0)):
|
|
378
419
|
for block in self.blocks:
|
|
379
|
-
x = block(x, context, t_mod, freqs)
|
|
420
|
+
x = block(x, context, t_mod, freqs, attn_kwargs)
|
|
380
421
|
x = self.head(x, t)
|
|
381
422
|
(x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
|
|
382
423
|
x = self.unpatchify(x, (f, h, w))
|
|
@@ -409,12 +450,11 @@ class WanDiT(PreTrainedModel):
|
|
|
409
450
|
config: Dict[str, Any],
|
|
410
451
|
device: str = "cuda:0",
|
|
411
452
|
dtype: torch.dtype = torch.bfloat16,
|
|
412
|
-
|
|
413
|
-
assign: bool = True,
|
|
453
|
+
use_vsa: bool = False,
|
|
414
454
|
):
|
|
415
|
-
model = cls(**config, device="meta", dtype=dtype,
|
|
455
|
+
model = cls(**config, device="meta", dtype=dtype, use_vsa=use_vsa)
|
|
416
456
|
model = model.requires_grad_(False)
|
|
417
|
-
model.load_state_dict(state_dict, assign=
|
|
457
|
+
model.load_state_dict(state_dict, assign=True)
|
|
418
458
|
model.to(device=device, dtype=dtype, non_blocking=True)
|
|
419
459
|
return model
|
|
420
460
|
|
|
@@ -499,8 +539,5 @@ class WanDiT(PreTrainedModel):
|
|
|
499
539
|
for block in self.blocks:
|
|
500
540
|
block.compile(*args, **kwargs)
|
|
501
541
|
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
def get_fsdp_modules(self):
|
|
506
|
-
return ["blocks"]
|
|
542
|
+
def get_fsdp_module_cls(self):
|
|
543
|
+
return {DiTBlock}
|
|
@@ -2,10 +2,18 @@ import os
|
|
|
2
2
|
import torch
|
|
3
3
|
import numpy as np
|
|
4
4
|
from einops import rearrange
|
|
5
|
-
from typing import Dict, List, Tuple, Union
|
|
5
|
+
from typing import Dict, List, Tuple, Union, Optional
|
|
6
6
|
from PIL import Image
|
|
7
7
|
|
|
8
|
-
from diffsynth_engine.configs import
|
|
8
|
+
from diffsynth_engine.configs import (
|
|
9
|
+
BaseConfig,
|
|
10
|
+
BaseStateDicts,
|
|
11
|
+
LoraConfig,
|
|
12
|
+
AttnImpl,
|
|
13
|
+
SpargeAttentionParams,
|
|
14
|
+
VideoSparseAttentionParams,
|
|
15
|
+
)
|
|
16
|
+
from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs
|
|
9
17
|
from diffsynth_engine.utils.offload import enable_sequential_cpu_offload, offload_model_to_dict, restore_model_from_dict
|
|
10
18
|
from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
|
|
11
19
|
from diffsynth_engine.utils.gguf import load_gguf_checkpoint
|
|
@@ -33,6 +41,7 @@ class BasePipeline:
|
|
|
33
41
|
dtype=torch.float16,
|
|
34
42
|
):
|
|
35
43
|
super().__init__()
|
|
44
|
+
self.config = None
|
|
36
45
|
self.vae_tiled = vae_tiled
|
|
37
46
|
self.vae_tile_size = vae_tile_size
|
|
38
47
|
self.vae_tile_stride = vae_tile_stride
|
|
@@ -48,7 +57,7 @@ class BasePipeline:
|
|
|
48
57
|
raise NotImplementedError()
|
|
49
58
|
|
|
50
59
|
@classmethod
|
|
51
|
-
def from_state_dict(cls, state_dicts: BaseStateDicts,
|
|
60
|
+
def from_state_dict(cls, state_dicts: BaseStateDicts, config: BaseConfig) -> "BasePipeline":
|
|
52
61
|
raise NotImplementedError()
|
|
53
62
|
|
|
54
63
|
def update_weights(self, state_dicts: BaseStateDicts) -> None:
|
|
@@ -70,7 +79,11 @@ class BasePipeline:
|
|
|
70
79
|
lora_list: List[Tuple[str, Union[float, LoraConfig]]],
|
|
71
80
|
fused: bool = True,
|
|
72
81
|
save_original_weight: bool = False,
|
|
82
|
+
lora_converter: Optional[LoRAStateDictConverter] = None,
|
|
73
83
|
):
|
|
84
|
+
if not lora_converter:
|
|
85
|
+
lora_converter = self.lora_converter
|
|
86
|
+
|
|
74
87
|
for lora_path, lora_item in lora_list:
|
|
75
88
|
if isinstance(lora_item, float):
|
|
76
89
|
lora_scale = lora_item
|
|
@@ -86,7 +99,7 @@ class BasePipeline:
|
|
|
86
99
|
self.apply_scheduler_config(scheduler_config)
|
|
87
100
|
logger.info(f"Applied scheduler args from LoraConfig: {scheduler_config}")
|
|
88
101
|
|
|
89
|
-
lora_state_dict =
|
|
102
|
+
lora_state_dict = lora_converter.convert(state_dict)
|
|
90
103
|
for model_name, state_dict in lora_state_dict.items():
|
|
91
104
|
model = getattr(self, model_name)
|
|
92
105
|
lora_args = []
|
|
@@ -256,6 +269,25 @@ class BasePipeline:
|
|
|
256
269
|
)
|
|
257
270
|
return init_latents, latents, sigmas, timesteps
|
|
258
271
|
|
|
272
|
+
def get_attn_kwargs(self, latents: torch.Tensor) -> Dict:
|
|
273
|
+
attn_kwargs = {"attn_impl": self.config.dit_attn_impl.value}
|
|
274
|
+
if isinstance(self.config.attn_params, SpargeAttentionParams):
|
|
275
|
+
assert self.config.dit_attn_impl == AttnImpl.SPARGE
|
|
276
|
+
attn_kwargs.update(
|
|
277
|
+
{
|
|
278
|
+
"smooth_k": self.config.attn_params.smooth_k,
|
|
279
|
+
"simthreshd1": self.config.attn_params.simthreshd1,
|
|
280
|
+
"cdfthreshd": self.config.attn_params.cdfthreshd,
|
|
281
|
+
"pvthreshd": self.config.attn_params.pvthreshd,
|
|
282
|
+
}
|
|
283
|
+
)
|
|
284
|
+
elif isinstance(self.config.attn_params, VideoSparseAttentionParams):
|
|
285
|
+
assert self.config.dit_attn_impl == AttnImpl.VSA
|
|
286
|
+
attn_kwargs.update(
|
|
287
|
+
get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.config.attn_params.sparsity, device=self.device)
|
|
288
|
+
)
|
|
289
|
+
return attn_kwargs
|
|
290
|
+
|
|
259
291
|
def eval(self):
|
|
260
292
|
for model_name in self.model_names:
|
|
261
293
|
model = getattr(self, model_name)
|
|
@@ -17,7 +17,12 @@ from diffsynth_engine.models.flux import (
|
|
|
17
17
|
flux_dit_config,
|
|
18
18
|
flux_text_encoder_config,
|
|
19
19
|
)
|
|
20
|
-
from diffsynth_engine.configs import
|
|
20
|
+
from diffsynth_engine.configs import (
|
|
21
|
+
FluxPipelineConfig,
|
|
22
|
+
FluxStateDicts,
|
|
23
|
+
ControlType,
|
|
24
|
+
ControlNetParams,
|
|
25
|
+
)
|
|
21
26
|
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
22
27
|
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
23
28
|
from diffsynth_engine.pipelines.utils import accumulate, calculate_shift
|
|
@@ -143,7 +148,7 @@ class FluxLoRAConverter(LoRAStateDictConverter):
|
|
|
143
148
|
layer_id, layer_type = name.split("_", 1)
|
|
144
149
|
layer_type = layer_type.replace("self_attn_", "self_attn.").replace("mlp_", "mlp.")
|
|
145
150
|
rename = ".".join(["encoders", layer_id, clip_attn_rename_dict[layer_type]])
|
|
146
|
-
|
|
151
|
+
|
|
147
152
|
lora_args = {}
|
|
148
153
|
lora_args["alpha"] = param
|
|
149
154
|
lora_args["up"] = lora_state_dict[origin_key.replace(".alpha", ".lora_up.weight")]
|
|
@@ -507,29 +512,20 @@ class FluxImagePipeline(BasePipeline):
|
|
|
507
512
|
vae_encoder = FluxVAEEncoder.from_state_dict(state_dicts.vae, device=init_device, dtype=config.vae_dtype)
|
|
508
513
|
|
|
509
514
|
with LoRAContext():
|
|
510
|
-
attn_kwargs = {
|
|
511
|
-
"attn_impl": config.dit_attn_impl.value,
|
|
512
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
513
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
514
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
515
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
516
|
-
}
|
|
517
515
|
if config.use_fbcache:
|
|
518
516
|
dit = FluxDiTFBCache.from_state_dict(
|
|
519
517
|
state_dicts.model,
|
|
520
|
-
device=init_device,
|
|
518
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
521
519
|
dtype=config.model_dtype,
|
|
522
520
|
in_channel=config.control_type.get_in_channel(),
|
|
523
|
-
attn_kwargs=attn_kwargs,
|
|
524
521
|
relative_l1_threshold=config.fbcache_relative_l1_threshold,
|
|
525
522
|
)
|
|
526
523
|
else:
|
|
527
524
|
dit = FluxDiT.from_state_dict(
|
|
528
525
|
state_dicts.model,
|
|
529
|
-
device=init_device,
|
|
526
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
530
527
|
dtype=config.model_dtype,
|
|
531
528
|
in_channel=config.control_type.get_in_channel(),
|
|
532
|
-
attn_kwargs=attn_kwargs,
|
|
533
529
|
)
|
|
534
530
|
if config.use_fp8_linear:
|
|
535
531
|
enable_fp8_linear(dit)
|
|
@@ -573,7 +569,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
573
569
|
self.update_component(self.vae_encoder, state_dicts.vae, self.config.device, self.config.vae_dtype)
|
|
574
570
|
|
|
575
571
|
def compile(self):
|
|
576
|
-
self.dit.compile_repeated_blocks(
|
|
572
|
+
self.dit.compile_repeated_blocks()
|
|
577
573
|
|
|
578
574
|
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
|
|
579
575
|
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
|
|
@@ -755,6 +751,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
755
751
|
latents = latents.to(self.dtype)
|
|
756
752
|
self.load_models_to_device(["dit"])
|
|
757
753
|
|
|
754
|
+
attn_kwargs = self.get_attn_kwargs(latents)
|
|
758
755
|
noise_pred = self.dit(
|
|
759
756
|
hidden_states=latents,
|
|
760
757
|
timestep=timestep,
|
|
@@ -766,6 +763,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
766
763
|
image_ids=image_ids,
|
|
767
764
|
controlnet_double_block_output=double_block_output,
|
|
768
765
|
controlnet_single_block_output=single_block_output,
|
|
766
|
+
attn_kwargs=attn_kwargs,
|
|
769
767
|
)
|
|
770
768
|
noise_pred = noise_pred[:, :image_seq_len]
|
|
771
769
|
noise_pred = self.dit.unpatchify(noise_pred, height, width)
|
|
@@ -830,7 +828,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
830
828
|
masked_image = image.clone()
|
|
831
829
|
masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1
|
|
832
830
|
latent = self.encode_image(masked_image)
|
|
833
|
-
mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3]))
|
|
831
|
+
mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3])).to(latent.dtype)
|
|
834
832
|
mask = 1 - mask
|
|
835
833
|
latent = torch.cat([latent, mask], dim=1)
|
|
836
834
|
elif self.config.control_type == ControlType.bfl_fill:
|
|
@@ -887,6 +885,8 @@ class FluxImagePipeline(BasePipeline):
|
|
|
887
885
|
if self.offload_mode is not None:
|
|
888
886
|
empty_cache()
|
|
889
887
|
param.model.to(self.device)
|
|
888
|
+
|
|
889
|
+
attn_kwargs = self.get_attn_kwargs(latents)
|
|
890
890
|
double_block_output, single_block_output = param.model(
|
|
891
891
|
hidden_states=latents,
|
|
892
892
|
control_condition=control_condition,
|
|
@@ -897,6 +897,7 @@ class FluxImagePipeline(BasePipeline):
|
|
|
897
897
|
image_ids=image_ids,
|
|
898
898
|
text_ids=text_ids,
|
|
899
899
|
guidance=guidance,
|
|
900
|
+
attn_kwargs=attn_kwargs,
|
|
900
901
|
)
|
|
901
902
|
if self.offload_mode is not None:
|
|
902
903
|
param.model.to("cpu")
|
|
@@ -983,8 +984,9 @@ class FluxImagePipeline(BasePipeline):
|
|
|
983
984
|
elif self.ip_adapter is not None:
|
|
984
985
|
image_emb = self.ip_adapter.encode_image(ref_image)
|
|
985
986
|
elif self.redux is not None:
|
|
986
|
-
|
|
987
|
-
|
|
987
|
+
ref_prompt_embeds = self.redux(ref_image)
|
|
988
|
+
flattened_ref_emb = ref_prompt_embeds.view(1, -1, ref_prompt_embeds.size(-1))
|
|
989
|
+
positive_prompt_emb = torch.cat([positive_prompt_emb, flattened_ref_emb], dim=1)
|
|
988
990
|
|
|
989
991
|
# Extra input
|
|
990
992
|
image_ids, text_ids, guidance = self.prepare_extra_input(
|
|
@@ -24,7 +24,7 @@ from diffsynth_engine.models.qwen_image import (
|
|
|
24
24
|
from diffsynth_engine.models.qwen_image import QwenImageVAE
|
|
25
25
|
from diffsynth_engine.tokenizers import Qwen2TokenizerFast, Qwen2VLProcessor
|
|
26
26
|
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
27
|
-
from diffsynth_engine.pipelines.utils import calculate_shift
|
|
27
|
+
from diffsynth_engine.pipelines.utils import calculate_shift, pad_and_concat
|
|
28
28
|
from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
|
|
29
29
|
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
|
|
30
30
|
from diffsynth_engine.utils.constants import (
|
|
@@ -91,7 +91,7 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
|
|
|
91
91
|
if "lora_A.weight" in key:
|
|
92
92
|
lora_a_suffix = "lora_A.weight"
|
|
93
93
|
lora_b_suffix = "lora_B.weight"
|
|
94
|
-
|
|
94
|
+
|
|
95
95
|
if lora_a_suffix is None:
|
|
96
96
|
continue
|
|
97
97
|
|
|
@@ -147,9 +147,18 @@ class QwenImagePipeline(BasePipeline):
|
|
|
147
147
|
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
|
148
148
|
self.prompt_template_encode_start_idx = 34
|
|
149
149
|
# qwen image edit
|
|
150
|
-
self.
|
|
150
|
+
self.edit_system_prompt = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
|
|
151
|
+
self.edit_prompt_template_encode = (
|
|
152
|
+
"<|im_start|>system\n"
|
|
153
|
+
+ self.edit_system_prompt
|
|
154
|
+
+ "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
|
155
|
+
)
|
|
151
156
|
# qwen image edit plus
|
|
152
|
-
self.edit_plus_prompt_template_encode =
|
|
157
|
+
self.edit_plus_prompt_template_encode = (
|
|
158
|
+
"<|im_start|>system\n"
|
|
159
|
+
+ self.edit_system_prompt
|
|
160
|
+
+ "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
|
161
|
+
)
|
|
153
162
|
|
|
154
163
|
self.edit_prompt_template_encode_start_idx = 64
|
|
155
164
|
|
|
@@ -185,6 +194,7 @@ class QwenImagePipeline(BasePipeline):
|
|
|
185
194
|
logger.info(f"loading state dict from {config.vae_path} ...")
|
|
186
195
|
vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
|
|
187
196
|
|
|
197
|
+
encoder_state_dict = None
|
|
188
198
|
if config.encoder_path is None:
|
|
189
199
|
config.encoder_path = fetch_model(
|
|
190
200
|
"MusePublic/Qwen-image",
|
|
@@ -196,8 +206,11 @@ class QwenImagePipeline(BasePipeline):
|
|
|
196
206
|
"text_encoder/model-00004-of-00004.safetensors",
|
|
197
207
|
],
|
|
198
208
|
)
|
|
199
|
-
|
|
200
|
-
|
|
209
|
+
if config.load_encoder:
|
|
210
|
+
logger.info(f"loading state dict from {config.encoder_path} ...")
|
|
211
|
+
encoder_state_dict = cls.load_model_checkpoint(
|
|
212
|
+
config.encoder_path, device="cpu", dtype=config.encoder_dtype
|
|
213
|
+
)
|
|
201
214
|
|
|
202
215
|
state_dicts = QwenImageStateDicts(
|
|
203
216
|
model=model_state_dict,
|
|
@@ -224,22 +237,25 @@ class QwenImagePipeline(BasePipeline):
|
|
|
224
237
|
@classmethod
|
|
225
238
|
def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipelineConfig) -> "QwenImagePipeline":
|
|
226
239
|
init_device = "cpu" if config.offload_mode is not None else config.device
|
|
227
|
-
tokenizer =
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
240
|
+
tokenizer, processor, encoder = None, None, None
|
|
241
|
+
if config.load_encoder:
|
|
242
|
+
tokenizer = Qwen2TokenizerFast.from_pretrained(QWEN_IMAGE_TOKENIZER_CONF_PATH)
|
|
243
|
+
processor = Qwen2VLProcessor.from_pretrained(
|
|
244
|
+
tokenizer_config_path=QWEN_IMAGE_TOKENIZER_CONF_PATH,
|
|
245
|
+
image_processor_config_path=QWEN_IMAGE_PROCESSOR_CONFIG_FILE,
|
|
246
|
+
)
|
|
247
|
+
with open(QWEN_IMAGE_VISION_CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
248
|
+
vision_config = Qwen2_5_VLVisionConfig(**json.load(f))
|
|
249
|
+
with open(QWEN_IMAGE_CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
250
|
+
text_config = Qwen2_5_VLConfig(**json.load(f))
|
|
251
|
+
encoder = Qwen2_5_VLForConditionalGeneration.from_state_dict(
|
|
252
|
+
state_dicts.encoder,
|
|
253
|
+
vision_config=vision_config,
|
|
254
|
+
config=text_config,
|
|
255
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
256
|
+
dtype=config.encoder_dtype,
|
|
257
|
+
)
|
|
258
|
+
|
|
243
259
|
with open(QWEN_IMAGE_VAE_CONFIG_FILE, "r", encoding="utf-8") as f:
|
|
244
260
|
vae_config = json.load(f)
|
|
245
261
|
vae = QwenImageVAE.from_state_dict(
|
|
@@ -247,27 +263,18 @@ class QwenImagePipeline(BasePipeline):
|
|
|
247
263
|
)
|
|
248
264
|
|
|
249
265
|
with LoRAContext():
|
|
250
|
-
attn_kwargs = {
|
|
251
|
-
"attn_impl": config.dit_attn_impl.value,
|
|
252
|
-
"sparge_smooth_k": config.sparge_smooth_k,
|
|
253
|
-
"sparge_cdfthreshd": config.sparge_cdfthreshd,
|
|
254
|
-
"sparge_simthreshd1": config.sparge_simthreshd1,
|
|
255
|
-
"sparge_pvthreshd": config.sparge_pvthreshd,
|
|
256
|
-
}
|
|
257
266
|
if config.use_fbcache:
|
|
258
267
|
dit = QwenImageDiTFBCache.from_state_dict(
|
|
259
268
|
state_dicts.model,
|
|
260
|
-
device=init_device,
|
|
269
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
261
270
|
dtype=config.model_dtype,
|
|
262
|
-
attn_kwargs=attn_kwargs,
|
|
263
271
|
relative_l1_threshold=config.fbcache_relative_l1_threshold,
|
|
264
272
|
)
|
|
265
273
|
else:
|
|
266
274
|
dit = QwenImageDiT.from_state_dict(
|
|
267
275
|
state_dicts.model,
|
|
268
|
-
device=init_device,
|
|
276
|
+
device=("cpu" if config.use_fsdp else init_device),
|
|
269
277
|
dtype=config.model_dtype,
|
|
270
|
-
attn_kwargs=attn_kwargs,
|
|
271
278
|
)
|
|
272
279
|
if config.use_fp8_linear:
|
|
273
280
|
enable_fp8_linear(dit)
|
|
@@ -307,7 +314,7 @@ class QwenImagePipeline(BasePipeline):
|
|
|
307
314
|
self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
|
|
308
315
|
|
|
309
316
|
def compile(self):
|
|
310
|
-
self.dit.compile_repeated_blocks(
|
|
317
|
+
self.dit.compile_repeated_blocks()
|
|
311
318
|
|
|
312
319
|
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
|
|
313
320
|
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
|
|
@@ -493,8 +500,8 @@ class QwenImagePipeline(BasePipeline):
|
|
|
493
500
|
else:
|
|
494
501
|
# cfg by predict noise in one batch
|
|
495
502
|
bs, _, h, w = latents.shape
|
|
496
|
-
prompt_emb =
|
|
497
|
-
prompt_emb_mask =
|
|
503
|
+
prompt_emb = pad_and_concat(prompt_emb, negative_prompt_emb)
|
|
504
|
+
prompt_emb_mask = pad_and_concat(prompt_emb_mask, negative_prompt_emb_mask)
|
|
498
505
|
if entity_prompt_embs is not None:
|
|
499
506
|
entity_prompt_embs = [
|
|
500
507
|
torch.cat([x, y], dim=0) for x, y in zip(entity_prompt_embs, negative_entity_prompt_embs)
|
|
@@ -542,6 +549,7 @@ class QwenImagePipeline(BasePipeline):
|
|
|
542
549
|
entity_masks: Optional[List[torch.Tensor]] = None,
|
|
543
550
|
):
|
|
544
551
|
self.load_models_to_device(["dit"])
|
|
552
|
+
attn_kwargs = self.get_attn_kwargs(latents)
|
|
545
553
|
noise_pred = self.dit(
|
|
546
554
|
image=latents,
|
|
547
555
|
edit=image_latents,
|
|
@@ -552,6 +560,7 @@ class QwenImagePipeline(BasePipeline):
|
|
|
552
560
|
entity_text=entity_prompt_embs,
|
|
553
561
|
entity_seq_lens=[mask.sum(dim=1) for mask in entity_prompt_emb_masks] if entity_prompt_emb_masks else None,
|
|
554
562
|
entity_masks=entity_masks,
|
|
563
|
+
attn_kwargs=attn_kwargs,
|
|
555
564
|
)
|
|
556
565
|
return noise_pred
|
|
557
566
|
|
|
@@ -181,7 +181,7 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
181
181
|
|
|
182
182
|
@classmethod
|
|
183
183
|
def from_state_dict(cls, state_dicts: SDXLStateDicts, config: SDXLPipelineConfig) -> "SDXLImagePipeline":
|
|
184
|
-
init_device = "cpu" if config.offload_mode else config.device
|
|
184
|
+
init_device = "cpu" if config.offload_mode is not None else config.device
|
|
185
185
|
tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
|
|
186
186
|
tokenizer_2 = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_2_CONF_PATH)
|
|
187
187
|
with LoRAContext():
|
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
|
|
1
5
|
def accumulate(result, new_item):
|
|
2
6
|
if result is None:
|
|
3
7
|
return new_item
|
|
@@ -17,3 +21,51 @@ def calculate_shift(
|
|
|
17
21
|
b = base_shift - m * base_seq_len
|
|
18
22
|
mu = image_seq_len * m + b
|
|
19
23
|
return mu
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def pad_and_concat(
|
|
27
|
+
tensor1: torch.Tensor,
|
|
28
|
+
tensor2: torch.Tensor,
|
|
29
|
+
concat_dim: int = 0,
|
|
30
|
+
pad_dim: int = 1,
|
|
31
|
+
) -> torch.Tensor:
|
|
32
|
+
"""
|
|
33
|
+
Concatenate two tensors along a specified dimension after padding along another dimension.
|
|
34
|
+
|
|
35
|
+
Assumes input tensors have shape (b, s, d), where:
|
|
36
|
+
- b: batch dimension
|
|
37
|
+
- s: sequence dimension (may differ)
|
|
38
|
+
- d: feature dimension
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
tensor1: First tensor with shape (b1, s1, d)
|
|
42
|
+
tensor2: Second tensor with shape (b2, s2, d)
|
|
43
|
+
concat_dim: Dimension to concatenate along, default is 0 (batch dimension)
|
|
44
|
+
pad_dim: Dimension to pad along, default is 1 (sequence dimension)
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Concatenated tensor, shape depends on concat_dim and pad_dim choices
|
|
48
|
+
"""
|
|
49
|
+
assert tensor1.dim() == tensor2.dim(), "Both tensors must have the same number of dimensions"
|
|
50
|
+
assert concat_dim != pad_dim, "concat_dim and pad_dim cannot be the same"
|
|
51
|
+
|
|
52
|
+
len1, len2 = tensor1.shape[pad_dim], tensor2.shape[pad_dim]
|
|
53
|
+
max_len = max(len1, len2)
|
|
54
|
+
|
|
55
|
+
# Calculate the position of pad_dim in the padding list
|
|
56
|
+
# Padding format: from the last dimension, each pair represents (dim_n_left, dim_n_right, ..., dim_0_left, dim_0_right)
|
|
57
|
+
ndim = tensor1.dim()
|
|
58
|
+
padding = [0] * (2 * ndim)
|
|
59
|
+
pad_right_idx = -2 * pad_dim - 1
|
|
60
|
+
|
|
61
|
+
if len1 < max_len:
|
|
62
|
+
pad_len = max_len - len1
|
|
63
|
+
padding[pad_right_idx] = pad_len
|
|
64
|
+
tensor1 = F.pad(tensor1, padding, mode="constant", value=0)
|
|
65
|
+
elif len2 < max_len:
|
|
66
|
+
pad_len = max_len - len2
|
|
67
|
+
padding[pad_right_idx] = pad_len
|
|
68
|
+
tensor2 = F.pad(tensor2, padding, mode="constant", value=0)
|
|
69
|
+
|
|
70
|
+
# Concatenate along the specified dimension
|
|
71
|
+
return torch.cat([tensor1, tensor2], dim=concat_dim)
|