diffsynth-engine 0.6.1.dev39__py3-none-any.whl → 0.6.1.dev41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,6 +26,7 @@ class AttnImpl(Enum):
26
26
  FA2 = "fa2" # Flash Attention 2
27
27
  FA3 = "fa3" # Flash Attention 3
28
28
  FA3_FP8 = "fa3_fp8" # Flash Attention 3 with FP8
29
+ FA4 = "fa4" # Flash Attention 4
29
30
  AITER = "aiter" # Aiter Flash Attention
30
31
  AITER_FP8 = "aiter_fp8" # Aiter Flash Attention with FP8
31
32
  XFORMERS = "xformers" # XFormers
@@ -6,6 +6,7 @@ from typing import Optional
6
6
 
7
7
  from diffsynth_engine.utils import logging
8
8
  from diffsynth_engine.utils.flag import (
9
+ FLASH_ATTN_4_AVAILABLE,
9
10
  FLASH_ATTN_3_AVAILABLE,
10
11
  FLASH_ATTN_2_AVAILABLE,
11
12
  XFORMERS_AVAILABLE,
@@ -21,7 +22,8 @@ FA3_MAX_HEADDIM = 256
21
22
 
22
23
  logger = logging.get_logger(__name__)
23
24
 
24
-
25
+ if FLASH_ATTN_4_AVAILABLE:
26
+ from flash_attn.cute.interface import flash_attn_func as flash_attn4
25
27
  if FLASH_ATTN_3_AVAILABLE:
26
28
  from flash_attn_interface import flash_attn_func as flash_attn3
27
29
  if FLASH_ATTN_2_AVAILABLE:
@@ -142,6 +144,7 @@ def attention(
142
144
  "fa2",
143
145
  "fa3",
144
146
  "fa3_fp8",
147
+ "fa4",
145
148
  "aiter",
146
149
  "aiter_fp8",
147
150
  "xformers",
@@ -152,6 +155,22 @@ def attention(
152
155
  ]
153
156
  flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
154
157
  if attn_impl is None or attn_impl == "auto":
158
+ if FLASH_ATTN_4_AVAILABLE:
159
+ # FA4 also has the same max-head-256 limitation as FA3
160
+ if flash_attn3_compatible and attn_mask is None:
161
+ attn_out = flash_attn4(q, k, v, softmax_scale=scale)
162
+ if isinstance(attn_out, tuple):
163
+ attn_out = attn_out[0]
164
+ return attn_out
165
+ else:
166
+ if not flash_attn3_compatible:
167
+ logger.warning(
168
+ f"head_dim={q.shape[-1]}, but flash_attn_4 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
169
+ )
170
+ else:
171
+ logger.debug(
172
+ "flash_attn_4 does not support attention mask, will use fallback attention implementation"
173
+ )
155
174
  if FLASH_ATTN_3_AVAILABLE:
156
175
  if flash_attn3_compatible and attn_mask is None:
157
176
  return flash_attn3(q, k, v, softmax_scale=scale)
@@ -213,6 +232,17 @@ def attention(
213
232
  v = v.to(dtype=DTYPE_FP8)
214
233
  out = aiter_flash_attn_fp8(q, k, v, softmax_scale=scale)
215
234
  return out.to(dtype=origin_dtype)
235
+ if attn_impl == "fa4":
236
+ if not flash_attn3_compatible:
237
+ raise RuntimeError(
238
+ f"head_dim={q.shape[-1]}, but flash_attn_4 only supports head dimension at most {FA3_MAX_HEADDIM}"
239
+ )
240
+ if attn_mask is not None:
241
+ raise RuntimeError("flash_attn_4 does not support attention mask")
242
+ attn_out = flash_attn4(q, k, v, softmax_scale=scale)
243
+ if isinstance(attn_out, tuple):
244
+ attn_out = attn_out[0]
245
+ return attn_out
216
246
  if attn_impl == "fa2":
217
247
  return flash_attn2(q, k, v, softmax_scale=scale)
218
248
  if attn_impl == "xformers":
@@ -6,24 +6,27 @@ from diffsynth_engine.utils import logging
6
6
  logger = logging.get_logger(__name__)
7
7
 
8
8
 
9
- # 无损
10
- FLASH_ATTN_3_AVAILABLE = importlib.util.find_spec("flash_attn_interface") is not None
11
- if FLASH_ATTN_3_AVAILABLE:
12
- logger.info("Flash attention 3 is available")
13
- else:
14
- logger.info("Flash attention 3 is not available")
9
+ def check_module_available(module_path: str, module_name: str = None) -> bool:
10
+ try:
11
+ available = importlib.util.find_spec(module_path) is not None
12
+ except (ModuleNotFoundError, AttributeError, ValueError):
13
+ available = False
15
14
 
16
- FLASH_ATTN_2_AVAILABLE = importlib.util.find_spec("flash_attn") is not None
17
- if FLASH_ATTN_2_AVAILABLE:
18
- logger.info("Flash attention 2 is available")
19
- else:
20
- logger.info("Flash attention 2 is not available")
15
+ if module_name:
16
+ if available:
17
+ logger.info(f"{module_name} is available")
18
+ else:
19
+ logger.info(f"{module_name} is not available")
21
20
 
22
- XFORMERS_AVAILABLE = importlib.util.find_spec("xformers") is not None
23
- if XFORMERS_AVAILABLE:
24
- logger.info("XFormers is available")
25
- else:
26
- logger.info("XFormers is not available")
21
+ return available
22
+
23
+
24
+ # 无损
25
+ FLASH_ATTN_4_AVAILABLE = check_module_available("flash_attn.cute.interface", "Flash attention 4")
26
+ FLASH_ATTN_3_AVAILABLE = check_module_available("flash_attn_interface", "Flash attention 3")
27
+ FLASH_ATTN_2_AVAILABLE = check_module_available("flash_attn", "Flash attention 2")
28
+ XFORMERS_AVAILABLE = check_module_available("xformers", "XFormers")
29
+ AITER_AVAILABLE = check_module_available("aiter", "Aiter")
27
30
 
28
31
  SDPA_AVAILABLE = hasattr(torch.nn.functional, "scaled_dot_product_attention")
29
32
  if SDPA_AVAILABLE:
@@ -31,37 +34,15 @@ if SDPA_AVAILABLE:
31
34
  else:
32
35
  logger.info("Torch SDPA is not available")
33
36
 
34
- AITER_AVAILABLE = importlib.util.find_spec("aiter") is not None
35
- if AITER_AVAILABLE:
36
- logger.info("Aiter is available")
37
- else:
38
- logger.info("Aiter is not available")
39
37
 
40
38
  # 有损
41
- SAGE_ATTN_AVAILABLE = importlib.util.find_spec("sageattention") is not None
42
- if SAGE_ATTN_AVAILABLE:
43
- logger.info("Sage attention is available")
44
- else:
45
- logger.info("Sage attention is not available")
46
-
47
- SPARGE_ATTN_AVAILABLE = importlib.util.find_spec("spas_sage_attn") is not None
48
- if SPARGE_ATTN_AVAILABLE:
49
- logger.info("Sparge attention is available")
50
- else:
51
- logger.info("Sparge attention is not available")
39
+ SAGE_ATTN_AVAILABLE = check_module_available("sageattention", "Sage attention")
40
+ SPARGE_ATTN_AVAILABLE = check_module_available("spas_sage_attn", "Sparge attention")
41
+ VIDEO_SPARSE_ATTN_AVAILABLE = check_module_available("vsa", "Video sparse attention")
52
42
 
53
- VIDEO_SPARSE_ATTN_AVAILABLE = importlib.util.find_spec("vsa") is not None
54
- if VIDEO_SPARSE_ATTN_AVAILABLE:
55
- logger.info("Video sparse attention is available")
56
- else:
57
- logger.info("Video sparse attention is not available")
58
-
59
- NUNCHAKU_AVAILABLE = importlib.util.find_spec("nunchaku") is not None
43
+ NUNCHAKU_AVAILABLE = check_module_available("nunchaku", "Nunchaku")
60
44
  NUNCHAKU_IMPORT_ERROR = None
61
- if NUNCHAKU_AVAILABLE:
62
- logger.info("Nunchaku is available")
63
- else:
64
- logger.info("Nunchaku is not available")
45
+ if not NUNCHAKU_AVAILABLE:
65
46
  import sys
66
47
  torch_version = getattr(torch, "__version__", "unknown")
67
48
  python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev39
3
+ Version: 0.6.1.dev41
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -86,12 +86,12 @@ diffsynth_engine/conf/tokenizers/z_image/tokenizer/tokenizer_config.json,sha256=
86
86
  diffsynth_engine/conf/tokenizers/z_image/tokenizer/vocab.json,sha256=yhDX6fs-0YV13R4neiV5wW0QjjLydDloSvoOELFECRA,2776833
87
87
  diffsynth_engine/configs/__init__.py,sha256=biluGSEw78PPwO7XFlms16iuWXDiM0Eg_qsOMMTY0NQ,1409
88
88
  diffsynth_engine/configs/controlnet.py,sha256=f3vclyP3lcAjxDGD9C1vevhqqQ7W2LL_c6Wye0uxk3Q,1180
89
- diffsynth_engine/configs/pipeline.py,sha256=RqhPAZOCpIMkFk-OsfiNYlqpqM-7B52ny0Zcr9Ix7wY,15310
89
+ diffsynth_engine/configs/pipeline.py,sha256=tcnhLGdQgvEibWBZVFH3uOS1pwB6WEnHgCFSer2bT0E,15347
90
90
  diffsynth_engine/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
91
91
  diffsynth_engine/models/__init__.py,sha256=8Ze7cSE8InetgXWTNb0neVA2Q44K7WlE-h7O-02m2sY,119
92
92
  diffsynth_engine/models/base.py,sha256=svao__9WH8VNcyXz5o5dzywYXDcGV0YV9IfkLzDKews,2558
93
93
  diffsynth_engine/models/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
94
- diffsynth_engine/models/basic/attention.py,sha256=YrIxkYoekC3I7-sMTw60CL4GIKMLOTrn-eCk-iHT7E4,15701
94
+ diffsynth_engine/models/basic/attention.py,sha256=sYoTOlloRA5KFNV18cZSkXuPY-Ck-13-kCM41AaBMu0,17225
95
95
  diffsynth_engine/models/basic/lora.py,sha256=Y6cBgrBsuDAP9FZz_fgK8vBi_EMg23saFIUSAsPIG-M,10670
96
96
  diffsynth_engine/models/basic/lora_nunchaku.py,sha256=7qhzGCzUIfDrwtWG0nspwdyZ7YUkaM4vMqzxZby2Zds,7510
97
97
  diffsynth_engine/models/basic/relative_position_emb.py,sha256=rCXOweZMcayVnNUVvBcYXMdhHS257B_PC8PZSWxvhNQ,2540
@@ -183,7 +183,7 @@ diffsynth_engine/utils/cache.py,sha256=Ivef22pCuhEq-4H00gSvkLS8ceVZoGis7OSitYL6g
183
183
  diffsynth_engine/utils/constants.py,sha256=x0-bsPRplW-KkRpLVajuC9Yv6f3QbdHgSr3XZ-eBCsQ,3745
184
184
  diffsynth_engine/utils/download.py,sha256=w9QQjllPfTUEY371UTREU7o_vvdMY-Q2DymDel3ZEZY,6792
185
185
  diffsynth_engine/utils/env.py,sha256=k749eYt_qKGq38GocDiXfkhp8nZrowFefNVTZ8R755I,363
186
- diffsynth_engine/utils/flag.py,sha256=KSzjnzRe7sleNCJm8IpbJQbmBY4KNV2kDrijxi27Jek,2928
186
+ diffsynth_engine/utils/flag.py,sha256=Ubm7FF0vHG197bmJGEplp4XauBlUaQVv-zr-w6VyEIM,2493
187
187
  diffsynth_engine/utils/fp8_linear.py,sha256=k34YFWo2dc3t8aKjHaCW9CbQMOTqXxaDHk8aw8aKif4,3857
188
188
  diffsynth_engine/utils/gguf.py,sha256=ZWvw46V4g4uVyAR_oCq-4K5nPdKVrYk3u47uXMgA9lU,14092
189
189
  diffsynth_engine/utils/image.py,sha256=PiDButjv0fsRS23kpQgCLZAlBumpzQmNnolfvb5EKQ0,9626
@@ -200,8 +200,8 @@ diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CD
200
200
  diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
201
201
  diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
202
202
  diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
203
- diffsynth_engine-0.6.1.dev39.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
204
- diffsynth_engine-0.6.1.dev39.dist-info/METADATA,sha256=f_qU_vp4RcHSOgW3Agm428engf8v7TKRCt8DuxAOEi8,1164
205
- diffsynth_engine-0.6.1.dev39.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
206
- diffsynth_engine-0.6.1.dev39.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
207
- diffsynth_engine-0.6.1.dev39.dist-info/RECORD,,
203
+ diffsynth_engine-0.6.1.dev41.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
204
+ diffsynth_engine-0.6.1.dev41.dist-info/METADATA,sha256=dygz72s8iSZlS1JE54FlMvYiSmMy8uB2B0Gd_WLCSws,1164
205
+ diffsynth_engine-0.6.1.dev41.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
206
+ diffsynth_engine-0.6.1.dev41.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
207
+ diffsynth_engine-0.6.1.dev41.dist-info/RECORD,,