EvoScientist 0.1.0rc1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. EvoScientist/EvoScientist.py +1 -1
  2. EvoScientist/cli.py +450 -178
  3. EvoScientist/middleware.py +5 -1
  4. EvoScientist/skills/accelerate/SKILL.md +332 -0
  5. EvoScientist/skills/accelerate/references/custom-plugins.md +453 -0
  6. EvoScientist/skills/accelerate/references/megatron-integration.md +489 -0
  7. EvoScientist/skills/accelerate/references/performance.md +525 -0
  8. EvoScientist/skills/bitsandbytes/SKILL.md +411 -0
  9. EvoScientist/skills/bitsandbytes/references/memory-optimization.md +521 -0
  10. EvoScientist/skills/bitsandbytes/references/qlora-training.md +521 -0
  11. EvoScientist/skills/bitsandbytes/references/quantization-formats.md +447 -0
  12. EvoScientist/skills/clip/SKILL.md +253 -0
  13. EvoScientist/skills/clip/references/applications.md +207 -0
  14. EvoScientist/skills/find-skills/SKILL.md +133 -0
  15. EvoScientist/skills/find-skills/scripts/install_skill.py +211 -0
  16. EvoScientist/skills/flash-attention/SKILL.md +367 -0
  17. EvoScientist/skills/flash-attention/references/benchmarks.md +215 -0
  18. EvoScientist/skills/flash-attention/references/transformers-integration.md +293 -0
  19. EvoScientist/skills/langgraph-docs/SKILL.md +36 -0
  20. EvoScientist/skills/llama-cpp/SKILL.md +258 -0
  21. EvoScientist/skills/llama-cpp/references/optimization.md +89 -0
  22. EvoScientist/skills/llama-cpp/references/quantization.md +213 -0
  23. EvoScientist/skills/llama-cpp/references/server.md +125 -0
  24. EvoScientist/skills/lm-evaluation-harness/SKILL.md +490 -0
  25. EvoScientist/skills/lm-evaluation-harness/references/api-evaluation.md +490 -0
  26. EvoScientist/skills/lm-evaluation-harness/references/benchmark-guide.md +488 -0
  27. EvoScientist/skills/lm-evaluation-harness/references/custom-tasks.md +602 -0
  28. EvoScientist/skills/lm-evaluation-harness/references/distributed-eval.md +519 -0
  29. EvoScientist/skills/ml-paper-writing/SKILL.md +937 -0
  30. EvoScientist/skills/ml-paper-writing/references/checklists.md +361 -0
  31. EvoScientist/skills/ml-paper-writing/references/citation-workflow.md +562 -0
  32. EvoScientist/skills/ml-paper-writing/references/reviewer-guidelines.md +367 -0
  33. EvoScientist/skills/ml-paper-writing/references/sources.md +159 -0
  34. EvoScientist/skills/ml-paper-writing/references/writing-guide.md +476 -0
  35. EvoScientist/skills/ml-paper-writing/templates/README.md +251 -0
  36. EvoScientist/skills/ml-paper-writing/templates/aaai2026/README.md +534 -0
  37. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026-unified-supp.tex +144 -0
  38. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026-unified-template.tex +952 -0
  39. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.bib +111 -0
  40. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.bst +1493 -0
  41. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.sty +315 -0
  42. EvoScientist/skills/ml-paper-writing/templates/acl/README.md +50 -0
  43. EvoScientist/skills/ml-paper-writing/templates/acl/acl.sty +312 -0
  44. EvoScientist/skills/ml-paper-writing/templates/acl/acl_latex.tex +377 -0
  45. EvoScientist/skills/ml-paper-writing/templates/acl/acl_lualatex.tex +101 -0
  46. EvoScientist/skills/ml-paper-writing/templates/acl/acl_natbib.bst +1940 -0
  47. EvoScientist/skills/ml-paper-writing/templates/acl/anthology.bib.txt +26 -0
  48. EvoScientist/skills/ml-paper-writing/templates/acl/custom.bib +70 -0
  49. EvoScientist/skills/ml-paper-writing/templates/acl/formatting.md +326 -0
  50. EvoScientist/skills/ml-paper-writing/templates/colm2025/README.md +3 -0
  51. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.bib +11 -0
  52. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.bst +1440 -0
  53. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.pdf +0 -0
  54. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.sty +218 -0
  55. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.tex +305 -0
  56. EvoScientist/skills/ml-paper-writing/templates/colm2025/fancyhdr.sty +485 -0
  57. EvoScientist/skills/ml-paper-writing/templates/colm2025/math_commands.tex +508 -0
  58. EvoScientist/skills/ml-paper-writing/templates/colm2025/natbib.sty +1246 -0
  59. EvoScientist/skills/ml-paper-writing/templates/iclr2026/fancyhdr.sty +485 -0
  60. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.bib +24 -0
  61. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.bst +1440 -0
  62. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.pdf +0 -0
  63. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.sty +246 -0
  64. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.tex +414 -0
  65. EvoScientist/skills/ml-paper-writing/templates/iclr2026/math_commands.tex +508 -0
  66. EvoScientist/skills/ml-paper-writing/templates/iclr2026/natbib.sty +1246 -0
  67. EvoScientist/skills/ml-paper-writing/templates/icml2026/algorithm.sty +79 -0
  68. EvoScientist/skills/ml-paper-writing/templates/icml2026/algorithmic.sty +201 -0
  69. EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.bib +75 -0
  70. EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.pdf +0 -0
  71. EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.tex +662 -0
  72. EvoScientist/skills/ml-paper-writing/templates/icml2026/fancyhdr.sty +864 -0
  73. EvoScientist/skills/ml-paper-writing/templates/icml2026/icml2026.bst +1443 -0
  74. EvoScientist/skills/ml-paper-writing/templates/icml2026/icml2026.sty +767 -0
  75. EvoScientist/skills/ml-paper-writing/templates/icml2026/icml_numpapers.pdf +0 -0
  76. EvoScientist/skills/ml-paper-writing/templates/neurips2025/Makefile +36 -0
  77. EvoScientist/skills/ml-paper-writing/templates/neurips2025/extra_pkgs.tex +53 -0
  78. EvoScientist/skills/ml-paper-writing/templates/neurips2025/main.tex +38 -0
  79. EvoScientist/skills/ml-paper-writing/templates/neurips2025/neurips.sty +382 -0
  80. EvoScientist/skills/peft/SKILL.md +431 -0
  81. EvoScientist/skills/peft/references/advanced-usage.md +514 -0
  82. EvoScientist/skills/peft/references/troubleshooting.md +480 -0
  83. EvoScientist/skills/ray-data/SKILL.md +326 -0
  84. EvoScientist/skills/ray-data/references/integration.md +82 -0
  85. EvoScientist/skills/ray-data/references/transformations.md +83 -0
  86. EvoScientist/skills/skill-creator/LICENSE.txt +202 -0
  87. EvoScientist/skills/skill-creator/SKILL.md +356 -0
  88. EvoScientist/skills/skill-creator/references/output-patterns.md +82 -0
  89. EvoScientist/skills/skill-creator/references/workflows.md +28 -0
  90. EvoScientist/skills/skill-creator/scripts/init_skill.py +303 -0
  91. EvoScientist/skills/skill-creator/scripts/package_skill.py +110 -0
  92. EvoScientist/skills/skill-creator/scripts/quick_validate.py +95 -0
  93. EvoScientist/skills/tensorboard/SKILL.md +629 -0
  94. EvoScientist/skills/tensorboard/references/integrations.md +638 -0
  95. EvoScientist/skills/tensorboard/references/profiling.md +545 -0
  96. EvoScientist/skills/tensorboard/references/visualization.md +620 -0
  97. EvoScientist/skills/vllm/SKILL.md +364 -0
  98. EvoScientist/skills/vllm/references/optimization.md +226 -0
  99. EvoScientist/skills/vllm/references/quantization.md +284 -0
  100. EvoScientist/skills/vllm/references/server-deployment.md +255 -0
  101. EvoScientist/skills/vllm/references/troubleshooting.md +447 -0
  102. {evoscientist-0.1.0rc1.dist-info → evoscientist-0.1.0rc2.dist-info}/METADATA +26 -3
  103. evoscientist-0.1.0rc2.dist-info/RECORD +119 -0
  104. evoscientist-0.1.0rc1.dist-info/RECORD +0 -21
  105. {evoscientist-0.1.0rc1.dist-info → evoscientist-0.1.0rc2.dist-info}/WHEEL +0 -0
  106. {evoscientist-0.1.0rc1.dist-info → evoscientist-0.1.0rc2.dist-info}/entry_points.txt +0 -0
  107. {evoscientist-0.1.0rc1.dist-info → evoscientist-0.1.0rc2.dist-info}/licenses/LICENSE +0 -0
  108. {evoscientist-0.1.0rc1.dist-info → evoscientist-0.1.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,367 @@
1
+ ---
2
+ name: flash-attention
3
+ description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
4
+ version: 1.0.0
5
+ author: Orchestra Research
6
+ license: MIT
7
+ tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
8
+ dependencies: [flash-attn, torch, transformers]
9
+ ---
10
+
11
+ # Flash Attention - Fast Memory-Efficient Attention
12
+
13
+ ## Quick start
14
+
15
+ Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
16
+
17
+ **PyTorch native (easiest, PyTorch 2.2+)**:
18
+ ```python
19
+ import torch
20
+ import torch.nn.functional as F
21
+
22
+ q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
23
+ k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
24
+ v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
25
+
26
+ # Automatically uses Flash Attention if available
27
+ out = F.scaled_dot_product_attention(q, k, v)
28
+ ```
29
+
30
+ **flash-attn library (more features)**:
31
+ ```bash
32
+ pip install flash-attn --no-build-isolation
33
+ ```
34
+
35
+ ```python
36
+ from flash_attn import flash_attn_func
37
+
38
+ # q, k, v: [batch, seqlen, nheads, headdim]
39
+ out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
40
+ ```
41
+
42
+ ## Common workflows
43
+
44
+ ### Workflow 1: Enable in existing PyTorch model
45
+
46
+ Copy this checklist:
47
+
48
+ ```
49
+ Flash Attention Integration:
50
+ - [ ] Step 1: Check PyTorch version (≥2.2)
51
+ - [ ] Step 2: Enable Flash Attention backend
52
+ - [ ] Step 3: Verify speedup with profiling
53
+ - [ ] Step 4: Test accuracy matches baseline
54
+ ```
55
+
56
+ **Step 1: Check PyTorch version**
57
+
58
+ ```bash
59
+ python -c "import torch; print(torch.__version__)"
60
+ # Should be ≥2.2.0
61
+ ```
62
+
63
+ If <2.2, upgrade:
64
+ ```bash
65
+ pip install --upgrade torch
66
+ ```
67
+
68
+ **Step 2: Enable Flash Attention backend**
69
+
70
+ Replace standard attention:
71
+ ```python
72
+ # Before (standard attention)
73
+ attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
74
+ out = attn_weights @ v
75
+
76
+ # After (Flash Attention)
77
+ import torch.nn.functional as F
78
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
79
+ ```
80
+
81
+ Force Flash Attention backend:
82
+ ```python
83
+ with torch.backends.cuda.sdp_kernel(
84
+ enable_flash=True,
85
+ enable_math=False,
86
+ enable_mem_efficient=False
87
+ ):
88
+ out = F.scaled_dot_product_attention(q, k, v)
89
+ ```
90
+
91
+ **Step 3: Verify speedup with profiling**
92
+
93
+ ```python
94
+ import torch.utils.benchmark as benchmark
95
+
96
+ def test_attention(use_flash):
97
+ q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
98
+
99
+ if use_flash:
100
+ with torch.backends.cuda.sdp_kernel(enable_flash=True):
101
+ return F.scaled_dot_product_attention(q, k, v)
102
+ else:
103
+ attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
104
+ return attn @ v
105
+
106
+ # Benchmark
107
+ t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
108
+ t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
109
+
110
+ print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
111
+ print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
112
+ ```
113
+
114
+ Expected: 2-4x speedup for sequences >512 tokens.
115
+
116
+ **Step 4: Test accuracy matches baseline**
117
+
118
+ ```python
119
+ # Compare outputs
120
+ q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
121
+
122
+ # Flash Attention
123
+ out_flash = F.scaled_dot_product_attention(q, k, v)
124
+
125
+ # Standard attention
126
+ attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
127
+ out_standard = attn_weights @ v
128
+
129
+ # Check difference
130
+ diff = (out_flash - out_standard).abs().max()
131
+ print(f"Max difference: {diff:.6f}")
132
+ # Should be <1e-3 for float16
133
+ ```
134
+
135
+ ### Workflow 2: Use flash-attn library for advanced features
136
+
137
+ For multi-query attention, sliding window, or H100 FP8.
138
+
139
+ Copy this checklist:
140
+
141
+ ```
142
+ flash-attn Library Setup:
143
+ - [ ] Step 1: Install flash-attn library
144
+ - [ ] Step 2: Modify attention code
145
+ - [ ] Step 3: Enable advanced features
146
+ - [ ] Step 4: Benchmark performance
147
+ ```
148
+
149
+ **Step 1: Install flash-attn library**
150
+
151
+ ```bash
152
+ # NVIDIA GPUs (CUDA 12.0+)
153
+ pip install flash-attn --no-build-isolation
154
+
155
+ # Verify installation
156
+ python -c "from flash_attn import flash_attn_func; print('Success')"
157
+ ```
158
+
159
+ **Step 2: Modify attention code**
160
+
161
+ ```python
162
+ from flash_attn import flash_attn_func
163
+
164
+ # Input: [batch_size, seq_len, num_heads, head_dim]
165
+ # Transpose from [batch, heads, seq, dim] if needed
166
+ q = q.transpose(1, 2) # [batch, seq, heads, dim]
167
+ k = k.transpose(1, 2)
168
+ v = v.transpose(1, 2)
169
+
170
+ out = flash_attn_func(
171
+ q, k, v,
172
+ dropout_p=0.1,
173
+ causal=True, # For autoregressive models
174
+ window_size=(-1, -1), # No sliding window
175
+ softmax_scale=None # Auto-scale
176
+ )
177
+
178
+ out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
179
+ ```
180
+
181
+ **Step 3: Enable advanced features**
182
+
183
+ Multi-query attention (shared K/V across heads):
184
+ ```python
185
+ from flash_attn import flash_attn_func
186
+
187
+ # q: [batch, seq, num_q_heads, dim]
188
+ # k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
189
+ out = flash_attn_func(q, k, v) # Automatically handles MQA
190
+ ```
191
+
192
+ Sliding window attention (local attention):
193
+ ```python
194
+ # Only attend to window of 256 tokens before/after
195
+ out = flash_attn_func(
196
+ q, k, v,
197
+ window_size=(256, 256), # (left, right) window
198
+ causal=True
199
+ )
200
+ ```
201
+
202
+ **Step 4: Benchmark performance**
203
+
204
+ ```python
205
+ import torch
206
+ from flash_attn import flash_attn_func
207
+ import time
208
+
209
+ q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
210
+
211
+ # Warmup
212
+ for _ in range(10):
213
+ _ = flash_attn_func(q, k, v)
214
+
215
+ # Benchmark
216
+ torch.cuda.synchronize()
217
+ start = time.time()
218
+ for _ in range(100):
219
+ out = flash_attn_func(q, k, v)
220
+ torch.cuda.synchronize()
221
+ end = time.time()
222
+
223
+ print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
224
+ print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
225
+ ```
226
+
227
+ ### Workflow 3: H100 FP8 optimization (FlashAttention-3)
228
+
229
+ For maximum performance on H100 GPUs.
230
+
231
+ ```
232
+ FP8 Setup:
233
+ - [ ] Step 1: Verify H100 GPU available
234
+ - [ ] Step 2: Install flash-attn with FP8 support
235
+ - [ ] Step 3: Convert inputs to FP8
236
+ - [ ] Step 4: Run with FP8 attention
237
+ ```
238
+
239
+ **Step 1: Verify H100 GPU**
240
+
241
+ ```bash
242
+ nvidia-smi --query-gpu=name --format=csv
243
+ # Should show "H100" or "H800"
244
+ ```
245
+
246
+ **Step 2: Install flash-attn with FP8 support**
247
+
248
+ ```bash
249
+ pip install flash-attn --no-build-isolation
250
+ # FP8 support included for H100
251
+ ```
252
+
253
+ **Step 3: Convert inputs to FP8**
254
+
255
+ ```python
256
+ import torch
257
+
258
+ q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
259
+ k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
260
+ v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
261
+
262
+ # Convert to float8_e4m3 (FP8)
263
+ q_fp8 = q.to(torch.float8_e4m3fn)
264
+ k_fp8 = k.to(torch.float8_e4m3fn)
265
+ v_fp8 = v.to(torch.float8_e4m3fn)
266
+ ```
267
+
268
+ **Step 4: Run with FP8 attention**
269
+
270
+ ```python
271
+ from flash_attn import flash_attn_func
272
+
273
+ # FlashAttention-3 automatically uses FP8 kernels on H100
274
+ out = flash_attn_func(q_fp8, k_fp8, v_fp8)
275
+ # Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
276
+ ```
277
+
278
+ ## When to use vs alternatives
279
+
280
+ **Use Flash Attention when:**
281
+ - Training transformers with sequences >512 tokens
282
+ - Running inference with long context (>2K tokens)
283
+ - GPU memory constrained (OOM with standard attention)
284
+ - Need 2-4x speedup without accuracy loss
285
+ - Using PyTorch 2.2+ or can install flash-attn
286
+
287
+ **Use alternatives instead:**
288
+ - **Standard attention**: Sequences <256 tokens (overhead not worth it)
289
+ - **xFormers**: Need more attention variants (not just speed)
290
+ - **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
291
+
292
+ ## Common issues
293
+
294
+ **Issue: ImportError: cannot import flash_attn**
295
+
296
+ Install with no-build-isolation flag:
297
+ ```bash
298
+ pip install flash-attn --no-build-isolation
299
+ ```
300
+
301
+ Or install CUDA toolkit first:
302
+ ```bash
303
+ conda install cuda -c nvidia
304
+ pip install flash-attn --no-build-isolation
305
+ ```
306
+
307
+ **Issue: Slower than expected (no speedup)**
308
+
309
+ Flash Attention benefits increase with sequence length:
310
+ - <512 tokens: Minimal speedup (10-20%)
311
+ - 512-2K tokens: 2-3x speedup
312
+ - >2K tokens: 3-4x speedup
313
+
314
+ Check sequence length is sufficient.
315
+
316
+ **Issue: RuntimeError: CUDA error**
317
+
318
+ Verify GPU supports Flash Attention:
319
+ ```python
320
+ import torch
321
+ print(torch.cuda.get_device_capability())
322
+ # Should be ≥(7, 5) for Turing+
323
+ ```
324
+
325
+ Flash Attention requires:
326
+ - Ampere (A100, A10): ✅ Full support
327
+ - Turing (T4): ✅ Supported
328
+ - Volta (V100): ❌ Not supported
329
+
330
+ **Issue: Accuracy degradation**
331
+
332
+ Check dtype is float16 or bfloat16 (not float32):
333
+ ```python
334
+ q = q.to(torch.float16) # Or torch.bfloat16
335
+ ```
336
+
337
+ Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
338
+
339
+ ## Advanced topics
340
+
341
+ **Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
342
+
343
+ **Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
344
+
345
+ **Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
346
+
347
+ **Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
348
+
349
+ ## Hardware requirements
350
+
351
+ - **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
352
+ - **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
353
+ - **CUDA**: 12.0+ (11.8 minimum)
354
+ - **PyTorch**: 2.2+ for native support
355
+
356
+ **Not supported**: V100 (Volta), CPU inference
357
+
358
+ ## Resources
359
+
360
+ - Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
361
+ - Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
362
+ - Blog: https://tridao.me/blog/2024/flash3/
363
+ - GitHub: https://github.com/Dao-AILab/flash-attention
364
+ - PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
365
+
366
+
367
+
@@ -0,0 +1,215 @@
1
+ # Performance Benchmarks
2
+
3
+ ## Contents
4
+ - Speed comparisons across GPUs
5
+ - Memory usage analysis
6
+ - Scaling with sequence length
7
+ - Training vs inference performance
8
+ - Flash Attention versions comparison
9
+
10
+ ## Speed comparisons across GPUs
11
+
12
+ ### A100 80GB (Ampere)
13
+
14
+ **Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
15
+
16
+ | Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
17
+ |------------|----------|--------------|--------------|---------------|
18
+ | 512 | 1.2 | 0.9 | N/A | 1.3x |
19
+ | 1024 | 3.8 | 1.4 | N/A | 2.7x |
20
+ | 2048 | 14.2 | 4.8 | N/A | 3.0x |
21
+ | 4096 | 55.1 | 17.3 | N/A | 3.2x |
22
+ | 8192 | 218.5 | 66.2 | N/A | 3.3x |
23
+
24
+ ### H100 80GB (Hopper)
25
+
26
+ **Forward pass time** (milliseconds, same config):
27
+
28
+ | Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
29
+ |------------|----------|--------------|---------------------|--------------------|--------------|
30
+ | 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
31
+ | 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
32
+ | 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
33
+ | 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
34
+ | 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
35
+
36
+ **Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
37
+
38
+ ### A10G 24GB (Ampere)
39
+
40
+ **Forward pass time** (milliseconds, batch=4):
41
+
42
+ | Seq Length | Standard | Flash Attn 2 | Speedup |
43
+ |------------|----------|--------------|---------|
44
+ | 512 | 2.1 | 1.6 | 1.3x |
45
+ | 1024 | 6.8 | 2.8 | 2.4x |
46
+ | 2048 | 25.9 | 9.4 | 2.8x |
47
+ | 4096 | 102.1 | 35.2 | 2.9x |
48
+
49
+ ## Memory usage analysis
50
+
51
+ ### GPU memory consumption (batch=8, heads=32, dim=64)
52
+
53
+ **Standard attention memory**:
54
+
55
+ | Seq Length | Attention Matrix | KV Cache | Total | Notes |
56
+ |------------|------------------|----------|-------|-------|
57
+ | 512 | 8 MB | 32 MB | 40 MB | Manageable |
58
+ | 2048 | 128 MB | 128 MB | 256 MB | Growing |
59
+ | 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
60
+ | 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
61
+
62
+ **Flash Attention 2 memory**:
63
+
64
+ | Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
65
+ |------------|---------------------|----------|-------|-----------|
66
+ | 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
67
+ | 2048 | 0 MB | 128 MB | 128 MB | 50% |
68
+ | 8192 | 0 MB | 512 MB | 512 MB | 80% |
69
+ | 32768 | 0 MB | 2048 MB | 2 GB | 94% |
70
+
71
+ **Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
72
+
73
+ ### Memory scaling comparison
74
+
75
+ **Llama 2 7B model memory** (float16, batch=1):
76
+
77
+ | Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
78
+ |----------------|-------------------|-------------------|-------------------|
79
+ | 2K | 3.2 GB | 2.1 GB | Both: Yes |
80
+ | 4K | 5.8 GB | 2.8 GB | Both: Yes |
81
+ | 8K | 12.1 GB | 4.2 GB | Both: Yes |
82
+ | 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
83
+ | 32K | OOM | 14.2 GB | Only Flash: Yes |
84
+
85
+ ### Training memory (Llama 2 7B, batch=4)
86
+
87
+ | Context | Standard (GB) | Flash Attn (GB) | Reduction |
88
+ |---------|---------------|-----------------|-----------|
89
+ | 2K | 18.2 | 12.4 | 32% |
90
+ | 4K | 34.8 | 16.8 | 52% |
91
+ | 8K | OOM (>40GB) | 26.2 | Fits! |
92
+
93
+ ## Scaling with sequence length
94
+
95
+ ### Computational complexity
96
+
97
+ **Standard attention**:
98
+ - Time: O(N² × d)
99
+ - Memory: O(N² + N × d)
100
+
101
+ **Flash Attention**:
102
+ - Time: O(N² × d) (same, but with better constants)
103
+ - Memory: O(N × d) (linear!)
104
+
105
+ ### Empirical scaling (A100, batch=1, heads=32, dim=64)
106
+
107
+ **Time per token (milliseconds)**:
108
+
109
+ | Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
110
+ |----------|-----|-----|-----|-----|-----|------|
111
+ | Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
112
+ | Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
113
+ | Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
114
+
115
+ **Observation**: Speedup increases quadratically with sequence length!
116
+
117
+ ### Memory per token (MB)
118
+
119
+ | Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
120
+ |----------|-----|-----|-----|-----|-----|------|
121
+ | Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
122
+ | Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
123
+
124
+ **Observation**: Flash Attention memory per token is constant!
125
+
126
+ ## Training vs inference performance
127
+
128
+ ### Training (forward + backward, Llama 2 7B, A100)
129
+
130
+ | Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
131
+ |-------------|------------------------|--------------------------|---------|
132
+ | 4 × 2K | 1.2 | 3.1 | 2.6x |
133
+ | 8 × 2K | 2.1 | 5.8 | 2.8x |
134
+ | 4 × 4K | 0.4 | 1.3 | 3.3x |
135
+ | 8 × 4K | OOM | 2.4 | Enabled |
136
+ | 2 × 8K | 0.1 | 0.4 | 4.0x |
137
+
138
+ ### Inference (generation, Llama 2 7B, A100)
139
+
140
+ | Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
141
+ |----------------|----------------------|-------------------------|---------|
142
+ | 512 | 48 | 52 | 1.1x |
143
+ | 2K | 42 | 62 | 1.5x |
144
+ | 4K | 31 | 58 | 1.9x |
145
+ | 8K | 18 | 51 | 2.8x |
146
+ | 16K | OOM | 42 | Enabled |
147
+
148
+ **Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
149
+
150
+ ## Flash Attention versions comparison
151
+
152
+ ### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
153
+
154
+ | Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
155
+ |--------|-----|-----|------------|-----------|
156
+ | Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
157
+ | Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
158
+ | TFLOPS | 180 | 420 | 740 | 1150 |
159
+ | GPU util % | 35% | 55% | 75% | 82% |
160
+
161
+ **Key improvements**:
162
+ - FA2: 2.3x faster than FA1 (better parallelism)
163
+ - FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
164
+ - FA3 (FP8): 2.6x faster than FA2 (low precision)
165
+
166
+ ### Features by version
167
+
168
+ | Feature | FA1 | FA2 | FA3 |
169
+ |---------|-----|-----|-----|
170
+ | Basic attention | ✅ | ✅ | ✅ |
171
+ | Causal masking | ✅ | ✅ | ✅ |
172
+ | Multi-query attention | ❌ | ✅ | ✅ |
173
+ | Sliding window | ❌ | ✅ | ✅ |
174
+ | Paged KV cache | ❌ | ✅ | ✅ |
175
+ | FP8 support | ❌ | ❌ | ✅ (H100 only) |
176
+ | Work partitioning | Basic | Advanced | Optimal |
177
+
178
+ ## Real-world model benchmarks
179
+
180
+ ### Llama 2 models (A100 80GB, batch=4, seq=2048)
181
+
182
+ | Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
183
+ |-------|--------|------------------------|--------------------------|---------|
184
+ | Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
185
+ | Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
186
+ | Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
187
+
188
+ ### GPT-style models (seq=1024)
189
+
190
+ | Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
191
+ |-------|----------------------|-------------------------|---------|
192
+ | GPT-2 (124M) | 520 | 680 | 1.3x |
193
+ | GPT-J (6B) | 42 | 98 | 2.3x |
194
+ | GPT-NeoX (20B) | 8 | 22 | 2.75x |
195
+
196
+ ## Recommendations by use case
197
+
198
+ **Training large models (>7B parameters)**:
199
+ - Use Flash Attention 2 on A100
200
+ - Use Flash Attention 3 FP8 on H100 for maximum speed
201
+ - Expected: 2.5-3x speedup
202
+
203
+ **Long context inference (>4K tokens)**:
204
+ - Flash Attention essential (enables contexts standard attention can't handle)
205
+ - Expected: 2-4x speedup, 5-10x memory reduction
206
+
207
+ **Short sequences (<512 tokens)**:
208
+ - Flash Attention provides 1.2-1.5x speedup
209
+ - Minimal memory benefit
210
+ - Still worth enabling (no downside)
211
+
212
+ **Multi-user serving**:
213
+ - Flash Attention reduces per-request memory
214
+ - Allows higher concurrent batch sizes
215
+ - Can serve 2-3x more users on same hardware