mps-flash-attn 0.3.0__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/PKG-INFO +1 -1
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/__init__.py +84 -21
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/pyproject.toml +1 -1
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/setup.py +1 -1
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/LICENSE +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/README.md +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/benchmark.py +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/setup.cfg +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/tests/test_issues.py +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.2}/tests/test_mfa_v2.py +0 -0
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.3.
|
|
7
|
+
__version__ = "0.3.2"
|
|
8
8
|
|
|
9
9
|
__all__ = [
|
|
10
10
|
# Core functions
|
|
@@ -97,6 +97,56 @@ def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
|
97
97
|
return attn_mask <= -1e3
|
|
98
98
|
|
|
99
99
|
|
|
100
|
+
def _validate_and_expand_mask(
|
|
101
|
+
attn_mask: Optional[torch.Tensor],
|
|
102
|
+
B: int,
|
|
103
|
+
H: int,
|
|
104
|
+
N_q: int,
|
|
105
|
+
N_kv: int,
|
|
106
|
+
) -> Optional[torch.Tensor]:
|
|
107
|
+
"""
|
|
108
|
+
Validate attention mask shape and expand broadcast dimensions.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
attn_mask: Optional mask of shape (B, H, N_q, N_kv) or broadcastable
|
|
112
|
+
B: Batch size
|
|
113
|
+
H: Number of heads
|
|
114
|
+
N_q: Query sequence length
|
|
115
|
+
N_kv: Key/Value sequence length
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Expanded mask of shape (mb, mh, N_q, N_kv) or None
|
|
119
|
+
"""
|
|
120
|
+
if attn_mask is None:
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
attn_mask = _ensure_contiguous(attn_mask, "attn_mask")
|
|
124
|
+
|
|
125
|
+
if attn_mask.dim() != 4:
|
|
126
|
+
raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
|
|
127
|
+
|
|
128
|
+
mb, mh, mq, mk = attn_mask.shape
|
|
129
|
+
|
|
130
|
+
# Allow broadcast: mq can be 1 (applies same mask to all query positions) or N_q
|
|
131
|
+
if (mq != 1 and mq != N_q) or (mk != 1 and mk != N_kv):
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv}) or broadcastable (1, {N_kv})"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Expand broadcast mask to full shape for Metal kernel
|
|
137
|
+
if mq == 1 and N_q > 1:
|
|
138
|
+
attn_mask = attn_mask.expand(mb, mh, N_q, mk)
|
|
139
|
+
if mk == 1 and N_kv > 1:
|
|
140
|
+
attn_mask = attn_mask.expand(mb, mh, mq if mq > 1 else N_q, N_kv)
|
|
141
|
+
|
|
142
|
+
if mb != 1 and mb != B:
|
|
143
|
+
raise ValueError(f"attn_mask batch size must be 1 or {B}, got {mb}")
|
|
144
|
+
if mh != 1 and mh != H:
|
|
145
|
+
raise ValueError(f"attn_mask head count must be 1 or {H}, got {mh}")
|
|
146
|
+
|
|
147
|
+
return attn_mask
|
|
148
|
+
|
|
149
|
+
|
|
100
150
|
class FlashAttentionFunction(torch.autograd.Function):
|
|
101
151
|
"""Autograd function for Flash Attention with backward pass support."""
|
|
102
152
|
|
|
@@ -247,26 +297,19 @@ def flash_attention(
|
|
|
247
297
|
query = _ensure_contiguous(query, "query")
|
|
248
298
|
key = _ensure_contiguous(key, "key")
|
|
249
299
|
value = _ensure_contiguous(value, "value")
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
B, H,
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
raise ValueError(
|
|
264
|
-
f"attn_mask batch size must be 1 or {B}, got {mb}"
|
|
265
|
-
)
|
|
266
|
-
if mh != 1 and mh != H:
|
|
267
|
-
raise ValueError(
|
|
268
|
-
f"attn_mask head count must be 1 or {H}, got {mh}"
|
|
269
|
-
)
|
|
300
|
+
|
|
301
|
+
# Validate tensor dimensions
|
|
302
|
+
if query.dim() != 4:
|
|
303
|
+
raise RuntimeError(f"query must be 4D (B, H, N, D), got {query.dim()}D")
|
|
304
|
+
if key.dim() != 4:
|
|
305
|
+
raise RuntimeError(f"key must be 4D (B, H, N, D), got {key.dim()}D")
|
|
306
|
+
if value.dim() != 4:
|
|
307
|
+
raise RuntimeError(f"value must be 4D (B, H, N, D), got {value.dim()}D")
|
|
308
|
+
|
|
309
|
+
# Validate and expand broadcast mask
|
|
310
|
+
B, H, N_q, D = query.shape
|
|
311
|
+
N_kv = key.shape[2]
|
|
312
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
270
313
|
|
|
271
314
|
# Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
|
|
272
315
|
if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
|
|
@@ -925,6 +968,11 @@ def flash_attention_fp8(
|
|
|
925
968
|
scale_factor = scale / default_scale
|
|
926
969
|
query = query * scale_factor
|
|
927
970
|
|
|
971
|
+
# Validate and expand broadcast mask
|
|
972
|
+
B, H, N_q, D = query.shape
|
|
973
|
+
N_kv = key.shape[2]
|
|
974
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
975
|
+
|
|
928
976
|
quant_type = QUANT_FP8_E5M2 if use_e5m2 else QUANT_FP8_E4M3
|
|
929
977
|
return _C.forward_quantized(
|
|
930
978
|
query, key, value, k_scale, v_scale,
|
|
@@ -978,6 +1026,11 @@ def flash_attention_int8(
|
|
|
978
1026
|
scale_factor = scale / default_scale
|
|
979
1027
|
query = query * scale_factor
|
|
980
1028
|
|
|
1029
|
+
# Validate and expand broadcast mask
|
|
1030
|
+
B, H, N_q, D = query.shape
|
|
1031
|
+
N_kv = key.shape[2]
|
|
1032
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
1033
|
+
|
|
981
1034
|
return _C.forward_quantized(
|
|
982
1035
|
query, key, value, k_scale, v_scale,
|
|
983
1036
|
QUANT_INT8, is_causal, attn_mask, window_size
|
|
@@ -1034,6 +1087,11 @@ def flash_attention_nf4(
|
|
|
1034
1087
|
scale_factor = scale / default_scale
|
|
1035
1088
|
query = query * scale_factor
|
|
1036
1089
|
|
|
1090
|
+
# Validate and expand broadcast mask
|
|
1091
|
+
B, H, N_q, D = query.shape
|
|
1092
|
+
N_kv = key.shape[2]
|
|
1093
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
1094
|
+
|
|
1037
1095
|
return _C.forward_quantized(
|
|
1038
1096
|
query, key, value, k_scale, v_scale,
|
|
1039
1097
|
QUANT_NF4, is_causal, attn_mask, window_size
|
|
@@ -1088,6 +1146,11 @@ def flash_attention_quantized(
|
|
|
1088
1146
|
scale_factor = scale / default_scale
|
|
1089
1147
|
query = query * scale_factor
|
|
1090
1148
|
|
|
1149
|
+
# Validate and expand broadcast mask
|
|
1150
|
+
B, H, N_q, D = query.shape
|
|
1151
|
+
N_kv = key.shape[2]
|
|
1152
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
1153
|
+
|
|
1091
1154
|
return _C.forward_quantized(
|
|
1092
1155
|
query, key, value, k_scale, v_scale,
|
|
1093
1156
|
quant_type, is_causal, attn_mask, window_size
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|