EvoScientist 0.1.0rc1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- EvoScientist/EvoScientist.py +1 -1
- EvoScientist/cli.py +450 -178
- EvoScientist/middleware.py +5 -1
- EvoScientist/skills/accelerate/SKILL.md +332 -0
- EvoScientist/skills/accelerate/references/custom-plugins.md +453 -0
- EvoScientist/skills/accelerate/references/megatron-integration.md +489 -0
- EvoScientist/skills/accelerate/references/performance.md +525 -0
- EvoScientist/skills/bitsandbytes/SKILL.md +411 -0
- EvoScientist/skills/bitsandbytes/references/memory-optimization.md +521 -0
- EvoScientist/skills/bitsandbytes/references/qlora-training.md +521 -0
- EvoScientist/skills/bitsandbytes/references/quantization-formats.md +447 -0
- EvoScientist/skills/clip/SKILL.md +253 -0
- EvoScientist/skills/clip/references/applications.md +207 -0
- EvoScientist/skills/find-skills/SKILL.md +133 -0
- EvoScientist/skills/find-skills/scripts/install_skill.py +211 -0
- EvoScientist/skills/flash-attention/SKILL.md +367 -0
- EvoScientist/skills/flash-attention/references/benchmarks.md +215 -0
- EvoScientist/skills/flash-attention/references/transformers-integration.md +293 -0
- EvoScientist/skills/langgraph-docs/SKILL.md +36 -0
- EvoScientist/skills/llama-cpp/SKILL.md +258 -0
- EvoScientist/skills/llama-cpp/references/optimization.md +89 -0
- EvoScientist/skills/llama-cpp/references/quantization.md +213 -0
- EvoScientist/skills/llama-cpp/references/server.md +125 -0
- EvoScientist/skills/lm-evaluation-harness/SKILL.md +490 -0
- EvoScientist/skills/lm-evaluation-harness/references/api-evaluation.md +490 -0
- EvoScientist/skills/lm-evaluation-harness/references/benchmark-guide.md +488 -0
- EvoScientist/skills/lm-evaluation-harness/references/custom-tasks.md +602 -0
- EvoScientist/skills/lm-evaluation-harness/references/distributed-eval.md +519 -0
- EvoScientist/skills/ml-paper-writing/SKILL.md +937 -0
- EvoScientist/skills/ml-paper-writing/references/checklists.md +361 -0
- EvoScientist/skills/ml-paper-writing/references/citation-workflow.md +562 -0
- EvoScientist/skills/ml-paper-writing/references/reviewer-guidelines.md +367 -0
- EvoScientist/skills/ml-paper-writing/references/sources.md +159 -0
- EvoScientist/skills/ml-paper-writing/references/writing-guide.md +476 -0
- EvoScientist/skills/ml-paper-writing/templates/README.md +251 -0
- EvoScientist/skills/ml-paper-writing/templates/aaai2026/README.md +534 -0
- EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026-unified-supp.tex +144 -0
- EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026-unified-template.tex +952 -0
- EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.bib +111 -0
- EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.bst +1493 -0
- EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.sty +315 -0
- EvoScientist/skills/ml-paper-writing/templates/acl/README.md +50 -0
- EvoScientist/skills/ml-paper-writing/templates/acl/acl.sty +312 -0
- EvoScientist/skills/ml-paper-writing/templates/acl/acl_latex.tex +377 -0
- EvoScientist/skills/ml-paper-writing/templates/acl/acl_lualatex.tex +101 -0
- EvoScientist/skills/ml-paper-writing/templates/acl/acl_natbib.bst +1940 -0
- EvoScientist/skills/ml-paper-writing/templates/acl/anthology.bib.txt +26 -0
- EvoScientist/skills/ml-paper-writing/templates/acl/custom.bib +70 -0
- EvoScientist/skills/ml-paper-writing/templates/acl/formatting.md +326 -0
- EvoScientist/skills/ml-paper-writing/templates/colm2025/README.md +3 -0
- EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.bib +11 -0
- EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.bst +1440 -0
- EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.pdf +0 -0
- EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.sty +218 -0
- EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.tex +305 -0
- EvoScientist/skills/ml-paper-writing/templates/colm2025/fancyhdr.sty +485 -0
- EvoScientist/skills/ml-paper-writing/templates/colm2025/math_commands.tex +508 -0
- EvoScientist/skills/ml-paper-writing/templates/colm2025/natbib.sty +1246 -0
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/fancyhdr.sty +485 -0
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.bib +24 -0
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.bst +1440 -0
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.pdf +0 -0
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.sty +246 -0
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.tex +414 -0
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/math_commands.tex +508 -0
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/natbib.sty +1246 -0
- EvoScientist/skills/ml-paper-writing/templates/icml2026/algorithm.sty +79 -0
- EvoScientist/skills/ml-paper-writing/templates/icml2026/algorithmic.sty +201 -0
- EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.bib +75 -0
- EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.pdf +0 -0
- EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.tex +662 -0
- EvoScientist/skills/ml-paper-writing/templates/icml2026/fancyhdr.sty +864 -0
- EvoScientist/skills/ml-paper-writing/templates/icml2026/icml2026.bst +1443 -0
- EvoScientist/skills/ml-paper-writing/templates/icml2026/icml2026.sty +767 -0
- EvoScientist/skills/ml-paper-writing/templates/icml2026/icml_numpapers.pdf +0 -0
- EvoScientist/skills/ml-paper-writing/templates/neurips2025/Makefile +36 -0
- EvoScientist/skills/ml-paper-writing/templates/neurips2025/extra_pkgs.tex +53 -0
- EvoScientist/skills/ml-paper-writing/templates/neurips2025/main.tex +38 -0
- EvoScientist/skills/ml-paper-writing/templates/neurips2025/neurips.sty +382 -0
- EvoScientist/skills/peft/SKILL.md +431 -0
- EvoScientist/skills/peft/references/advanced-usage.md +514 -0
- EvoScientist/skills/peft/references/troubleshooting.md +480 -0
- EvoScientist/skills/ray-data/SKILL.md +326 -0
- EvoScientist/skills/ray-data/references/integration.md +82 -0
- EvoScientist/skills/ray-data/references/transformations.md +83 -0
- EvoScientist/skills/skill-creator/LICENSE.txt +202 -0
- EvoScientist/skills/skill-creator/SKILL.md +356 -0
- EvoScientist/skills/skill-creator/references/output-patterns.md +82 -0
- EvoScientist/skills/skill-creator/references/workflows.md +28 -0
- EvoScientist/skills/skill-creator/scripts/init_skill.py +303 -0
- EvoScientist/skills/skill-creator/scripts/package_skill.py +110 -0
- EvoScientist/skills/skill-creator/scripts/quick_validate.py +95 -0
- EvoScientist/skills/tensorboard/SKILL.md +629 -0
- EvoScientist/skills/tensorboard/references/integrations.md +638 -0
- EvoScientist/skills/tensorboard/references/profiling.md +545 -0
- EvoScientist/skills/tensorboard/references/visualization.md +620 -0
- EvoScientist/skills/vllm/SKILL.md +364 -0
- EvoScientist/skills/vllm/references/optimization.md +226 -0
- EvoScientist/skills/vllm/references/quantization.md +284 -0
- EvoScientist/skills/vllm/references/server-deployment.md +255 -0
- EvoScientist/skills/vllm/references/troubleshooting.md +447 -0
- {evoscientist-0.1.0rc1.dist-info → evoscientist-0.1.0rc2.dist-info}/METADATA +26 -3
- evoscientist-0.1.0rc2.dist-info/RECORD +119 -0
- evoscientist-0.1.0rc1.dist-info/RECORD +0 -21
- {evoscientist-0.1.0rc1.dist-info → evoscientist-0.1.0rc2.dist-info}/WHEEL +0 -0
- {evoscientist-0.1.0rc1.dist-info → evoscientist-0.1.0rc2.dist-info}/entry_points.txt +0 -0
- {evoscientist-0.1.0rc1.dist-info → evoscientist-0.1.0rc2.dist-info}/licenses/LICENSE +0 -0
- {evoscientist-0.1.0rc1.dist-info → evoscientist-0.1.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: flash-attention
|
|
3
|
+
description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
|
|
4
|
+
version: 1.0.0
|
|
5
|
+
author: Orchestra Research
|
|
6
|
+
license: MIT
|
|
7
|
+
tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
|
|
8
|
+
dependencies: [flash-attn, torch, transformers]
|
|
9
|
+
---
|
|
10
|
+
|
|
11
|
+
# Flash Attention - Fast Memory-Efficient Attention
|
|
12
|
+
|
|
13
|
+
## Quick start
|
|
14
|
+
|
|
15
|
+
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
|
|
16
|
+
|
|
17
|
+
**PyTorch native (easiest, PyTorch 2.2+)**:
|
|
18
|
+
```python
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn.functional as F
|
|
21
|
+
|
|
22
|
+
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
|
|
23
|
+
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
|
24
|
+
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
|
25
|
+
|
|
26
|
+
# Automatically uses Flash Attention if available
|
|
27
|
+
out = F.scaled_dot_product_attention(q, k, v)
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
**flash-attn library (more features)**:
|
|
31
|
+
```bash
|
|
32
|
+
pip install flash-attn --no-build-isolation
|
|
33
|
+
```
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
from flash_attn import flash_attn_func
|
|
37
|
+
|
|
38
|
+
# q, k, v: [batch, seqlen, nheads, headdim]
|
|
39
|
+
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
## Common workflows
|
|
43
|
+
|
|
44
|
+
### Workflow 1: Enable in existing PyTorch model
|
|
45
|
+
|
|
46
|
+
Copy this checklist:
|
|
47
|
+
|
|
48
|
+
```
|
|
49
|
+
Flash Attention Integration:
|
|
50
|
+
- [ ] Step 1: Check PyTorch version (≥2.2)
|
|
51
|
+
- [ ] Step 2: Enable Flash Attention backend
|
|
52
|
+
- [ ] Step 3: Verify speedup with profiling
|
|
53
|
+
- [ ] Step 4: Test accuracy matches baseline
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
**Step 1: Check PyTorch version**
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
python -c "import torch; print(torch.__version__)"
|
|
60
|
+
# Should be ≥2.2.0
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
If <2.2, upgrade:
|
|
64
|
+
```bash
|
|
65
|
+
pip install --upgrade torch
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
**Step 2: Enable Flash Attention backend**
|
|
69
|
+
|
|
70
|
+
Replace standard attention:
|
|
71
|
+
```python
|
|
72
|
+
# Before (standard attention)
|
|
73
|
+
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
|
|
74
|
+
out = attn_weights @ v
|
|
75
|
+
|
|
76
|
+
# After (Flash Attention)
|
|
77
|
+
import torch.nn.functional as F
|
|
78
|
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
Force Flash Attention backend:
|
|
82
|
+
```python
|
|
83
|
+
with torch.backends.cuda.sdp_kernel(
|
|
84
|
+
enable_flash=True,
|
|
85
|
+
enable_math=False,
|
|
86
|
+
enable_mem_efficient=False
|
|
87
|
+
):
|
|
88
|
+
out = F.scaled_dot_product_attention(q, k, v)
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
**Step 3: Verify speedup with profiling**
|
|
92
|
+
|
|
93
|
+
```python
|
|
94
|
+
import torch.utils.benchmark as benchmark
|
|
95
|
+
|
|
96
|
+
def test_attention(use_flash):
|
|
97
|
+
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
|
98
|
+
|
|
99
|
+
if use_flash:
|
|
100
|
+
with torch.backends.cuda.sdp_kernel(enable_flash=True):
|
|
101
|
+
return F.scaled_dot_product_attention(q, k, v)
|
|
102
|
+
else:
|
|
103
|
+
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
|
|
104
|
+
return attn @ v
|
|
105
|
+
|
|
106
|
+
# Benchmark
|
|
107
|
+
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
|
|
108
|
+
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
|
|
109
|
+
|
|
110
|
+
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
|
|
111
|
+
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
Expected: 2-4x speedup for sequences >512 tokens.
|
|
115
|
+
|
|
116
|
+
**Step 4: Test accuracy matches baseline**
|
|
117
|
+
|
|
118
|
+
```python
|
|
119
|
+
# Compare outputs
|
|
120
|
+
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
|
121
|
+
|
|
122
|
+
# Flash Attention
|
|
123
|
+
out_flash = F.scaled_dot_product_attention(q, k, v)
|
|
124
|
+
|
|
125
|
+
# Standard attention
|
|
126
|
+
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
|
|
127
|
+
out_standard = attn_weights @ v
|
|
128
|
+
|
|
129
|
+
# Check difference
|
|
130
|
+
diff = (out_flash - out_standard).abs().max()
|
|
131
|
+
print(f"Max difference: {diff:.6f}")
|
|
132
|
+
# Should be <1e-3 for float16
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
### Workflow 2: Use flash-attn library for advanced features
|
|
136
|
+
|
|
137
|
+
For multi-query attention, sliding window, or H100 FP8.
|
|
138
|
+
|
|
139
|
+
Copy this checklist:
|
|
140
|
+
|
|
141
|
+
```
|
|
142
|
+
flash-attn Library Setup:
|
|
143
|
+
- [ ] Step 1: Install flash-attn library
|
|
144
|
+
- [ ] Step 2: Modify attention code
|
|
145
|
+
- [ ] Step 3: Enable advanced features
|
|
146
|
+
- [ ] Step 4: Benchmark performance
|
|
147
|
+
```
|
|
148
|
+
|
|
149
|
+
**Step 1: Install flash-attn library**
|
|
150
|
+
|
|
151
|
+
```bash
|
|
152
|
+
# NVIDIA GPUs (CUDA 12.0+)
|
|
153
|
+
pip install flash-attn --no-build-isolation
|
|
154
|
+
|
|
155
|
+
# Verify installation
|
|
156
|
+
python -c "from flash_attn import flash_attn_func; print('Success')"
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
**Step 2: Modify attention code**
|
|
160
|
+
|
|
161
|
+
```python
|
|
162
|
+
from flash_attn import flash_attn_func
|
|
163
|
+
|
|
164
|
+
# Input: [batch_size, seq_len, num_heads, head_dim]
|
|
165
|
+
# Transpose from [batch, heads, seq, dim] if needed
|
|
166
|
+
q = q.transpose(1, 2) # [batch, seq, heads, dim]
|
|
167
|
+
k = k.transpose(1, 2)
|
|
168
|
+
v = v.transpose(1, 2)
|
|
169
|
+
|
|
170
|
+
out = flash_attn_func(
|
|
171
|
+
q, k, v,
|
|
172
|
+
dropout_p=0.1,
|
|
173
|
+
causal=True, # For autoregressive models
|
|
174
|
+
window_size=(-1, -1), # No sliding window
|
|
175
|
+
softmax_scale=None # Auto-scale
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
**Step 3: Enable advanced features**
|
|
182
|
+
|
|
183
|
+
Multi-query attention (shared K/V across heads):
|
|
184
|
+
```python
|
|
185
|
+
from flash_attn import flash_attn_func
|
|
186
|
+
|
|
187
|
+
# q: [batch, seq, num_q_heads, dim]
|
|
188
|
+
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
|
|
189
|
+
out = flash_attn_func(q, k, v) # Automatically handles MQA
|
|
190
|
+
```
|
|
191
|
+
|
|
192
|
+
Sliding window attention (local attention):
|
|
193
|
+
```python
|
|
194
|
+
# Only attend to window of 256 tokens before/after
|
|
195
|
+
out = flash_attn_func(
|
|
196
|
+
q, k, v,
|
|
197
|
+
window_size=(256, 256), # (left, right) window
|
|
198
|
+
causal=True
|
|
199
|
+
)
|
|
200
|
+
```
|
|
201
|
+
|
|
202
|
+
**Step 4: Benchmark performance**
|
|
203
|
+
|
|
204
|
+
```python
|
|
205
|
+
import torch
|
|
206
|
+
from flash_attn import flash_attn_func
|
|
207
|
+
import time
|
|
208
|
+
|
|
209
|
+
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
|
210
|
+
|
|
211
|
+
# Warmup
|
|
212
|
+
for _ in range(10):
|
|
213
|
+
_ = flash_attn_func(q, k, v)
|
|
214
|
+
|
|
215
|
+
# Benchmark
|
|
216
|
+
torch.cuda.synchronize()
|
|
217
|
+
start = time.time()
|
|
218
|
+
for _ in range(100):
|
|
219
|
+
out = flash_attn_func(q, k, v)
|
|
220
|
+
torch.cuda.synchronize()
|
|
221
|
+
end = time.time()
|
|
222
|
+
|
|
223
|
+
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
|
|
224
|
+
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
|
|
225
|
+
```
|
|
226
|
+
|
|
227
|
+
### Workflow 3: H100 FP8 optimization (FlashAttention-3)
|
|
228
|
+
|
|
229
|
+
For maximum performance on H100 GPUs.
|
|
230
|
+
|
|
231
|
+
```
|
|
232
|
+
FP8 Setup:
|
|
233
|
+
- [ ] Step 1: Verify H100 GPU available
|
|
234
|
+
- [ ] Step 2: Install flash-attn with FP8 support
|
|
235
|
+
- [ ] Step 3: Convert inputs to FP8
|
|
236
|
+
- [ ] Step 4: Run with FP8 attention
|
|
237
|
+
```
|
|
238
|
+
|
|
239
|
+
**Step 1: Verify H100 GPU**
|
|
240
|
+
|
|
241
|
+
```bash
|
|
242
|
+
nvidia-smi --query-gpu=name --format=csv
|
|
243
|
+
# Should show "H100" or "H800"
|
|
244
|
+
```
|
|
245
|
+
|
|
246
|
+
**Step 2: Install flash-attn with FP8 support**
|
|
247
|
+
|
|
248
|
+
```bash
|
|
249
|
+
pip install flash-attn --no-build-isolation
|
|
250
|
+
# FP8 support included for H100
|
|
251
|
+
```
|
|
252
|
+
|
|
253
|
+
**Step 3: Convert inputs to FP8**
|
|
254
|
+
|
|
255
|
+
```python
|
|
256
|
+
import torch
|
|
257
|
+
|
|
258
|
+
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
|
259
|
+
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
|
260
|
+
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
|
261
|
+
|
|
262
|
+
# Convert to float8_e4m3 (FP8)
|
|
263
|
+
q_fp8 = q.to(torch.float8_e4m3fn)
|
|
264
|
+
k_fp8 = k.to(torch.float8_e4m3fn)
|
|
265
|
+
v_fp8 = v.to(torch.float8_e4m3fn)
|
|
266
|
+
```
|
|
267
|
+
|
|
268
|
+
**Step 4: Run with FP8 attention**
|
|
269
|
+
|
|
270
|
+
```python
|
|
271
|
+
from flash_attn import flash_attn_func
|
|
272
|
+
|
|
273
|
+
# FlashAttention-3 automatically uses FP8 kernels on H100
|
|
274
|
+
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
|
|
275
|
+
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
|
|
276
|
+
```
|
|
277
|
+
|
|
278
|
+
## When to use vs alternatives
|
|
279
|
+
|
|
280
|
+
**Use Flash Attention when:**
|
|
281
|
+
- Training transformers with sequences >512 tokens
|
|
282
|
+
- Running inference with long context (>2K tokens)
|
|
283
|
+
- GPU memory constrained (OOM with standard attention)
|
|
284
|
+
- Need 2-4x speedup without accuracy loss
|
|
285
|
+
- Using PyTorch 2.2+ or can install flash-attn
|
|
286
|
+
|
|
287
|
+
**Use alternatives instead:**
|
|
288
|
+
- **Standard attention**: Sequences <256 tokens (overhead not worth it)
|
|
289
|
+
- **xFormers**: Need more attention variants (not just speed)
|
|
290
|
+
- **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
|
|
291
|
+
|
|
292
|
+
## Common issues
|
|
293
|
+
|
|
294
|
+
**Issue: ImportError: cannot import flash_attn**
|
|
295
|
+
|
|
296
|
+
Install with no-build-isolation flag:
|
|
297
|
+
```bash
|
|
298
|
+
pip install flash-attn --no-build-isolation
|
|
299
|
+
```
|
|
300
|
+
|
|
301
|
+
Or install CUDA toolkit first:
|
|
302
|
+
```bash
|
|
303
|
+
conda install cuda -c nvidia
|
|
304
|
+
pip install flash-attn --no-build-isolation
|
|
305
|
+
```
|
|
306
|
+
|
|
307
|
+
**Issue: Slower than expected (no speedup)**
|
|
308
|
+
|
|
309
|
+
Flash Attention benefits increase with sequence length:
|
|
310
|
+
- <512 tokens: Minimal speedup (10-20%)
|
|
311
|
+
- 512-2K tokens: 2-3x speedup
|
|
312
|
+
- >2K tokens: 3-4x speedup
|
|
313
|
+
|
|
314
|
+
Check sequence length is sufficient.
|
|
315
|
+
|
|
316
|
+
**Issue: RuntimeError: CUDA error**
|
|
317
|
+
|
|
318
|
+
Verify GPU supports Flash Attention:
|
|
319
|
+
```python
|
|
320
|
+
import torch
|
|
321
|
+
print(torch.cuda.get_device_capability())
|
|
322
|
+
# Should be ≥(7, 5) for Turing+
|
|
323
|
+
```
|
|
324
|
+
|
|
325
|
+
Flash Attention requires:
|
|
326
|
+
- Ampere (A100, A10): ✅ Full support
|
|
327
|
+
- Turing (T4): ✅ Supported
|
|
328
|
+
- Volta (V100): ❌ Not supported
|
|
329
|
+
|
|
330
|
+
**Issue: Accuracy degradation**
|
|
331
|
+
|
|
332
|
+
Check dtype is float16 or bfloat16 (not float32):
|
|
333
|
+
```python
|
|
334
|
+
q = q.to(torch.float16) # Or torch.bfloat16
|
|
335
|
+
```
|
|
336
|
+
|
|
337
|
+
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
|
|
338
|
+
|
|
339
|
+
## Advanced topics
|
|
340
|
+
|
|
341
|
+
**Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
|
|
342
|
+
|
|
343
|
+
**Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
|
|
344
|
+
|
|
345
|
+
**Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
|
|
346
|
+
|
|
347
|
+
**Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
|
|
348
|
+
|
|
349
|
+
## Hardware requirements
|
|
350
|
+
|
|
351
|
+
- **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
|
|
352
|
+
- **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
|
|
353
|
+
- **CUDA**: 12.0+ (11.8 minimum)
|
|
354
|
+
- **PyTorch**: 2.2+ for native support
|
|
355
|
+
|
|
356
|
+
**Not supported**: V100 (Volta), CPU inference
|
|
357
|
+
|
|
358
|
+
## Resources
|
|
359
|
+
|
|
360
|
+
- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
|
|
361
|
+
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
|
|
362
|
+
- Blog: https://tridao.me/blog/2024/flash3/
|
|
363
|
+
- GitHub: https://github.com/Dao-AILab/flash-attention
|
|
364
|
+
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
# Performance Benchmarks
|
|
2
|
+
|
|
3
|
+
## Contents
|
|
4
|
+
- Speed comparisons across GPUs
|
|
5
|
+
- Memory usage analysis
|
|
6
|
+
- Scaling with sequence length
|
|
7
|
+
- Training vs inference performance
|
|
8
|
+
- Flash Attention versions comparison
|
|
9
|
+
|
|
10
|
+
## Speed comparisons across GPUs
|
|
11
|
+
|
|
12
|
+
### A100 80GB (Ampere)
|
|
13
|
+
|
|
14
|
+
**Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
|
|
15
|
+
|
|
16
|
+
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
|
|
17
|
+
|------------|----------|--------------|--------------|---------------|
|
|
18
|
+
| 512 | 1.2 | 0.9 | N/A | 1.3x |
|
|
19
|
+
| 1024 | 3.8 | 1.4 | N/A | 2.7x |
|
|
20
|
+
| 2048 | 14.2 | 4.8 | N/A | 3.0x |
|
|
21
|
+
| 4096 | 55.1 | 17.3 | N/A | 3.2x |
|
|
22
|
+
| 8192 | 218.5 | 66.2 | N/A | 3.3x |
|
|
23
|
+
|
|
24
|
+
### H100 80GB (Hopper)
|
|
25
|
+
|
|
26
|
+
**Forward pass time** (milliseconds, same config):
|
|
27
|
+
|
|
28
|
+
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
|
|
29
|
+
|------------|----------|--------------|---------------------|--------------------|--------------|
|
|
30
|
+
| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
|
|
31
|
+
| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
|
|
32
|
+
| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
|
|
33
|
+
| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
|
|
34
|
+
| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
|
|
35
|
+
|
|
36
|
+
**Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
|
|
37
|
+
|
|
38
|
+
### A10G 24GB (Ampere)
|
|
39
|
+
|
|
40
|
+
**Forward pass time** (milliseconds, batch=4):
|
|
41
|
+
|
|
42
|
+
| Seq Length | Standard | Flash Attn 2 | Speedup |
|
|
43
|
+
|------------|----------|--------------|---------|
|
|
44
|
+
| 512 | 2.1 | 1.6 | 1.3x |
|
|
45
|
+
| 1024 | 6.8 | 2.8 | 2.4x |
|
|
46
|
+
| 2048 | 25.9 | 9.4 | 2.8x |
|
|
47
|
+
| 4096 | 102.1 | 35.2 | 2.9x |
|
|
48
|
+
|
|
49
|
+
## Memory usage analysis
|
|
50
|
+
|
|
51
|
+
### GPU memory consumption (batch=8, heads=32, dim=64)
|
|
52
|
+
|
|
53
|
+
**Standard attention memory**:
|
|
54
|
+
|
|
55
|
+
| Seq Length | Attention Matrix | KV Cache | Total | Notes |
|
|
56
|
+
|------------|------------------|----------|-------|-------|
|
|
57
|
+
| 512 | 8 MB | 32 MB | 40 MB | Manageable |
|
|
58
|
+
| 2048 | 128 MB | 128 MB | 256 MB | Growing |
|
|
59
|
+
| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
|
|
60
|
+
| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
|
|
61
|
+
|
|
62
|
+
**Flash Attention 2 memory**:
|
|
63
|
+
|
|
64
|
+
| Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
|
|
65
|
+
|------------|---------------------|----------|-------|-----------|
|
|
66
|
+
| 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
|
|
67
|
+
| 2048 | 0 MB | 128 MB | 128 MB | 50% |
|
|
68
|
+
| 8192 | 0 MB | 512 MB | 512 MB | 80% |
|
|
69
|
+
| 32768 | 0 MB | 2048 MB | 2 GB | 94% |
|
|
70
|
+
|
|
71
|
+
**Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
|
|
72
|
+
|
|
73
|
+
### Memory scaling comparison
|
|
74
|
+
|
|
75
|
+
**Llama 2 7B model memory** (float16, batch=1):
|
|
76
|
+
|
|
77
|
+
| Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
|
|
78
|
+
|----------------|-------------------|-------------------|-------------------|
|
|
79
|
+
| 2K | 3.2 GB | 2.1 GB | Both: Yes |
|
|
80
|
+
| 4K | 5.8 GB | 2.8 GB | Both: Yes |
|
|
81
|
+
| 8K | 12.1 GB | 4.2 GB | Both: Yes |
|
|
82
|
+
| 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
|
|
83
|
+
| 32K | OOM | 14.2 GB | Only Flash: Yes |
|
|
84
|
+
|
|
85
|
+
### Training memory (Llama 2 7B, batch=4)
|
|
86
|
+
|
|
87
|
+
| Context | Standard (GB) | Flash Attn (GB) | Reduction |
|
|
88
|
+
|---------|---------------|-----------------|-----------|
|
|
89
|
+
| 2K | 18.2 | 12.4 | 32% |
|
|
90
|
+
| 4K | 34.8 | 16.8 | 52% |
|
|
91
|
+
| 8K | OOM (>40GB) | 26.2 | Fits! |
|
|
92
|
+
|
|
93
|
+
## Scaling with sequence length
|
|
94
|
+
|
|
95
|
+
### Computational complexity
|
|
96
|
+
|
|
97
|
+
**Standard attention**:
|
|
98
|
+
- Time: O(N² × d)
|
|
99
|
+
- Memory: O(N² + N × d)
|
|
100
|
+
|
|
101
|
+
**Flash Attention**:
|
|
102
|
+
- Time: O(N² × d) (same, but with better constants)
|
|
103
|
+
- Memory: O(N × d) (linear!)
|
|
104
|
+
|
|
105
|
+
### Empirical scaling (A100, batch=1, heads=32, dim=64)
|
|
106
|
+
|
|
107
|
+
**Time per token (milliseconds)**:
|
|
108
|
+
|
|
109
|
+
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
|
110
|
+
|----------|-----|-----|-----|-----|-----|------|
|
|
111
|
+
| Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
|
|
112
|
+
| Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
|
|
113
|
+
| Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
|
|
114
|
+
|
|
115
|
+
**Observation**: Speedup increases quadratically with sequence length!
|
|
116
|
+
|
|
117
|
+
### Memory per token (MB)
|
|
118
|
+
|
|
119
|
+
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
|
120
|
+
|----------|-----|-----|-----|-----|-----|------|
|
|
121
|
+
| Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
|
|
122
|
+
| Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
|
|
123
|
+
|
|
124
|
+
**Observation**: Flash Attention memory per token is constant!
|
|
125
|
+
|
|
126
|
+
## Training vs inference performance
|
|
127
|
+
|
|
128
|
+
### Training (forward + backward, Llama 2 7B, A100)
|
|
129
|
+
|
|
130
|
+
| Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
|
131
|
+
|-------------|------------------------|--------------------------|---------|
|
|
132
|
+
| 4 × 2K | 1.2 | 3.1 | 2.6x |
|
|
133
|
+
| 8 × 2K | 2.1 | 5.8 | 2.8x |
|
|
134
|
+
| 4 × 4K | 0.4 | 1.3 | 3.3x |
|
|
135
|
+
| 8 × 4K | OOM | 2.4 | Enabled |
|
|
136
|
+
| 2 × 8K | 0.1 | 0.4 | 4.0x |
|
|
137
|
+
|
|
138
|
+
### Inference (generation, Llama 2 7B, A100)
|
|
139
|
+
|
|
140
|
+
| Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
|
141
|
+
|----------------|----------------------|-------------------------|---------|
|
|
142
|
+
| 512 | 48 | 52 | 1.1x |
|
|
143
|
+
| 2K | 42 | 62 | 1.5x |
|
|
144
|
+
| 4K | 31 | 58 | 1.9x |
|
|
145
|
+
| 8K | 18 | 51 | 2.8x |
|
|
146
|
+
| 16K | OOM | 42 | Enabled |
|
|
147
|
+
|
|
148
|
+
**Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
|
|
149
|
+
|
|
150
|
+
## Flash Attention versions comparison
|
|
151
|
+
|
|
152
|
+
### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
|
|
153
|
+
|
|
154
|
+
| Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
|
|
155
|
+
|--------|-----|-----|------------|-----------|
|
|
156
|
+
| Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
|
|
157
|
+
| Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
|
|
158
|
+
| TFLOPS | 180 | 420 | 740 | 1150 |
|
|
159
|
+
| GPU util % | 35% | 55% | 75% | 82% |
|
|
160
|
+
|
|
161
|
+
**Key improvements**:
|
|
162
|
+
- FA2: 2.3x faster than FA1 (better parallelism)
|
|
163
|
+
- FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
|
|
164
|
+
- FA3 (FP8): 2.6x faster than FA2 (low precision)
|
|
165
|
+
|
|
166
|
+
### Features by version
|
|
167
|
+
|
|
168
|
+
| Feature | FA1 | FA2 | FA3 |
|
|
169
|
+
|---------|-----|-----|-----|
|
|
170
|
+
| Basic attention | ✅ | ✅ | ✅ |
|
|
171
|
+
| Causal masking | ✅ | ✅ | ✅ |
|
|
172
|
+
| Multi-query attention | ❌ | ✅ | ✅ |
|
|
173
|
+
| Sliding window | ❌ | ✅ | ✅ |
|
|
174
|
+
| Paged KV cache | ❌ | ✅ | ✅ |
|
|
175
|
+
| FP8 support | ❌ | ❌ | ✅ (H100 only) |
|
|
176
|
+
| Work partitioning | Basic | Advanced | Optimal |
|
|
177
|
+
|
|
178
|
+
## Real-world model benchmarks
|
|
179
|
+
|
|
180
|
+
### Llama 2 models (A100 80GB, batch=4, seq=2048)
|
|
181
|
+
|
|
182
|
+
| Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
|
183
|
+
|-------|--------|------------------------|--------------------------|---------|
|
|
184
|
+
| Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
|
|
185
|
+
| Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
|
|
186
|
+
| Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
|
|
187
|
+
|
|
188
|
+
### GPT-style models (seq=1024)
|
|
189
|
+
|
|
190
|
+
| Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
|
191
|
+
|-------|----------------------|-------------------------|---------|
|
|
192
|
+
| GPT-2 (124M) | 520 | 680 | 1.3x |
|
|
193
|
+
| GPT-J (6B) | 42 | 98 | 2.3x |
|
|
194
|
+
| GPT-NeoX (20B) | 8 | 22 | 2.75x |
|
|
195
|
+
|
|
196
|
+
## Recommendations by use case
|
|
197
|
+
|
|
198
|
+
**Training large models (>7B parameters)**:
|
|
199
|
+
- Use Flash Attention 2 on A100
|
|
200
|
+
- Use Flash Attention 3 FP8 on H100 for maximum speed
|
|
201
|
+
- Expected: 2.5-3x speedup
|
|
202
|
+
|
|
203
|
+
**Long context inference (>4K tokens)**:
|
|
204
|
+
- Flash Attention essential (enables contexts standard attention can't handle)
|
|
205
|
+
- Expected: 2-4x speedup, 5-10x memory reduction
|
|
206
|
+
|
|
207
|
+
**Short sequences (<512 tokens)**:
|
|
208
|
+
- Flash Attention provides 1.2-1.5x speedup
|
|
209
|
+
- Minimal memory benefit
|
|
210
|
+
- Still worth enabling (no downside)
|
|
211
|
+
|
|
212
|
+
**Multi-user serving**:
|
|
213
|
+
- Flash Attention reduces per-request memory
|
|
214
|
+
- Allows higher concurrent batch sizes
|
|
215
|
+
- Can serve 2-3x more users on same hardware
|