EvoScientist 0.0.1.dev4__py3-none-any.whl → 0.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- EvoScientist/EvoScientist.py +26 -62
- EvoScientist/__init__.py +0 -19
- EvoScientist/backends.py +0 -26
- EvoScientist/cli.py +1111 -498
- EvoScientist/middleware.py +8 -61
- EvoScientist/stream/__init__.py +0 -25
- EvoScientist/stream/utils.py +16 -23
- EvoScientist/tools.py +2 -75
- evoscientist-0.1.0rc1.dist-info/METADATA +199 -0
- evoscientist-0.1.0rc1.dist-info/RECORD +21 -0
- evoscientist-0.1.0rc1.dist-info/entry_points.txt +2 -0
- EvoScientist/config.py +0 -274
- EvoScientist/llm/__init__.py +0 -21
- EvoScientist/llm/models.py +0 -99
- EvoScientist/memory.py +0 -715
- EvoScientist/onboard.py +0 -725
- EvoScientist/paths.py +0 -44
- EvoScientist/skills/accelerate/SKILL.md +0 -332
- EvoScientist/skills/accelerate/references/custom-plugins.md +0 -453
- EvoScientist/skills/accelerate/references/megatron-integration.md +0 -489
- EvoScientist/skills/accelerate/references/performance.md +0 -525
- EvoScientist/skills/bitsandbytes/SKILL.md +0 -411
- EvoScientist/skills/bitsandbytes/references/memory-optimization.md +0 -521
- EvoScientist/skills/bitsandbytes/references/qlora-training.md +0 -521
- EvoScientist/skills/bitsandbytes/references/quantization-formats.md +0 -447
- EvoScientist/skills/find-skills/SKILL.md +0 -133
- EvoScientist/skills/find-skills/scripts/install_skill.py +0 -211
- EvoScientist/skills/flash-attention/SKILL.md +0 -367
- EvoScientist/skills/flash-attention/references/benchmarks.md +0 -215
- EvoScientist/skills/flash-attention/references/transformers-integration.md +0 -293
- EvoScientist/skills/llama-cpp/SKILL.md +0 -258
- EvoScientist/skills/llama-cpp/references/optimization.md +0 -89
- EvoScientist/skills/llama-cpp/references/quantization.md +0 -213
- EvoScientist/skills/llama-cpp/references/server.md +0 -125
- EvoScientist/skills/lm-evaluation-harness/SKILL.md +0 -490
- EvoScientist/skills/lm-evaluation-harness/references/api-evaluation.md +0 -490
- EvoScientist/skills/lm-evaluation-harness/references/benchmark-guide.md +0 -488
- EvoScientist/skills/lm-evaluation-harness/references/custom-tasks.md +0 -602
- EvoScientist/skills/lm-evaluation-harness/references/distributed-eval.md +0 -519
- EvoScientist/skills/ml-paper-writing/SKILL.md +0 -937
- EvoScientist/skills/ml-paper-writing/references/checklists.md +0 -361
- EvoScientist/skills/ml-paper-writing/references/citation-workflow.md +0 -562
- EvoScientist/skills/ml-paper-writing/references/reviewer-guidelines.md +0 -367
- EvoScientist/skills/ml-paper-writing/references/sources.md +0 -159
- EvoScientist/skills/ml-paper-writing/references/writing-guide.md +0 -476
- EvoScientist/skills/ml-paper-writing/templates/README.md +0 -251
- EvoScientist/skills/ml-paper-writing/templates/aaai2026/README.md +0 -534
- EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026-unified-supp.tex +0 -144
- EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026-unified-template.tex +0 -952
- EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.bib +0 -111
- EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.bst +0 -1493
- EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.sty +0 -315
- EvoScientist/skills/ml-paper-writing/templates/acl/README.md +0 -50
- EvoScientist/skills/ml-paper-writing/templates/acl/acl.sty +0 -312
- EvoScientist/skills/ml-paper-writing/templates/acl/acl_latex.tex +0 -377
- EvoScientist/skills/ml-paper-writing/templates/acl/acl_lualatex.tex +0 -101
- EvoScientist/skills/ml-paper-writing/templates/acl/acl_natbib.bst +0 -1940
- EvoScientist/skills/ml-paper-writing/templates/acl/anthology.bib.txt +0 -26
- EvoScientist/skills/ml-paper-writing/templates/acl/custom.bib +0 -70
- EvoScientist/skills/ml-paper-writing/templates/acl/formatting.md +0 -326
- EvoScientist/skills/ml-paper-writing/templates/colm2025/README.md +0 -3
- EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.bib +0 -11
- EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.bst +0 -1440
- EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.pdf +0 -0
- EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.sty +0 -218
- EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.tex +0 -305
- EvoScientist/skills/ml-paper-writing/templates/colm2025/fancyhdr.sty +0 -485
- EvoScientist/skills/ml-paper-writing/templates/colm2025/math_commands.tex +0 -508
- EvoScientist/skills/ml-paper-writing/templates/colm2025/natbib.sty +0 -1246
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/fancyhdr.sty +0 -485
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.bib +0 -24
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.bst +0 -1440
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.pdf +0 -0
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.sty +0 -246
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.tex +0 -414
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/math_commands.tex +0 -508
- EvoScientist/skills/ml-paper-writing/templates/iclr2026/natbib.sty +0 -1246
- EvoScientist/skills/ml-paper-writing/templates/icml2026/algorithm.sty +0 -79
- EvoScientist/skills/ml-paper-writing/templates/icml2026/algorithmic.sty +0 -201
- EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.bib +0 -75
- EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.pdf +0 -0
- EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.tex +0 -662
- EvoScientist/skills/ml-paper-writing/templates/icml2026/fancyhdr.sty +0 -864
- EvoScientist/skills/ml-paper-writing/templates/icml2026/icml2026.bst +0 -1443
- EvoScientist/skills/ml-paper-writing/templates/icml2026/icml2026.sty +0 -767
- EvoScientist/skills/ml-paper-writing/templates/icml2026/icml_numpapers.pdf +0 -0
- EvoScientist/skills/ml-paper-writing/templates/neurips2025/Makefile +0 -36
- EvoScientist/skills/ml-paper-writing/templates/neurips2025/extra_pkgs.tex +0 -53
- EvoScientist/skills/ml-paper-writing/templates/neurips2025/main.tex +0 -38
- EvoScientist/skills/ml-paper-writing/templates/neurips2025/neurips.sty +0 -382
- EvoScientist/skills/peft/SKILL.md +0 -431
- EvoScientist/skills/peft/references/advanced-usage.md +0 -514
- EvoScientist/skills/peft/references/troubleshooting.md +0 -480
- EvoScientist/skills/ray-data/SKILL.md +0 -326
- EvoScientist/skills/ray-data/references/integration.md +0 -82
- EvoScientist/skills/ray-data/references/transformations.md +0 -83
- EvoScientist/skills/skill-creator/LICENSE.txt +0 -202
- EvoScientist/skills/skill-creator/SKILL.md +0 -356
- EvoScientist/skills/skill-creator/references/output-patterns.md +0 -82
- EvoScientist/skills/skill-creator/references/workflows.md +0 -28
- EvoScientist/skills/skill-creator/scripts/init_skill.py +0 -303
- EvoScientist/skills/skill-creator/scripts/package_skill.py +0 -110
- EvoScientist/skills/skill-creator/scripts/quick_validate.py +0 -95
- EvoScientist/skills_manager.py +0 -391
- EvoScientist/stream/display.py +0 -604
- EvoScientist/stream/events.py +0 -415
- EvoScientist/stream/state.py +0 -343
- evoscientist-0.0.1.dev4.dist-info/METADATA +0 -367
- evoscientist-0.0.1.dev4.dist-info/RECORD +0 -117
- evoscientist-0.0.1.dev4.dist-info/entry_points.txt +0 -5
- {evoscientist-0.0.1.dev4.dist-info → evoscientist-0.1.0rc1.dist-info}/WHEEL +0 -0
- {evoscientist-0.0.1.dev4.dist-info → evoscientist-0.1.0rc1.dist-info}/licenses/LICENSE +0 -0
- {evoscientist-0.0.1.dev4.dist-info → evoscientist-0.1.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,215 +0,0 @@
|
|
|
1
|
-
# Performance Benchmarks
|
|
2
|
-
|
|
3
|
-
## Contents
|
|
4
|
-
- Speed comparisons across GPUs
|
|
5
|
-
- Memory usage analysis
|
|
6
|
-
- Scaling with sequence length
|
|
7
|
-
- Training vs inference performance
|
|
8
|
-
- Flash Attention versions comparison
|
|
9
|
-
|
|
10
|
-
## Speed comparisons across GPUs
|
|
11
|
-
|
|
12
|
-
### A100 80GB (Ampere)
|
|
13
|
-
|
|
14
|
-
**Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
|
|
15
|
-
|
|
16
|
-
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
|
|
17
|
-
|------------|----------|--------------|--------------|---------------|
|
|
18
|
-
| 512 | 1.2 | 0.9 | N/A | 1.3x |
|
|
19
|
-
| 1024 | 3.8 | 1.4 | N/A | 2.7x |
|
|
20
|
-
| 2048 | 14.2 | 4.8 | N/A | 3.0x |
|
|
21
|
-
| 4096 | 55.1 | 17.3 | N/A | 3.2x |
|
|
22
|
-
| 8192 | 218.5 | 66.2 | N/A | 3.3x |
|
|
23
|
-
|
|
24
|
-
### H100 80GB (Hopper)
|
|
25
|
-
|
|
26
|
-
**Forward pass time** (milliseconds, same config):
|
|
27
|
-
|
|
28
|
-
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
|
|
29
|
-
|------------|----------|--------------|---------------------|--------------------|--------------|
|
|
30
|
-
| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
|
|
31
|
-
| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
|
|
32
|
-
| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
|
|
33
|
-
| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
|
|
34
|
-
| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
|
|
35
|
-
|
|
36
|
-
**Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
|
|
37
|
-
|
|
38
|
-
### A10G 24GB (Ampere)
|
|
39
|
-
|
|
40
|
-
**Forward pass time** (milliseconds, batch=4):
|
|
41
|
-
|
|
42
|
-
| Seq Length | Standard | Flash Attn 2 | Speedup |
|
|
43
|
-
|------------|----------|--------------|---------|
|
|
44
|
-
| 512 | 2.1 | 1.6 | 1.3x |
|
|
45
|
-
| 1024 | 6.8 | 2.8 | 2.4x |
|
|
46
|
-
| 2048 | 25.9 | 9.4 | 2.8x |
|
|
47
|
-
| 4096 | 102.1 | 35.2 | 2.9x |
|
|
48
|
-
|
|
49
|
-
## Memory usage analysis
|
|
50
|
-
|
|
51
|
-
### GPU memory consumption (batch=8, heads=32, dim=64)
|
|
52
|
-
|
|
53
|
-
**Standard attention memory**:
|
|
54
|
-
|
|
55
|
-
| Seq Length | Attention Matrix | KV Cache | Total | Notes |
|
|
56
|
-
|------------|------------------|----------|-------|-------|
|
|
57
|
-
| 512 | 8 MB | 32 MB | 40 MB | Manageable |
|
|
58
|
-
| 2048 | 128 MB | 128 MB | 256 MB | Growing |
|
|
59
|
-
| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
|
|
60
|
-
| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
|
|
61
|
-
|
|
62
|
-
**Flash Attention 2 memory**:
|
|
63
|
-
|
|
64
|
-
| Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
|
|
65
|
-
|------------|---------------------|----------|-------|-----------|
|
|
66
|
-
| 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
|
|
67
|
-
| 2048 | 0 MB | 128 MB | 128 MB | 50% |
|
|
68
|
-
| 8192 | 0 MB | 512 MB | 512 MB | 80% |
|
|
69
|
-
| 32768 | 0 MB | 2048 MB | 2 GB | 94% |
|
|
70
|
-
|
|
71
|
-
**Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
|
|
72
|
-
|
|
73
|
-
### Memory scaling comparison
|
|
74
|
-
|
|
75
|
-
**Llama 2 7B model memory** (float16, batch=1):
|
|
76
|
-
|
|
77
|
-
| Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
|
|
78
|
-
|----------------|-------------------|-------------------|-------------------|
|
|
79
|
-
| 2K | 3.2 GB | 2.1 GB | Both: Yes |
|
|
80
|
-
| 4K | 5.8 GB | 2.8 GB | Both: Yes |
|
|
81
|
-
| 8K | 12.1 GB | 4.2 GB | Both: Yes |
|
|
82
|
-
| 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
|
|
83
|
-
| 32K | OOM | 14.2 GB | Only Flash: Yes |
|
|
84
|
-
|
|
85
|
-
### Training memory (Llama 2 7B, batch=4)
|
|
86
|
-
|
|
87
|
-
| Context | Standard (GB) | Flash Attn (GB) | Reduction |
|
|
88
|
-
|---------|---------------|-----------------|-----------|
|
|
89
|
-
| 2K | 18.2 | 12.4 | 32% |
|
|
90
|
-
| 4K | 34.8 | 16.8 | 52% |
|
|
91
|
-
| 8K | OOM (>40GB) | 26.2 | Fits! |
|
|
92
|
-
|
|
93
|
-
## Scaling with sequence length
|
|
94
|
-
|
|
95
|
-
### Computational complexity
|
|
96
|
-
|
|
97
|
-
**Standard attention**:
|
|
98
|
-
- Time: O(N² × d)
|
|
99
|
-
- Memory: O(N² + N × d)
|
|
100
|
-
|
|
101
|
-
**Flash Attention**:
|
|
102
|
-
- Time: O(N² × d) (same, but with better constants)
|
|
103
|
-
- Memory: O(N × d) (linear!)
|
|
104
|
-
|
|
105
|
-
### Empirical scaling (A100, batch=1, heads=32, dim=64)
|
|
106
|
-
|
|
107
|
-
**Time per token (milliseconds)**:
|
|
108
|
-
|
|
109
|
-
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
|
110
|
-
|----------|-----|-----|-----|-----|-----|------|
|
|
111
|
-
| Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
|
|
112
|
-
| Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
|
|
113
|
-
| Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
|
|
114
|
-
|
|
115
|
-
**Observation**: Speedup increases quadratically with sequence length!
|
|
116
|
-
|
|
117
|
-
### Memory per token (MB)
|
|
118
|
-
|
|
119
|
-
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
|
120
|
-
|----------|-----|-----|-----|-----|-----|------|
|
|
121
|
-
| Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
|
|
122
|
-
| Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
|
|
123
|
-
|
|
124
|
-
**Observation**: Flash Attention memory per token is constant!
|
|
125
|
-
|
|
126
|
-
## Training vs inference performance
|
|
127
|
-
|
|
128
|
-
### Training (forward + backward, Llama 2 7B, A100)
|
|
129
|
-
|
|
130
|
-
| Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
|
131
|
-
|-------------|------------------------|--------------------------|---------|
|
|
132
|
-
| 4 × 2K | 1.2 | 3.1 | 2.6x |
|
|
133
|
-
| 8 × 2K | 2.1 | 5.8 | 2.8x |
|
|
134
|
-
| 4 × 4K | 0.4 | 1.3 | 3.3x |
|
|
135
|
-
| 8 × 4K | OOM | 2.4 | Enabled |
|
|
136
|
-
| 2 × 8K | 0.1 | 0.4 | 4.0x |
|
|
137
|
-
|
|
138
|
-
### Inference (generation, Llama 2 7B, A100)
|
|
139
|
-
|
|
140
|
-
| Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
|
141
|
-
|----------------|----------------------|-------------------------|---------|
|
|
142
|
-
| 512 | 48 | 52 | 1.1x |
|
|
143
|
-
| 2K | 42 | 62 | 1.5x |
|
|
144
|
-
| 4K | 31 | 58 | 1.9x |
|
|
145
|
-
| 8K | 18 | 51 | 2.8x |
|
|
146
|
-
| 16K | OOM | 42 | Enabled |
|
|
147
|
-
|
|
148
|
-
**Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
|
|
149
|
-
|
|
150
|
-
## Flash Attention versions comparison
|
|
151
|
-
|
|
152
|
-
### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
|
|
153
|
-
|
|
154
|
-
| Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
|
|
155
|
-
|--------|-----|-----|------------|-----------|
|
|
156
|
-
| Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
|
|
157
|
-
| Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
|
|
158
|
-
| TFLOPS | 180 | 420 | 740 | 1150 |
|
|
159
|
-
| GPU util % | 35% | 55% | 75% | 82% |
|
|
160
|
-
|
|
161
|
-
**Key improvements**:
|
|
162
|
-
- FA2: 2.3x faster than FA1 (better parallelism)
|
|
163
|
-
- FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
|
|
164
|
-
- FA3 (FP8): 2.6x faster than FA2 (low precision)
|
|
165
|
-
|
|
166
|
-
### Features by version
|
|
167
|
-
|
|
168
|
-
| Feature | FA1 | FA2 | FA3 |
|
|
169
|
-
|---------|-----|-----|-----|
|
|
170
|
-
| Basic attention | ✅ | ✅ | ✅ |
|
|
171
|
-
| Causal masking | ✅ | ✅ | ✅ |
|
|
172
|
-
| Multi-query attention | ❌ | ✅ | ✅ |
|
|
173
|
-
| Sliding window | ❌ | ✅ | ✅ |
|
|
174
|
-
| Paged KV cache | ❌ | ✅ | ✅ |
|
|
175
|
-
| FP8 support | ❌ | ❌ | ✅ (H100 only) |
|
|
176
|
-
| Work partitioning | Basic | Advanced | Optimal |
|
|
177
|
-
|
|
178
|
-
## Real-world model benchmarks
|
|
179
|
-
|
|
180
|
-
### Llama 2 models (A100 80GB, batch=4, seq=2048)
|
|
181
|
-
|
|
182
|
-
| Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
|
183
|
-
|-------|--------|------------------------|--------------------------|---------|
|
|
184
|
-
| Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
|
|
185
|
-
| Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
|
|
186
|
-
| Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
|
|
187
|
-
|
|
188
|
-
### GPT-style models (seq=1024)
|
|
189
|
-
|
|
190
|
-
| Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
|
191
|
-
|-------|----------------------|-------------------------|---------|
|
|
192
|
-
| GPT-2 (124M) | 520 | 680 | 1.3x |
|
|
193
|
-
| GPT-J (6B) | 42 | 98 | 2.3x |
|
|
194
|
-
| GPT-NeoX (20B) | 8 | 22 | 2.75x |
|
|
195
|
-
|
|
196
|
-
## Recommendations by use case
|
|
197
|
-
|
|
198
|
-
**Training large models (>7B parameters)**:
|
|
199
|
-
- Use Flash Attention 2 on A100
|
|
200
|
-
- Use Flash Attention 3 FP8 on H100 for maximum speed
|
|
201
|
-
- Expected: 2.5-3x speedup
|
|
202
|
-
|
|
203
|
-
**Long context inference (>4K tokens)**:
|
|
204
|
-
- Flash Attention essential (enables contexts standard attention can't handle)
|
|
205
|
-
- Expected: 2-4x speedup, 5-10x memory reduction
|
|
206
|
-
|
|
207
|
-
**Short sequences (<512 tokens)**:
|
|
208
|
-
- Flash Attention provides 1.2-1.5x speedup
|
|
209
|
-
- Minimal memory benefit
|
|
210
|
-
- Still worth enabling (no downside)
|
|
211
|
-
|
|
212
|
-
**Multi-user serving**:
|
|
213
|
-
- Flash Attention reduces per-request memory
|
|
214
|
-
- Allows higher concurrent batch sizes
|
|
215
|
-
- Can serve 2-3x more users on same hardware
|
|
@@ -1,293 +0,0 @@
|
|
|
1
|
-
# HuggingFace Transformers Integration
|
|
2
|
-
|
|
3
|
-
## Contents
|
|
4
|
-
- Enabling Flash Attention in Transformers
|
|
5
|
-
- Supported model architectures
|
|
6
|
-
- Configuration examples
|
|
7
|
-
- Performance comparisons
|
|
8
|
-
- Troubleshooting model-specific issues
|
|
9
|
-
|
|
10
|
-
## Enabling Flash Attention in Transformers
|
|
11
|
-
|
|
12
|
-
HuggingFace Transformers (v4.36+) supports Flash Attention 2 natively.
|
|
13
|
-
|
|
14
|
-
**Simple enable for any supported model**:
|
|
15
|
-
```python
|
|
16
|
-
from transformers import AutoModel
|
|
17
|
-
|
|
18
|
-
model = AutoModel.from_pretrained(
|
|
19
|
-
"meta-llama/Llama-2-7b-hf",
|
|
20
|
-
attn_implementation="flash_attention_2",
|
|
21
|
-
torch_dtype=torch.float16,
|
|
22
|
-
device_map="auto"
|
|
23
|
-
)
|
|
24
|
-
```
|
|
25
|
-
|
|
26
|
-
**Install requirements**:
|
|
27
|
-
```bash
|
|
28
|
-
pip install transformers>=4.36
|
|
29
|
-
pip install flash-attn --no-build-isolation
|
|
30
|
-
```
|
|
31
|
-
|
|
32
|
-
## Supported model architectures
|
|
33
|
-
|
|
34
|
-
As of Transformers 4.40:
|
|
35
|
-
|
|
36
|
-
**Fully supported**:
|
|
37
|
-
- Llama / Llama 2 / Llama 3
|
|
38
|
-
- Mistral / Mixtral
|
|
39
|
-
- Falcon
|
|
40
|
-
- GPT-NeoX
|
|
41
|
-
- Phi / Phi-2 / Phi-3
|
|
42
|
-
- Qwen / Qwen2
|
|
43
|
-
- Gemma
|
|
44
|
-
- Starcoder2
|
|
45
|
-
- GPT-J
|
|
46
|
-
- OPT
|
|
47
|
-
- BLOOM
|
|
48
|
-
|
|
49
|
-
**Partially supported** (encoder-decoder):
|
|
50
|
-
- BART
|
|
51
|
-
- T5 / Flan-T5
|
|
52
|
-
- Whisper
|
|
53
|
-
|
|
54
|
-
**Check support**:
|
|
55
|
-
```python
|
|
56
|
-
from transformers import AutoConfig
|
|
57
|
-
|
|
58
|
-
config = AutoConfig.from_pretrained("model-name")
|
|
59
|
-
print(config._attn_implementation_internal)
|
|
60
|
-
# 'flash_attention_2' if supported
|
|
61
|
-
```
|
|
62
|
-
|
|
63
|
-
## Configuration examples
|
|
64
|
-
|
|
65
|
-
### Llama 2 with Flash Attention
|
|
66
|
-
|
|
67
|
-
```python
|
|
68
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
69
|
-
import torch
|
|
70
|
-
|
|
71
|
-
model_id = "meta-llama/Llama-2-7b-hf"
|
|
72
|
-
|
|
73
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
74
|
-
model_id,
|
|
75
|
-
attn_implementation="flash_attention_2",
|
|
76
|
-
torch_dtype=torch.float16,
|
|
77
|
-
device_map="auto"
|
|
78
|
-
)
|
|
79
|
-
|
|
80
|
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
81
|
-
|
|
82
|
-
# Generate
|
|
83
|
-
inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
|
|
84
|
-
outputs = model.generate(**inputs, max_length=100)
|
|
85
|
-
print(tokenizer.decode(outputs[0]))
|
|
86
|
-
```
|
|
87
|
-
|
|
88
|
-
### Mistral with Flash Attention for long context
|
|
89
|
-
|
|
90
|
-
```python
|
|
91
|
-
from transformers import AutoModelForCausalLM
|
|
92
|
-
import torch
|
|
93
|
-
|
|
94
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
95
|
-
"mistralai/Mistral-7B-v0.1",
|
|
96
|
-
attn_implementation="flash_attention_2",
|
|
97
|
-
torch_dtype=torch.bfloat16, # Better for long context
|
|
98
|
-
device_map="auto",
|
|
99
|
-
max_position_embeddings=32768 # Extended context
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
# Process long document (32K tokens)
|
|
103
|
-
long_text = "..." * 10000
|
|
104
|
-
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda")
|
|
105
|
-
outputs = model.generate(**inputs, max_new_tokens=512)
|
|
106
|
-
```
|
|
107
|
-
|
|
108
|
-
### Fine-tuning with Flash Attention
|
|
109
|
-
|
|
110
|
-
```python
|
|
111
|
-
from transformers import Trainer, TrainingArguments
|
|
112
|
-
from transformers import AutoModelForCausalLM
|
|
113
|
-
|
|
114
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
115
|
-
"meta-llama/Llama-2-7b-hf",
|
|
116
|
-
attn_implementation="flash_attention_2",
|
|
117
|
-
torch_dtype=torch.float16
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
training_args = TrainingArguments(
|
|
121
|
-
output_dir="./results",
|
|
122
|
-
per_device_train_batch_size=4,
|
|
123
|
-
gradient_accumulation_steps=4,
|
|
124
|
-
num_train_epochs=3,
|
|
125
|
-
fp16=True, # Must match model dtype
|
|
126
|
-
optim="adamw_torch_fused" # Fast optimizer
|
|
127
|
-
)
|
|
128
|
-
|
|
129
|
-
trainer = Trainer(
|
|
130
|
-
model=model,
|
|
131
|
-
args=training_args,
|
|
132
|
-
train_dataset=train_dataset
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
trainer.train()
|
|
136
|
-
```
|
|
137
|
-
|
|
138
|
-
### Multi-GPU training
|
|
139
|
-
|
|
140
|
-
```python
|
|
141
|
-
from transformers import AutoModelForCausalLM
|
|
142
|
-
import torch
|
|
143
|
-
|
|
144
|
-
# Model parallelism with Flash Attention
|
|
145
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
146
|
-
"meta-llama/Llama-2-13b-hf",
|
|
147
|
-
attn_implementation="flash_attention_2",
|
|
148
|
-
torch_dtype=torch.float16,
|
|
149
|
-
device_map="auto", # Automatic multi-GPU placement
|
|
150
|
-
max_memory={0: "20GB", 1: "20GB"} # Limit per GPU
|
|
151
|
-
)
|
|
152
|
-
```
|
|
153
|
-
|
|
154
|
-
## Performance comparisons
|
|
155
|
-
|
|
156
|
-
### Memory usage (Llama 2 7B, batch=1)
|
|
157
|
-
|
|
158
|
-
| Sequence Length | Standard Attention | Flash Attention 2 | Reduction |
|
|
159
|
-
|-----------------|-------------------|-------------------|-----------|
|
|
160
|
-
| 512 | 1.2 GB | 0.9 GB | 25% |
|
|
161
|
-
| 2048 | 3.8 GB | 1.4 GB | 63% |
|
|
162
|
-
| 8192 | 14.2 GB | 3.2 GB | 77% |
|
|
163
|
-
| 32768 | OOM (>24GB) | 10.8 GB | Fits! |
|
|
164
|
-
|
|
165
|
-
### Speed (tokens/sec, A100 80GB)
|
|
166
|
-
|
|
167
|
-
| Model | Standard | Flash Attn 2 | Speedup |
|
|
168
|
-
|-------|----------|--------------|---------|
|
|
169
|
-
| Llama 2 7B (seq=2048) | 42 | 118 | 2.8x |
|
|
170
|
-
| Llama 2 13B (seq=4096) | 18 | 52 | 2.9x |
|
|
171
|
-
| Llama 2 70B (seq=2048) | 4 | 11 | 2.75x |
|
|
172
|
-
|
|
173
|
-
### Training throughput (samples/sec)
|
|
174
|
-
|
|
175
|
-
| Model | Batch Size | Standard | Flash Attn 2 | Speedup |
|
|
176
|
-
|-------|------------|----------|--------------|---------|
|
|
177
|
-
| Llama 2 7B | 4 | 1.2 | 3.1 | 2.6x |
|
|
178
|
-
| Llama 2 7B | 8 | 2.1 | 5.8 | 2.8x |
|
|
179
|
-
| Llama 2 13B | 2 | 0.6 | 1.7 | 2.8x |
|
|
180
|
-
|
|
181
|
-
## Troubleshooting model-specific issues
|
|
182
|
-
|
|
183
|
-
### Issue: Model doesn't support Flash Attention
|
|
184
|
-
|
|
185
|
-
Check support list above. If not supported, use PyTorch SDPA as fallback:
|
|
186
|
-
|
|
187
|
-
```python
|
|
188
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
189
|
-
"model-name",
|
|
190
|
-
attn_implementation="sdpa", # PyTorch native (still faster)
|
|
191
|
-
torch_dtype=torch.float16
|
|
192
|
-
)
|
|
193
|
-
```
|
|
194
|
-
|
|
195
|
-
### Issue: CUDA out of memory during loading
|
|
196
|
-
|
|
197
|
-
Reduce memory footprint:
|
|
198
|
-
|
|
199
|
-
```python
|
|
200
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
201
|
-
"model-name",
|
|
202
|
-
attn_implementation="flash_attention_2",
|
|
203
|
-
torch_dtype=torch.float16,
|
|
204
|
-
device_map="auto",
|
|
205
|
-
max_memory={0: "18GB"}, # Reserve memory for KV cache
|
|
206
|
-
low_cpu_mem_usage=True
|
|
207
|
-
)
|
|
208
|
-
```
|
|
209
|
-
|
|
210
|
-
### Issue: Slower inference than expected
|
|
211
|
-
|
|
212
|
-
Ensure dtype matches:
|
|
213
|
-
|
|
214
|
-
```python
|
|
215
|
-
# Model and inputs must both be float16/bfloat16
|
|
216
|
-
model = model.to(torch.float16)
|
|
217
|
-
inputs = tokenizer(..., return_tensors="pt").to("cuda")
|
|
218
|
-
inputs = {k: v.to(torch.float16) if v.dtype == torch.float32 else v
|
|
219
|
-
for k, v in inputs.items()}
|
|
220
|
-
```
|
|
221
|
-
|
|
222
|
-
### Issue: Different outputs vs standard attention
|
|
223
|
-
|
|
224
|
-
Flash Attention is numerically equivalent but uses different computation order. Small differences (<1e-3) are normal:
|
|
225
|
-
|
|
226
|
-
```python
|
|
227
|
-
# Compare outputs
|
|
228
|
-
model_standard = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype=torch.float16)
|
|
229
|
-
model_flash = AutoModelForCausalLM.from_pretrained(
|
|
230
|
-
"model-name",
|
|
231
|
-
attn_implementation="flash_attention_2",
|
|
232
|
-
torch_dtype=torch.float16
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
inputs = tokenizer("Test", return_tensors="pt").to("cuda")
|
|
236
|
-
|
|
237
|
-
with torch.no_grad():
|
|
238
|
-
out_standard = model_standard(**inputs).logits
|
|
239
|
-
out_flash = model_flash(**inputs).logits
|
|
240
|
-
|
|
241
|
-
diff = (out_standard - out_flash).abs().max()
|
|
242
|
-
print(f"Max diff: {diff:.6f}") # Should be ~1e-3 to 1e-4
|
|
243
|
-
```
|
|
244
|
-
|
|
245
|
-
### Issue: ImportError during model loading
|
|
246
|
-
|
|
247
|
-
Install flash-attn:
|
|
248
|
-
```bash
|
|
249
|
-
pip install flash-attn --no-build-isolation
|
|
250
|
-
```
|
|
251
|
-
|
|
252
|
-
Or disable Flash Attention:
|
|
253
|
-
```python
|
|
254
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
255
|
-
"model-name",
|
|
256
|
-
attn_implementation="eager", # Standard PyTorch
|
|
257
|
-
torch_dtype=torch.float16
|
|
258
|
-
)
|
|
259
|
-
```
|
|
260
|
-
|
|
261
|
-
## Best practices
|
|
262
|
-
|
|
263
|
-
1. **Always use float16/bfloat16** with Flash Attention (not float32)
|
|
264
|
-
2. **Set device_map="auto"** for automatic memory management
|
|
265
|
-
3. **Use bfloat16 for long context** (better numerical stability)
|
|
266
|
-
4. **Enable gradient checkpointing** for training large models
|
|
267
|
-
5. **Monitor memory** with `torch.cuda.max_memory_allocated()`
|
|
268
|
-
|
|
269
|
-
**Example with all best practices**:
|
|
270
|
-
```python
|
|
271
|
-
from transformers import AutoModelForCausalLM, TrainingArguments
|
|
272
|
-
|
|
273
|
-
model = AutoModelForCausalLM.from_pretrained(
|
|
274
|
-
"meta-llama/Llama-2-7b-hf",
|
|
275
|
-
attn_implementation="flash_attention_2",
|
|
276
|
-
torch_dtype=torch.bfloat16, # Better for training
|
|
277
|
-
device_map="auto",
|
|
278
|
-
low_cpu_mem_usage=True
|
|
279
|
-
)
|
|
280
|
-
|
|
281
|
-
# Enable gradient checkpointing for memory
|
|
282
|
-
model.gradient_checkpointing_enable()
|
|
283
|
-
|
|
284
|
-
# Training with optimizations
|
|
285
|
-
training_args = TrainingArguments(
|
|
286
|
-
output_dir="./results",
|
|
287
|
-
per_device_train_batch_size=8,
|
|
288
|
-
gradient_accumulation_steps=2,
|
|
289
|
-
bf16=True, # Match model dtype
|
|
290
|
-
optim="adamw_torch_fused",
|
|
291
|
-
gradient_checkpointing=True
|
|
292
|
-
)
|
|
293
|
-
```
|