diffsynth-engine 0.6.1.dev25__py3-none-any.whl → 0.6.1.dev26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,6 +26,8 @@ class AttnImpl(Enum):
26
26
  FA2 = "fa2" # Flash Attention 2
27
27
  FA3 = "fa3" # Flash Attention 3
28
28
  FA3_FP8 = "fa3_fp8" # Flash Attention 3 with FP8
29
+ AITER = "aiter" # Aiter Flash Attention
30
+ AITER_FP8 = "aiter_fp8" # Aiter Flash Attention with FP8
29
31
  XFORMERS = "xformers" # XFormers
30
32
  SDPA = "sdpa" # Scaled Dot Product Attention
31
33
  SAGE = "sage" # Sage Attention
@@ -13,6 +13,7 @@ from diffsynth_engine.utils.flag import (
13
13
  SAGE_ATTN_AVAILABLE,
14
14
  SPARGE_ATTN_AVAILABLE,
15
15
  VIDEO_SPARSE_ATTN_AVAILABLE,
16
+ AITER_AVAILABLE,
16
17
  )
17
18
  from diffsynth_engine.utils.platform import DTYPE_FP8
18
19
 
@@ -93,6 +94,9 @@ if SPARGE_ATTN_AVAILABLE:
93
94
  )
94
95
  return out.transpose(1, 2)
95
96
 
97
+ if AITER_AVAILABLE:
98
+ from aiter import flash_attn_func as aiter_flash_attn
99
+ from aiter import flash_attn_fp8_pertensor_func as aiter_flash_attn_fp8
96
100
 
97
101
  if VIDEO_SPARSE_ATTN_AVAILABLE:
98
102
  from diffsynth_engine.models.basic.video_sparse_attention import (
@@ -137,6 +141,8 @@ def attention(
137
141
  "fa2",
138
142
  "fa3",
139
143
  "fa3_fp8",
144
+ "aiter",
145
+ "aiter_fp8",
140
146
  "xformers",
141
147
  "sdpa",
142
148
  "sage",
@@ -157,6 +163,13 @@ def attention(
157
163
  logger.debug(
158
164
  "flash_attn_3 does not support attention mask, will use fallback attention implementation"
159
165
  )
166
+ if AITER_AVAILABLE:
167
+ if flash_attn3_compatible:
168
+ return aiter_flash_attn(q, k, v, softmax_scale=scale)
169
+ else:
170
+ logger.warning(
171
+ f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
172
+ )
160
173
  if XFORMERS_AVAILABLE:
161
174
  return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
162
175
  if SDPA_AVAILABLE:
@@ -183,6 +196,22 @@ def attention(
183
196
  v = v.to(dtype=DTYPE_FP8)
184
197
  out = flash_attn3(q, k, v, softmax_scale=scale)
185
198
  return out.to(dtype=origin_dtype)
199
+ if attn_impl == "aiter" or attn_impl == "aiter_fp8":
200
+ if not flash_attn3_compatible:
201
+ raise RuntimeError(
202
+ f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}"
203
+ )
204
+ if attn_mask is not None:
205
+ raise RuntimeError("aiter_flash_attn does not support attention mask")
206
+ if attn_impl == "aiter" :
207
+ return aiter_flash_attn(q, k, v, softmax_scale=scale)
208
+ else:
209
+ origin_dtype = q.dtype
210
+ q = q.to(dtype=DTYPE_FP8)
211
+ k = k.to(dtype=DTYPE_FP8)
212
+ v = v.to(dtype=DTYPE_FP8)
213
+ out = aiter_flash_attn_fp8(q, k, v, softmax_scale=scale)
214
+ return out.to(dtype=origin_dtype)
186
215
  if attn_impl == "fa2":
187
216
  return flash_attn2(q, k, v, softmax_scale=scale)
188
217
  if attn_impl == "xformers":
@@ -288,6 +317,8 @@ def long_context_attention(
288
317
  "fa2",
289
318
  "fa3",
290
319
  "fa3_fp8",
320
+ "aiter",
321
+ "aiter_fp8",
291
322
  "sdpa",
292
323
  "sage",
293
324
  "sparge",
@@ -303,6 +334,13 @@ def long_context_attention(
303
334
  logger.warning(
304
335
  f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
305
336
  )
337
+ if AITER_AVAILABLE:
338
+ if flash_attn3_compatible:
339
+ return LongContextAttention(attn_type=AttnType.AITER)(q, k, v, softmax_scale=scale)
340
+ else:
341
+ logger.warning(
342
+ f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
343
+ )
306
344
  if SDPA_AVAILABLE:
307
345
  return LongContextAttention(attn_type=AttnType.TORCH)(q, k, v, softmax_scale=scale)
308
346
  if FLASH_ATTN_2_AVAILABLE:
@@ -323,6 +361,20 @@ def long_context_attention(
323
361
  v = v.to(dtype=DTYPE_FP8)
324
362
  out = LongContextAttention(attn_type=AttnType.FA3)(q, k, v, softmax_scale=scale)
325
363
  return out.to(dtype=origin_dtype)
364
+ if attn_impl == "aiter" or attn_impl == "aiter_fp8":
365
+ if not flash_attn3_compatible:
366
+ raise RuntimeError(
367
+ f"head_dim={q.shape[-1]}, but aiter_flash_attn only supports head dimension at most {FA3_MAX_HEADDIM}"
368
+ )
369
+ if attn_impl == "aiter":
370
+ return LongContextAttention(attn_type=AttnType.AITER)(q, k, v, softmax_scale=scale)
371
+
372
+ origin_dtype = q.dtype
373
+ q = q.to(dtype=DTYPE_FP8)
374
+ k = k.to(dtype=DTYPE_FP8)
375
+ v = v.to(dtype=DTYPE_FP8)
376
+ out = LongContextAttention(attn_type=AttnType.AITER)(q, k, v, softmax_scale=scale)
377
+ return out.to(dtype=origin_dtype)
326
378
  if attn_impl == "fa2":
327
379
  return LongContextAttention(attn_type=AttnType.FA)(q, k, v, softmax_scale=scale)
328
380
  if attn_impl == "sdpa":
@@ -31,6 +31,11 @@ if SDPA_AVAILABLE:
31
31
  else:
32
32
  logger.info("Torch SDPA is not available")
33
33
 
34
+ AITER_AVAILABLE = importlib.util.find_spec("aiter") is not None
35
+ if AITER_AVAILABLE:
36
+ logger.info("Aiter is available")
37
+ else:
38
+ logger.info("Aiter is not available")
34
39
 
35
40
  # 有损
36
41
  SAGE_ATTN_AVAILABLE = importlib.util.find_spec("sageattention") is not None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev25
3
+ Version: 0.6.1.dev26
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -81,12 +81,12 @@ diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json,sha256=bhl7TT29cdoU
81
81
  diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json,sha256=7Zo6iw-qcacKMoR-BDX-A25uES1N9O23u0ipIeNE3AU,61728
82
82
  diffsynth_engine/configs/__init__.py,sha256=vSjJToEdq3JX7t81_z4nwNwIdD4bYnFjxnMZH7PXMKo,1309
83
83
  diffsynth_engine/configs/controlnet.py,sha256=f3vclyP3lcAjxDGD9C1vevhqqQ7W2LL_c6Wye0uxk3Q,1180
84
- diffsynth_engine/configs/pipeline.py,sha256=2tCcW3qndx5GdzYNvpbAsR6ZGnzY8q7EzJjWDIATBr0,13297
84
+ diffsynth_engine/configs/pipeline.py,sha256=ADgWJa7bA3Z3Z1JtVLgmt4N3eS1KRp9yHu1QvTBzTm0,13404
85
85
  diffsynth_engine/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
86
86
  diffsynth_engine/models/__init__.py,sha256=8Ze7cSE8InetgXWTNb0neVA2Q44K7WlE-h7O-02m2sY,119
87
87
  diffsynth_engine/models/base.py,sha256=BA5vgMqfy_cjuL2OtXbrFD-Qg5xQnaumHpj5TabwSy8,2559
88
88
  diffsynth_engine/models/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
89
- diffsynth_engine/models/basic/attention.py,sha256=iFxpvXdaEJZHddTTRuKL1grKb6beU53y-VuRPX8FpFw,13127
89
+ diffsynth_engine/models/basic/attention.py,sha256=mvgk8LTqFwgtPdBeRv797IZNg9k7--X9wD92Hcr188c,15682
90
90
  diffsynth_engine/models/basic/lora.py,sha256=PT-A3pwIuUrW2w3TnNlBPb1KRj70QYiBaoCvLnkR5cs,10652
91
91
  diffsynth_engine/models/basic/relative_position_emb.py,sha256=rCXOweZMcayVnNUVvBcYXMdhHS257B_PC8PZSWxvhNQ,2540
92
92
  diffsynth_engine/models/basic/timestep.py,sha256=WJODYqkSXEM0wcS42YkkfrGwxWt0e60zMTkDdUBQqBw,2810
@@ -171,7 +171,7 @@ diffsynth_engine/utils/cache.py,sha256=Ivef22pCuhEq-4H00gSvkLS8ceVZoGis7OSitYL6g
171
171
  diffsynth_engine/utils/constants.py,sha256=sJio3Vy8i0-PWYRnqquYt6ez9k6Tc9JdjCv6pn2BU_4,3551
172
172
  diffsynth_engine/utils/download.py,sha256=w9QQjllPfTUEY371UTREU7o_vvdMY-Q2DymDel3ZEZY,6792
173
173
  diffsynth_engine/utils/env.py,sha256=k749eYt_qKGq38GocDiXfkhp8nZrowFefNVTZ8R755I,363
174
- diffsynth_engine/utils/flag.py,sha256=wODDbMMLTGOl7yoLMZDKGyqXSYANPaDQdZGXOJryGeI,1597
174
+ diffsynth_engine/utils/flag.py,sha256=v9GcRFYiNMonD9qmDLWdbXONuF-AcQ_KABPFtRZd0Tc,1767
175
175
  diffsynth_engine/utils/fp8_linear.py,sha256=k34YFWo2dc3t8aKjHaCW9CbQMOTqXxaDHk8aw8aKif4,3857
176
176
  diffsynth_engine/utils/gguf.py,sha256=ZWvw46V4g4uVyAR_oCq-4K5nPdKVrYk3u47uXMgA9lU,14092
177
177
  diffsynth_engine/utils/image.py,sha256=PiDButjv0fsRS23kpQgCLZAlBumpzQmNnolfvb5EKQ0,9626
@@ -187,8 +187,8 @@ diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CD
187
187
  diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
188
188
  diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
189
189
  diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
190
- diffsynth_engine-0.6.1.dev25.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
191
- diffsynth_engine-0.6.1.dev25.dist-info/METADATA,sha256=hbm3Xm8GajphVodptdo1vPnvB098xLQk8B1ORFoUQ8k,1164
192
- diffsynth_engine-0.6.1.dev25.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
193
- diffsynth_engine-0.6.1.dev25.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
194
- diffsynth_engine-0.6.1.dev25.dist-info/RECORD,,
190
+ diffsynth_engine-0.6.1.dev26.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
191
+ diffsynth_engine-0.6.1.dev26.dist-info/METADATA,sha256=z6sjXpooZoFJJGqqdE_DFtsi2f3aqhjLBbyXPX0RdgE,1164
192
+ diffsynth_engine-0.6.1.dev26.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
193
+ diffsynth_engine-0.6.1.dev26.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
194
+ diffsynth_engine-0.6.1.dev26.dist-info/RECORD,,