mps-flash-attn 0.1.13__tar.gz → 0.1.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (39) hide show
  1. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/PKG-INFO +1 -1
  2. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/__init__.py +4 -3
  3. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn.egg-info/PKG-INFO +1 -1
  4. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn.egg-info/SOURCES.txt +2 -1
  5. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/pyproject.toml +1 -1
  6. mps_flash_attn-0.1.15/tests/test_flash_attn.py +255 -0
  7. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/LICENSE +0 -0
  8. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/README.md +0 -0
  9. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
  10. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  11. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  12. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  13. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  14. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  15. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  16. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  17. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  18. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  19. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  20. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  21. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  22. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  23. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  24. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  25. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  26. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  27. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  28. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  29. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  30. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  31. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  32. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/manifest.json +0 -0
  33. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  34. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  35. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn.egg-info/requires.txt +0 -0
  36. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn.egg-info/top_level.txt +0 -0
  37. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/setup.cfg +0 -0
  38. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/setup.py +0 -0
  39. {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/tests/test_attention.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.1.13
3
+ Version: 0.1.15
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.1.13"
7
+ __version__ = "0.1.14"
8
8
 
9
9
  import torch
10
10
  from typing import Optional
@@ -200,12 +200,13 @@ def replace_sdpa():
200
200
  def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
201
201
  is_causal=False, scale=None):
202
202
  # Use MFA for MPS tensors without dropout
203
- # Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
203
+ # Only use MFA for seq_len >= 1536 where it outperforms PyTorch's math backend
204
204
  # For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
205
+ # Benchmark (BF16, heads=30, head_dim=128): crossover is ~1200-1500
205
206
  if (query.device.type == 'mps' and
206
207
  dropout_p == 0.0 and
207
208
  _HAS_MFA and
208
- query.shape[2] >= 1024):
209
+ query.shape[2] >= 1536):
209
210
  try:
210
211
  # Convert float mask to bool mask if needed
211
212
  # PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.1.13
3
+ Version: 0.1.15
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -33,4 +33,5 @@ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e
33
33
  mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib
34
34
  mps_flash_attn/kernels/manifest.json
35
35
  mps_flash_attn/lib/libMFABridge.dylib
36
- tests/test_attention.py
36
+ tests/test_attention.py
37
+ tests/test_flash_attn.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mps-flash-attn"
7
- version = "0.1.13"
7
+ version = "0.1.15"
8
8
  description = "Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -0,0 +1,255 @@
1
+ """Test mps-flash-attn: FP32, FP16, BF16 support with benchmarks"""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import math
6
+ import time
7
+
8
+ # Build and load the extension
9
+ print("Loading mps_flash_attn...")
10
+ from mps_flash_attn import flash_attention, is_available
11
+
12
+ print(f"MPS available: {is_available()}")
13
+
14
+ def reference_attention(q, k, v, is_causal=False, attn_mask=None):
15
+ """Reference implementation using PyTorch ops."""
16
+ scale = 1.0 / math.sqrt(q.size(-1))
17
+ attn = torch.matmul(q.float(), k.float().transpose(-2, -1)) * scale
18
+
19
+ if is_causal:
20
+ seq_len = q.size(-2)
21
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool), diagonal=1)
22
+ attn = attn.masked_fill(causal_mask, float('-inf'))
23
+
24
+ if attn_mask is not None:
25
+ attn = attn.masked_fill(attn_mask.bool(), float('-inf'))
26
+
27
+ attn = F.softmax(attn, dim=-1)
28
+ out = torch.matmul(attn, v.float())
29
+ return out.to(q.dtype)
30
+
31
+ def test_forward(dtype, name):
32
+ """Test forward pass."""
33
+ torch.manual_seed(42)
34
+
35
+ B, H, N, D = 2, 8, 128, 64
36
+
37
+ q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
38
+ k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
39
+ v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
40
+
41
+ output = flash_attention(q, k, v)
42
+
43
+ has_nan = torch.isnan(output).any().item()
44
+ has_inf = torch.isinf(output).any().item()
45
+ ok = not has_nan and not has_inf
46
+
47
+ shape_ok = output.shape == (B, H, N, D)
48
+
49
+ print(f" {name} forward: shape={output.shape}, dtype={output.dtype}, ok={ok and shape_ok}")
50
+ return ok and shape_ok
51
+
52
+ def test_backward(dtype, name):
53
+ """Test backward pass."""
54
+ torch.manual_seed(42)
55
+
56
+ B, H, N, D = 2, 4, 64, 32
57
+
58
+ q = torch.randn(B, H, N, D, device='mps', dtype=dtype, requires_grad=True)
59
+ k = torch.randn(B, H, N, D, device='mps', dtype=dtype, requires_grad=True)
60
+ v = torch.randn(B, H, N, D, device='mps', dtype=dtype, requires_grad=True)
61
+
62
+ output = flash_attention(q, k, v)
63
+ loss = output.sum()
64
+ loss.backward()
65
+
66
+ grad_q_ok = q.grad is not None and not torch.isnan(q.grad).any()
67
+ grad_k_ok = k.grad is not None and not torch.isnan(k.grad).any()
68
+ grad_v_ok = v.grad is not None and not torch.isnan(v.grad).any()
69
+ ok = grad_q_ok and grad_k_ok and grad_v_ok
70
+
71
+ print(f" {name} backward: q={grad_q_ok}, k={grad_k_ok}, v={grad_v_ok}")
72
+ return ok
73
+
74
+ def test_causal(dtype, name):
75
+ """Test causal attention."""
76
+ torch.manual_seed(42)
77
+
78
+ B, H, N, D = 2, 8, 128, 64
79
+
80
+ q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
81
+ k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
82
+ v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
83
+
84
+ output = flash_attention(q, k, v, is_causal=True)
85
+
86
+ ok = not torch.isnan(output).any() and not torch.isinf(output).any()
87
+ print(f" {name} causal: shape={output.shape}, ok={ok}")
88
+ return ok
89
+
90
+ def test_gqa(dtype, name):
91
+ """Test Grouped Query Attention (GQA)."""
92
+ torch.manual_seed(42)
93
+
94
+ B, H_q, H_kv, N, D = 2, 8, 2, 128, 64 # 8 query heads, 2 KV heads
95
+
96
+ q = torch.randn(B, H_q, N, D, device='mps', dtype=dtype)
97
+ k = torch.randn(B, H_kv, N, D, device='mps', dtype=dtype)
98
+ v = torch.randn(B, H_kv, N, D, device='mps', dtype=dtype)
99
+
100
+ output = flash_attention(q, k, v)
101
+
102
+ ok = not torch.isnan(output).any() and not torch.isinf(output).any()
103
+ shape_ok = output.shape == (B, H_q, N, D)
104
+ print(f" {name} GQA ({H_q}q/{H_kv}kv): shape={output.shape}, ok={ok and shape_ok}")
105
+ return ok and shape_ok
106
+
107
+ def test_correctness(dtype, name):
108
+ """Test correctness against reference implementation."""
109
+ torch.manual_seed(42)
110
+
111
+ B, H, N, D = 1, 4, 64, 32
112
+
113
+ q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
114
+ k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
115
+ v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
116
+
117
+ output_flash = flash_attention(q, k, v)
118
+ output_ref = reference_attention(q, k, v)
119
+
120
+ diff = (output_flash.float() - output_ref.float()).abs()
121
+ max_diff = diff.max().item()
122
+ rel_diff = (diff / (output_ref.float().abs() + 1e-6)).mean().item() * 100
123
+
124
+ # Tolerance depends on dtype
125
+ if dtype == torch.float32:
126
+ ok = rel_diff < 1.0
127
+ elif dtype == torch.float16:
128
+ ok = rel_diff < 2.0
129
+ else: # BF16
130
+ ok = rel_diff < 5.0
131
+
132
+ print(f" {name} correctness: max_diff={max_diff:.6f}, rel_diff={rel_diff:.2f}%, ok={ok}")
133
+ return ok
134
+
135
+ def compare_fp32_bf16():
136
+ """Compare FP32 vs BF16 outputs."""
137
+ torch.manual_seed(42)
138
+
139
+ B, H, N, D = 1, 4, 64, 32
140
+
141
+ q_fp32 = torch.randn(B, H, N, D, device='mps', dtype=torch.float32)
142
+ k_fp32 = torch.randn(B, H, N, D, device='mps', dtype=torch.float32)
143
+ v_fp32 = torch.randn(B, H, N, D, device='mps', dtype=torch.float32)
144
+
145
+ output_fp32 = flash_attention(q_fp32, k_fp32, v_fp32)
146
+
147
+ # BF16
148
+ q_bf16 = q_fp32.to(torch.bfloat16)
149
+ k_bf16 = k_fp32.to(torch.bfloat16)
150
+ v_bf16 = v_fp32.to(torch.bfloat16)
151
+
152
+ output_bf16 = flash_attention(q_bf16, k_bf16, v_bf16)
153
+
154
+ diff = (output_fp32 - output_bf16.to(torch.float32)).abs()
155
+ max_diff = diff.max().item()
156
+ rel_diff = (diff / (output_fp32.abs() + 1e-6)).mean().item() * 100
157
+
158
+ # Attention has softmax which amplifies small differences, 10% tolerance is acceptable
159
+ ok = rel_diff < 10.0
160
+ print(f" FP32 vs BF16: max_diff={max_diff:.6f}, rel_diff={rel_diff:.2f}%, ok={ok}")
161
+ return ok
162
+
163
+ def test_large_sequence():
164
+ """Test with large sequence length (memory efficiency test)."""
165
+ torch.manual_seed(42)
166
+
167
+ B, H, N, D = 1, 8, 4096, 64 # Large sequence
168
+ dtype = torch.float16
169
+
170
+ q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
171
+ k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
172
+ v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
173
+
174
+ # This should NOT OOM with flash attention
175
+ output = flash_attention(q, k, v)
176
+
177
+ ok = not torch.isnan(output).any() and not torch.isinf(output).any()
178
+ shape_ok = output.shape == (B, H, N, D)
179
+
180
+ print(f" Large seq (N={N}): shape={output.shape}, ok={ok and shape_ok}")
181
+ return ok and shape_ok
182
+
183
+ def benchmark(dtype, name, warmup=5, runs=20):
184
+ """Benchmark forward pass."""
185
+ torch.manual_seed(42)
186
+
187
+ B, H, N, D = 4, 8, 512, 64
188
+
189
+ q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
190
+ k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
191
+ v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
192
+
193
+ for _ in range(warmup):
194
+ _ = flash_attention(q, k, v)
195
+ torch.mps.synchronize()
196
+
197
+ start = time.time()
198
+ for _ in range(runs):
199
+ _ = flash_attention(q, k, v)
200
+ torch.mps.synchronize()
201
+ elapsed = time.time() - start
202
+
203
+ ms = (elapsed / runs) * 1000
204
+ print(f" {name}: {ms:.2f} ms")
205
+ return ms
206
+
207
+ if __name__ == "__main__":
208
+ print("\n" + "=" * 50)
209
+ print("Testing mps-flash-attn")
210
+ print("=" * 50)
211
+
212
+ all_ok = True
213
+
214
+ print("\n1. Forward pass:")
215
+ all_ok &= test_forward(torch.float32, "FP32")
216
+ all_ok &= test_forward(torch.float16, "FP16")
217
+ all_ok &= test_forward(torch.bfloat16, "BF16")
218
+
219
+ print("\n2. Backward pass:")
220
+ all_ok &= test_backward(torch.float32, "FP32")
221
+ all_ok &= test_backward(torch.float16, "FP16")
222
+ all_ok &= test_backward(torch.bfloat16, "BF16")
223
+
224
+ print("\n3. Causal attention:")
225
+ all_ok &= test_causal(torch.float32, "FP32")
226
+ all_ok &= test_causal(torch.float16, "FP16")
227
+ all_ok &= test_causal(torch.bfloat16, "BF16")
228
+
229
+ print("\n4. GQA (Grouped Query Attention):")
230
+ all_ok &= test_gqa(torch.float32, "FP32")
231
+ all_ok &= test_gqa(torch.float16, "FP16")
232
+ all_ok &= test_gqa(torch.bfloat16, "BF16")
233
+
234
+ print("\n5. Correctness vs reference:")
235
+ all_ok &= test_correctness(torch.float32, "FP32")
236
+ all_ok &= test_correctness(torch.float16, "FP16")
237
+ all_ok &= test_correctness(torch.bfloat16, "BF16")
238
+
239
+ print("\n6. FP32 vs BF16 comparison:")
240
+ all_ok &= compare_fp32_bf16()
241
+
242
+ print("\n7. Large sequence (memory efficiency):")
243
+ all_ok &= test_large_sequence()
244
+
245
+ print("\n8. Benchmarks (B=4, H=8, N=512, D=64):")
246
+ benchmark(torch.float32, "FP32")
247
+ benchmark(torch.float16, "FP16")
248
+ benchmark(torch.bfloat16, "BF16")
249
+
250
+ print("\n" + "=" * 50)
251
+ if all_ok:
252
+ print("ALL TESTS PASSED!")
253
+ else:
254
+ print("SOME TESTS FAILED!")
255
+ print("=" * 50)
File without changes