EvoScientist 0.0.1.dev3__py3-none-any.whl → 0.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. EvoScientist/EvoScientist.py +17 -49
  2. EvoScientist/backends.py +0 -26
  3. EvoScientist/cli.py +1109 -255
  4. EvoScientist/middleware.py +8 -61
  5. EvoScientist/stream/__init__.py +0 -25
  6. EvoScientist/stream/utils.py +16 -23
  7. EvoScientist/tools.py +0 -64
  8. evoscientist-0.1.0rc1.dist-info/METADATA +199 -0
  9. evoscientist-0.1.0rc1.dist-info/RECORD +21 -0
  10. evoscientist-0.1.0rc1.dist-info/entry_points.txt +2 -0
  11. EvoScientist/memory.py +0 -715
  12. EvoScientist/paths.py +0 -45
  13. EvoScientist/skills/accelerate/SKILL.md +0 -332
  14. EvoScientist/skills/accelerate/references/custom-plugins.md +0 -453
  15. EvoScientist/skills/accelerate/references/megatron-integration.md +0 -489
  16. EvoScientist/skills/accelerate/references/performance.md +0 -525
  17. EvoScientist/skills/bitsandbytes/SKILL.md +0 -411
  18. EvoScientist/skills/bitsandbytes/references/memory-optimization.md +0 -521
  19. EvoScientist/skills/bitsandbytes/references/qlora-training.md +0 -521
  20. EvoScientist/skills/bitsandbytes/references/quantization-formats.md +0 -447
  21. EvoScientist/skills/find-skills/SKILL.md +0 -133
  22. EvoScientist/skills/find-skills/scripts/install_skill.py +0 -211
  23. EvoScientist/skills/flash-attention/SKILL.md +0 -367
  24. EvoScientist/skills/flash-attention/references/benchmarks.md +0 -215
  25. EvoScientist/skills/flash-attention/references/transformers-integration.md +0 -293
  26. EvoScientist/skills/llama-cpp/SKILL.md +0 -258
  27. EvoScientist/skills/llama-cpp/references/optimization.md +0 -89
  28. EvoScientist/skills/llama-cpp/references/quantization.md +0 -213
  29. EvoScientist/skills/llama-cpp/references/server.md +0 -125
  30. EvoScientist/skills/lm-evaluation-harness/SKILL.md +0 -490
  31. EvoScientist/skills/lm-evaluation-harness/references/api-evaluation.md +0 -490
  32. EvoScientist/skills/lm-evaluation-harness/references/benchmark-guide.md +0 -488
  33. EvoScientist/skills/lm-evaluation-harness/references/custom-tasks.md +0 -602
  34. EvoScientist/skills/lm-evaluation-harness/references/distributed-eval.md +0 -519
  35. EvoScientist/skills/ml-paper-writing/SKILL.md +0 -937
  36. EvoScientist/skills/ml-paper-writing/references/checklists.md +0 -361
  37. EvoScientist/skills/ml-paper-writing/references/citation-workflow.md +0 -562
  38. EvoScientist/skills/ml-paper-writing/references/reviewer-guidelines.md +0 -367
  39. EvoScientist/skills/ml-paper-writing/references/sources.md +0 -159
  40. EvoScientist/skills/ml-paper-writing/references/writing-guide.md +0 -476
  41. EvoScientist/skills/ml-paper-writing/templates/README.md +0 -251
  42. EvoScientist/skills/ml-paper-writing/templates/aaai2026/README.md +0 -534
  43. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026-unified-supp.tex +0 -144
  44. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026-unified-template.tex +0 -952
  45. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.bib +0 -111
  46. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.bst +0 -1493
  47. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.sty +0 -315
  48. EvoScientist/skills/ml-paper-writing/templates/acl/README.md +0 -50
  49. EvoScientist/skills/ml-paper-writing/templates/acl/acl.sty +0 -312
  50. EvoScientist/skills/ml-paper-writing/templates/acl/acl_latex.tex +0 -377
  51. EvoScientist/skills/ml-paper-writing/templates/acl/acl_lualatex.tex +0 -101
  52. EvoScientist/skills/ml-paper-writing/templates/acl/acl_natbib.bst +0 -1940
  53. EvoScientist/skills/ml-paper-writing/templates/acl/anthology.bib.txt +0 -26
  54. EvoScientist/skills/ml-paper-writing/templates/acl/custom.bib +0 -70
  55. EvoScientist/skills/ml-paper-writing/templates/acl/formatting.md +0 -326
  56. EvoScientist/skills/ml-paper-writing/templates/colm2025/README.md +0 -3
  57. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.bib +0 -11
  58. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.bst +0 -1440
  59. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.pdf +0 -0
  60. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.sty +0 -218
  61. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.tex +0 -305
  62. EvoScientist/skills/ml-paper-writing/templates/colm2025/fancyhdr.sty +0 -485
  63. EvoScientist/skills/ml-paper-writing/templates/colm2025/math_commands.tex +0 -508
  64. EvoScientist/skills/ml-paper-writing/templates/colm2025/natbib.sty +0 -1246
  65. EvoScientist/skills/ml-paper-writing/templates/iclr2026/fancyhdr.sty +0 -485
  66. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.bib +0 -24
  67. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.bst +0 -1440
  68. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.pdf +0 -0
  69. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.sty +0 -246
  70. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.tex +0 -414
  71. EvoScientist/skills/ml-paper-writing/templates/iclr2026/math_commands.tex +0 -508
  72. EvoScientist/skills/ml-paper-writing/templates/iclr2026/natbib.sty +0 -1246
  73. EvoScientist/skills/ml-paper-writing/templates/icml2026/algorithm.sty +0 -79
  74. EvoScientist/skills/ml-paper-writing/templates/icml2026/algorithmic.sty +0 -201
  75. EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.bib +0 -75
  76. EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.pdf +0 -0
  77. EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.tex +0 -662
  78. EvoScientist/skills/ml-paper-writing/templates/icml2026/fancyhdr.sty +0 -864
  79. EvoScientist/skills/ml-paper-writing/templates/icml2026/icml2026.bst +0 -1443
  80. EvoScientist/skills/ml-paper-writing/templates/icml2026/icml2026.sty +0 -767
  81. EvoScientist/skills/ml-paper-writing/templates/icml2026/icml_numpapers.pdf +0 -0
  82. EvoScientist/skills/ml-paper-writing/templates/neurips2025/Makefile +0 -36
  83. EvoScientist/skills/ml-paper-writing/templates/neurips2025/extra_pkgs.tex +0 -53
  84. EvoScientist/skills/ml-paper-writing/templates/neurips2025/main.tex +0 -38
  85. EvoScientist/skills/ml-paper-writing/templates/neurips2025/neurips.sty +0 -382
  86. EvoScientist/skills/peft/SKILL.md +0 -431
  87. EvoScientist/skills/peft/references/advanced-usage.md +0 -514
  88. EvoScientist/skills/peft/references/troubleshooting.md +0 -480
  89. EvoScientist/skills/ray-data/SKILL.md +0 -326
  90. EvoScientist/skills/ray-data/references/integration.md +0 -82
  91. EvoScientist/skills/ray-data/references/transformations.md +0 -83
  92. EvoScientist/skills/skill-creator/LICENSE.txt +0 -202
  93. EvoScientist/skills/skill-creator/SKILL.md +0 -356
  94. EvoScientist/skills/skill-creator/references/output-patterns.md +0 -82
  95. EvoScientist/skills/skill-creator/references/workflows.md +0 -28
  96. EvoScientist/skills/skill-creator/scripts/init_skill.py +0 -303
  97. EvoScientist/skills/skill-creator/scripts/package_skill.py +0 -110
  98. EvoScientist/skills/skill-creator/scripts/quick_validate.py +0 -95
  99. EvoScientist/skills_manager.py +0 -392
  100. EvoScientist/stream/display.py +0 -604
  101. EvoScientist/stream/events.py +0 -415
  102. EvoScientist/stream/state.py +0 -343
  103. evoscientist-0.0.1.dev3.dist-info/METADATA +0 -321
  104. evoscientist-0.0.1.dev3.dist-info/RECORD +0 -113
  105. evoscientist-0.0.1.dev3.dist-info/entry_points.txt +0 -5
  106. {evoscientist-0.0.1.dev3.dist-info → evoscientist-0.1.0rc1.dist-info}/WHEEL +0 -0
  107. {evoscientist-0.0.1.dev3.dist-info → evoscientist-0.1.0rc1.dist-info}/licenses/LICENSE +0 -0
  108. {evoscientist-0.0.1.dev3.dist-info → evoscientist-0.1.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,215 +0,0 @@
1
- # Performance Benchmarks
2
-
3
- ## Contents
4
- - Speed comparisons across GPUs
5
- - Memory usage analysis
6
- - Scaling with sequence length
7
- - Training vs inference performance
8
- - Flash Attention versions comparison
9
-
10
- ## Speed comparisons across GPUs
11
-
12
- ### A100 80GB (Ampere)
13
-
14
- **Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
15
-
16
- | Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
17
- |------------|----------|--------------|--------------|---------------|
18
- | 512 | 1.2 | 0.9 | N/A | 1.3x |
19
- | 1024 | 3.8 | 1.4 | N/A | 2.7x |
20
- | 2048 | 14.2 | 4.8 | N/A | 3.0x |
21
- | 4096 | 55.1 | 17.3 | N/A | 3.2x |
22
- | 8192 | 218.5 | 66.2 | N/A | 3.3x |
23
-
24
- ### H100 80GB (Hopper)
25
-
26
- **Forward pass time** (milliseconds, same config):
27
-
28
- | Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
29
- |------------|----------|--------------|---------------------|--------------------|--------------|
30
- | 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
31
- | 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
32
- | 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
33
- | 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
34
- | 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
35
-
36
- **Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
37
-
38
- ### A10G 24GB (Ampere)
39
-
40
- **Forward pass time** (milliseconds, batch=4):
41
-
42
- | Seq Length | Standard | Flash Attn 2 | Speedup |
43
- |------------|----------|--------------|---------|
44
- | 512 | 2.1 | 1.6 | 1.3x |
45
- | 1024 | 6.8 | 2.8 | 2.4x |
46
- | 2048 | 25.9 | 9.4 | 2.8x |
47
- | 4096 | 102.1 | 35.2 | 2.9x |
48
-
49
- ## Memory usage analysis
50
-
51
- ### GPU memory consumption (batch=8, heads=32, dim=64)
52
-
53
- **Standard attention memory**:
54
-
55
- | Seq Length | Attention Matrix | KV Cache | Total | Notes |
56
- |------------|------------------|----------|-------|-------|
57
- | 512 | 8 MB | 32 MB | 40 MB | Manageable |
58
- | 2048 | 128 MB | 128 MB | 256 MB | Growing |
59
- | 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
60
- | 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
61
-
62
- **Flash Attention 2 memory**:
63
-
64
- | Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
65
- |------------|---------------------|----------|-------|-----------|
66
- | 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
67
- | 2048 | 0 MB | 128 MB | 128 MB | 50% |
68
- | 8192 | 0 MB | 512 MB | 512 MB | 80% |
69
- | 32768 | 0 MB | 2048 MB | 2 GB | 94% |
70
-
71
- **Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
72
-
73
- ### Memory scaling comparison
74
-
75
- **Llama 2 7B model memory** (float16, batch=1):
76
-
77
- | Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
78
- |----------------|-------------------|-------------------|-------------------|
79
- | 2K | 3.2 GB | 2.1 GB | Both: Yes |
80
- | 4K | 5.8 GB | 2.8 GB | Both: Yes |
81
- | 8K | 12.1 GB | 4.2 GB | Both: Yes |
82
- | 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
83
- | 32K | OOM | 14.2 GB | Only Flash: Yes |
84
-
85
- ### Training memory (Llama 2 7B, batch=4)
86
-
87
- | Context | Standard (GB) | Flash Attn (GB) | Reduction |
88
- |---------|---------------|-----------------|-----------|
89
- | 2K | 18.2 | 12.4 | 32% |
90
- | 4K | 34.8 | 16.8 | 52% |
91
- | 8K | OOM (>40GB) | 26.2 | Fits! |
92
-
93
- ## Scaling with sequence length
94
-
95
- ### Computational complexity
96
-
97
- **Standard attention**:
98
- - Time: O(N² × d)
99
- - Memory: O(N² + N × d)
100
-
101
- **Flash Attention**:
102
- - Time: O(N² × d) (same, but with better constants)
103
- - Memory: O(N × d) (linear!)
104
-
105
- ### Empirical scaling (A100, batch=1, heads=32, dim=64)
106
-
107
- **Time per token (milliseconds)**:
108
-
109
- | Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
110
- |----------|-----|-----|-----|-----|-----|------|
111
- | Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
112
- | Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
113
- | Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
114
-
115
- **Observation**: Speedup increases quadratically with sequence length!
116
-
117
- ### Memory per token (MB)
118
-
119
- | Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
120
- |----------|-----|-----|-----|-----|-----|------|
121
- | Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
122
- | Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
123
-
124
- **Observation**: Flash Attention memory per token is constant!
125
-
126
- ## Training vs inference performance
127
-
128
- ### Training (forward + backward, Llama 2 7B, A100)
129
-
130
- | Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
131
- |-------------|------------------------|--------------------------|---------|
132
- | 4 × 2K | 1.2 | 3.1 | 2.6x |
133
- | 8 × 2K | 2.1 | 5.8 | 2.8x |
134
- | 4 × 4K | 0.4 | 1.3 | 3.3x |
135
- | 8 × 4K | OOM | 2.4 | Enabled |
136
- | 2 × 8K | 0.1 | 0.4 | 4.0x |
137
-
138
- ### Inference (generation, Llama 2 7B, A100)
139
-
140
- | Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
141
- |----------------|----------------------|-------------------------|---------|
142
- | 512 | 48 | 52 | 1.1x |
143
- | 2K | 42 | 62 | 1.5x |
144
- | 4K | 31 | 58 | 1.9x |
145
- | 8K | 18 | 51 | 2.8x |
146
- | 16K | OOM | 42 | Enabled |
147
-
148
- **Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
149
-
150
- ## Flash Attention versions comparison
151
-
152
- ### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
153
-
154
- | Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
155
- |--------|-----|-----|------------|-----------|
156
- | Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
157
- | Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
158
- | TFLOPS | 180 | 420 | 740 | 1150 |
159
- | GPU util % | 35% | 55% | 75% | 82% |
160
-
161
- **Key improvements**:
162
- - FA2: 2.3x faster than FA1 (better parallelism)
163
- - FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
164
- - FA3 (FP8): 2.6x faster than FA2 (low precision)
165
-
166
- ### Features by version
167
-
168
- | Feature | FA1 | FA2 | FA3 |
169
- |---------|-----|-----|-----|
170
- | Basic attention | ✅ | ✅ | ✅ |
171
- | Causal masking | ✅ | ✅ | ✅ |
172
- | Multi-query attention | ❌ | ✅ | ✅ |
173
- | Sliding window | ❌ | ✅ | ✅ |
174
- | Paged KV cache | ❌ | ✅ | ✅ |
175
- | FP8 support | ❌ | ❌ | ✅ (H100 only) |
176
- | Work partitioning | Basic | Advanced | Optimal |
177
-
178
- ## Real-world model benchmarks
179
-
180
- ### Llama 2 models (A100 80GB, batch=4, seq=2048)
181
-
182
- | Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
183
- |-------|--------|------------------------|--------------------------|---------|
184
- | Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
185
- | Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
186
- | Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
187
-
188
- ### GPT-style models (seq=1024)
189
-
190
- | Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
191
- |-------|----------------------|-------------------------|---------|
192
- | GPT-2 (124M) | 520 | 680 | 1.3x |
193
- | GPT-J (6B) | 42 | 98 | 2.3x |
194
- | GPT-NeoX (20B) | 8 | 22 | 2.75x |
195
-
196
- ## Recommendations by use case
197
-
198
- **Training large models (>7B parameters)**:
199
- - Use Flash Attention 2 on A100
200
- - Use Flash Attention 3 FP8 on H100 for maximum speed
201
- - Expected: 2.5-3x speedup
202
-
203
- **Long context inference (>4K tokens)**:
204
- - Flash Attention essential (enables contexts standard attention can't handle)
205
- - Expected: 2-4x speedup, 5-10x memory reduction
206
-
207
- **Short sequences (<512 tokens)**:
208
- - Flash Attention provides 1.2-1.5x speedup
209
- - Minimal memory benefit
210
- - Still worth enabling (no downside)
211
-
212
- **Multi-user serving**:
213
- - Flash Attention reduces per-request memory
214
- - Allows higher concurrent batch sizes
215
- - Can serve 2-3x more users on same hardware
@@ -1,293 +0,0 @@
1
- # HuggingFace Transformers Integration
2
-
3
- ## Contents
4
- - Enabling Flash Attention in Transformers
5
- - Supported model architectures
6
- - Configuration examples
7
- - Performance comparisons
8
- - Troubleshooting model-specific issues
9
-
10
- ## Enabling Flash Attention in Transformers
11
-
12
- HuggingFace Transformers (v4.36+) supports Flash Attention 2 natively.
13
-
14
- **Simple enable for any supported model**:
15
- ```python
16
- from transformers import AutoModel
17
-
18
- model = AutoModel.from_pretrained(
19
- "meta-llama/Llama-2-7b-hf",
20
- attn_implementation="flash_attention_2",
21
- torch_dtype=torch.float16,
22
- device_map="auto"
23
- )
24
- ```
25
-
26
- **Install requirements**:
27
- ```bash
28
- pip install transformers>=4.36
29
- pip install flash-attn --no-build-isolation
30
- ```
31
-
32
- ## Supported model architectures
33
-
34
- As of Transformers 4.40:
35
-
36
- **Fully supported**:
37
- - Llama / Llama 2 / Llama 3
38
- - Mistral / Mixtral
39
- - Falcon
40
- - GPT-NeoX
41
- - Phi / Phi-2 / Phi-3
42
- - Qwen / Qwen2
43
- - Gemma
44
- - Starcoder2
45
- - GPT-J
46
- - OPT
47
- - BLOOM
48
-
49
- **Partially supported** (encoder-decoder):
50
- - BART
51
- - T5 / Flan-T5
52
- - Whisper
53
-
54
- **Check support**:
55
- ```python
56
- from transformers import AutoConfig
57
-
58
- config = AutoConfig.from_pretrained("model-name")
59
- print(config._attn_implementation_internal)
60
- # 'flash_attention_2' if supported
61
- ```
62
-
63
- ## Configuration examples
64
-
65
- ### Llama 2 with Flash Attention
66
-
67
- ```python
68
- from transformers import AutoModelForCausalLM, AutoTokenizer
69
- import torch
70
-
71
- model_id = "meta-llama/Llama-2-7b-hf"
72
-
73
- model = AutoModelForCausalLM.from_pretrained(
74
- model_id,
75
- attn_implementation="flash_attention_2",
76
- torch_dtype=torch.float16,
77
- device_map="auto"
78
- )
79
-
80
- tokenizer = AutoTokenizer.from_pretrained(model_id)
81
-
82
- # Generate
83
- inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
84
- outputs = model.generate(**inputs, max_length=100)
85
- print(tokenizer.decode(outputs[0]))
86
- ```
87
-
88
- ### Mistral with Flash Attention for long context
89
-
90
- ```python
91
- from transformers import AutoModelForCausalLM
92
- import torch
93
-
94
- model = AutoModelForCausalLM.from_pretrained(
95
- "mistralai/Mistral-7B-v0.1",
96
- attn_implementation="flash_attention_2",
97
- torch_dtype=torch.bfloat16, # Better for long context
98
- device_map="auto",
99
- max_position_embeddings=32768 # Extended context
100
- )
101
-
102
- # Process long document (32K tokens)
103
- long_text = "..." * 10000
104
- inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda")
105
- outputs = model.generate(**inputs, max_new_tokens=512)
106
- ```
107
-
108
- ### Fine-tuning with Flash Attention
109
-
110
- ```python
111
- from transformers import Trainer, TrainingArguments
112
- from transformers import AutoModelForCausalLM
113
-
114
- model = AutoModelForCausalLM.from_pretrained(
115
- "meta-llama/Llama-2-7b-hf",
116
- attn_implementation="flash_attention_2",
117
- torch_dtype=torch.float16
118
- )
119
-
120
- training_args = TrainingArguments(
121
- output_dir="./results",
122
- per_device_train_batch_size=4,
123
- gradient_accumulation_steps=4,
124
- num_train_epochs=3,
125
- fp16=True, # Must match model dtype
126
- optim="adamw_torch_fused" # Fast optimizer
127
- )
128
-
129
- trainer = Trainer(
130
- model=model,
131
- args=training_args,
132
- train_dataset=train_dataset
133
- )
134
-
135
- trainer.train()
136
- ```
137
-
138
- ### Multi-GPU training
139
-
140
- ```python
141
- from transformers import AutoModelForCausalLM
142
- import torch
143
-
144
- # Model parallelism with Flash Attention
145
- model = AutoModelForCausalLM.from_pretrained(
146
- "meta-llama/Llama-2-13b-hf",
147
- attn_implementation="flash_attention_2",
148
- torch_dtype=torch.float16,
149
- device_map="auto", # Automatic multi-GPU placement
150
- max_memory={0: "20GB", 1: "20GB"} # Limit per GPU
151
- )
152
- ```
153
-
154
- ## Performance comparisons
155
-
156
- ### Memory usage (Llama 2 7B, batch=1)
157
-
158
- | Sequence Length | Standard Attention | Flash Attention 2 | Reduction |
159
- |-----------------|-------------------|-------------------|-----------|
160
- | 512 | 1.2 GB | 0.9 GB | 25% |
161
- | 2048 | 3.8 GB | 1.4 GB | 63% |
162
- | 8192 | 14.2 GB | 3.2 GB | 77% |
163
- | 32768 | OOM (>24GB) | 10.8 GB | Fits! |
164
-
165
- ### Speed (tokens/sec, A100 80GB)
166
-
167
- | Model | Standard | Flash Attn 2 | Speedup |
168
- |-------|----------|--------------|---------|
169
- | Llama 2 7B (seq=2048) | 42 | 118 | 2.8x |
170
- | Llama 2 13B (seq=4096) | 18 | 52 | 2.9x |
171
- | Llama 2 70B (seq=2048) | 4 | 11 | 2.75x |
172
-
173
- ### Training throughput (samples/sec)
174
-
175
- | Model | Batch Size | Standard | Flash Attn 2 | Speedup |
176
- |-------|------------|----------|--------------|---------|
177
- | Llama 2 7B | 4 | 1.2 | 3.1 | 2.6x |
178
- | Llama 2 7B | 8 | 2.1 | 5.8 | 2.8x |
179
- | Llama 2 13B | 2 | 0.6 | 1.7 | 2.8x |
180
-
181
- ## Troubleshooting model-specific issues
182
-
183
- ### Issue: Model doesn't support Flash Attention
184
-
185
- Check support list above. If not supported, use PyTorch SDPA as fallback:
186
-
187
- ```python
188
- model = AutoModelForCausalLM.from_pretrained(
189
- "model-name",
190
- attn_implementation="sdpa", # PyTorch native (still faster)
191
- torch_dtype=torch.float16
192
- )
193
- ```
194
-
195
- ### Issue: CUDA out of memory during loading
196
-
197
- Reduce memory footprint:
198
-
199
- ```python
200
- model = AutoModelForCausalLM.from_pretrained(
201
- "model-name",
202
- attn_implementation="flash_attention_2",
203
- torch_dtype=torch.float16,
204
- device_map="auto",
205
- max_memory={0: "18GB"}, # Reserve memory for KV cache
206
- low_cpu_mem_usage=True
207
- )
208
- ```
209
-
210
- ### Issue: Slower inference than expected
211
-
212
- Ensure dtype matches:
213
-
214
- ```python
215
- # Model and inputs must both be float16/bfloat16
216
- model = model.to(torch.float16)
217
- inputs = tokenizer(..., return_tensors="pt").to("cuda")
218
- inputs = {k: v.to(torch.float16) if v.dtype == torch.float32 else v
219
- for k, v in inputs.items()}
220
- ```
221
-
222
- ### Issue: Different outputs vs standard attention
223
-
224
- Flash Attention is numerically equivalent but uses different computation order. Small differences (<1e-3) are normal:
225
-
226
- ```python
227
- # Compare outputs
228
- model_standard = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype=torch.float16)
229
- model_flash = AutoModelForCausalLM.from_pretrained(
230
- "model-name",
231
- attn_implementation="flash_attention_2",
232
- torch_dtype=torch.float16
233
- )
234
-
235
- inputs = tokenizer("Test", return_tensors="pt").to("cuda")
236
-
237
- with torch.no_grad():
238
- out_standard = model_standard(**inputs).logits
239
- out_flash = model_flash(**inputs).logits
240
-
241
- diff = (out_standard - out_flash).abs().max()
242
- print(f"Max diff: {diff:.6f}") # Should be ~1e-3 to 1e-4
243
- ```
244
-
245
- ### Issue: ImportError during model loading
246
-
247
- Install flash-attn:
248
- ```bash
249
- pip install flash-attn --no-build-isolation
250
- ```
251
-
252
- Or disable Flash Attention:
253
- ```python
254
- model = AutoModelForCausalLM.from_pretrained(
255
- "model-name",
256
- attn_implementation="eager", # Standard PyTorch
257
- torch_dtype=torch.float16
258
- )
259
- ```
260
-
261
- ## Best practices
262
-
263
- 1. **Always use float16/bfloat16** with Flash Attention (not float32)
264
- 2. **Set device_map="auto"** for automatic memory management
265
- 3. **Use bfloat16 for long context** (better numerical stability)
266
- 4. **Enable gradient checkpointing** for training large models
267
- 5. **Monitor memory** with `torch.cuda.max_memory_allocated()`
268
-
269
- **Example with all best practices**:
270
- ```python
271
- from transformers import AutoModelForCausalLM, TrainingArguments
272
-
273
- model = AutoModelForCausalLM.from_pretrained(
274
- "meta-llama/Llama-2-7b-hf",
275
- attn_implementation="flash_attention_2",
276
- torch_dtype=torch.bfloat16, # Better for training
277
- device_map="auto",
278
- low_cpu_mem_usage=True
279
- )
280
-
281
- # Enable gradient checkpointing for memory
282
- model.gradient_checkpointing_enable()
283
-
284
- # Training with optimizations
285
- training_args = TrainingArguments(
286
- output_dir="./results",
287
- per_device_train_batch_size=8,
288
- gradient_accumulation_steps=2,
289
- bf16=True, # Match model dtype
290
- optim="adamw_torch_fused",
291
- gradient_checkpointing=True
292
- )
293
- ```