mps-flash-attn 0.3.1__tar.gz → 0.3.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/PKG-INFO +1 -1
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/__init__.py +84 -27
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/pyproject.toml +1 -1
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/setup.py +1 -1
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/LICENSE +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/README.md +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/benchmark.py +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/setup.cfg +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/tests/test_issues.py +0 -0
- {mps_flash_attn-0.3.1 → mps_flash_attn-0.3.3}/tests/test_mfa_v2.py +0 -0
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.3.
|
|
7
|
+
__version__ = "0.3.3"
|
|
8
8
|
|
|
9
9
|
__all__ = [
|
|
10
10
|
# Core functions
|
|
@@ -97,6 +97,56 @@ def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
|
97
97
|
return attn_mask <= -1e3
|
|
98
98
|
|
|
99
99
|
|
|
100
|
+
def _validate_and_expand_mask(
|
|
101
|
+
attn_mask: Optional[torch.Tensor],
|
|
102
|
+
B: int,
|
|
103
|
+
H: int,
|
|
104
|
+
N_q: int,
|
|
105
|
+
N_kv: int,
|
|
106
|
+
) -> Optional[torch.Tensor]:
|
|
107
|
+
"""
|
|
108
|
+
Validate attention mask shape and expand broadcast dimensions.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
attn_mask: Optional mask of shape (B, H, N_q, N_kv) or broadcastable
|
|
112
|
+
B: Batch size
|
|
113
|
+
H: Number of heads
|
|
114
|
+
N_q: Query sequence length
|
|
115
|
+
N_kv: Key/Value sequence length
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Expanded mask of shape (mb, mh, N_q, N_kv) or None
|
|
119
|
+
"""
|
|
120
|
+
if attn_mask is None:
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
attn_mask = _ensure_contiguous(attn_mask, "attn_mask")
|
|
124
|
+
|
|
125
|
+
if attn_mask.dim() != 4:
|
|
126
|
+
raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
|
|
127
|
+
|
|
128
|
+
mb, mh, mq, mk = attn_mask.shape
|
|
129
|
+
|
|
130
|
+
# Allow broadcast: mq can be 1 (applies same mask to all query positions) or N_q
|
|
131
|
+
if (mq != 1 and mq != N_q) or (mk != 1 and mk != N_kv):
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv}) or broadcastable (1, {N_kv})"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Expand broadcast mask to full shape for Metal kernel
|
|
137
|
+
if mq == 1 and N_q > 1:
|
|
138
|
+
attn_mask = attn_mask.expand(mb, mh, N_q, mk)
|
|
139
|
+
if mk == 1 and N_kv > 1:
|
|
140
|
+
attn_mask = attn_mask.expand(mb, mh, mq if mq > 1 else N_q, N_kv)
|
|
141
|
+
|
|
142
|
+
if mb != 1 and mb != B:
|
|
143
|
+
raise ValueError(f"attn_mask batch size must be 1 or {B}, got {mb}")
|
|
144
|
+
if mh != 1 and mh != H:
|
|
145
|
+
raise ValueError(f"attn_mask head count must be 1 or {H}, got {mh}")
|
|
146
|
+
|
|
147
|
+
return attn_mask
|
|
148
|
+
|
|
149
|
+
|
|
100
150
|
class FlashAttentionFunction(torch.autograd.Function):
|
|
101
151
|
"""Autograd function for Flash Attention with backward pass support."""
|
|
102
152
|
|
|
@@ -247,32 +297,19 @@ def flash_attention(
|
|
|
247
297
|
query = _ensure_contiguous(query, "query")
|
|
248
298
|
key = _ensure_contiguous(key, "key")
|
|
249
299
|
value = _ensure_contiguous(value, "value")
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
B, H,
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
# Expand broadcast mask to full shape for Metal kernel
|
|
264
|
-
if mq == 1 and N_q > 1:
|
|
265
|
-
attn_mask = attn_mask.expand(mb, mh, N_q, mk)
|
|
266
|
-
if mk == 1 and N_kv > 1:
|
|
267
|
-
attn_mask = attn_mask.expand(mb, mh, mq if mq > 1 else N_q, N_kv)
|
|
268
|
-
if mb != 1 and mb != B:
|
|
269
|
-
raise ValueError(
|
|
270
|
-
f"attn_mask batch size must be 1 or {B}, got {mb}"
|
|
271
|
-
)
|
|
272
|
-
if mh != 1 and mh != H:
|
|
273
|
-
raise ValueError(
|
|
274
|
-
f"attn_mask head count must be 1 or {H}, got {mh}"
|
|
275
|
-
)
|
|
300
|
+
|
|
301
|
+
# Validate tensor dimensions
|
|
302
|
+
if query.dim() != 4:
|
|
303
|
+
raise RuntimeError(f"query must be 4D (B, H, N, D), got {query.dim()}D")
|
|
304
|
+
if key.dim() != 4:
|
|
305
|
+
raise RuntimeError(f"key must be 4D (B, H, N, D), got {key.dim()}D")
|
|
306
|
+
if value.dim() != 4:
|
|
307
|
+
raise RuntimeError(f"value must be 4D (B, H, N, D), got {value.dim()}D")
|
|
308
|
+
|
|
309
|
+
# Validate and expand broadcast mask
|
|
310
|
+
B, H, N_q, D = query.shape
|
|
311
|
+
N_kv = key.shape[2]
|
|
312
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
276
313
|
|
|
277
314
|
# Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
|
|
278
315
|
if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
|
|
@@ -931,6 +968,11 @@ def flash_attention_fp8(
|
|
|
931
968
|
scale_factor = scale / default_scale
|
|
932
969
|
query = query * scale_factor
|
|
933
970
|
|
|
971
|
+
# Validate and expand broadcast mask
|
|
972
|
+
B, H, N_q, D = query.shape
|
|
973
|
+
N_kv = key.shape[2]
|
|
974
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
975
|
+
|
|
934
976
|
quant_type = QUANT_FP8_E5M2 if use_e5m2 else QUANT_FP8_E4M3
|
|
935
977
|
return _C.forward_quantized(
|
|
936
978
|
query, key, value, k_scale, v_scale,
|
|
@@ -984,6 +1026,11 @@ def flash_attention_int8(
|
|
|
984
1026
|
scale_factor = scale / default_scale
|
|
985
1027
|
query = query * scale_factor
|
|
986
1028
|
|
|
1029
|
+
# Validate and expand broadcast mask
|
|
1030
|
+
B, H, N_q, D = query.shape
|
|
1031
|
+
N_kv = key.shape[2]
|
|
1032
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
1033
|
+
|
|
987
1034
|
return _C.forward_quantized(
|
|
988
1035
|
query, key, value, k_scale, v_scale,
|
|
989
1036
|
QUANT_INT8, is_causal, attn_mask, window_size
|
|
@@ -1040,6 +1087,11 @@ def flash_attention_nf4(
|
|
|
1040
1087
|
scale_factor = scale / default_scale
|
|
1041
1088
|
query = query * scale_factor
|
|
1042
1089
|
|
|
1090
|
+
# Validate and expand broadcast mask
|
|
1091
|
+
B, H, N_q, D = query.shape
|
|
1092
|
+
N_kv = key.shape[2]
|
|
1093
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
1094
|
+
|
|
1043
1095
|
return _C.forward_quantized(
|
|
1044
1096
|
query, key, value, k_scale, v_scale,
|
|
1045
1097
|
QUANT_NF4, is_causal, attn_mask, window_size
|
|
@@ -1094,6 +1146,11 @@ def flash_attention_quantized(
|
|
|
1094
1146
|
scale_factor = scale / default_scale
|
|
1095
1147
|
query = query * scale_factor
|
|
1096
1148
|
|
|
1149
|
+
# Validate and expand broadcast mask
|
|
1150
|
+
B, H, N_q, D = query.shape
|
|
1151
|
+
N_kv = key.shape[2]
|
|
1152
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
1153
|
+
|
|
1097
1154
|
return _C.forward_quantized(
|
|
1098
1155
|
query, key, value, k_scale, v_scale,
|
|
1099
1156
|
quant_type, is_causal, attn_mask, window_size
|
|
Binary file
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|