EvoScientist 0.0.1.dev3__py3-none-any.whl → 0.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. EvoScientist/EvoScientist.py +17 -49
  2. EvoScientist/backends.py +0 -26
  3. EvoScientist/cli.py +1109 -255
  4. EvoScientist/middleware.py +8 -61
  5. EvoScientist/stream/__init__.py +0 -25
  6. EvoScientist/stream/utils.py +16 -23
  7. EvoScientist/tools.py +0 -64
  8. evoscientist-0.1.0rc1.dist-info/METADATA +199 -0
  9. evoscientist-0.1.0rc1.dist-info/RECORD +21 -0
  10. evoscientist-0.1.0rc1.dist-info/entry_points.txt +2 -0
  11. EvoScientist/memory.py +0 -715
  12. EvoScientist/paths.py +0 -45
  13. EvoScientist/skills/accelerate/SKILL.md +0 -332
  14. EvoScientist/skills/accelerate/references/custom-plugins.md +0 -453
  15. EvoScientist/skills/accelerate/references/megatron-integration.md +0 -489
  16. EvoScientist/skills/accelerate/references/performance.md +0 -525
  17. EvoScientist/skills/bitsandbytes/SKILL.md +0 -411
  18. EvoScientist/skills/bitsandbytes/references/memory-optimization.md +0 -521
  19. EvoScientist/skills/bitsandbytes/references/qlora-training.md +0 -521
  20. EvoScientist/skills/bitsandbytes/references/quantization-formats.md +0 -447
  21. EvoScientist/skills/find-skills/SKILL.md +0 -133
  22. EvoScientist/skills/find-skills/scripts/install_skill.py +0 -211
  23. EvoScientist/skills/flash-attention/SKILL.md +0 -367
  24. EvoScientist/skills/flash-attention/references/benchmarks.md +0 -215
  25. EvoScientist/skills/flash-attention/references/transformers-integration.md +0 -293
  26. EvoScientist/skills/llama-cpp/SKILL.md +0 -258
  27. EvoScientist/skills/llama-cpp/references/optimization.md +0 -89
  28. EvoScientist/skills/llama-cpp/references/quantization.md +0 -213
  29. EvoScientist/skills/llama-cpp/references/server.md +0 -125
  30. EvoScientist/skills/lm-evaluation-harness/SKILL.md +0 -490
  31. EvoScientist/skills/lm-evaluation-harness/references/api-evaluation.md +0 -490
  32. EvoScientist/skills/lm-evaluation-harness/references/benchmark-guide.md +0 -488
  33. EvoScientist/skills/lm-evaluation-harness/references/custom-tasks.md +0 -602
  34. EvoScientist/skills/lm-evaluation-harness/references/distributed-eval.md +0 -519
  35. EvoScientist/skills/ml-paper-writing/SKILL.md +0 -937
  36. EvoScientist/skills/ml-paper-writing/references/checklists.md +0 -361
  37. EvoScientist/skills/ml-paper-writing/references/citation-workflow.md +0 -562
  38. EvoScientist/skills/ml-paper-writing/references/reviewer-guidelines.md +0 -367
  39. EvoScientist/skills/ml-paper-writing/references/sources.md +0 -159
  40. EvoScientist/skills/ml-paper-writing/references/writing-guide.md +0 -476
  41. EvoScientist/skills/ml-paper-writing/templates/README.md +0 -251
  42. EvoScientist/skills/ml-paper-writing/templates/aaai2026/README.md +0 -534
  43. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026-unified-supp.tex +0 -144
  44. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026-unified-template.tex +0 -952
  45. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.bib +0 -111
  46. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.bst +0 -1493
  47. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.sty +0 -315
  48. EvoScientist/skills/ml-paper-writing/templates/acl/README.md +0 -50
  49. EvoScientist/skills/ml-paper-writing/templates/acl/acl.sty +0 -312
  50. EvoScientist/skills/ml-paper-writing/templates/acl/acl_latex.tex +0 -377
  51. EvoScientist/skills/ml-paper-writing/templates/acl/acl_lualatex.tex +0 -101
  52. EvoScientist/skills/ml-paper-writing/templates/acl/acl_natbib.bst +0 -1940
  53. EvoScientist/skills/ml-paper-writing/templates/acl/anthology.bib.txt +0 -26
  54. EvoScientist/skills/ml-paper-writing/templates/acl/custom.bib +0 -70
  55. EvoScientist/skills/ml-paper-writing/templates/acl/formatting.md +0 -326
  56. EvoScientist/skills/ml-paper-writing/templates/colm2025/README.md +0 -3
  57. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.bib +0 -11
  58. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.bst +0 -1440
  59. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.pdf +0 -0
  60. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.sty +0 -218
  61. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.tex +0 -305
  62. EvoScientist/skills/ml-paper-writing/templates/colm2025/fancyhdr.sty +0 -485
  63. EvoScientist/skills/ml-paper-writing/templates/colm2025/math_commands.tex +0 -508
  64. EvoScientist/skills/ml-paper-writing/templates/colm2025/natbib.sty +0 -1246
  65. EvoScientist/skills/ml-paper-writing/templates/iclr2026/fancyhdr.sty +0 -485
  66. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.bib +0 -24
  67. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.bst +0 -1440
  68. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.pdf +0 -0
  69. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.sty +0 -246
  70. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.tex +0 -414
  71. EvoScientist/skills/ml-paper-writing/templates/iclr2026/math_commands.tex +0 -508
  72. EvoScientist/skills/ml-paper-writing/templates/iclr2026/natbib.sty +0 -1246
  73. EvoScientist/skills/ml-paper-writing/templates/icml2026/algorithm.sty +0 -79
  74. EvoScientist/skills/ml-paper-writing/templates/icml2026/algorithmic.sty +0 -201
  75. EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.bib +0 -75
  76. EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.pdf +0 -0
  77. EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.tex +0 -662
  78. EvoScientist/skills/ml-paper-writing/templates/icml2026/fancyhdr.sty +0 -864
  79. EvoScientist/skills/ml-paper-writing/templates/icml2026/icml2026.bst +0 -1443
  80. EvoScientist/skills/ml-paper-writing/templates/icml2026/icml2026.sty +0 -767
  81. EvoScientist/skills/ml-paper-writing/templates/icml2026/icml_numpapers.pdf +0 -0
  82. EvoScientist/skills/ml-paper-writing/templates/neurips2025/Makefile +0 -36
  83. EvoScientist/skills/ml-paper-writing/templates/neurips2025/extra_pkgs.tex +0 -53
  84. EvoScientist/skills/ml-paper-writing/templates/neurips2025/main.tex +0 -38
  85. EvoScientist/skills/ml-paper-writing/templates/neurips2025/neurips.sty +0 -382
  86. EvoScientist/skills/peft/SKILL.md +0 -431
  87. EvoScientist/skills/peft/references/advanced-usage.md +0 -514
  88. EvoScientist/skills/peft/references/troubleshooting.md +0 -480
  89. EvoScientist/skills/ray-data/SKILL.md +0 -326
  90. EvoScientist/skills/ray-data/references/integration.md +0 -82
  91. EvoScientist/skills/ray-data/references/transformations.md +0 -83
  92. EvoScientist/skills/skill-creator/LICENSE.txt +0 -202
  93. EvoScientist/skills/skill-creator/SKILL.md +0 -356
  94. EvoScientist/skills/skill-creator/references/output-patterns.md +0 -82
  95. EvoScientist/skills/skill-creator/references/workflows.md +0 -28
  96. EvoScientist/skills/skill-creator/scripts/init_skill.py +0 -303
  97. EvoScientist/skills/skill-creator/scripts/package_skill.py +0 -110
  98. EvoScientist/skills/skill-creator/scripts/quick_validate.py +0 -95
  99. EvoScientist/skills_manager.py +0 -392
  100. EvoScientist/stream/display.py +0 -604
  101. EvoScientist/stream/events.py +0 -415
  102. EvoScientist/stream/state.py +0 -343
  103. evoscientist-0.0.1.dev3.dist-info/METADATA +0 -321
  104. evoscientist-0.0.1.dev3.dist-info/RECORD +0 -113
  105. evoscientist-0.0.1.dev3.dist-info/entry_points.txt +0 -5
  106. {evoscientist-0.0.1.dev3.dist-info → evoscientist-0.1.0rc1.dist-info}/WHEEL +0 -0
  107. {evoscientist-0.0.1.dev3.dist-info → evoscientist-0.1.0rc1.dist-info}/licenses/LICENSE +0 -0
  108. {evoscientist-0.0.1.dev3.dist-info → evoscientist-0.1.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,211 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Install a skill from GitHub into a local skills directory.
3
-
4
- Self-contained installer — no external dependencies beyond git.
5
-
6
- Usage examples:
7
- # Install from a GitHub URL (auto-detects repo, ref, path)
8
- python install_skill.py --url https://github.com/anthropics/skills/tree/main/excel
9
-
10
- # Install from repo + path
11
- python install_skill.py --repo anthropics/skills --path excel
12
-
13
- # Install multiple skills from the same repo
14
- python install_skill.py --repo anthropics/skills --path excel --path pdf
15
-
16
- # Install with a specific git ref
17
- python install_skill.py --repo org/repo --path my-skill --ref v2.0
18
- """
19
-
20
- from __future__ import annotations
21
-
22
- import argparse
23
- import os
24
- import re
25
- import shutil
26
- import subprocess
27
- import sys
28
- import tempfile
29
-
30
-
31
- def parse_github_url(url: str) -> tuple[str, str | None, str | None]:
32
- """Parse a GitHub URL into (repo, ref, path).
33
-
34
- Supports formats:
35
- https://github.com/owner/repo
36
- https://github.com/owner/repo/tree/main/path/to/skill
37
- github.com/owner/repo/tree/branch/path
38
- owner/repo@skill-name (shorthand from skills.sh)
39
-
40
- Returns:
41
- (repo, ref_or_none, path_or_none)
42
- """
43
- # Shorthand: owner/repo@path
44
- if "@" in url and "://" not in url:
45
- repo, path = url.split("@", 1)
46
- return repo.strip(), None, path.strip()
47
-
48
- # Strip protocol and github.com prefix
49
- cleaned = re.sub(r"^https?://", "", url)
50
- cleaned = re.sub(r"^github\.com/", "", cleaned)
51
- cleaned = cleaned.rstrip("/")
52
-
53
- # Match: owner/repo/tree/ref/path...
54
- m = re.match(r"^([^/]+/[^/]+)/tree/([^/]+)(?:/(.+))?$", cleaned)
55
- if m:
56
- return m.group(1), m.group(2), m.group(3)
57
-
58
- # Match: owner/repo (no tree)
59
- m = re.match(r"^([^/]+/[^/]+)$", cleaned)
60
- if m:
61
- return m.group(1), None, None
62
-
63
- raise ValueError(f"Cannot parse GitHub URL: {url}")
64
-
65
-
66
- def clone_repo(repo: str, ref: str | None, dest: str) -> None:
67
- """Shallow-clone a GitHub repo."""
68
- clone_url = f"https://github.com/{repo}.git"
69
- cmd = ["git", "clone", "--depth", "1"]
70
- if ref:
71
- cmd += ["--branch", ref]
72
- cmd += [clone_url, dest]
73
-
74
- result = subprocess.run(cmd, capture_output=True, text=True)
75
- if result.returncode != 0:
76
- raise RuntimeError(f"git clone failed: {result.stderr.strip()}")
77
-
78
-
79
- def copy_skill(src: str, dest_dir: str) -> str:
80
- """Copy a skill directory to the destination.
81
-
82
- Returns:
83
- The skill name (directory basename).
84
- """
85
- skill_name = os.path.basename(src.rstrip("/"))
86
- target = os.path.join(dest_dir, skill_name)
87
-
88
- if os.path.exists(target):
89
- shutil.rmtree(target)
90
- print(f" Replaced existing: {skill_name}")
91
-
92
- shutil.copytree(src, target)
93
- return skill_name
94
-
95
-
96
- def validate_skill(path: str) -> bool:
97
- """Check that a directory looks like a valid skill (has SKILL.md)."""
98
- return os.path.isfile(os.path.join(path, "SKILL.md"))
99
-
100
-
101
- def install(
102
- repo: str,
103
- paths: list[str],
104
- ref: str | None,
105
- dest: str,
106
- ) -> list[str]:
107
- """Install skill(s) from a GitHub repo.
108
-
109
- Returns:
110
- List of installed skill names.
111
- """
112
- os.makedirs(dest, exist_ok=True)
113
- installed: list[str] = []
114
-
115
- with tempfile.TemporaryDirectory(prefix="skill-install-") as tmp:
116
- clone_dir = os.path.join(tmp, "repo")
117
- print(f"Cloning {repo}" + (f" @{ref}" if ref else "") + "...")
118
- clone_repo(repo, ref, clone_dir)
119
-
120
- if not paths:
121
- # No path specified — treat entire repo as a single skill
122
- if validate_skill(clone_dir):
123
- name = copy_skill(clone_dir, dest)
124
- installed.append(name)
125
- else:
126
- # List top-level directories that look like skills
127
- for entry in sorted(os.listdir(clone_dir)):
128
- entry_path = os.path.join(clone_dir, entry)
129
- if os.path.isdir(entry_path) and validate_skill(entry_path):
130
- name = copy_skill(entry_path, dest)
131
- installed.append(name)
132
-
133
- if not installed:
134
- print("No valid skills found in repository root.", file=sys.stderr)
135
- else:
136
- for p in paths:
137
- skill_path = os.path.join(clone_dir, p.strip("/"))
138
- if not os.path.isdir(skill_path):
139
- print(f" Path not found: {p}", file=sys.stderr)
140
- continue
141
- if not validate_skill(skill_path):
142
- print(f" No SKILL.md in: {p}", file=sys.stderr)
143
- continue
144
- name = copy_skill(skill_path, dest)
145
- installed.append(name)
146
-
147
- return installed
148
-
149
-
150
- def main() -> int:
151
- parser = argparse.ArgumentParser(
152
- description="Install skills from GitHub into a local skills directory.",
153
- )
154
- src = parser.add_mutually_exclusive_group(required=True)
155
- src.add_argument(
156
- "--url",
157
- help="GitHub URL (e.g. https://github.com/owner/repo/tree/main/skill-name)",
158
- )
159
- src.add_argument(
160
- "--repo",
161
- help="GitHub repo (e.g. owner/repo)",
162
- )
163
- parser.add_argument(
164
- "--path",
165
- action="append",
166
- default=[],
167
- help="Path to skill inside repo (repeatable)",
168
- )
169
- parser.add_argument(
170
- "--ref",
171
- default=None,
172
- help="Git branch or tag (default: repo default branch)",
173
- )
174
- parser.add_argument(
175
- "--dest",
176
- default="./skills",
177
- help="Destination directory (default: ./skills)",
178
- )
179
-
180
- args = parser.parse_args()
181
- dest = args.dest
182
-
183
- # Parse source
184
- if args.url:
185
- repo, ref, path = parse_github_url(args.url)
186
- ref = args.ref or ref
187
- paths = [path] if path else args.path
188
- else:
189
- repo = args.repo
190
- ref = args.ref
191
- paths = args.path
192
-
193
- try:
194
- installed = install(repo, paths, ref, dest)
195
- except RuntimeError as e:
196
- print(f"Error: {e}", file=sys.stderr)
197
- return 1
198
-
199
- if installed:
200
- print(f"\nInstalled {len(installed)} skill(s) to {dest}/:")
201
- for name in installed:
202
- print(f" - {name}")
203
- else:
204
- print("No skills were installed.", file=sys.stderr)
205
- return 1
206
-
207
- return 0
208
-
209
-
210
- if __name__ == "__main__":
211
- raise SystemExit(main())
@@ -1,367 +0,0 @@
1
- ---
2
- name: flash-attention
3
- description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
4
- version: 1.0.0
5
- author: Orchestra Research
6
- license: MIT
7
- tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
8
- dependencies: [flash-attn, torch, transformers]
9
- ---
10
-
11
- # Flash Attention - Fast Memory-Efficient Attention
12
-
13
- ## Quick start
14
-
15
- Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
16
-
17
- **PyTorch native (easiest, PyTorch 2.2+)**:
18
- ```python
19
- import torch
20
- import torch.nn.functional as F
21
-
22
- q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
23
- k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
24
- v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
25
-
26
- # Automatically uses Flash Attention if available
27
- out = F.scaled_dot_product_attention(q, k, v)
28
- ```
29
-
30
- **flash-attn library (more features)**:
31
- ```bash
32
- pip install flash-attn --no-build-isolation
33
- ```
34
-
35
- ```python
36
- from flash_attn import flash_attn_func
37
-
38
- # q, k, v: [batch, seqlen, nheads, headdim]
39
- out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
40
- ```
41
-
42
- ## Common workflows
43
-
44
- ### Workflow 1: Enable in existing PyTorch model
45
-
46
- Copy this checklist:
47
-
48
- ```
49
- Flash Attention Integration:
50
- - [ ] Step 1: Check PyTorch version (≥2.2)
51
- - [ ] Step 2: Enable Flash Attention backend
52
- - [ ] Step 3: Verify speedup with profiling
53
- - [ ] Step 4: Test accuracy matches baseline
54
- ```
55
-
56
- **Step 1: Check PyTorch version**
57
-
58
- ```bash
59
- python -c "import torch; print(torch.__version__)"
60
- # Should be ≥2.2.0
61
- ```
62
-
63
- If <2.2, upgrade:
64
- ```bash
65
- pip install --upgrade torch
66
- ```
67
-
68
- **Step 2: Enable Flash Attention backend**
69
-
70
- Replace standard attention:
71
- ```python
72
- # Before (standard attention)
73
- attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
74
- out = attn_weights @ v
75
-
76
- # After (Flash Attention)
77
- import torch.nn.functional as F
78
- out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
79
- ```
80
-
81
- Force Flash Attention backend:
82
- ```python
83
- with torch.backends.cuda.sdp_kernel(
84
- enable_flash=True,
85
- enable_math=False,
86
- enable_mem_efficient=False
87
- ):
88
- out = F.scaled_dot_product_attention(q, k, v)
89
- ```
90
-
91
- **Step 3: Verify speedup with profiling**
92
-
93
- ```python
94
- import torch.utils.benchmark as benchmark
95
-
96
- def test_attention(use_flash):
97
- q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
98
-
99
- if use_flash:
100
- with torch.backends.cuda.sdp_kernel(enable_flash=True):
101
- return F.scaled_dot_product_attention(q, k, v)
102
- else:
103
- attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
104
- return attn @ v
105
-
106
- # Benchmark
107
- t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
108
- t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
109
-
110
- print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
111
- print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
112
- ```
113
-
114
- Expected: 2-4x speedup for sequences >512 tokens.
115
-
116
- **Step 4: Test accuracy matches baseline**
117
-
118
- ```python
119
- # Compare outputs
120
- q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
121
-
122
- # Flash Attention
123
- out_flash = F.scaled_dot_product_attention(q, k, v)
124
-
125
- # Standard attention
126
- attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
127
- out_standard = attn_weights @ v
128
-
129
- # Check difference
130
- diff = (out_flash - out_standard).abs().max()
131
- print(f"Max difference: {diff:.6f}")
132
- # Should be <1e-3 for float16
133
- ```
134
-
135
- ### Workflow 2: Use flash-attn library for advanced features
136
-
137
- For multi-query attention, sliding window, or H100 FP8.
138
-
139
- Copy this checklist:
140
-
141
- ```
142
- flash-attn Library Setup:
143
- - [ ] Step 1: Install flash-attn library
144
- - [ ] Step 2: Modify attention code
145
- - [ ] Step 3: Enable advanced features
146
- - [ ] Step 4: Benchmark performance
147
- ```
148
-
149
- **Step 1: Install flash-attn library**
150
-
151
- ```bash
152
- # NVIDIA GPUs (CUDA 12.0+)
153
- pip install flash-attn --no-build-isolation
154
-
155
- # Verify installation
156
- python -c "from flash_attn import flash_attn_func; print('Success')"
157
- ```
158
-
159
- **Step 2: Modify attention code**
160
-
161
- ```python
162
- from flash_attn import flash_attn_func
163
-
164
- # Input: [batch_size, seq_len, num_heads, head_dim]
165
- # Transpose from [batch, heads, seq, dim] if needed
166
- q = q.transpose(1, 2) # [batch, seq, heads, dim]
167
- k = k.transpose(1, 2)
168
- v = v.transpose(1, 2)
169
-
170
- out = flash_attn_func(
171
- q, k, v,
172
- dropout_p=0.1,
173
- causal=True, # For autoregressive models
174
- window_size=(-1, -1), # No sliding window
175
- softmax_scale=None # Auto-scale
176
- )
177
-
178
- out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
179
- ```
180
-
181
- **Step 3: Enable advanced features**
182
-
183
- Multi-query attention (shared K/V across heads):
184
- ```python
185
- from flash_attn import flash_attn_func
186
-
187
- # q: [batch, seq, num_q_heads, dim]
188
- # k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
189
- out = flash_attn_func(q, k, v) # Automatically handles MQA
190
- ```
191
-
192
- Sliding window attention (local attention):
193
- ```python
194
- # Only attend to window of 256 tokens before/after
195
- out = flash_attn_func(
196
- q, k, v,
197
- window_size=(256, 256), # (left, right) window
198
- causal=True
199
- )
200
- ```
201
-
202
- **Step 4: Benchmark performance**
203
-
204
- ```python
205
- import torch
206
- from flash_attn import flash_attn_func
207
- import time
208
-
209
- q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
210
-
211
- # Warmup
212
- for _ in range(10):
213
- _ = flash_attn_func(q, k, v)
214
-
215
- # Benchmark
216
- torch.cuda.synchronize()
217
- start = time.time()
218
- for _ in range(100):
219
- out = flash_attn_func(q, k, v)
220
- torch.cuda.synchronize()
221
- end = time.time()
222
-
223
- print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
224
- print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
225
- ```
226
-
227
- ### Workflow 3: H100 FP8 optimization (FlashAttention-3)
228
-
229
- For maximum performance on H100 GPUs.
230
-
231
- ```
232
- FP8 Setup:
233
- - [ ] Step 1: Verify H100 GPU available
234
- - [ ] Step 2: Install flash-attn with FP8 support
235
- - [ ] Step 3: Convert inputs to FP8
236
- - [ ] Step 4: Run with FP8 attention
237
- ```
238
-
239
- **Step 1: Verify H100 GPU**
240
-
241
- ```bash
242
- nvidia-smi --query-gpu=name --format=csv
243
- # Should show "H100" or "H800"
244
- ```
245
-
246
- **Step 2: Install flash-attn with FP8 support**
247
-
248
- ```bash
249
- pip install flash-attn --no-build-isolation
250
- # FP8 support included for H100
251
- ```
252
-
253
- **Step 3: Convert inputs to FP8**
254
-
255
- ```python
256
- import torch
257
-
258
- q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
259
- k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
260
- v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
261
-
262
- # Convert to float8_e4m3 (FP8)
263
- q_fp8 = q.to(torch.float8_e4m3fn)
264
- k_fp8 = k.to(torch.float8_e4m3fn)
265
- v_fp8 = v.to(torch.float8_e4m3fn)
266
- ```
267
-
268
- **Step 4: Run with FP8 attention**
269
-
270
- ```python
271
- from flash_attn import flash_attn_func
272
-
273
- # FlashAttention-3 automatically uses FP8 kernels on H100
274
- out = flash_attn_func(q_fp8, k_fp8, v_fp8)
275
- # Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
276
- ```
277
-
278
- ## When to use vs alternatives
279
-
280
- **Use Flash Attention when:**
281
- - Training transformers with sequences >512 tokens
282
- - Running inference with long context (>2K tokens)
283
- - GPU memory constrained (OOM with standard attention)
284
- - Need 2-4x speedup without accuracy loss
285
- - Using PyTorch 2.2+ or can install flash-attn
286
-
287
- **Use alternatives instead:**
288
- - **Standard attention**: Sequences <256 tokens (overhead not worth it)
289
- - **xFormers**: Need more attention variants (not just speed)
290
- - **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
291
-
292
- ## Common issues
293
-
294
- **Issue: ImportError: cannot import flash_attn**
295
-
296
- Install with no-build-isolation flag:
297
- ```bash
298
- pip install flash-attn --no-build-isolation
299
- ```
300
-
301
- Or install CUDA toolkit first:
302
- ```bash
303
- conda install cuda -c nvidia
304
- pip install flash-attn --no-build-isolation
305
- ```
306
-
307
- **Issue: Slower than expected (no speedup)**
308
-
309
- Flash Attention benefits increase with sequence length:
310
- - <512 tokens: Minimal speedup (10-20%)
311
- - 512-2K tokens: 2-3x speedup
312
- - >2K tokens: 3-4x speedup
313
-
314
- Check sequence length is sufficient.
315
-
316
- **Issue: RuntimeError: CUDA error**
317
-
318
- Verify GPU supports Flash Attention:
319
- ```python
320
- import torch
321
- print(torch.cuda.get_device_capability())
322
- # Should be ≥(7, 5) for Turing+
323
- ```
324
-
325
- Flash Attention requires:
326
- - Ampere (A100, A10): ✅ Full support
327
- - Turing (T4): ✅ Supported
328
- - Volta (V100): ❌ Not supported
329
-
330
- **Issue: Accuracy degradation**
331
-
332
- Check dtype is float16 or bfloat16 (not float32):
333
- ```python
334
- q = q.to(torch.float16) # Or torch.bfloat16
335
- ```
336
-
337
- Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
338
-
339
- ## Advanced topics
340
-
341
- **Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
342
-
343
- **Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
344
-
345
- **Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
346
-
347
- **Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
348
-
349
- ## Hardware requirements
350
-
351
- - **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
352
- - **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
353
- - **CUDA**: 12.0+ (11.8 minimum)
354
- - **PyTorch**: 2.2+ for native support
355
-
356
- **Not supported**: V100 (Volta), CPU inference
357
-
358
- ## Resources
359
-
360
- - Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
361
- - Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
362
- - Blog: https://tridao.me/blog/2024/flash3/
363
- - GitHub: https://github.com/Dao-AILab/flash-attention
364
- - PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
365
-
366
-
367
-