EvoScientist 0.0.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. EvoScientist/EvoScientist.py +157 -0
  2. EvoScientist/__init__.py +24 -0
  3. EvoScientist/__main__.py +4 -0
  4. EvoScientist/backends.py +392 -0
  5. EvoScientist/cli.py +1553 -0
  6. EvoScientist/middleware.py +35 -0
  7. EvoScientist/prompts.py +277 -0
  8. EvoScientist/skills/accelerate/SKILL.md +332 -0
  9. EvoScientist/skills/accelerate/references/custom-plugins.md +453 -0
  10. EvoScientist/skills/accelerate/references/megatron-integration.md +489 -0
  11. EvoScientist/skills/accelerate/references/performance.md +525 -0
  12. EvoScientist/skills/bitsandbytes/SKILL.md +411 -0
  13. EvoScientist/skills/bitsandbytes/references/memory-optimization.md +521 -0
  14. EvoScientist/skills/bitsandbytes/references/qlora-training.md +521 -0
  15. EvoScientist/skills/bitsandbytes/references/quantization-formats.md +447 -0
  16. EvoScientist/skills/find-skills/SKILL.md +133 -0
  17. EvoScientist/skills/find-skills/scripts/install_skill.py +211 -0
  18. EvoScientist/skills/flash-attention/SKILL.md +367 -0
  19. EvoScientist/skills/flash-attention/references/benchmarks.md +215 -0
  20. EvoScientist/skills/flash-attention/references/transformers-integration.md +293 -0
  21. EvoScientist/skills/llama-cpp/SKILL.md +258 -0
  22. EvoScientist/skills/llama-cpp/references/optimization.md +89 -0
  23. EvoScientist/skills/llama-cpp/references/quantization.md +213 -0
  24. EvoScientist/skills/llama-cpp/references/server.md +125 -0
  25. EvoScientist/skills/lm-evaluation-harness/SKILL.md +490 -0
  26. EvoScientist/skills/lm-evaluation-harness/references/api-evaluation.md +490 -0
  27. EvoScientist/skills/lm-evaluation-harness/references/benchmark-guide.md +488 -0
  28. EvoScientist/skills/lm-evaluation-harness/references/custom-tasks.md +602 -0
  29. EvoScientist/skills/lm-evaluation-harness/references/distributed-eval.md +519 -0
  30. EvoScientist/skills/ml-paper-writing/SKILL.md +937 -0
  31. EvoScientist/skills/ml-paper-writing/references/checklists.md +361 -0
  32. EvoScientist/skills/ml-paper-writing/references/citation-workflow.md +562 -0
  33. EvoScientist/skills/ml-paper-writing/references/reviewer-guidelines.md +367 -0
  34. EvoScientist/skills/ml-paper-writing/references/sources.md +159 -0
  35. EvoScientist/skills/ml-paper-writing/references/writing-guide.md +476 -0
  36. EvoScientist/skills/ml-paper-writing/templates/README.md +251 -0
  37. EvoScientist/skills/ml-paper-writing/templates/aaai2026/README.md +534 -0
  38. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026-unified-supp.tex +144 -0
  39. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026-unified-template.tex +952 -0
  40. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.bib +111 -0
  41. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.bst +1493 -0
  42. EvoScientist/skills/ml-paper-writing/templates/aaai2026/aaai2026.sty +315 -0
  43. EvoScientist/skills/ml-paper-writing/templates/acl/README.md +50 -0
  44. EvoScientist/skills/ml-paper-writing/templates/acl/acl.sty +312 -0
  45. EvoScientist/skills/ml-paper-writing/templates/acl/acl_latex.tex +377 -0
  46. EvoScientist/skills/ml-paper-writing/templates/acl/acl_lualatex.tex +101 -0
  47. EvoScientist/skills/ml-paper-writing/templates/acl/acl_natbib.bst +1940 -0
  48. EvoScientist/skills/ml-paper-writing/templates/acl/anthology.bib.txt +26 -0
  49. EvoScientist/skills/ml-paper-writing/templates/acl/custom.bib +70 -0
  50. EvoScientist/skills/ml-paper-writing/templates/acl/formatting.md +326 -0
  51. EvoScientist/skills/ml-paper-writing/templates/colm2025/README.md +3 -0
  52. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.bib +11 -0
  53. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.bst +1440 -0
  54. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.pdf +0 -0
  55. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.sty +218 -0
  56. EvoScientist/skills/ml-paper-writing/templates/colm2025/colm2025_conference.tex +305 -0
  57. EvoScientist/skills/ml-paper-writing/templates/colm2025/fancyhdr.sty +485 -0
  58. EvoScientist/skills/ml-paper-writing/templates/colm2025/math_commands.tex +508 -0
  59. EvoScientist/skills/ml-paper-writing/templates/colm2025/natbib.sty +1246 -0
  60. EvoScientist/skills/ml-paper-writing/templates/iclr2026/fancyhdr.sty +485 -0
  61. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.bib +24 -0
  62. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.bst +1440 -0
  63. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.pdf +0 -0
  64. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.sty +246 -0
  65. EvoScientist/skills/ml-paper-writing/templates/iclr2026/iclr2026_conference.tex +414 -0
  66. EvoScientist/skills/ml-paper-writing/templates/iclr2026/math_commands.tex +508 -0
  67. EvoScientist/skills/ml-paper-writing/templates/iclr2026/natbib.sty +1246 -0
  68. EvoScientist/skills/ml-paper-writing/templates/icml2026/algorithm.sty +79 -0
  69. EvoScientist/skills/ml-paper-writing/templates/icml2026/algorithmic.sty +201 -0
  70. EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.bib +75 -0
  71. EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.pdf +0 -0
  72. EvoScientist/skills/ml-paper-writing/templates/icml2026/example_paper.tex +662 -0
  73. EvoScientist/skills/ml-paper-writing/templates/icml2026/fancyhdr.sty +864 -0
  74. EvoScientist/skills/ml-paper-writing/templates/icml2026/icml2026.bst +1443 -0
  75. EvoScientist/skills/ml-paper-writing/templates/icml2026/icml2026.sty +767 -0
  76. EvoScientist/skills/ml-paper-writing/templates/icml2026/icml_numpapers.pdf +0 -0
  77. EvoScientist/skills/ml-paper-writing/templates/neurips2025/Makefile +36 -0
  78. EvoScientist/skills/ml-paper-writing/templates/neurips2025/extra_pkgs.tex +53 -0
  79. EvoScientist/skills/ml-paper-writing/templates/neurips2025/main.tex +38 -0
  80. EvoScientist/skills/ml-paper-writing/templates/neurips2025/neurips.sty +382 -0
  81. EvoScientist/skills/peft/SKILL.md +431 -0
  82. EvoScientist/skills/peft/references/advanced-usage.md +514 -0
  83. EvoScientist/skills/peft/references/troubleshooting.md +480 -0
  84. EvoScientist/skills/ray-data/SKILL.md +326 -0
  85. EvoScientist/skills/ray-data/references/integration.md +82 -0
  86. EvoScientist/skills/ray-data/references/transformations.md +83 -0
  87. EvoScientist/skills/skill-creator/LICENSE.txt +202 -0
  88. EvoScientist/skills/skill-creator/SKILL.md +356 -0
  89. EvoScientist/skills/skill-creator/references/output-patterns.md +82 -0
  90. EvoScientist/skills/skill-creator/references/workflows.md +28 -0
  91. EvoScientist/skills/skill-creator/scripts/init_skill.py +303 -0
  92. EvoScientist/skills/skill-creator/scripts/package_skill.py +110 -0
  93. EvoScientist/skills/skill-creator/scripts/quick_validate.py +95 -0
  94. EvoScientist/stream/__init__.py +53 -0
  95. EvoScientist/stream/emitter.py +94 -0
  96. EvoScientist/stream/formatter.py +168 -0
  97. EvoScientist/stream/tracker.py +115 -0
  98. EvoScientist/stream/utils.py +255 -0
  99. EvoScientist/subagent.yaml +147 -0
  100. EvoScientist/tools.py +135 -0
  101. EvoScientist/utils.py +207 -0
  102. evoscientist-0.0.1.dev1.dist-info/METADATA +222 -0
  103. evoscientist-0.0.1.dev1.dist-info/RECORD +107 -0
  104. evoscientist-0.0.1.dev1.dist-info/WHEEL +5 -0
  105. evoscientist-0.0.1.dev1.dist-info/entry_points.txt +2 -0
  106. evoscientist-0.0.1.dev1.dist-info/licenses/LICENSE +21 -0
  107. evoscientist-0.0.1.dev1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,211 @@
1
+ #!/usr/bin/env python3
2
+ """Install a skill from GitHub into a local skills directory.
3
+
4
+ Self-contained installer — no external dependencies beyond git.
5
+
6
+ Usage examples:
7
+ # Install from a GitHub URL (auto-detects repo, ref, path)
8
+ python install_skill.py --url https://github.com/anthropics/skills/tree/main/excel
9
+
10
+ # Install from repo + path
11
+ python install_skill.py --repo anthropics/skills --path excel
12
+
13
+ # Install multiple skills from the same repo
14
+ python install_skill.py --repo anthropics/skills --path excel --path pdf
15
+
16
+ # Install with a specific git ref
17
+ python install_skill.py --repo org/repo --path my-skill --ref v2.0
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import os
24
+ import re
25
+ import shutil
26
+ import subprocess
27
+ import sys
28
+ import tempfile
29
+
30
+
31
+ def parse_github_url(url: str) -> tuple[str, str | None, str | None]:
32
+ """Parse a GitHub URL into (repo, ref, path).
33
+
34
+ Supports formats:
35
+ https://github.com/owner/repo
36
+ https://github.com/owner/repo/tree/main/path/to/skill
37
+ github.com/owner/repo/tree/branch/path
38
+ owner/repo@skill-name (shorthand from skills.sh)
39
+
40
+ Returns:
41
+ (repo, ref_or_none, path_or_none)
42
+ """
43
+ # Shorthand: owner/repo@path
44
+ if "@" in url and "://" not in url:
45
+ repo, path = url.split("@", 1)
46
+ return repo.strip(), None, path.strip()
47
+
48
+ # Strip protocol and github.com prefix
49
+ cleaned = re.sub(r"^https?://", "", url)
50
+ cleaned = re.sub(r"^github\.com/", "", cleaned)
51
+ cleaned = cleaned.rstrip("/")
52
+
53
+ # Match: owner/repo/tree/ref/path...
54
+ m = re.match(r"^([^/]+/[^/]+)/tree/([^/]+)(?:/(.+))?$", cleaned)
55
+ if m:
56
+ return m.group(1), m.group(2), m.group(3)
57
+
58
+ # Match: owner/repo (no tree)
59
+ m = re.match(r"^([^/]+/[^/]+)$", cleaned)
60
+ if m:
61
+ return m.group(1), None, None
62
+
63
+ raise ValueError(f"Cannot parse GitHub URL: {url}")
64
+
65
+
66
+ def clone_repo(repo: str, ref: str | None, dest: str) -> None:
67
+ """Shallow-clone a GitHub repo."""
68
+ clone_url = f"https://github.com/{repo}.git"
69
+ cmd = ["git", "clone", "--depth", "1"]
70
+ if ref:
71
+ cmd += ["--branch", ref]
72
+ cmd += [clone_url, dest]
73
+
74
+ result = subprocess.run(cmd, capture_output=True, text=True)
75
+ if result.returncode != 0:
76
+ raise RuntimeError(f"git clone failed: {result.stderr.strip()}")
77
+
78
+
79
+ def copy_skill(src: str, dest_dir: str) -> str:
80
+ """Copy a skill directory to the destination.
81
+
82
+ Returns:
83
+ The skill name (directory basename).
84
+ """
85
+ skill_name = os.path.basename(src.rstrip("/"))
86
+ target = os.path.join(dest_dir, skill_name)
87
+
88
+ if os.path.exists(target):
89
+ shutil.rmtree(target)
90
+ print(f" Replaced existing: {skill_name}")
91
+
92
+ shutil.copytree(src, target)
93
+ return skill_name
94
+
95
+
96
+ def validate_skill(path: str) -> bool:
97
+ """Check that a directory looks like a valid skill (has SKILL.md)."""
98
+ return os.path.isfile(os.path.join(path, "SKILL.md"))
99
+
100
+
101
+ def install(
102
+ repo: str,
103
+ paths: list[str],
104
+ ref: str | None,
105
+ dest: str,
106
+ ) -> list[str]:
107
+ """Install skill(s) from a GitHub repo.
108
+
109
+ Returns:
110
+ List of installed skill names.
111
+ """
112
+ os.makedirs(dest, exist_ok=True)
113
+ installed: list[str] = []
114
+
115
+ with tempfile.TemporaryDirectory(prefix="skill-install-") as tmp:
116
+ clone_dir = os.path.join(tmp, "repo")
117
+ print(f"Cloning {repo}" + (f" @{ref}" if ref else "") + "...")
118
+ clone_repo(repo, ref, clone_dir)
119
+
120
+ if not paths:
121
+ # No path specified — treat entire repo as a single skill
122
+ if validate_skill(clone_dir):
123
+ name = copy_skill(clone_dir, dest)
124
+ installed.append(name)
125
+ else:
126
+ # List top-level directories that look like skills
127
+ for entry in sorted(os.listdir(clone_dir)):
128
+ entry_path = os.path.join(clone_dir, entry)
129
+ if os.path.isdir(entry_path) and validate_skill(entry_path):
130
+ name = copy_skill(entry_path, dest)
131
+ installed.append(name)
132
+
133
+ if not installed:
134
+ print("No valid skills found in repository root.", file=sys.stderr)
135
+ else:
136
+ for p in paths:
137
+ skill_path = os.path.join(clone_dir, p.strip("/"))
138
+ if not os.path.isdir(skill_path):
139
+ print(f" Path not found: {p}", file=sys.stderr)
140
+ continue
141
+ if not validate_skill(skill_path):
142
+ print(f" No SKILL.md in: {p}", file=sys.stderr)
143
+ continue
144
+ name = copy_skill(skill_path, dest)
145
+ installed.append(name)
146
+
147
+ return installed
148
+
149
+
150
+ def main() -> int:
151
+ parser = argparse.ArgumentParser(
152
+ description="Install skills from GitHub into a local skills directory.",
153
+ )
154
+ src = parser.add_mutually_exclusive_group(required=True)
155
+ src.add_argument(
156
+ "--url",
157
+ help="GitHub URL (e.g. https://github.com/owner/repo/tree/main/skill-name)",
158
+ )
159
+ src.add_argument(
160
+ "--repo",
161
+ help="GitHub repo (e.g. owner/repo)",
162
+ )
163
+ parser.add_argument(
164
+ "--path",
165
+ action="append",
166
+ default=[],
167
+ help="Path to skill inside repo (repeatable)",
168
+ )
169
+ parser.add_argument(
170
+ "--ref",
171
+ default=None,
172
+ help="Git branch or tag (default: repo default branch)",
173
+ )
174
+ parser.add_argument(
175
+ "--dest",
176
+ default="./skills",
177
+ help="Destination directory (default: ./skills)",
178
+ )
179
+
180
+ args = parser.parse_args()
181
+ dest = args.dest
182
+
183
+ # Parse source
184
+ if args.url:
185
+ repo, ref, path = parse_github_url(args.url)
186
+ ref = args.ref or ref
187
+ paths = [path] if path else args.path
188
+ else:
189
+ repo = args.repo
190
+ ref = args.ref
191
+ paths = args.path
192
+
193
+ try:
194
+ installed = install(repo, paths, ref, dest)
195
+ except RuntimeError as e:
196
+ print(f"Error: {e}", file=sys.stderr)
197
+ return 1
198
+
199
+ if installed:
200
+ print(f"\nInstalled {len(installed)} skill(s) to {dest}/:")
201
+ for name in installed:
202
+ print(f" - {name}")
203
+ else:
204
+ print("No skills were installed.", file=sys.stderr)
205
+ return 1
206
+
207
+ return 0
208
+
209
+
210
+ if __name__ == "__main__":
211
+ raise SystemExit(main())
@@ -0,0 +1,367 @@
1
+ ---
2
+ name: flash-attention
3
+ description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
4
+ version: 1.0.0
5
+ author: Orchestra Research
6
+ license: MIT
7
+ tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
8
+ dependencies: [flash-attn, torch, transformers]
9
+ ---
10
+
11
+ # Flash Attention - Fast Memory-Efficient Attention
12
+
13
+ ## Quick start
14
+
15
+ Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
16
+
17
+ **PyTorch native (easiest, PyTorch 2.2+)**:
18
+ ```python
19
+ import torch
20
+ import torch.nn.functional as F
21
+
22
+ q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
23
+ k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
24
+ v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
25
+
26
+ # Automatically uses Flash Attention if available
27
+ out = F.scaled_dot_product_attention(q, k, v)
28
+ ```
29
+
30
+ **flash-attn library (more features)**:
31
+ ```bash
32
+ pip install flash-attn --no-build-isolation
33
+ ```
34
+
35
+ ```python
36
+ from flash_attn import flash_attn_func
37
+
38
+ # q, k, v: [batch, seqlen, nheads, headdim]
39
+ out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
40
+ ```
41
+
42
+ ## Common workflows
43
+
44
+ ### Workflow 1: Enable in existing PyTorch model
45
+
46
+ Copy this checklist:
47
+
48
+ ```
49
+ Flash Attention Integration:
50
+ - [ ] Step 1: Check PyTorch version (≥2.2)
51
+ - [ ] Step 2: Enable Flash Attention backend
52
+ - [ ] Step 3: Verify speedup with profiling
53
+ - [ ] Step 4: Test accuracy matches baseline
54
+ ```
55
+
56
+ **Step 1: Check PyTorch version**
57
+
58
+ ```bash
59
+ python -c "import torch; print(torch.__version__)"
60
+ # Should be ≥2.2.0
61
+ ```
62
+
63
+ If <2.2, upgrade:
64
+ ```bash
65
+ pip install --upgrade torch
66
+ ```
67
+
68
+ **Step 2: Enable Flash Attention backend**
69
+
70
+ Replace standard attention:
71
+ ```python
72
+ # Before (standard attention)
73
+ attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
74
+ out = attn_weights @ v
75
+
76
+ # After (Flash Attention)
77
+ import torch.nn.functional as F
78
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
79
+ ```
80
+
81
+ Force Flash Attention backend:
82
+ ```python
83
+ with torch.backends.cuda.sdp_kernel(
84
+ enable_flash=True,
85
+ enable_math=False,
86
+ enable_mem_efficient=False
87
+ ):
88
+ out = F.scaled_dot_product_attention(q, k, v)
89
+ ```
90
+
91
+ **Step 3: Verify speedup with profiling**
92
+
93
+ ```python
94
+ import torch.utils.benchmark as benchmark
95
+
96
+ def test_attention(use_flash):
97
+ q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
98
+
99
+ if use_flash:
100
+ with torch.backends.cuda.sdp_kernel(enable_flash=True):
101
+ return F.scaled_dot_product_attention(q, k, v)
102
+ else:
103
+ attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
104
+ return attn @ v
105
+
106
+ # Benchmark
107
+ t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
108
+ t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
109
+
110
+ print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
111
+ print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
112
+ ```
113
+
114
+ Expected: 2-4x speedup for sequences >512 tokens.
115
+
116
+ **Step 4: Test accuracy matches baseline**
117
+
118
+ ```python
119
+ # Compare outputs
120
+ q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
121
+
122
+ # Flash Attention
123
+ out_flash = F.scaled_dot_product_attention(q, k, v)
124
+
125
+ # Standard attention
126
+ attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
127
+ out_standard = attn_weights @ v
128
+
129
+ # Check difference
130
+ diff = (out_flash - out_standard).abs().max()
131
+ print(f"Max difference: {diff:.6f}")
132
+ # Should be <1e-3 for float16
133
+ ```
134
+
135
+ ### Workflow 2: Use flash-attn library for advanced features
136
+
137
+ For multi-query attention, sliding window, or H100 FP8.
138
+
139
+ Copy this checklist:
140
+
141
+ ```
142
+ flash-attn Library Setup:
143
+ - [ ] Step 1: Install flash-attn library
144
+ - [ ] Step 2: Modify attention code
145
+ - [ ] Step 3: Enable advanced features
146
+ - [ ] Step 4: Benchmark performance
147
+ ```
148
+
149
+ **Step 1: Install flash-attn library**
150
+
151
+ ```bash
152
+ # NVIDIA GPUs (CUDA 12.0+)
153
+ pip install flash-attn --no-build-isolation
154
+
155
+ # Verify installation
156
+ python -c "from flash_attn import flash_attn_func; print('Success')"
157
+ ```
158
+
159
+ **Step 2: Modify attention code**
160
+
161
+ ```python
162
+ from flash_attn import flash_attn_func
163
+
164
+ # Input: [batch_size, seq_len, num_heads, head_dim]
165
+ # Transpose from [batch, heads, seq, dim] if needed
166
+ q = q.transpose(1, 2) # [batch, seq, heads, dim]
167
+ k = k.transpose(1, 2)
168
+ v = v.transpose(1, 2)
169
+
170
+ out = flash_attn_func(
171
+ q, k, v,
172
+ dropout_p=0.1,
173
+ causal=True, # For autoregressive models
174
+ window_size=(-1, -1), # No sliding window
175
+ softmax_scale=None # Auto-scale
176
+ )
177
+
178
+ out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
179
+ ```
180
+
181
+ **Step 3: Enable advanced features**
182
+
183
+ Multi-query attention (shared K/V across heads):
184
+ ```python
185
+ from flash_attn import flash_attn_func
186
+
187
+ # q: [batch, seq, num_q_heads, dim]
188
+ # k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
189
+ out = flash_attn_func(q, k, v) # Automatically handles MQA
190
+ ```
191
+
192
+ Sliding window attention (local attention):
193
+ ```python
194
+ # Only attend to window of 256 tokens before/after
195
+ out = flash_attn_func(
196
+ q, k, v,
197
+ window_size=(256, 256), # (left, right) window
198
+ causal=True
199
+ )
200
+ ```
201
+
202
+ **Step 4: Benchmark performance**
203
+
204
+ ```python
205
+ import torch
206
+ from flash_attn import flash_attn_func
207
+ import time
208
+
209
+ q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
210
+
211
+ # Warmup
212
+ for _ in range(10):
213
+ _ = flash_attn_func(q, k, v)
214
+
215
+ # Benchmark
216
+ torch.cuda.synchronize()
217
+ start = time.time()
218
+ for _ in range(100):
219
+ out = flash_attn_func(q, k, v)
220
+ torch.cuda.synchronize()
221
+ end = time.time()
222
+
223
+ print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
224
+ print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
225
+ ```
226
+
227
+ ### Workflow 3: H100 FP8 optimization (FlashAttention-3)
228
+
229
+ For maximum performance on H100 GPUs.
230
+
231
+ ```
232
+ FP8 Setup:
233
+ - [ ] Step 1: Verify H100 GPU available
234
+ - [ ] Step 2: Install flash-attn with FP8 support
235
+ - [ ] Step 3: Convert inputs to FP8
236
+ - [ ] Step 4: Run with FP8 attention
237
+ ```
238
+
239
+ **Step 1: Verify H100 GPU**
240
+
241
+ ```bash
242
+ nvidia-smi --query-gpu=name --format=csv
243
+ # Should show "H100" or "H800"
244
+ ```
245
+
246
+ **Step 2: Install flash-attn with FP8 support**
247
+
248
+ ```bash
249
+ pip install flash-attn --no-build-isolation
250
+ # FP8 support included for H100
251
+ ```
252
+
253
+ **Step 3: Convert inputs to FP8**
254
+
255
+ ```python
256
+ import torch
257
+
258
+ q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
259
+ k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
260
+ v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
261
+
262
+ # Convert to float8_e4m3 (FP8)
263
+ q_fp8 = q.to(torch.float8_e4m3fn)
264
+ k_fp8 = k.to(torch.float8_e4m3fn)
265
+ v_fp8 = v.to(torch.float8_e4m3fn)
266
+ ```
267
+
268
+ **Step 4: Run with FP8 attention**
269
+
270
+ ```python
271
+ from flash_attn import flash_attn_func
272
+
273
+ # FlashAttention-3 automatically uses FP8 kernels on H100
274
+ out = flash_attn_func(q_fp8, k_fp8, v_fp8)
275
+ # Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
276
+ ```
277
+
278
+ ## When to use vs alternatives
279
+
280
+ **Use Flash Attention when:**
281
+ - Training transformers with sequences >512 tokens
282
+ - Running inference with long context (>2K tokens)
283
+ - GPU memory constrained (OOM with standard attention)
284
+ - Need 2-4x speedup without accuracy loss
285
+ - Using PyTorch 2.2+ or can install flash-attn
286
+
287
+ **Use alternatives instead:**
288
+ - **Standard attention**: Sequences <256 tokens (overhead not worth it)
289
+ - **xFormers**: Need more attention variants (not just speed)
290
+ - **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
291
+
292
+ ## Common issues
293
+
294
+ **Issue: ImportError: cannot import flash_attn**
295
+
296
+ Install with no-build-isolation flag:
297
+ ```bash
298
+ pip install flash-attn --no-build-isolation
299
+ ```
300
+
301
+ Or install CUDA toolkit first:
302
+ ```bash
303
+ conda install cuda -c nvidia
304
+ pip install flash-attn --no-build-isolation
305
+ ```
306
+
307
+ **Issue: Slower than expected (no speedup)**
308
+
309
+ Flash Attention benefits increase with sequence length:
310
+ - <512 tokens: Minimal speedup (10-20%)
311
+ - 512-2K tokens: 2-3x speedup
312
+ - >2K tokens: 3-4x speedup
313
+
314
+ Check sequence length is sufficient.
315
+
316
+ **Issue: RuntimeError: CUDA error**
317
+
318
+ Verify GPU supports Flash Attention:
319
+ ```python
320
+ import torch
321
+ print(torch.cuda.get_device_capability())
322
+ # Should be ≥(7, 5) for Turing+
323
+ ```
324
+
325
+ Flash Attention requires:
326
+ - Ampere (A100, A10): ✅ Full support
327
+ - Turing (T4): ✅ Supported
328
+ - Volta (V100): ❌ Not supported
329
+
330
+ **Issue: Accuracy degradation**
331
+
332
+ Check dtype is float16 or bfloat16 (not float32):
333
+ ```python
334
+ q = q.to(torch.float16) # Or torch.bfloat16
335
+ ```
336
+
337
+ Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
338
+
339
+ ## Advanced topics
340
+
341
+ **Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
342
+
343
+ **Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
344
+
345
+ **Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
346
+
347
+ **Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
348
+
349
+ ## Hardware requirements
350
+
351
+ - **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
352
+ - **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
353
+ - **CUDA**: 12.0+ (11.8 minimum)
354
+ - **PyTorch**: 2.2+ for native support
355
+
356
+ **Not supported**: V100 (Volta), CPU inference
357
+
358
+ ## Resources
359
+
360
+ - Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
361
+ - Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
362
+ - Blog: https://tridao.me/blog/2024/flash3/
363
+ - GitHub: https://github.com/Dao-AILab/flash-attention
364
+ - PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
365
+
366
+
367
+