diffsynth-engine 0.5.1.dev4__py3-none-any.whl → 0.6.1.dev25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +12 -0
- diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py +19 -0
- diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py +22 -6
- diffsynth_engine/conf/models/flux/flux_dit.json +20 -1
- diffsynth_engine/conf/models/flux/flux_vae.json +253 -5
- diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json +41 -0
- diffsynth_engine/configs/__init__.py +16 -1
- diffsynth_engine/configs/controlnet.py +13 -0
- diffsynth_engine/configs/pipeline.py +37 -11
- diffsynth_engine/models/base.py +1 -1
- diffsynth_engine/models/basic/attention.py +105 -43
- diffsynth_engine/models/basic/transformer_helper.py +36 -2
- diffsynth_engine/models/basic/video_sparse_attention.py +238 -0
- diffsynth_engine/models/flux/flux_controlnet.py +16 -30
- diffsynth_engine/models/flux/flux_dit.py +49 -62
- diffsynth_engine/models/flux/flux_dit_fbcache.py +26 -28
- diffsynth_engine/models/flux/flux_ipadapter.py +5 -5
- diffsynth_engine/models/flux/flux_text_encoder.py +1 -1
- diffsynth_engine/models/flux/flux_vae.py +20 -2
- diffsynth_engine/models/hunyuan3d/dino_image_encoder.py +4 -2
- diffsynth_engine/models/qwen_image/qwen2_5_vl.py +5 -0
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +151 -58
- diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +14 -6
- diffsynth_engine/models/qwen_image/qwen_image_vae.py +1 -1
- diffsynth_engine/models/sd/sd_text_encoder.py +1 -1
- diffsynth_engine/models/sd/sd_unet.py +1 -1
- diffsynth_engine/models/sd3/sd3_dit.py +1 -1
- diffsynth_engine/models/sd3/sd3_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_text_encoder.py +1 -1
- diffsynth_engine/models/sdxl/sdxl_unet.py +1 -1
- diffsynth_engine/models/vae/vae.py +1 -1
- diffsynth_engine/models/wan/wan_audio_encoder.py +6 -3
- diffsynth_engine/models/wan/wan_dit.py +65 -28
- diffsynth_engine/models/wan/wan_s2v_dit.py +1 -1
- diffsynth_engine/models/wan/wan_text_encoder.py +13 -13
- diffsynth_engine/models/wan/wan_vae.py +2 -2
- diffsynth_engine/pipelines/base.py +73 -7
- diffsynth_engine/pipelines/flux_image.py +139 -120
- diffsynth_engine/pipelines/hunyuan3d_shape.py +4 -0
- diffsynth_engine/pipelines/qwen_image.py +272 -87
- diffsynth_engine/pipelines/sdxl_image.py +1 -1
- diffsynth_engine/pipelines/utils.py +52 -0
- diffsynth_engine/pipelines/wan_s2v.py +25 -14
- diffsynth_engine/pipelines/wan_video.py +43 -19
- diffsynth_engine/tokenizers/base.py +6 -0
- diffsynth_engine/tokenizers/qwen2.py +12 -4
- diffsynth_engine/utils/constants.py +13 -12
- diffsynth_engine/utils/download.py +4 -2
- diffsynth_engine/utils/env.py +2 -0
- diffsynth_engine/utils/flag.py +6 -0
- diffsynth_engine/utils/loader.py +25 -6
- diffsynth_engine/utils/parallel.py +62 -29
- diffsynth_engine/utils/video.py +3 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/RECORD +69 -67
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-flf2v-14b.json → wan2.1_flf2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-i2v-14b.json → wan2.1_i2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-1.3b.json → wan2.1_t2v_1.3b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.1-t2v-14b.json → wan2.1_t2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-i2v-a14b.json → wan2.2_i2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-s2v-14b.json → wan2.2_s2v_14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-t2v-a14b.json → wan2.2_t2v_a14b.json} +0 -0
- /diffsynth_engine/conf/models/wan/dit/{wan2.2-ti2v-5b.json → wan2.2_ti2v_5b.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.1-vae.json → wan2.1_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan2.2-vae.json → wan2.2_vae.json} +0 -0
- /diffsynth_engine/conf/models/wan/vae/{wan-vae-keymap.json → wan_vae_keymap.json} +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.5.1.dev4.dist-info → diffsynth_engine-0.6.1.dev25.dist-info}/top_level.txt +0 -0
|
@@ -12,19 +12,15 @@ from diffsynth_engine.utils.flag import (
|
|
|
12
12
|
SDPA_AVAILABLE,
|
|
13
13
|
SAGE_ATTN_AVAILABLE,
|
|
14
14
|
SPARGE_ATTN_AVAILABLE,
|
|
15
|
+
VIDEO_SPARSE_ATTN_AVAILABLE,
|
|
15
16
|
)
|
|
17
|
+
from diffsynth_engine.utils.platform import DTYPE_FP8
|
|
16
18
|
|
|
17
19
|
FA3_MAX_HEADDIM = 256
|
|
18
20
|
|
|
19
21
|
logger = logging.get_logger(__name__)
|
|
20
22
|
|
|
21
23
|
|
|
22
|
-
def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
|
|
23
|
-
padding_size = (alignment - x.shape[dim] % alignment) % alignment
|
|
24
|
-
padded_x = F.pad(x, (0, padding_size), "constant", 0)
|
|
25
|
-
return padded_x[..., : x.shape[dim]]
|
|
26
|
-
|
|
27
|
-
|
|
28
24
|
if FLASH_ATTN_3_AVAILABLE:
|
|
29
25
|
from flash_attn_interface import flash_attn_func as flash_attn3
|
|
30
26
|
if FLASH_ATTN_2_AVAILABLE:
|
|
@@ -32,6 +28,11 @@ if FLASH_ATTN_2_AVAILABLE:
|
|
|
32
28
|
if XFORMERS_AVAILABLE:
|
|
33
29
|
from xformers.ops import memory_efficient_attention
|
|
34
30
|
|
|
31
|
+
def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8):
|
|
32
|
+
padding_size = (alignment - x.shape[dim] % alignment) % alignment
|
|
33
|
+
padded_x = F.pad(x, (0, padding_size), "constant", 0)
|
|
34
|
+
return padded_x[..., : x.shape[dim]]
|
|
35
|
+
|
|
35
36
|
def xformers_attn(q, k, v, attn_mask=None, scale=None):
|
|
36
37
|
if attn_mask is not None:
|
|
37
38
|
if attn_mask.ndim == 2:
|
|
@@ -93,6 +94,13 @@ if SPARGE_ATTN_AVAILABLE:
|
|
|
93
94
|
return out.transpose(1, 2)
|
|
94
95
|
|
|
95
96
|
|
|
97
|
+
if VIDEO_SPARSE_ATTN_AVAILABLE:
|
|
98
|
+
from diffsynth_engine.models.basic.video_sparse_attention import (
|
|
99
|
+
video_sparse_attn,
|
|
100
|
+
distributed_video_sparse_attn,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
96
104
|
def eager_attn(q, k, v, attn_mask=None, scale=None):
|
|
97
105
|
q = q.transpose(1, 2)
|
|
98
106
|
k = k.transpose(1, 2)
|
|
@@ -108,9 +116,10 @@ def eager_attn(q, k, v, attn_mask=None, scale=None):
|
|
|
108
116
|
|
|
109
117
|
|
|
110
118
|
def attention(
|
|
111
|
-
q,
|
|
112
|
-
k,
|
|
113
|
-
v,
|
|
119
|
+
q: torch.Tensor,
|
|
120
|
+
k: torch.Tensor,
|
|
121
|
+
v: torch.Tensor,
|
|
122
|
+
g: Optional[torch.Tensor] = None,
|
|
114
123
|
attn_impl: Optional[str] = "auto",
|
|
115
124
|
attn_mask: Optional[torch.Tensor] = None,
|
|
116
125
|
scale: Optional[float] = None,
|
|
@@ -125,12 +134,14 @@ def attention(
|
|
|
125
134
|
None,
|
|
126
135
|
"auto",
|
|
127
136
|
"eager",
|
|
128
|
-
"
|
|
129
|
-
"
|
|
137
|
+
"fa2",
|
|
138
|
+
"fa3",
|
|
139
|
+
"fa3_fp8",
|
|
130
140
|
"xformers",
|
|
131
141
|
"sdpa",
|
|
132
|
-
"
|
|
133
|
-
"
|
|
142
|
+
"sage",
|
|
143
|
+
"sparge",
|
|
144
|
+
"vsa",
|
|
134
145
|
]
|
|
135
146
|
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
|
|
136
147
|
if attn_impl is None or attn_impl == "auto":
|
|
@@ -139,9 +150,13 @@ def attention(
|
|
|
139
150
|
return flash_attn3(q, k, v, softmax_scale=scale)
|
|
140
151
|
else:
|
|
141
152
|
if not flash_attn3_compatible:
|
|
142
|
-
logger.warning(
|
|
153
|
+
logger.warning(
|
|
154
|
+
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
|
|
155
|
+
)
|
|
143
156
|
else:
|
|
144
|
-
logger.debug(
|
|
157
|
+
logger.debug(
|
|
158
|
+
"flash_attn_3 does not support attention mask, will use fallback attention implementation"
|
|
159
|
+
)
|
|
145
160
|
if XFORMERS_AVAILABLE:
|
|
146
161
|
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
147
162
|
if SDPA_AVAILABLE:
|
|
@@ -152,33 +167,55 @@ def attention(
|
|
|
152
167
|
else:
|
|
153
168
|
if attn_impl == "eager":
|
|
154
169
|
return eager_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
155
|
-
if attn_impl == "
|
|
170
|
+
if attn_impl == "fa3" or attn_impl == "fa3_fp8":
|
|
156
171
|
if not flash_attn3_compatible:
|
|
157
172
|
raise RuntimeError(
|
|
158
173
|
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
|
|
159
174
|
)
|
|
160
175
|
if attn_mask is not None:
|
|
161
176
|
raise RuntimeError("flash_attn_3 does not support attention mask")
|
|
162
|
-
|
|
163
|
-
|
|
177
|
+
if attn_impl == "fa3":
|
|
178
|
+
return flash_attn3(q, k, v, softmax_scale=scale)
|
|
179
|
+
else:
|
|
180
|
+
origin_dtype = q.dtype
|
|
181
|
+
q = q.to(dtype=DTYPE_FP8)
|
|
182
|
+
k = k.to(dtype=DTYPE_FP8)
|
|
183
|
+
v = v.to(dtype=DTYPE_FP8)
|
|
184
|
+
out = flash_attn3(q, k, v, softmax_scale=scale)
|
|
185
|
+
return out.to(dtype=origin_dtype)
|
|
186
|
+
if attn_impl == "fa2":
|
|
164
187
|
return flash_attn2(q, k, v, softmax_scale=scale)
|
|
165
188
|
if attn_impl == "xformers":
|
|
166
189
|
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
167
190
|
if attn_impl == "sdpa":
|
|
168
191
|
return sdpa_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
169
|
-
if attn_impl == "
|
|
192
|
+
if attn_impl == "sage":
|
|
170
193
|
return sage_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
171
|
-
if attn_impl == "
|
|
194
|
+
if attn_impl == "sparge":
|
|
172
195
|
return sparge_attn(
|
|
173
196
|
q,
|
|
174
197
|
k,
|
|
175
198
|
v,
|
|
176
199
|
attn_mask=attn_mask,
|
|
177
200
|
scale=scale,
|
|
178
|
-
smooth_k=kwargs.get("
|
|
179
|
-
simthreshd1=kwargs.get("
|
|
180
|
-
cdfthreshd=kwargs.get("
|
|
181
|
-
pvthreshd=kwargs.get("
|
|
201
|
+
smooth_k=kwargs.get("smooth_k", True),
|
|
202
|
+
simthreshd1=kwargs.get("simthreshd1", 0.6),
|
|
203
|
+
cdfthreshd=kwargs.get("cdfthreshd", 0.98),
|
|
204
|
+
pvthreshd=kwargs.get("pvthreshd", 50),
|
|
205
|
+
)
|
|
206
|
+
if attn_impl == "vsa":
|
|
207
|
+
return video_sparse_attn(
|
|
208
|
+
q,
|
|
209
|
+
k,
|
|
210
|
+
v,
|
|
211
|
+
g,
|
|
212
|
+
sparsity=kwargs.get("sparsity"),
|
|
213
|
+
num_tiles=kwargs.get("num_tiles"),
|
|
214
|
+
total_seq_length=kwargs.get("total_seq_length"),
|
|
215
|
+
tile_partition_indices=kwargs.get("tile_partition_indices"),
|
|
216
|
+
reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
|
|
217
|
+
variable_block_sizes=kwargs.get("variable_block_sizes"),
|
|
218
|
+
non_pad_index=kwargs.get("non_pad_index"),
|
|
182
219
|
)
|
|
183
220
|
raise ValueError(f"Invalid attention implementation: {attn_impl}")
|
|
184
221
|
|
|
@@ -228,9 +265,10 @@ class Attention(nn.Module):
|
|
|
228
265
|
|
|
229
266
|
|
|
230
267
|
def long_context_attention(
|
|
231
|
-
q,
|
|
232
|
-
k,
|
|
233
|
-
v,
|
|
268
|
+
q: torch.Tensor,
|
|
269
|
+
k: torch.Tensor,
|
|
270
|
+
v: torch.Tensor,
|
|
271
|
+
g: Optional[torch.Tensor] = None,
|
|
234
272
|
attn_impl: Optional[str] = None,
|
|
235
273
|
attn_mask: Optional[torch.Tensor] = None,
|
|
236
274
|
scale: Optional[float] = None,
|
|
@@ -247,12 +285,15 @@ def long_context_attention(
|
|
|
247
285
|
assert attn_impl in [
|
|
248
286
|
None,
|
|
249
287
|
"auto",
|
|
250
|
-
"
|
|
251
|
-
"
|
|
288
|
+
"fa2",
|
|
289
|
+
"fa3",
|
|
290
|
+
"fa3_fp8",
|
|
252
291
|
"sdpa",
|
|
253
|
-
"
|
|
254
|
-
"
|
|
292
|
+
"sage",
|
|
293
|
+
"sparge",
|
|
294
|
+
"vsa",
|
|
255
295
|
]
|
|
296
|
+
assert attn_mask is None, "long context attention does not support attention mask"
|
|
256
297
|
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
|
|
257
298
|
if attn_impl is None or attn_impl == "auto":
|
|
258
299
|
if FLASH_ATTN_3_AVAILABLE:
|
|
@@ -268,27 +309,48 @@ def long_context_attention(
|
|
|
268
309
|
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
|
|
269
310
|
raise ValueError("No available long context attention implementation")
|
|
270
311
|
else:
|
|
271
|
-
if attn_impl == "
|
|
272
|
-
if flash_attn3_compatible:
|
|
273
|
-
return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
|
|
274
|
-
else:
|
|
312
|
+
if attn_impl == "fa3" or attn_impl == "fa3_fp8":
|
|
313
|
+
if not flash_attn3_compatible:
|
|
275
314
|
raise RuntimeError(
|
|
276
315
|
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
|
|
277
316
|
)
|
|
278
|
-
|
|
317
|
+
if attn_impl == "fa3":
|
|
318
|
+
return LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
|
|
319
|
+
|
|
320
|
+
origin_dtype = q.dtype
|
|
321
|
+
q = q.to(dtype=DTYPE_FP8)
|
|
322
|
+
k = k.to(dtype=DTYPE_FP8)
|
|
323
|
+
v = v.to(dtype=DTYPE_FP8)
|
|
324
|
+
out = LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
|
|
325
|
+
return out.to(dtype=origin_dtype)
|
|
326
|
+
if attn_impl == "fa2":
|
|
279
327
|
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
|
|
280
328
|
if attn_impl == "sdpa":
|
|
281
329
|
return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
|
|
282
|
-
if attn_impl == "
|
|
283
|
-
return LongContextAttention(attn_type=AttnType.
|
|
284
|
-
if attn_impl == "
|
|
330
|
+
if attn_impl == "sage":
|
|
331
|
+
return LongContextAttention(attn_type=AttnType.SAGE_AUTO)(q, k, v, softmax_scale=scale)
|
|
332
|
+
if attn_impl == "sparge":
|
|
285
333
|
attn_processor = SparseAttentionMeansim()
|
|
286
334
|
# default args from spas_sage2_attn_meansim_cuda
|
|
287
|
-
attn_processor.smooth_k = torch.tensor(kwargs.get("
|
|
288
|
-
attn_processor.simthreshd1 = torch.tensor(kwargs.get("
|
|
289
|
-
attn_processor.cdfthreshd = torch.tensor(kwargs.get("
|
|
290
|
-
attn_processor.pvthreshd = torch.tensor(kwargs.get("
|
|
335
|
+
attn_processor.smooth_k = torch.tensor(kwargs.get("smooth_k", True))
|
|
336
|
+
attn_processor.simthreshd1 = torch.tensor(kwargs.get("simthreshd1", 0.6))
|
|
337
|
+
attn_processor.cdfthreshd = torch.tensor(kwargs.get("cdfthreshd", 0.98))
|
|
338
|
+
attn_processor.pvthreshd = torch.tensor(kwargs.get("pvthreshd", 50))
|
|
291
339
|
return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)(
|
|
292
340
|
q, k, v, softmax_scale=scale
|
|
293
341
|
)
|
|
342
|
+
if attn_impl == "vsa":
|
|
343
|
+
return distributed_video_sparse_attn(
|
|
344
|
+
q,
|
|
345
|
+
k,
|
|
346
|
+
v,
|
|
347
|
+
g,
|
|
348
|
+
sparsity=kwargs.get("sparsity"),
|
|
349
|
+
num_tiles=kwargs.get("num_tiles"),
|
|
350
|
+
total_seq_length=kwargs.get("total_seq_length"),
|
|
351
|
+
tile_partition_indices=kwargs.get("tile_partition_indices"),
|
|
352
|
+
reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"),
|
|
353
|
+
variable_block_sizes=kwargs.get("variable_block_sizes"),
|
|
354
|
+
non_pad_index=kwargs.get("non_pad_index"),
|
|
355
|
+
)
|
|
294
356
|
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
3
4
|
import math
|
|
4
5
|
|
|
5
6
|
|
|
@@ -91,8 +92,8 @@ class NewGELUActivation(nn.Module):
|
|
|
91
92
|
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
|
92
93
|
"""
|
|
93
94
|
|
|
94
|
-
def forward(self,
|
|
95
|
-
return 0.5 *
|
|
95
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
96
|
+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
|
96
97
|
|
|
97
98
|
|
|
98
99
|
class ApproximateGELU(nn.Module):
|
|
@@ -115,3 +116,36 @@ class ApproximateGELU(nn.Module):
|
|
|
115
116
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
116
117
|
x = self.proj(x)
|
|
117
118
|
return x * torch.sigmoid(1.702 * x)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class GELU(nn.Module):
|
|
122
|
+
r"""
|
|
123
|
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
|
124
|
+
|
|
125
|
+
Parameters:
|
|
126
|
+
dim_in (`int`): The number of channels in the input.
|
|
127
|
+
dim_out (`int`): The number of channels in the output.
|
|
128
|
+
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
|
129
|
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
def __init__(
|
|
133
|
+
self,
|
|
134
|
+
dim_in: int,
|
|
135
|
+
dim_out: int,
|
|
136
|
+
approximate: str = "none",
|
|
137
|
+
bias: bool = True,
|
|
138
|
+
device: str = "cuda:0",
|
|
139
|
+
dtype: torch.dtype = torch.float16,
|
|
140
|
+
):
|
|
141
|
+
super().__init__()
|
|
142
|
+
self.proj = nn.Linear(dim_in, dim_out, bias=bias, device=device, dtype=dtype)
|
|
143
|
+
self.approximate = approximate
|
|
144
|
+
|
|
145
|
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
|
146
|
+
return F.gelu(gate, approximate=self.approximate)
|
|
147
|
+
|
|
148
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
149
|
+
x = self.proj(x)
|
|
150
|
+
x = self.gelu(x)
|
|
151
|
+
return x
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import math
|
|
3
|
+
import functools
|
|
4
|
+
|
|
5
|
+
from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE
|
|
6
|
+
from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
|
|
7
|
+
|
|
8
|
+
if VIDEO_SPARSE_ATTN_AVAILABLE:
|
|
9
|
+
from vsa import video_sparse_attn as vsa_core
|
|
10
|
+
|
|
11
|
+
VSA_TILE_SIZE = (4, 4, 4)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@functools.lru_cache(maxsize=10)
|
|
15
|
+
def get_tile_partition_indices(
|
|
16
|
+
dit_seq_shape: tuple[int, int, int],
|
|
17
|
+
tile_size: tuple[int, int, int],
|
|
18
|
+
device: torch.device,
|
|
19
|
+
) -> torch.LongTensor:
|
|
20
|
+
T, H, W = dit_seq_shape
|
|
21
|
+
ts, hs, ws = tile_size
|
|
22
|
+
indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
|
|
23
|
+
ls = []
|
|
24
|
+
for t in range(math.ceil(T / ts)):
|
|
25
|
+
for h in range(math.ceil(H / hs)):
|
|
26
|
+
for w in range(math.ceil(W / ws)):
|
|
27
|
+
ls.append(
|
|
28
|
+
indices[
|
|
29
|
+
t * ts : min(t * ts + ts, T), h * hs : min(h * hs + hs, H), w * ws : min(w * ws + ws, W)
|
|
30
|
+
].flatten()
|
|
31
|
+
)
|
|
32
|
+
index = torch.cat(ls, dim=0)
|
|
33
|
+
return index
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@functools.lru_cache(maxsize=10)
|
|
37
|
+
def get_reverse_tile_partition_indices(
|
|
38
|
+
dit_seq_shape: tuple[int, int, int],
|
|
39
|
+
tile_size: tuple[int, int, int],
|
|
40
|
+
device: torch.device,
|
|
41
|
+
) -> torch.LongTensor:
|
|
42
|
+
return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@functools.lru_cache(maxsize=10)
|
|
46
|
+
def construct_variable_block_sizes(
|
|
47
|
+
dit_seq_shape: tuple[int, int, int],
|
|
48
|
+
num_tiles: tuple[int, int, int],
|
|
49
|
+
device: torch.device,
|
|
50
|
+
) -> torch.LongTensor:
|
|
51
|
+
"""
|
|
52
|
+
Compute the number of valid (non-padded) tokens inside every
|
|
53
|
+
(ts_t x ts_h x ts_w) tile after padding -- flattened in the order
|
|
54
|
+
(t-tile, h-tile, w-tile) that `rearrange` uses.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
torch.LongTensor # shape: [∏ full_window_size]
|
|
59
|
+
"""
|
|
60
|
+
# unpack
|
|
61
|
+
t, h, w = dit_seq_shape
|
|
62
|
+
ts_t, ts_h, ts_w = VSA_TILE_SIZE
|
|
63
|
+
n_t, n_h, n_w = num_tiles
|
|
64
|
+
|
|
65
|
+
def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
|
|
66
|
+
"""Vector with the size of each tile along one dimension."""
|
|
67
|
+
sizes = torch.full((n_tiles,), tile, dtype=torch.int, device=device)
|
|
68
|
+
# size of last (possibly partial) tile
|
|
69
|
+
remainder = dim_len - (n_tiles - 1) * tile
|
|
70
|
+
sizes[-1] = remainder if remainder > 0 else tile
|
|
71
|
+
return sizes
|
|
72
|
+
|
|
73
|
+
t_sizes = _sizes(t, ts_t, n_t) # [n_t]
|
|
74
|
+
h_sizes = _sizes(h, ts_h, n_h) # [n_h]
|
|
75
|
+
w_sizes = _sizes(w, ts_w, n_w) # [n_w]
|
|
76
|
+
|
|
77
|
+
# broadcast‑multiply to get voxels per tile, then flatten
|
|
78
|
+
block_sizes = (
|
|
79
|
+
t_sizes[:, None, None] # [n_t, 1, 1]
|
|
80
|
+
* h_sizes[None, :, None] # [1, n_h, 1]
|
|
81
|
+
* w_sizes[None, None, :] # [1, 1, n_w]
|
|
82
|
+
).reshape(-1) # [n_t * n_h * n_w]
|
|
83
|
+
|
|
84
|
+
return block_sizes
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@functools.lru_cache(maxsize=10)
|
|
88
|
+
def get_non_pad_index(
|
|
89
|
+
variable_block_sizes: torch.LongTensor,
|
|
90
|
+
max_block_size: int,
|
|
91
|
+
):
|
|
92
|
+
n_win = variable_block_sizes.shape[0]
|
|
93
|
+
device = variable_block_sizes.device
|
|
94
|
+
starts_pad = torch.arange(n_win, device=device) * max_block_size
|
|
95
|
+
index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
|
|
96
|
+
index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
|
|
97
|
+
return index_pad[index_mask]
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def get_vsa_kwargs(
|
|
101
|
+
latent_shape: tuple[int, int, int],
|
|
102
|
+
patch_size: tuple[int, int, int],
|
|
103
|
+
sparsity: float,
|
|
104
|
+
device: torch.device,
|
|
105
|
+
):
|
|
106
|
+
dit_seq_shape = (
|
|
107
|
+
latent_shape[0] // patch_size[0],
|
|
108
|
+
latent_shape[1] // patch_size[1],
|
|
109
|
+
latent_shape[2] // patch_size[2],
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
num_tiles = (
|
|
113
|
+
math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]),
|
|
114
|
+
math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
|
|
115
|
+
math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]),
|
|
116
|
+
)
|
|
117
|
+
total_seq_length = math.prod(dit_seq_shape)
|
|
118
|
+
|
|
119
|
+
tile_partition_indices = get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
|
|
120
|
+
reverse_tile_partition_indices = get_reverse_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
|
|
121
|
+
variable_block_sizes = construct_variable_block_sizes(dit_seq_shape, num_tiles, device)
|
|
122
|
+
non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE))
|
|
123
|
+
|
|
124
|
+
return {
|
|
125
|
+
"sparsity": sparsity,
|
|
126
|
+
"num_tiles": num_tiles,
|
|
127
|
+
"total_seq_length": total_seq_length,
|
|
128
|
+
"tile_partition_indices": tile_partition_indices,
|
|
129
|
+
"reverse_tile_partition_indices": reverse_tile_partition_indices,
|
|
130
|
+
"variable_block_sizes": variable_block_sizes,
|
|
131
|
+
"non_pad_index": non_pad_index,
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def tile(
|
|
136
|
+
x: torch.Tensor,
|
|
137
|
+
num_tiles: tuple[int, int, int],
|
|
138
|
+
tile_partition_indices: torch.LongTensor,
|
|
139
|
+
non_pad_index: torch.LongTensor,
|
|
140
|
+
) -> torch.Tensor:
|
|
141
|
+
t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
|
|
142
|
+
h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
|
|
143
|
+
w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
|
|
144
|
+
|
|
145
|
+
x_padded = torch.zeros(
|
|
146
|
+
(x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]),
|
|
147
|
+
device=x.device,
|
|
148
|
+
dtype=x.dtype,
|
|
149
|
+
)
|
|
150
|
+
x_padded[:, non_pad_index] = x[:, tile_partition_indices]
|
|
151
|
+
return x_padded
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def untile(
|
|
155
|
+
x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor
|
|
156
|
+
) -> torch.Tensor:
|
|
157
|
+
x = x[:, non_pad_index][:, reverse_tile_partition_indices]
|
|
158
|
+
return x
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def video_sparse_attn(
|
|
162
|
+
q: torch.Tensor,
|
|
163
|
+
k: torch.Tensor,
|
|
164
|
+
v: torch.Tensor,
|
|
165
|
+
g: torch.Tensor,
|
|
166
|
+
sparsity: float,
|
|
167
|
+
num_tiles: tuple[int, int, int],
|
|
168
|
+
total_seq_length: int,
|
|
169
|
+
tile_partition_indices: torch.LongTensor,
|
|
170
|
+
reverse_tile_partition_indices: torch.LongTensor,
|
|
171
|
+
variable_block_sizes: torch.LongTensor,
|
|
172
|
+
non_pad_index: torch.LongTensor,
|
|
173
|
+
):
|
|
174
|
+
q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
|
|
175
|
+
k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
|
|
176
|
+
v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
|
|
177
|
+
g = tile(g, num_tiles, tile_partition_indices, non_pad_index)
|
|
178
|
+
|
|
179
|
+
q = q.transpose(1, 2).contiguous()
|
|
180
|
+
k = k.transpose(1, 2).contiguous()
|
|
181
|
+
v = v.transpose(1, 2).contiguous()
|
|
182
|
+
g = g.transpose(1, 2).contiguous()
|
|
183
|
+
|
|
184
|
+
topk = math.ceil((1 - sparsity) * (total_seq_length / math.prod(VSA_TILE_SIZE)))
|
|
185
|
+
out = vsa_core(
|
|
186
|
+
q,
|
|
187
|
+
k,
|
|
188
|
+
v,
|
|
189
|
+
variable_block_sizes=variable_block_sizes,
|
|
190
|
+
topk=topk,
|
|
191
|
+
block_size=VSA_TILE_SIZE,
|
|
192
|
+
compress_attn_weight=g,
|
|
193
|
+
).transpose(1, 2)
|
|
194
|
+
out = untile(out, reverse_tile_partition_indices, non_pad_index)
|
|
195
|
+
return out
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def distributed_video_sparse_attn(
|
|
199
|
+
q: torch.Tensor,
|
|
200
|
+
k: torch.Tensor,
|
|
201
|
+
v: torch.Tensor,
|
|
202
|
+
g: torch.Tensor,
|
|
203
|
+
sparsity: float,
|
|
204
|
+
num_tiles: tuple[int, int, int],
|
|
205
|
+
total_seq_length: int,
|
|
206
|
+
tile_partition_indices: torch.LongTensor,
|
|
207
|
+
reverse_tile_partition_indices: torch.LongTensor,
|
|
208
|
+
variable_block_sizes: torch.LongTensor,
|
|
209
|
+
non_pad_index: torch.LongTensor,
|
|
210
|
+
scatter_idx: int = 2,
|
|
211
|
+
gather_idx: int = 1,
|
|
212
|
+
):
|
|
213
|
+
from yunchang.comm.all_to_all import SeqAllToAll4D
|
|
214
|
+
|
|
215
|
+
assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1"
|
|
216
|
+
sp_ulysses_group = get_sp_ulysses_group()
|
|
217
|
+
|
|
218
|
+
q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
|
|
219
|
+
k = SeqAllToAll4D.apply(sp_ulysses_group, k, scatter_idx, gather_idx)
|
|
220
|
+
v = SeqAllToAll4D.apply(sp_ulysses_group, v, scatter_idx, gather_idx)
|
|
221
|
+
g = SeqAllToAll4D.apply(sp_ulysses_group, g, scatter_idx, gather_idx)
|
|
222
|
+
|
|
223
|
+
out = video_sparse_attn(
|
|
224
|
+
q,
|
|
225
|
+
k,
|
|
226
|
+
v,
|
|
227
|
+
g,
|
|
228
|
+
sparsity,
|
|
229
|
+
num_tiles,
|
|
230
|
+
total_seq_length,
|
|
231
|
+
tile_partition_indices,
|
|
232
|
+
reverse_tile_partition_indices,
|
|
233
|
+
variable_block_sizes,
|
|
234
|
+
non_pad_index,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
out = SeqAllToAll4D.apply(sp_ulysses_group, out, gather_idx, scatter_idx)
|
|
238
|
+
return out
|
|
@@ -86,7 +86,6 @@ class FluxControlNet(PreTrainedModel):
|
|
|
86
86
|
def __init__(
|
|
87
87
|
self,
|
|
88
88
|
condition_channels: int = 64,
|
|
89
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
90
89
|
device: str = "cuda:0",
|
|
91
90
|
dtype: torch.dtype = torch.bfloat16,
|
|
92
91
|
):
|
|
@@ -103,10 +102,7 @@ class FluxControlNet(PreTrainedModel):
|
|
|
103
102
|
self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype)
|
|
104
103
|
self.controlnet_x_embedder = nn.Linear(condition_channels, 3072)
|
|
105
104
|
self.blocks = nn.ModuleList(
|
|
106
|
-
[
|
|
107
|
-
FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
|
|
108
|
-
for _ in range(6)
|
|
109
|
-
]
|
|
105
|
+
[FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(6)]
|
|
110
106
|
)
|
|
111
107
|
# controlnet projection
|
|
112
108
|
self.blocks_proj = nn.ModuleList(
|
|
@@ -119,18 +115,17 @@ class FluxControlNet(PreTrainedModel):
|
|
|
119
115
|
|
|
120
116
|
def forward(
|
|
121
117
|
self,
|
|
122
|
-
hidden_states,
|
|
123
|
-
control_condition,
|
|
124
|
-
control_scale,
|
|
125
|
-
timestep,
|
|
126
|
-
prompt_emb,
|
|
127
|
-
pooled_prompt_emb,
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
118
|
+
hidden_states: torch.Tensor,
|
|
119
|
+
control_condition: torch.Tensor,
|
|
120
|
+
control_scale: float,
|
|
121
|
+
timestep: torch.Tensor,
|
|
122
|
+
prompt_emb: torch.Tensor,
|
|
123
|
+
pooled_prompt_emb: torch.Tensor,
|
|
124
|
+
image_ids: torch.Tensor,
|
|
125
|
+
text_ids: torch.Tensor,
|
|
126
|
+
guidance: torch.Tensor,
|
|
127
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
131
128
|
):
|
|
132
|
-
hidden_states = self.patchify(hidden_states)
|
|
133
|
-
control_condition = self.patchify(control_condition)
|
|
134
129
|
hidden_states = self.x_embedder(hidden_states) + self.controlnet_x_embedder(control_condition)
|
|
135
130
|
condition = (
|
|
136
131
|
self.time_embedder(timestep, hidden_states.dtype)
|
|
@@ -143,7 +138,9 @@ class FluxControlNet(PreTrainedModel):
|
|
|
143
138
|
# double block
|
|
144
139
|
double_block_outputs = []
|
|
145
140
|
for i, block in enumerate(self.blocks):
|
|
146
|
-
hidden_states, prompt_emb = block(
|
|
141
|
+
hidden_states, prompt_emb = block(
|
|
142
|
+
hidden_states, prompt_emb, condition, image_rotary_emb, attn_kwargs=attn_kwargs
|
|
143
|
+
)
|
|
147
144
|
double_block_outputs.append(self.blocks_proj[i](hidden_states))
|
|
148
145
|
|
|
149
146
|
# apply control scale
|
|
@@ -151,24 +148,13 @@ class FluxControlNet(PreTrainedModel):
|
|
|
151
148
|
return double_block_outputs, None
|
|
152
149
|
|
|
153
150
|
@classmethod
|
|
154
|
-
def from_state_dict(
|
|
155
|
-
cls,
|
|
156
|
-
state_dict: Dict[str, torch.Tensor],
|
|
157
|
-
device: str,
|
|
158
|
-
dtype: torch.dtype,
|
|
159
|
-
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
160
|
-
):
|
|
151
|
+
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
|
|
161
152
|
if "controlnet_x_embedder.weight" in state_dict:
|
|
162
153
|
condition_channels = state_dict["controlnet_x_embedder.weight"].shape[1]
|
|
163
154
|
else:
|
|
164
155
|
condition_channels = 64
|
|
165
156
|
|
|
166
|
-
model = cls(
|
|
167
|
-
condition_channels=condition_channels,
|
|
168
|
-
attn_kwargs=attn_kwargs,
|
|
169
|
-
device="meta",
|
|
170
|
-
dtype=dtype,
|
|
171
|
-
)
|
|
157
|
+
model = cls(condition_channels=condition_channels, device="meta", dtype=dtype)
|
|
172
158
|
model.requires_grad_(False)
|
|
173
159
|
model.load_state_dict(state_dict, assign=True)
|
|
174
160
|
model.to(device=device, dtype=dtype, non_blocking=True)
|