diffsynth-engine 0.6.1.dev39__py3-none-any.whl → 0.6.1.dev41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/configs/pipeline.py +1 -0
- diffsynth_engine/models/basic/attention.py +31 -1
- diffsynth_engine/utils/flag.py +24 -43
- {diffsynth_engine-0.6.1.dev39.dist-info → diffsynth_engine-0.6.1.dev41.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev39.dist-info → diffsynth_engine-0.6.1.dev41.dist-info}/RECORD +8 -8
- {diffsynth_engine-0.6.1.dev39.dist-info → diffsynth_engine-0.6.1.dev41.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev39.dist-info → diffsynth_engine-0.6.1.dev41.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev39.dist-info → diffsynth_engine-0.6.1.dev41.dist-info}/top_level.txt +0 -0
|
@@ -26,6 +26,7 @@ class AttnImpl(Enum):
|
|
|
26
26
|
FA2 = "fa2" # Flash Attention 2
|
|
27
27
|
FA3 = "fa3" # Flash Attention 3
|
|
28
28
|
FA3_FP8 = "fa3_fp8" # Flash Attention 3 with FP8
|
|
29
|
+
FA4 = "fa4" # Flash Attention 4
|
|
29
30
|
AITER = "aiter" # Aiter Flash Attention
|
|
30
31
|
AITER_FP8 = "aiter_fp8" # Aiter Flash Attention with FP8
|
|
31
32
|
XFORMERS = "xformers" # XFormers
|
|
@@ -6,6 +6,7 @@ from typing import Optional
|
|
|
6
6
|
|
|
7
7
|
from diffsynth_engine.utils import logging
|
|
8
8
|
from diffsynth_engine.utils.flag import (
|
|
9
|
+
FLASH_ATTN_4_AVAILABLE,
|
|
9
10
|
FLASH_ATTN_3_AVAILABLE,
|
|
10
11
|
FLASH_ATTN_2_AVAILABLE,
|
|
11
12
|
XFORMERS_AVAILABLE,
|
|
@@ -21,7 +22,8 @@ FA3_MAX_HEADDIM = 256
|
|
|
21
22
|
|
|
22
23
|
logger = logging.get_logger(__name__)
|
|
23
24
|
|
|
24
|
-
|
|
25
|
+
if FLASH_ATTN_4_AVAILABLE:
|
|
26
|
+
from flash_attn.cute.interface import flash_attn_func as flash_attn4
|
|
25
27
|
if FLASH_ATTN_3_AVAILABLE:
|
|
26
28
|
from flash_attn_interface import flash_attn_func as flash_attn3
|
|
27
29
|
if FLASH_ATTN_2_AVAILABLE:
|
|
@@ -142,6 +144,7 @@ def attention(
|
|
|
142
144
|
"fa2",
|
|
143
145
|
"fa3",
|
|
144
146
|
"fa3_fp8",
|
|
147
|
+
"fa4",
|
|
145
148
|
"aiter",
|
|
146
149
|
"aiter_fp8",
|
|
147
150
|
"xformers",
|
|
@@ -152,6 +155,22 @@ def attention(
|
|
|
152
155
|
]
|
|
153
156
|
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
|
|
154
157
|
if attn_impl is None or attn_impl == "auto":
|
|
158
|
+
if FLASH_ATTN_4_AVAILABLE:
|
|
159
|
+
# FA4 also has the same max-head-256 limitation as FA3
|
|
160
|
+
if flash_attn3_compatible and attn_mask is None:
|
|
161
|
+
attn_out = flash_attn4(q, k, v, softmax_scale=scale)
|
|
162
|
+
if isinstance(attn_out, tuple):
|
|
163
|
+
attn_out = attn_out[0]
|
|
164
|
+
return attn_out
|
|
165
|
+
else:
|
|
166
|
+
if not flash_attn3_compatible:
|
|
167
|
+
logger.warning(
|
|
168
|
+
f"head_dim={q.shape[-1]}, but flash_attn_4 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
|
|
169
|
+
)
|
|
170
|
+
else:
|
|
171
|
+
logger.debug(
|
|
172
|
+
"flash_attn_4 does not support attention mask, will use fallback attention implementation"
|
|
173
|
+
)
|
|
155
174
|
if FLASH_ATTN_3_AVAILABLE:
|
|
156
175
|
if flash_attn3_compatible and attn_mask is None:
|
|
157
176
|
return flash_attn3(q, k, v, softmax_scale=scale)
|
|
@@ -213,6 +232,17 @@ def attention(
|
|
|
213
232
|
v = v.to(dtype=DTYPE_FP8)
|
|
214
233
|
out = aiter_flash_attn_fp8(q, k, v, softmax_scale=scale)
|
|
215
234
|
return out.to(dtype=origin_dtype)
|
|
235
|
+
if attn_impl == "fa4":
|
|
236
|
+
if not flash_attn3_compatible:
|
|
237
|
+
raise RuntimeError(
|
|
238
|
+
f"head_dim={q.shape[-1]}, but flash_attn_4 only supports head dimension at most {FA3_MAX_HEADDIM}"
|
|
239
|
+
)
|
|
240
|
+
if attn_mask is not None:
|
|
241
|
+
raise RuntimeError("flash_attn_4 does not support attention mask")
|
|
242
|
+
attn_out = flash_attn4(q, k, v, softmax_scale=scale)
|
|
243
|
+
if isinstance(attn_out, tuple):
|
|
244
|
+
attn_out = attn_out[0]
|
|
245
|
+
return attn_out
|
|
216
246
|
if attn_impl == "fa2":
|
|
217
247
|
return flash_attn2(q, k, v, softmax_scale=scale)
|
|
218
248
|
if attn_impl == "xformers":
|
diffsynth_engine/utils/flag.py
CHANGED
|
@@ -6,24 +6,27 @@ from diffsynth_engine.utils import logging
|
|
|
6
6
|
logger = logging.get_logger(__name__)
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
logger.info("Flash attention 3 is not available")
|
|
9
|
+
def check_module_available(module_path: str, module_name: str = None) -> bool:
|
|
10
|
+
try:
|
|
11
|
+
available = importlib.util.find_spec(module_path) is not None
|
|
12
|
+
except (ModuleNotFoundError, AttributeError, ValueError):
|
|
13
|
+
available = False
|
|
15
14
|
|
|
16
|
-
|
|
17
|
-
if
|
|
18
|
-
|
|
19
|
-
else:
|
|
20
|
-
|
|
15
|
+
if module_name:
|
|
16
|
+
if available:
|
|
17
|
+
logger.info(f"{module_name} is available")
|
|
18
|
+
else:
|
|
19
|
+
logger.info(f"{module_name} is not available")
|
|
21
20
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
21
|
+
return available
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# 无损
|
|
25
|
+
FLASH_ATTN_4_AVAILABLE = check_module_available("flash_attn.cute.interface", "Flash attention 4")
|
|
26
|
+
FLASH_ATTN_3_AVAILABLE = check_module_available("flash_attn_interface", "Flash attention 3")
|
|
27
|
+
FLASH_ATTN_2_AVAILABLE = check_module_available("flash_attn", "Flash attention 2")
|
|
28
|
+
XFORMERS_AVAILABLE = check_module_available("xformers", "XFormers")
|
|
29
|
+
AITER_AVAILABLE = check_module_available("aiter", "Aiter")
|
|
27
30
|
|
|
28
31
|
SDPA_AVAILABLE = hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
|
29
32
|
if SDPA_AVAILABLE:
|
|
@@ -31,37 +34,15 @@ if SDPA_AVAILABLE:
|
|
|
31
34
|
else:
|
|
32
35
|
logger.info("Torch SDPA is not available")
|
|
33
36
|
|
|
34
|
-
AITER_AVAILABLE = importlib.util.find_spec("aiter") is not None
|
|
35
|
-
if AITER_AVAILABLE:
|
|
36
|
-
logger.info("Aiter is available")
|
|
37
|
-
else:
|
|
38
|
-
logger.info("Aiter is not available")
|
|
39
37
|
|
|
40
38
|
# 有损
|
|
41
|
-
SAGE_ATTN_AVAILABLE =
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
else:
|
|
45
|
-
logger.info("Sage attention is not available")
|
|
46
|
-
|
|
47
|
-
SPARGE_ATTN_AVAILABLE = importlib.util.find_spec("spas_sage_attn") is not None
|
|
48
|
-
if SPARGE_ATTN_AVAILABLE:
|
|
49
|
-
logger.info("Sparge attention is available")
|
|
50
|
-
else:
|
|
51
|
-
logger.info("Sparge attention is not available")
|
|
39
|
+
SAGE_ATTN_AVAILABLE = check_module_available("sageattention", "Sage attention")
|
|
40
|
+
SPARGE_ATTN_AVAILABLE = check_module_available("spas_sage_attn", "Sparge attention")
|
|
41
|
+
VIDEO_SPARSE_ATTN_AVAILABLE = check_module_available("vsa", "Video sparse attention")
|
|
52
42
|
|
|
53
|
-
|
|
54
|
-
if VIDEO_SPARSE_ATTN_AVAILABLE:
|
|
55
|
-
logger.info("Video sparse attention is available")
|
|
56
|
-
else:
|
|
57
|
-
logger.info("Video sparse attention is not available")
|
|
58
|
-
|
|
59
|
-
NUNCHAKU_AVAILABLE = importlib.util.find_spec("nunchaku") is not None
|
|
43
|
+
NUNCHAKU_AVAILABLE = check_module_available("nunchaku", "Nunchaku")
|
|
60
44
|
NUNCHAKU_IMPORT_ERROR = None
|
|
61
|
-
if NUNCHAKU_AVAILABLE:
|
|
62
|
-
logger.info("Nunchaku is available")
|
|
63
|
-
else:
|
|
64
|
-
logger.info("Nunchaku is not available")
|
|
45
|
+
if not NUNCHAKU_AVAILABLE:
|
|
65
46
|
import sys
|
|
66
47
|
torch_version = getattr(torch, "__version__", "unknown")
|
|
67
48
|
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
|
@@ -86,12 +86,12 @@ diffsynth_engine/conf/tokenizers/z_image/tokenizer/tokenizer_config.json,sha256=
|
|
|
86
86
|
diffsynth_engine/conf/tokenizers/z_image/tokenizer/vocab.json,sha256=yhDX6fs-0YV13R4neiV5wW0QjjLydDloSvoOELFECRA,2776833
|
|
87
87
|
diffsynth_engine/configs/__init__.py,sha256=biluGSEw78PPwO7XFlms16iuWXDiM0Eg_qsOMMTY0NQ,1409
|
|
88
88
|
diffsynth_engine/configs/controlnet.py,sha256=f3vclyP3lcAjxDGD9C1vevhqqQ7W2LL_c6Wye0uxk3Q,1180
|
|
89
|
-
diffsynth_engine/configs/pipeline.py,sha256=
|
|
89
|
+
diffsynth_engine/configs/pipeline.py,sha256=tcnhLGdQgvEibWBZVFH3uOS1pwB6WEnHgCFSer2bT0E,15347
|
|
90
90
|
diffsynth_engine/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
91
91
|
diffsynth_engine/models/__init__.py,sha256=8Ze7cSE8InetgXWTNb0neVA2Q44K7WlE-h7O-02m2sY,119
|
|
92
92
|
diffsynth_engine/models/base.py,sha256=svao__9WH8VNcyXz5o5dzywYXDcGV0YV9IfkLzDKews,2558
|
|
93
93
|
diffsynth_engine/models/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
94
|
-
diffsynth_engine/models/basic/attention.py,sha256=
|
|
94
|
+
diffsynth_engine/models/basic/attention.py,sha256=sYoTOlloRA5KFNV18cZSkXuPY-Ck-13-kCM41AaBMu0,17225
|
|
95
95
|
diffsynth_engine/models/basic/lora.py,sha256=Y6cBgrBsuDAP9FZz_fgK8vBi_EMg23saFIUSAsPIG-M,10670
|
|
96
96
|
diffsynth_engine/models/basic/lora_nunchaku.py,sha256=7qhzGCzUIfDrwtWG0nspwdyZ7YUkaM4vMqzxZby2Zds,7510
|
|
97
97
|
diffsynth_engine/models/basic/relative_position_emb.py,sha256=rCXOweZMcayVnNUVvBcYXMdhHS257B_PC8PZSWxvhNQ,2540
|
|
@@ -183,7 +183,7 @@ diffsynth_engine/utils/cache.py,sha256=Ivef22pCuhEq-4H00gSvkLS8ceVZoGis7OSitYL6g
|
|
|
183
183
|
diffsynth_engine/utils/constants.py,sha256=x0-bsPRplW-KkRpLVajuC9Yv6f3QbdHgSr3XZ-eBCsQ,3745
|
|
184
184
|
diffsynth_engine/utils/download.py,sha256=w9QQjllPfTUEY371UTREU7o_vvdMY-Q2DymDel3ZEZY,6792
|
|
185
185
|
diffsynth_engine/utils/env.py,sha256=k749eYt_qKGq38GocDiXfkhp8nZrowFefNVTZ8R755I,363
|
|
186
|
-
diffsynth_engine/utils/flag.py,sha256=
|
|
186
|
+
diffsynth_engine/utils/flag.py,sha256=Ubm7FF0vHG197bmJGEplp4XauBlUaQVv-zr-w6VyEIM,2493
|
|
187
187
|
diffsynth_engine/utils/fp8_linear.py,sha256=k34YFWo2dc3t8aKjHaCW9CbQMOTqXxaDHk8aw8aKif4,3857
|
|
188
188
|
diffsynth_engine/utils/gguf.py,sha256=ZWvw46V4g4uVyAR_oCq-4K5nPdKVrYk3u47uXMgA9lU,14092
|
|
189
189
|
diffsynth_engine/utils/image.py,sha256=PiDButjv0fsRS23kpQgCLZAlBumpzQmNnolfvb5EKQ0,9626
|
|
@@ -200,8 +200,8 @@ diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CD
|
|
|
200
200
|
diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
201
201
|
diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
|
|
202
202
|
diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
|
|
203
|
-
diffsynth_engine-0.6.1.
|
|
204
|
-
diffsynth_engine-0.6.1.
|
|
205
|
-
diffsynth_engine-0.6.1.
|
|
206
|
-
diffsynth_engine-0.6.1.
|
|
207
|
-
diffsynth_engine-0.6.1.
|
|
203
|
+
diffsynth_engine-0.6.1.dev41.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
|
|
204
|
+
diffsynth_engine-0.6.1.dev41.dist-info/METADATA,sha256=dygz72s8iSZlS1JE54FlMvYiSmMy8uB2B0Gd_WLCSws,1164
|
|
205
|
+
diffsynth_engine-0.6.1.dev41.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
206
|
+
diffsynth_engine-0.6.1.dev41.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
|
|
207
|
+
diffsynth_engine-0.6.1.dev41.dist-info/RECORD,,
|
|
File without changes
|
{diffsynth_engine-0.6.1.dev39.dist-info → diffsynth_engine-0.6.1.dev41.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{diffsynth_engine-0.6.1.dev39.dist-info → diffsynth_engine-0.6.1.dev41.dist-info}/top_level.txt
RENAMED
|
File without changes
|