mps-flash-attn 0.1.13__tar.gz → 0.1.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/PKG-INFO +1 -1
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/__init__.py +4 -3
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn.egg-info/SOURCES.txt +2 -1
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/pyproject.toml +1 -1
- mps_flash_attn-0.1.15/tests/test_flash_attn.py +255 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/LICENSE +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/README.md +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/setup.cfg +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/setup.py +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/tests/test_attention.py +0 -0
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.1.
|
|
7
|
+
__version__ = "0.1.14"
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from typing import Optional
|
|
@@ -200,12 +200,13 @@ def replace_sdpa():
|
|
|
200
200
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
201
201
|
is_causal=False, scale=None):
|
|
202
202
|
# Use MFA for MPS tensors without dropout
|
|
203
|
-
# Only use MFA for seq_len >=
|
|
203
|
+
# Only use MFA for seq_len >= 1536 where it outperforms PyTorch's math backend
|
|
204
204
|
# For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
|
|
205
|
+
# Benchmark (BF16, heads=30, head_dim=128): crossover is ~1200-1500
|
|
205
206
|
if (query.device.type == 'mps' and
|
|
206
207
|
dropout_p == 0.0 and
|
|
207
208
|
_HAS_MFA and
|
|
208
|
-
query.shape[2] >=
|
|
209
|
+
query.shape[2] >= 1536):
|
|
209
210
|
try:
|
|
210
211
|
# Convert float mask to bool mask if needed
|
|
211
212
|
# PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
|
|
@@ -33,4 +33,5 @@ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e
|
|
|
33
33
|
mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib
|
|
34
34
|
mps_flash_attn/kernels/manifest.json
|
|
35
35
|
mps_flash_attn/lib/libMFABridge.dylib
|
|
36
|
-
tests/test_attention.py
|
|
36
|
+
tests/test_attention.py
|
|
37
|
+
tests/test_flash_attn.py
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
"""Test mps-flash-attn: FP32, FP16, BF16 support with benchmarks"""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
import math
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
# Build and load the extension
|
|
9
|
+
print("Loading mps_flash_attn...")
|
|
10
|
+
from mps_flash_attn import flash_attention, is_available
|
|
11
|
+
|
|
12
|
+
print(f"MPS available: {is_available()}")
|
|
13
|
+
|
|
14
|
+
def reference_attention(q, k, v, is_causal=False, attn_mask=None):
|
|
15
|
+
"""Reference implementation using PyTorch ops."""
|
|
16
|
+
scale = 1.0 / math.sqrt(q.size(-1))
|
|
17
|
+
attn = torch.matmul(q.float(), k.float().transpose(-2, -1)) * scale
|
|
18
|
+
|
|
19
|
+
if is_causal:
|
|
20
|
+
seq_len = q.size(-2)
|
|
21
|
+
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool), diagonal=1)
|
|
22
|
+
attn = attn.masked_fill(causal_mask, float('-inf'))
|
|
23
|
+
|
|
24
|
+
if attn_mask is not None:
|
|
25
|
+
attn = attn.masked_fill(attn_mask.bool(), float('-inf'))
|
|
26
|
+
|
|
27
|
+
attn = F.softmax(attn, dim=-1)
|
|
28
|
+
out = torch.matmul(attn, v.float())
|
|
29
|
+
return out.to(q.dtype)
|
|
30
|
+
|
|
31
|
+
def test_forward(dtype, name):
|
|
32
|
+
"""Test forward pass."""
|
|
33
|
+
torch.manual_seed(42)
|
|
34
|
+
|
|
35
|
+
B, H, N, D = 2, 8, 128, 64
|
|
36
|
+
|
|
37
|
+
q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
38
|
+
k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
39
|
+
v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
40
|
+
|
|
41
|
+
output = flash_attention(q, k, v)
|
|
42
|
+
|
|
43
|
+
has_nan = torch.isnan(output).any().item()
|
|
44
|
+
has_inf = torch.isinf(output).any().item()
|
|
45
|
+
ok = not has_nan and not has_inf
|
|
46
|
+
|
|
47
|
+
shape_ok = output.shape == (B, H, N, D)
|
|
48
|
+
|
|
49
|
+
print(f" {name} forward: shape={output.shape}, dtype={output.dtype}, ok={ok and shape_ok}")
|
|
50
|
+
return ok and shape_ok
|
|
51
|
+
|
|
52
|
+
def test_backward(dtype, name):
|
|
53
|
+
"""Test backward pass."""
|
|
54
|
+
torch.manual_seed(42)
|
|
55
|
+
|
|
56
|
+
B, H, N, D = 2, 4, 64, 32
|
|
57
|
+
|
|
58
|
+
q = torch.randn(B, H, N, D, device='mps', dtype=dtype, requires_grad=True)
|
|
59
|
+
k = torch.randn(B, H, N, D, device='mps', dtype=dtype, requires_grad=True)
|
|
60
|
+
v = torch.randn(B, H, N, D, device='mps', dtype=dtype, requires_grad=True)
|
|
61
|
+
|
|
62
|
+
output = flash_attention(q, k, v)
|
|
63
|
+
loss = output.sum()
|
|
64
|
+
loss.backward()
|
|
65
|
+
|
|
66
|
+
grad_q_ok = q.grad is not None and not torch.isnan(q.grad).any()
|
|
67
|
+
grad_k_ok = k.grad is not None and not torch.isnan(k.grad).any()
|
|
68
|
+
grad_v_ok = v.grad is not None and not torch.isnan(v.grad).any()
|
|
69
|
+
ok = grad_q_ok and grad_k_ok and grad_v_ok
|
|
70
|
+
|
|
71
|
+
print(f" {name} backward: q={grad_q_ok}, k={grad_k_ok}, v={grad_v_ok}")
|
|
72
|
+
return ok
|
|
73
|
+
|
|
74
|
+
def test_causal(dtype, name):
|
|
75
|
+
"""Test causal attention."""
|
|
76
|
+
torch.manual_seed(42)
|
|
77
|
+
|
|
78
|
+
B, H, N, D = 2, 8, 128, 64
|
|
79
|
+
|
|
80
|
+
q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
81
|
+
k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
82
|
+
v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
83
|
+
|
|
84
|
+
output = flash_attention(q, k, v, is_causal=True)
|
|
85
|
+
|
|
86
|
+
ok = not torch.isnan(output).any() and not torch.isinf(output).any()
|
|
87
|
+
print(f" {name} causal: shape={output.shape}, ok={ok}")
|
|
88
|
+
return ok
|
|
89
|
+
|
|
90
|
+
def test_gqa(dtype, name):
|
|
91
|
+
"""Test Grouped Query Attention (GQA)."""
|
|
92
|
+
torch.manual_seed(42)
|
|
93
|
+
|
|
94
|
+
B, H_q, H_kv, N, D = 2, 8, 2, 128, 64 # 8 query heads, 2 KV heads
|
|
95
|
+
|
|
96
|
+
q = torch.randn(B, H_q, N, D, device='mps', dtype=dtype)
|
|
97
|
+
k = torch.randn(B, H_kv, N, D, device='mps', dtype=dtype)
|
|
98
|
+
v = torch.randn(B, H_kv, N, D, device='mps', dtype=dtype)
|
|
99
|
+
|
|
100
|
+
output = flash_attention(q, k, v)
|
|
101
|
+
|
|
102
|
+
ok = not torch.isnan(output).any() and not torch.isinf(output).any()
|
|
103
|
+
shape_ok = output.shape == (B, H_q, N, D)
|
|
104
|
+
print(f" {name} GQA ({H_q}q/{H_kv}kv): shape={output.shape}, ok={ok and shape_ok}")
|
|
105
|
+
return ok and shape_ok
|
|
106
|
+
|
|
107
|
+
def test_correctness(dtype, name):
|
|
108
|
+
"""Test correctness against reference implementation."""
|
|
109
|
+
torch.manual_seed(42)
|
|
110
|
+
|
|
111
|
+
B, H, N, D = 1, 4, 64, 32
|
|
112
|
+
|
|
113
|
+
q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
114
|
+
k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
115
|
+
v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
116
|
+
|
|
117
|
+
output_flash = flash_attention(q, k, v)
|
|
118
|
+
output_ref = reference_attention(q, k, v)
|
|
119
|
+
|
|
120
|
+
diff = (output_flash.float() - output_ref.float()).abs()
|
|
121
|
+
max_diff = diff.max().item()
|
|
122
|
+
rel_diff = (diff / (output_ref.float().abs() + 1e-6)).mean().item() * 100
|
|
123
|
+
|
|
124
|
+
# Tolerance depends on dtype
|
|
125
|
+
if dtype == torch.float32:
|
|
126
|
+
ok = rel_diff < 1.0
|
|
127
|
+
elif dtype == torch.float16:
|
|
128
|
+
ok = rel_diff < 2.0
|
|
129
|
+
else: # BF16
|
|
130
|
+
ok = rel_diff < 5.0
|
|
131
|
+
|
|
132
|
+
print(f" {name} correctness: max_diff={max_diff:.6f}, rel_diff={rel_diff:.2f}%, ok={ok}")
|
|
133
|
+
return ok
|
|
134
|
+
|
|
135
|
+
def compare_fp32_bf16():
|
|
136
|
+
"""Compare FP32 vs BF16 outputs."""
|
|
137
|
+
torch.manual_seed(42)
|
|
138
|
+
|
|
139
|
+
B, H, N, D = 1, 4, 64, 32
|
|
140
|
+
|
|
141
|
+
q_fp32 = torch.randn(B, H, N, D, device='mps', dtype=torch.float32)
|
|
142
|
+
k_fp32 = torch.randn(B, H, N, D, device='mps', dtype=torch.float32)
|
|
143
|
+
v_fp32 = torch.randn(B, H, N, D, device='mps', dtype=torch.float32)
|
|
144
|
+
|
|
145
|
+
output_fp32 = flash_attention(q_fp32, k_fp32, v_fp32)
|
|
146
|
+
|
|
147
|
+
# BF16
|
|
148
|
+
q_bf16 = q_fp32.to(torch.bfloat16)
|
|
149
|
+
k_bf16 = k_fp32.to(torch.bfloat16)
|
|
150
|
+
v_bf16 = v_fp32.to(torch.bfloat16)
|
|
151
|
+
|
|
152
|
+
output_bf16 = flash_attention(q_bf16, k_bf16, v_bf16)
|
|
153
|
+
|
|
154
|
+
diff = (output_fp32 - output_bf16.to(torch.float32)).abs()
|
|
155
|
+
max_diff = diff.max().item()
|
|
156
|
+
rel_diff = (diff / (output_fp32.abs() + 1e-6)).mean().item() * 100
|
|
157
|
+
|
|
158
|
+
# Attention has softmax which amplifies small differences, 10% tolerance is acceptable
|
|
159
|
+
ok = rel_diff < 10.0
|
|
160
|
+
print(f" FP32 vs BF16: max_diff={max_diff:.6f}, rel_diff={rel_diff:.2f}%, ok={ok}")
|
|
161
|
+
return ok
|
|
162
|
+
|
|
163
|
+
def test_large_sequence():
|
|
164
|
+
"""Test with large sequence length (memory efficiency test)."""
|
|
165
|
+
torch.manual_seed(42)
|
|
166
|
+
|
|
167
|
+
B, H, N, D = 1, 8, 4096, 64 # Large sequence
|
|
168
|
+
dtype = torch.float16
|
|
169
|
+
|
|
170
|
+
q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
171
|
+
k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
172
|
+
v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
173
|
+
|
|
174
|
+
# This should NOT OOM with flash attention
|
|
175
|
+
output = flash_attention(q, k, v)
|
|
176
|
+
|
|
177
|
+
ok = not torch.isnan(output).any() and not torch.isinf(output).any()
|
|
178
|
+
shape_ok = output.shape == (B, H, N, D)
|
|
179
|
+
|
|
180
|
+
print(f" Large seq (N={N}): shape={output.shape}, ok={ok and shape_ok}")
|
|
181
|
+
return ok and shape_ok
|
|
182
|
+
|
|
183
|
+
def benchmark(dtype, name, warmup=5, runs=20):
|
|
184
|
+
"""Benchmark forward pass."""
|
|
185
|
+
torch.manual_seed(42)
|
|
186
|
+
|
|
187
|
+
B, H, N, D = 4, 8, 512, 64
|
|
188
|
+
|
|
189
|
+
q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
190
|
+
k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
191
|
+
v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
192
|
+
|
|
193
|
+
for _ in range(warmup):
|
|
194
|
+
_ = flash_attention(q, k, v)
|
|
195
|
+
torch.mps.synchronize()
|
|
196
|
+
|
|
197
|
+
start = time.time()
|
|
198
|
+
for _ in range(runs):
|
|
199
|
+
_ = flash_attention(q, k, v)
|
|
200
|
+
torch.mps.synchronize()
|
|
201
|
+
elapsed = time.time() - start
|
|
202
|
+
|
|
203
|
+
ms = (elapsed / runs) * 1000
|
|
204
|
+
print(f" {name}: {ms:.2f} ms")
|
|
205
|
+
return ms
|
|
206
|
+
|
|
207
|
+
if __name__ == "__main__":
|
|
208
|
+
print("\n" + "=" * 50)
|
|
209
|
+
print("Testing mps-flash-attn")
|
|
210
|
+
print("=" * 50)
|
|
211
|
+
|
|
212
|
+
all_ok = True
|
|
213
|
+
|
|
214
|
+
print("\n1. Forward pass:")
|
|
215
|
+
all_ok &= test_forward(torch.float32, "FP32")
|
|
216
|
+
all_ok &= test_forward(torch.float16, "FP16")
|
|
217
|
+
all_ok &= test_forward(torch.bfloat16, "BF16")
|
|
218
|
+
|
|
219
|
+
print("\n2. Backward pass:")
|
|
220
|
+
all_ok &= test_backward(torch.float32, "FP32")
|
|
221
|
+
all_ok &= test_backward(torch.float16, "FP16")
|
|
222
|
+
all_ok &= test_backward(torch.bfloat16, "BF16")
|
|
223
|
+
|
|
224
|
+
print("\n3. Causal attention:")
|
|
225
|
+
all_ok &= test_causal(torch.float32, "FP32")
|
|
226
|
+
all_ok &= test_causal(torch.float16, "FP16")
|
|
227
|
+
all_ok &= test_causal(torch.bfloat16, "BF16")
|
|
228
|
+
|
|
229
|
+
print("\n4. GQA (Grouped Query Attention):")
|
|
230
|
+
all_ok &= test_gqa(torch.float32, "FP32")
|
|
231
|
+
all_ok &= test_gqa(torch.float16, "FP16")
|
|
232
|
+
all_ok &= test_gqa(torch.bfloat16, "BF16")
|
|
233
|
+
|
|
234
|
+
print("\n5. Correctness vs reference:")
|
|
235
|
+
all_ok &= test_correctness(torch.float32, "FP32")
|
|
236
|
+
all_ok &= test_correctness(torch.float16, "FP16")
|
|
237
|
+
all_ok &= test_correctness(torch.bfloat16, "BF16")
|
|
238
|
+
|
|
239
|
+
print("\n6. FP32 vs BF16 comparison:")
|
|
240
|
+
all_ok &= compare_fp32_bf16()
|
|
241
|
+
|
|
242
|
+
print("\n7. Large sequence (memory efficiency):")
|
|
243
|
+
all_ok &= test_large_sequence()
|
|
244
|
+
|
|
245
|
+
print("\n8. Benchmarks (B=4, H=8, N=512, D=64):")
|
|
246
|
+
benchmark(torch.float32, "FP32")
|
|
247
|
+
benchmark(torch.float16, "FP16")
|
|
248
|
+
benchmark(torch.bfloat16, "BF16")
|
|
249
|
+
|
|
250
|
+
print("\n" + "=" * 50)
|
|
251
|
+
if all_ok:
|
|
252
|
+
print("ALL TESTS PASSED!")
|
|
253
|
+
else:
|
|
254
|
+
print("SOME TESTS FAILED!")
|
|
255
|
+
print("=" * 50)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{mps_flash_attn-0.1.13 → mps_flash_attn-0.1.15}/mps_flash_attn.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|