diffsynth-engine 0.6.1.dev25__py3-none-any.whl → 0.6.1.dev26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/configs/pipeline.py +2 -0
- diffsynth_engine/models/basic/attention.py +52 -0
- diffsynth_engine/utils/flag.py +5 -0
- {diffsynth_engine-0.6.1.dev25.dist-info → diffsynth_engine-0.6.1.dev26.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev25.dist-info → diffsynth_engine-0.6.1.dev26.dist-info}/RECORD +8 -8
- {diffsynth_engine-0.6.1.dev25.dist-info → diffsynth_engine-0.6.1.dev26.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev25.dist-info → diffsynth_engine-0.6.1.dev26.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev25.dist-info → diffsynth_engine-0.6.1.dev26.dist-info}/top_level.txt +0 -0
|
@@ -26,6 +26,8 @@ class AttnImpl(Enum):
|
|
|
26
26
|
FA2 = "fa2" # Flash Attention 2
|
|
27
27
|
FA3 = "fa3" # Flash Attention 3
|
|
28
28
|
FA3_FP8 = "fa3_fp8" # Flash Attention 3 with FP8
|
|
29
|
+
AITER = "aiter" # Aiter Flash Attention
|
|
30
|
+
AITER_FP8 = "aiter_fp8" # Aiter Flash Attention with FP8
|
|
29
31
|
XFORMERS = "xformers" # XFormers
|
|
30
32
|
SDPA = "sdpa" # Scaled Dot Product Attention
|
|
31
33
|
SAGE = "sage" # Sage Attention
|
|
@@ -13,6 +13,7 @@ from diffsynth_engine.utils.flag import (
|
|
|
13
13
|
SAGE_ATTN_AVAILABLE,
|
|
14
14
|
SPARGE_ATTN_AVAILABLE,
|
|
15
15
|
VIDEO_SPARSE_ATTN_AVAILABLE,
|
|
16
|
+
AITER_AVAILABLE,
|
|
16
17
|
)
|
|
17
18
|
from diffsynth_engine.utils.platform import DTYPE_FP8
|
|
18
19
|
|
|
@@ -93,6 +94,9 @@ if SPARGE_ATTN_AVAILABLE:
|
|
|
93
94
|
)
|
|
94
95
|
return out.transpose(1, 2)
|
|
95
96
|
|
|
97
|
+
if AITER_AVAILABLE:
|
|
98
|
+
from aiter import flash_attn_func as aiter_flash_attn
|
|
99
|
+
from aiter import flash_attn_fp8_pertensor_func as aiter_flash_attn_fp8
|
|
96
100
|
|
|
97
101
|
if VIDEO_SPARSE_ATTN_AVAILABLE:
|
|
98
102
|
from diffsynth_engine.models.basic.video_sparse_attention import (
|
|
@@ -137,6 +141,8 @@ def attention(
|
|
|
137
141
|
"fa2",
|
|
138
142
|
"fa3",
|
|
139
143
|
"fa3_fp8",
|
|
144
|
+
"aiter",
|
|
145
|
+
"aiter_fp8",
|
|
140
146
|
"xformers",
|
|
141
147
|
"sdpa",
|
|
142
148
|
"sage",
|
|
@@ -157,6 +163,13 @@ def attention(
|
|
|
157
163
|
logger.debug(
|
|
158
164
|
"flash_attn_3 does not support attention mask, will use fallback attention implementation"
|
|
159
165
|
)
|
|
166
|
+
if AITER_AVAILABLE:
|
|
167
|
+
if flash_attn3_compatible:
|
|
168
|
+
return aiter_flash_attn(q, k, v, softmax_scale=scale)
|
|
169
|
+
else:
|
|
170
|
+
logger.warning(
|
|
171
|
+
f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
|
|
172
|
+
)
|
|
160
173
|
if XFORMERS_AVAILABLE:
|
|
161
174
|
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
162
175
|
if SDPA_AVAILABLE:
|
|
@@ -183,6 +196,22 @@ def attention(
|
|
|
183
196
|
v = v.to(dtype=DTYPE_FP8)
|
|
184
197
|
out = flash_attn3(q, k, v, softmax_scale=scale)
|
|
185
198
|
return out.to(dtype=origin_dtype)
|
|
199
|
+
if attn_impl == "aiter" or attn_impl == "aiter_fp8":
|
|
200
|
+
if not flash_attn3_compatible:
|
|
201
|
+
raise RuntimeError(
|
|
202
|
+
f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}"
|
|
203
|
+
)
|
|
204
|
+
if attn_mask is not None:
|
|
205
|
+
raise RuntimeError("aiter_flash_attn does not support attention mask")
|
|
206
|
+
if attn_impl == "aiter" :
|
|
207
|
+
return aiter_flash_attn(q, k, v, softmax_scale=scale)
|
|
208
|
+
else:
|
|
209
|
+
origin_dtype = q.dtype
|
|
210
|
+
q = q.to(dtype=DTYPE_FP8)
|
|
211
|
+
k = k.to(dtype=DTYPE_FP8)
|
|
212
|
+
v = v.to(dtype=DTYPE_FP8)
|
|
213
|
+
out = aiter_flash_attn_fp8(q, k, v, softmax_scale=scale)
|
|
214
|
+
return out.to(dtype=origin_dtype)
|
|
186
215
|
if attn_impl == "fa2":
|
|
187
216
|
return flash_attn2(q, k, v, softmax_scale=scale)
|
|
188
217
|
if attn_impl == "xformers":
|
|
@@ -288,6 +317,8 @@ def long_context_attention(
|
|
|
288
317
|
"fa2",
|
|
289
318
|
"fa3",
|
|
290
319
|
"fa3_fp8",
|
|
320
|
+
"aiter",
|
|
321
|
+
"aiter_fp8",
|
|
291
322
|
"sdpa",
|
|
292
323
|
"sage",
|
|
293
324
|
"sparge",
|
|
@@ -303,6 +334,13 @@ def long_context_attention(
|
|
|
303
334
|
logger.warning(
|
|
304
335
|
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
|
|
305
336
|
)
|
|
337
|
+
if AITER_AVAILABLE:
|
|
338
|
+
if flash_attn3_compatible:
|
|
339
|
+
return LongContextAttention(attn_type=AttnType.AITER)(q, k, v, softmax_scale=scale)
|
|
340
|
+
else:
|
|
341
|
+
logger.warning(
|
|
342
|
+
f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
|
|
343
|
+
)
|
|
306
344
|
if SDPA_AVAILABLE:
|
|
307
345
|
return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
|
|
308
346
|
if FLASH_ATTN_2_AVAILABLE:
|
|
@@ -323,6 +361,20 @@ def long_context_attention(
|
|
|
323
361
|
v = v.to(dtype=DTYPE_FP8)
|
|
324
362
|
out = LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
|
|
325
363
|
return out.to(dtype=origin_dtype)
|
|
364
|
+
if attn_impl == "aiter" or attn_impl == "aiter_fp8":
|
|
365
|
+
if not flash_attn3_compatible:
|
|
366
|
+
raise RuntimeError(
|
|
367
|
+
f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}"
|
|
368
|
+
)
|
|
369
|
+
if attn_impl == "aiter":
|
|
370
|
+
return LongContextAttention(attn_type=AttnType.AITER)(q, k, v, softmax_scale=scale)
|
|
371
|
+
|
|
372
|
+
origin_dtype = q.dtype
|
|
373
|
+
q = q.to(dtype=DTYPE_FP8)
|
|
374
|
+
k = k.to(dtype=DTYPE_FP8)
|
|
375
|
+
v = v.to(dtype=DTYPE_FP8)
|
|
376
|
+
out = LongContextAttention(attn_type=AttnType.AITER)(q, k, v, softmax_scale=scale)
|
|
377
|
+
return out.to(dtype=origin_dtype)
|
|
326
378
|
if attn_impl == "fa2":
|
|
327
379
|
return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
|
|
328
380
|
if attn_impl == "sdpa":
|
diffsynth_engine/utils/flag.py
CHANGED
|
@@ -31,6 +31,11 @@ if SDPA_AVAILABLE:
|
|
|
31
31
|
else:
|
|
32
32
|
logger.info("Torch SDPA is not available")
|
|
33
33
|
|
|
34
|
+
AITER_AVAILABLE = importlib.util.find_spec("aiter") is not None
|
|
35
|
+
if AITER_AVAILABLE:
|
|
36
|
+
logger.info("Aiter is available")
|
|
37
|
+
else:
|
|
38
|
+
logger.info("Aiter is not available")
|
|
34
39
|
|
|
35
40
|
# 有损
|
|
36
41
|
SAGE_ATTN_AVAILABLE = importlib.util.find_spec("sageattention") is not None
|
|
@@ -81,12 +81,12 @@ diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json,sha256=bhl7TT29cdoU
|
|
|
81
81
|
diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json,sha256=7Zo6iw-qcacKMoR-BDX-A25uES1N9O23u0ipIeNE3AU,61728
|
|
82
82
|
diffsynth_engine/configs/__init__.py,sha256=vSjJToEdq3JX7t81_z4nwNwIdD4bYnFjxnMZH7PXMKo,1309
|
|
83
83
|
diffsynth_engine/configs/controlnet.py,sha256=f3vclyP3lcAjxDGD9C1vevhqqQ7W2LL_c6Wye0uxk3Q,1180
|
|
84
|
-
diffsynth_engine/configs/pipeline.py,sha256=
|
|
84
|
+
diffsynth_engine/configs/pipeline.py,sha256=ADgWJa7bA3Z3Z1JtVLgmt4N3eS1KRp9yHu1QvTBzTm0,13404
|
|
85
85
|
diffsynth_engine/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
86
86
|
diffsynth_engine/models/__init__.py,sha256=8Ze7cSE8InetgXWTNb0neVA2Q44K7WlE-h7O-02m2sY,119
|
|
87
87
|
diffsynth_engine/models/base.py,sha256=BA5vgMqfy_cjuL2OtXbrFD-Qg5xQnaumHpj5TabwSy8,2559
|
|
88
88
|
diffsynth_engine/models/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
89
|
-
diffsynth_engine/models/basic/attention.py,sha256=
|
|
89
|
+
diffsynth_engine/models/basic/attention.py,sha256=mvgk8LTqFwgtPdBeRv797IZNg9k7--X9wD92Hcr188c,15682
|
|
90
90
|
diffsynth_engine/models/basic/lora.py,sha256=PT-A3pwIuUrW2w3TnNlBPb1KRj70QYiBaoCvLnkR5cs,10652
|
|
91
91
|
diffsynth_engine/models/basic/relative_position_emb.py,sha256=rCXOweZMcayVnNUVvBcYXMdhHS257B_PC8PZSWxvhNQ,2540
|
|
92
92
|
diffsynth_engine/models/basic/timestep.py,sha256=WJODYqkSXEM0wcS42YkkfrGwxWt0e60zMTkDdUBQqBw,2810
|
|
@@ -171,7 +171,7 @@ diffsynth_engine/utils/cache.py,sha256=Ivef22pCuhEq-4H00gSvkLS8ceVZoGis7OSitYL6g
|
|
|
171
171
|
diffsynth_engine/utils/constants.py,sha256=sJio3Vy8i0-PWYRnqquYt6ez9k6Tc9JdjCv6pn2BU_4,3551
|
|
172
172
|
diffsynth_engine/utils/download.py,sha256=w9QQjllPfTUEY371UTREU7o_vvdMY-Q2DymDel3ZEZY,6792
|
|
173
173
|
diffsynth_engine/utils/env.py,sha256=k749eYt_qKGq38GocDiXfkhp8nZrowFefNVTZ8R755I,363
|
|
174
|
-
diffsynth_engine/utils/flag.py,sha256=
|
|
174
|
+
diffsynth_engine/utils/flag.py,sha256=v9GcRFYiNMonD9qmDLWdbXONuF-AcQ_KABPFtRZd0Tc,1767
|
|
175
175
|
diffsynth_engine/utils/fp8_linear.py,sha256=k34YFWo2dc3t8aKjHaCW9CbQMOTqXxaDHk8aw8aKif4,3857
|
|
176
176
|
diffsynth_engine/utils/gguf.py,sha256=ZWvw46V4g4uVyAR_oCq-4K5nPdKVrYk3u47uXMgA9lU,14092
|
|
177
177
|
diffsynth_engine/utils/image.py,sha256=PiDButjv0fsRS23kpQgCLZAlBumpzQmNnolfvb5EKQ0,9626
|
|
@@ -187,8 +187,8 @@ diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CD
|
|
|
187
187
|
diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
188
188
|
diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
|
|
189
189
|
diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
|
|
190
|
-
diffsynth_engine-0.6.1.
|
|
191
|
-
diffsynth_engine-0.6.1.
|
|
192
|
-
diffsynth_engine-0.6.1.
|
|
193
|
-
diffsynth_engine-0.6.1.
|
|
194
|
-
diffsynth_engine-0.6.1.
|
|
190
|
+
diffsynth_engine-0.6.1.dev26.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
|
|
191
|
+
diffsynth_engine-0.6.1.dev26.dist-info/METADATA,sha256=z6sjXpooZoFJJGqqdE_DFtsi2f3aqhjLBbyXPX0RdgE,1164
|
|
192
|
+
diffsynth_engine-0.6.1.dev26.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
193
|
+
diffsynth_engine-0.6.1.dev26.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
|
|
194
|
+
diffsynth_engine-0.6.1.dev26.dist-info/RECORD,,
|
|
File without changes
|
{diffsynth_engine-0.6.1.dev25.dist-info → diffsynth_engine-0.6.1.dev26.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{diffsynth_engine-0.6.1.dev25.dist-info → diffsynth_engine-0.6.1.dev26.dist-info}/top_level.txt
RENAMED
|
File without changes
|