mps-flash-attn 0.3.1__cp314-cp314-macosx_15_0_arm64.whl → 0.3.2__cp314-cp314-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Binary file
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.3.1"
7
+ __version__ = "0.3.2"
8
8
 
9
9
  __all__ = [
10
10
  # Core functions
@@ -97,6 +97,56 @@ def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
97
97
  return attn_mask <= -1e3
98
98
 
99
99
 
100
+ def _validate_and_expand_mask(
101
+ attn_mask: Optional[torch.Tensor],
102
+ B: int,
103
+ H: int,
104
+ N_q: int,
105
+ N_kv: int,
106
+ ) -> Optional[torch.Tensor]:
107
+ """
108
+ Validate attention mask shape and expand broadcast dimensions.
109
+
110
+ Args:
111
+ attn_mask: Optional mask of shape (B, H, N_q, N_kv) or broadcastable
112
+ B: Batch size
113
+ H: Number of heads
114
+ N_q: Query sequence length
115
+ N_kv: Key/Value sequence length
116
+
117
+ Returns:
118
+ Expanded mask of shape (mb, mh, N_q, N_kv) or None
119
+ """
120
+ if attn_mask is None:
121
+ return None
122
+
123
+ attn_mask = _ensure_contiguous(attn_mask, "attn_mask")
124
+
125
+ if attn_mask.dim() != 4:
126
+ raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
127
+
128
+ mb, mh, mq, mk = attn_mask.shape
129
+
130
+ # Allow broadcast: mq can be 1 (applies same mask to all query positions) or N_q
131
+ if (mq != 1 and mq != N_q) or (mk != 1 and mk != N_kv):
132
+ raise ValueError(
133
+ f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv}) or broadcastable (1, {N_kv})"
134
+ )
135
+
136
+ # Expand broadcast mask to full shape for Metal kernel
137
+ if mq == 1 and N_q > 1:
138
+ attn_mask = attn_mask.expand(mb, mh, N_q, mk)
139
+ if mk == 1 and N_kv > 1:
140
+ attn_mask = attn_mask.expand(mb, mh, mq if mq > 1 else N_q, N_kv)
141
+
142
+ if mb != 1 and mb != B:
143
+ raise ValueError(f"attn_mask batch size must be 1 or {B}, got {mb}")
144
+ if mh != 1 and mh != H:
145
+ raise ValueError(f"attn_mask head count must be 1 or {H}, got {mh}")
146
+
147
+ return attn_mask
148
+
149
+
100
150
  class FlashAttentionFunction(torch.autograd.Function):
101
151
  """Autograd function for Flash Attention with backward pass support."""
102
152
 
@@ -247,32 +297,19 @@ def flash_attention(
247
297
  query = _ensure_contiguous(query, "query")
248
298
  key = _ensure_contiguous(key, "key")
249
299
  value = _ensure_contiguous(value, "value")
250
- if attn_mask is not None:
251
- attn_mask = _ensure_contiguous(attn_mask, "attn_mask")
252
- # Validate mask shape
253
- B, H, N_q, D = query.shape
254
- N_kv = key.shape[2]
255
- if attn_mask.dim() != 4:
256
- raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
257
- mb, mh, mq, mk = attn_mask.shape
258
- # Allow broadcast: mq can be 1 (applies same mask to all query positions) or N_q
259
- if (mq != 1 and mq != N_q) or (mk != 1 and mk != N_kv):
260
- raise ValueError(
261
- f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv}) or broadcastable (1, {N_kv})"
262
- )
263
- # Expand broadcast mask to full shape for Metal kernel
264
- if mq == 1 and N_q > 1:
265
- attn_mask = attn_mask.expand(mb, mh, N_q, mk)
266
- if mk == 1 and N_kv > 1:
267
- attn_mask = attn_mask.expand(mb, mh, mq if mq > 1 else N_q, N_kv)
268
- if mb != 1 and mb != B:
269
- raise ValueError(
270
- f"attn_mask batch size must be 1 or {B}, got {mb}"
271
- )
272
- if mh != 1 and mh != H:
273
- raise ValueError(
274
- f"attn_mask head count must be 1 or {H}, got {mh}"
275
- )
300
+
301
+ # Validate tensor dimensions
302
+ if query.dim() != 4:
303
+ raise RuntimeError(f"query must be 4D (B, H, N, D), got {query.dim()}D")
304
+ if key.dim() != 4:
305
+ raise RuntimeError(f"key must be 4D (B, H, N, D), got {key.dim()}D")
306
+ if value.dim() != 4:
307
+ raise RuntimeError(f"value must be 4D (B, H, N, D), got {value.dim()}D")
308
+
309
+ # Validate and expand broadcast mask
310
+ B, H, N_q, D = query.shape
311
+ N_kv = key.shape[2]
312
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
276
313
 
277
314
  # Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
278
315
  if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
@@ -931,6 +968,11 @@ def flash_attention_fp8(
931
968
  scale_factor = scale / default_scale
932
969
  query = query * scale_factor
933
970
 
971
+ # Validate and expand broadcast mask
972
+ B, H, N_q, D = query.shape
973
+ N_kv = key.shape[2]
974
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
975
+
934
976
  quant_type = QUANT_FP8_E5M2 if use_e5m2 else QUANT_FP8_E4M3
935
977
  return _C.forward_quantized(
936
978
  query, key, value, k_scale, v_scale,
@@ -984,6 +1026,11 @@ def flash_attention_int8(
984
1026
  scale_factor = scale / default_scale
985
1027
  query = query * scale_factor
986
1028
 
1029
+ # Validate and expand broadcast mask
1030
+ B, H, N_q, D = query.shape
1031
+ N_kv = key.shape[2]
1032
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1033
+
987
1034
  return _C.forward_quantized(
988
1035
  query, key, value, k_scale, v_scale,
989
1036
  QUANT_INT8, is_causal, attn_mask, window_size
@@ -1040,6 +1087,11 @@ def flash_attention_nf4(
1040
1087
  scale_factor = scale / default_scale
1041
1088
  query = query * scale_factor
1042
1089
 
1090
+ # Validate and expand broadcast mask
1091
+ B, H, N_q, D = query.shape
1092
+ N_kv = key.shape[2]
1093
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1094
+
1043
1095
  return _C.forward_quantized(
1044
1096
  query, key, value, k_scale, v_scale,
1045
1097
  QUANT_NF4, is_causal, attn_mask, window_size
@@ -1094,6 +1146,11 @@ def flash_attention_quantized(
1094
1146
  scale_factor = scale / default_scale
1095
1147
  query = query * scale_factor
1096
1148
 
1149
+ # Validate and expand broadcast mask
1150
+ B, H, N_q, D = query.shape
1151
+ N_kv = key.shape[2]
1152
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1153
+
1097
1154
  return _C.forward_quantized(
1098
1155
  query, key, value, k_scale, v_scale,
1099
1156
  quant_type, is_causal, attn_mask, window_size
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -1,5 +1,5 @@
1
- mps_flash_attn/_C.cpython-314-darwin.so,sha256=V9bjj53KRFmbMSslzTf7YV8N2l9NPa9_Ia2dORgRjqA,313448
2
- mps_flash_attn/__init__.py,sha256=Esm5wd3As4es3ne1GjUtlQGfBtj0LB05UuND-SaIRXo,47730
1
+ mps_flash_attn/_C.cpython-314-darwin.so,sha256=GtWa4KIcynqjbCQYw-uTBpkX5NTcxyKuK0APoGnRlQM,313448
2
+ mps_flash_attn/__init__.py,sha256=u6B_WenZOTk1WMe9u4-PbyPy7pt8NValR5i8Oz6bI-U,49252
3
3
  mps_flash_attn/benchmark.py,sha256=qHhvb8Dmh07OEa_iXuPuJSEnRJlrjVF5nKzVwbWypWE,24141
4
4
  mps_flash_attn/csrc/mps_flash_attn.mm,sha256=mR4S8SHLtRiksrmoFH6s2118q662SMNlFU8HmxAE3YY,51204
5
5
  mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib,sha256=_oig6f2I6ZxBCKWbJF3ofmZMySm8gB399_M-lD2NOfM,13747
@@ -26,8 +26,8 @@ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e
26
26
  mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib,sha256=qyOaQtRVwL_Wc6GGdu6z-ftf0iX84XexuY09-lNLl5o,13747
27
27
  mps_flash_attn/kernels/manifest.json,sha256=d5MkE_BjqDQuMNm1jZiwWkQKfB-yfFml3lLSeR-wCLo,1867
28
28
  mps_flash_attn/lib/libMFABridge.dylib,sha256=iKgfYISSKMSNt_iXnljjUr_hZZHyCAg2tdS3_ZjmLkc,605696
29
- mps_flash_attn-0.3.1.dist-info/licenses/LICENSE,sha256=F_XmXSab2O-hHcqLpYJWeFaqB6GA_qiTEN23p2VfZWU,1237
30
- mps_flash_attn-0.3.1.dist-info/METADATA,sha256=hp_w8UG_IpMF6BfS7STV69sM0Ss01-n6nWz9s1S2JzM,5834
31
- mps_flash_attn-0.3.1.dist-info/WHEEL,sha256=uAzMRtb2noxPlbYLbRgeD25pPgKOo3k59IS71Dg5Qjs,110
32
- mps_flash_attn-0.3.1.dist-info/top_level.txt,sha256=zbArDcWhJDnJfMUKnOUhs5TjsMgSxa2GzOlscTRfobE,15
33
- mps_flash_attn-0.3.1.dist-info/RECORD,,
29
+ mps_flash_attn-0.3.2.dist-info/licenses/LICENSE,sha256=F_XmXSab2O-hHcqLpYJWeFaqB6GA_qiTEN23p2VfZWU,1237
30
+ mps_flash_attn-0.3.2.dist-info/METADATA,sha256=vhcu8d8NdzmuQbOqVUpzacXJF__Eu-BW1C7Em_CNoyg,5834
31
+ mps_flash_attn-0.3.2.dist-info/WHEEL,sha256=uAzMRtb2noxPlbYLbRgeD25pPgKOo3k59IS71Dg5Qjs,110
32
+ mps_flash_attn-0.3.2.dist-info/top_level.txt,sha256=zbArDcWhJDnJfMUKnOUhs5TjsMgSxa2GzOlscTRfobE,15
33
+ mps_flash_attn-0.3.2.dist-info/RECORD,,