autosteer 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autosteer-1.0.0/.gitignore +38 -0
- autosteer-1.0.0/PKG-INFO +90 -0
- autosteer-1.0.0/README.md +68 -0
- autosteer-1.0.0/auto_steer.py +1037 -0
- autosteer-1.0.0/pyproject.toml +35 -0
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Python
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
*.egg-info/
|
|
6
|
+
*.egg
|
|
7
|
+
dist/
|
|
8
|
+
build/
|
|
9
|
+
*.whl
|
|
10
|
+
|
|
11
|
+
# Environment variables
|
|
12
|
+
.env
|
|
13
|
+
|
|
14
|
+
# Virtual environments
|
|
15
|
+
.venv/
|
|
16
|
+
venv/
|
|
17
|
+
env/
|
|
18
|
+
|
|
19
|
+
# IDE
|
|
20
|
+
.idea/
|
|
21
|
+
.vscode/
|
|
22
|
+
*.swp
|
|
23
|
+
*.swo
|
|
24
|
+
*~
|
|
25
|
+
.DS_Store
|
|
26
|
+
|
|
27
|
+
# Testing
|
|
28
|
+
.pytest_cache/
|
|
29
|
+
.coverage
|
|
30
|
+
htmlcov/
|
|
31
|
+
.mypy_cache/
|
|
32
|
+
|
|
33
|
+
# Claude Code working memory
|
|
34
|
+
.memory/
|
|
35
|
+
|
|
36
|
+
# Distribution
|
|
37
|
+
*.tar.gz
|
|
38
|
+
.docs/
|
autosteer-1.0.0/PKG-INFO
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: autosteer
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: Research direction generator for autoresearch — analyzes experiment history and suggests next steps
|
|
5
|
+
Project-URL: Homepage, https://github.com/dean0x/autolab
|
|
6
|
+
Project-URL: Repository, https://github.com/dean0x/autolab
|
|
7
|
+
Project-URL: Issues, https://github.com/dean0x/autolab/issues
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
Keywords: autoresearch,experiment-suggestions,gpt,karpathy,pretraining
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Environment :: Console
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Requires-Python: >=3.10
|
|
20
|
+
Requires-Dist: click>=8.0
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
|
|
23
|
+
# autosteer
|
|
24
|
+
|
|
25
|
+
Research direction generator for [autoresearch](https://github.com/karpathy/autoresearch). Analyzes experiment history and suggests data-driven next steps instead of random-walking through experiment space.
|
|
26
|
+
|
|
27
|
+
## Install
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
pip install autosteer
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
## Usage
|
|
34
|
+
|
|
35
|
+
```bash
|
|
36
|
+
# Get 5 suggestions (default)
|
|
37
|
+
autosteer --results results.tsv
|
|
38
|
+
|
|
39
|
+
# Explore mode — favor untried directions (good when stuck)
|
|
40
|
+
autosteer --results results.tsv --strategy explore
|
|
41
|
+
|
|
42
|
+
# Exploit mode — double down on what works
|
|
43
|
+
autosteer --results results.tsv --strategy exploit
|
|
44
|
+
|
|
45
|
+
# More suggestions
|
|
46
|
+
autosteer --results results.tsv --num-suggestions 10
|
|
47
|
+
|
|
48
|
+
# Quick numbered list
|
|
49
|
+
autosteer --results results.tsv --quiet
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
## Strategy Modes
|
|
53
|
+
|
|
54
|
+
| Mode | When to Use |
|
|
55
|
+
|------|-------------|
|
|
56
|
+
| `auto` | Default. Balances explore/exploit based on experiment count. |
|
|
57
|
+
| `explore` | Early research, or stuck after 3+ discards. Favors untried categories. |
|
|
58
|
+
| `exploit` | You have proven wins. Doubles down on what works. |
|
|
59
|
+
|
|
60
|
+
## Output
|
|
61
|
+
|
|
62
|
+
Each suggestion includes:
|
|
63
|
+
- **Badge**: `[EXPLORE]` or `[EXPLOIT]` indicating category status
|
|
64
|
+
- **Risk level**: `low`, `medium`, or `high`
|
|
65
|
+
- **Rationale**: Why this was ranked where it is
|
|
66
|
+
|
|
67
|
+
```
|
|
68
|
+
[1] [EXPLOIT] Tune learning rate warmup schedule risk: low
|
|
69
|
+
Rationale: Learning rate experiments have 3 keeps in 4 attempts.
|
|
70
|
+
[2] [EXPLORE] Try rotary position embeddings risk: medium
|
|
71
|
+
Rationale: Positional encoding category untested. High potential.
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
## How It Works
|
|
75
|
+
|
|
76
|
+
- 20 built-in research directions specific to GPT pretraining
|
|
77
|
+
- Categorizes past experiments (architecture, hyperparams, optimizer, etc.)
|
|
78
|
+
- Keyword deduplication: won't re-suggest failed directions
|
|
79
|
+
- Git integration: reads diffs to classify experiments automatically
|
|
80
|
+
- Strategy-weighted scoring that adapts to experiment count
|
|
81
|
+
|
|
82
|
+
## Requirements
|
|
83
|
+
|
|
84
|
+
- Python >= 3.10
|
|
85
|
+
- A `results.tsv` file from autoresearch
|
|
86
|
+
- Git repository (for diff-based experiment classification)
|
|
87
|
+
|
|
88
|
+
## License
|
|
89
|
+
|
|
90
|
+
MIT
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# autosteer
|
|
2
|
+
|
|
3
|
+
Research direction generator for [autoresearch](https://github.com/karpathy/autoresearch). Analyzes experiment history and suggests data-driven next steps instead of random-walking through experiment space.
|
|
4
|
+
|
|
5
|
+
## Install
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install autosteer
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Usage
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
# Get 5 suggestions (default)
|
|
15
|
+
autosteer --results results.tsv
|
|
16
|
+
|
|
17
|
+
# Explore mode — favor untried directions (good when stuck)
|
|
18
|
+
autosteer --results results.tsv --strategy explore
|
|
19
|
+
|
|
20
|
+
# Exploit mode — double down on what works
|
|
21
|
+
autosteer --results results.tsv --strategy exploit
|
|
22
|
+
|
|
23
|
+
# More suggestions
|
|
24
|
+
autosteer --results results.tsv --num-suggestions 10
|
|
25
|
+
|
|
26
|
+
# Quick numbered list
|
|
27
|
+
autosteer --results results.tsv --quiet
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
## Strategy Modes
|
|
31
|
+
|
|
32
|
+
| Mode | When to Use |
|
|
33
|
+
|------|-------------|
|
|
34
|
+
| `auto` | Default. Balances explore/exploit based on experiment count. |
|
|
35
|
+
| `explore` | Early research, or stuck after 3+ discards. Favors untried categories. |
|
|
36
|
+
| `exploit` | You have proven wins. Doubles down on what works. |
|
|
37
|
+
|
|
38
|
+
## Output
|
|
39
|
+
|
|
40
|
+
Each suggestion includes:
|
|
41
|
+
- **Badge**: `[EXPLORE]` or `[EXPLOIT]` indicating category status
|
|
42
|
+
- **Risk level**: `low`, `medium`, or `high`
|
|
43
|
+
- **Rationale**: Why this was ranked where it is
|
|
44
|
+
|
|
45
|
+
```
|
|
46
|
+
[1] [EXPLOIT] Tune learning rate warmup schedule risk: low
|
|
47
|
+
Rationale: Learning rate experiments have 3 keeps in 4 attempts.
|
|
48
|
+
[2] [EXPLORE] Try rotary position embeddings risk: medium
|
|
49
|
+
Rationale: Positional encoding category untested. High potential.
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
## How It Works
|
|
53
|
+
|
|
54
|
+
- 20 built-in research directions specific to GPT pretraining
|
|
55
|
+
- Categorizes past experiments (architecture, hyperparams, optimizer, etc.)
|
|
56
|
+
- Keyword deduplication: won't re-suggest failed directions
|
|
57
|
+
- Git integration: reads diffs to classify experiments automatically
|
|
58
|
+
- Strategy-weighted scoring that adapts to experiment count
|
|
59
|
+
|
|
60
|
+
## Requirements
|
|
61
|
+
|
|
62
|
+
- Python >= 3.10
|
|
63
|
+
- A `results.tsv` file from autoresearch
|
|
64
|
+
- Git repository (for diff-based experiment classification)
|
|
65
|
+
|
|
66
|
+
## License
|
|
67
|
+
|
|
68
|
+
MIT
|
|
@@ -0,0 +1,1037 @@
|
|
|
1
|
+
"""
|
|
2
|
+
auto-steer: Research direction generator for autoresearch.
|
|
3
|
+
|
|
4
|
+
Analyzes experiment history (results.tsv + git diffs) and generates
|
|
5
|
+
smart next-step suggestions for the research agent.
|
|
6
|
+
|
|
7
|
+
Companion tool for https://github.com/karpathy/autoresearch
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import csv
|
|
13
|
+
import json
|
|
14
|
+
import re
|
|
15
|
+
import subprocess
|
|
16
|
+
import sys
|
|
17
|
+
import textwrap
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from enum import Enum
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Optional
|
|
22
|
+
|
|
23
|
+
import click
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# ---------------------------------------------------------------------------
|
|
27
|
+
# Output infrastructure
|
|
28
|
+
# ---------------------------------------------------------------------------
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class OutputConfig:
|
|
32
|
+
color: bool
|
|
33
|
+
quiet: bool
|
|
34
|
+
def styled(self, text: str, **kwargs) -> str:
|
|
35
|
+
return click.style(text, **kwargs) if self.color else text
|
|
36
|
+
|
|
37
|
+
# Status symbols
|
|
38
|
+
SYM_KEEP = "\u2714" # ✔
|
|
39
|
+
SYM_FAIL = "\u2718" # ✘
|
|
40
|
+
SYM_CRASH = "\u2620" # ☠
|
|
41
|
+
SYM_WARN = "\u26A0" # ⚠
|
|
42
|
+
SYM_ARROW = "\u2192" # →
|
|
43
|
+
SYM_STAR = "\u2605" # ★
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# ---------------------------------------------------------------------------
|
|
47
|
+
# Domain types
|
|
48
|
+
# ---------------------------------------------------------------------------
|
|
49
|
+
|
|
50
|
+
class Category(Enum):
|
|
51
|
+
ARCHITECTURE = "architecture"
|
|
52
|
+
HYPERPARAMS = "hyperparams"
|
|
53
|
+
OPTIMIZER = "optimizer"
|
|
54
|
+
REGULARIZATION = "regularization"
|
|
55
|
+
ACTIVATION = "activation"
|
|
56
|
+
EMBEDDING = "embedding"
|
|
57
|
+
EFFICIENCY = "efficiency"
|
|
58
|
+
OTHER = "other"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Status(Enum):
|
|
62
|
+
KEEP = "keep"
|
|
63
|
+
DISCARD = "discard"
|
|
64
|
+
CRASH = "crash"
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class RiskLevel(Enum):
|
|
68
|
+
LOW = "low"
|
|
69
|
+
MEDIUM = "medium"
|
|
70
|
+
HIGH = "high"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class SuggestionKind(Enum):
|
|
74
|
+
EXPLORE = "EXPLORE"
|
|
75
|
+
EXPLOIT = "EXPLOIT"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class Strategy(Enum):
|
|
79
|
+
AUTO = "auto"
|
|
80
|
+
EXPLORE = "explore"
|
|
81
|
+
EXPLOIT = "exploit"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass(frozen=True)
|
|
85
|
+
class Experiment:
|
|
86
|
+
commit: str
|
|
87
|
+
val_bpb: float
|
|
88
|
+
memory_gb: float
|
|
89
|
+
status: Status
|
|
90
|
+
description: str
|
|
91
|
+
category: Category = Category.OTHER
|
|
92
|
+
diff_text: str = ""
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass(frozen=True)
|
|
96
|
+
class CategoryStats:
|
|
97
|
+
category: Category
|
|
98
|
+
total: int
|
|
99
|
+
keeps: int
|
|
100
|
+
discards: int
|
|
101
|
+
crashes: int
|
|
102
|
+
avg_improvement_pct: float # average % improvement over baseline when kept
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def success_rate_pct(self) -> float:
|
|
106
|
+
if self.total == 0:
|
|
107
|
+
return 0.0
|
|
108
|
+
return (self.keeps / self.total) * 100.0
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@dataclass(frozen=True)
|
|
112
|
+
class Suggestion:
|
|
113
|
+
rank: int
|
|
114
|
+
kind: SuggestionKind
|
|
115
|
+
title: str
|
|
116
|
+
category: Category
|
|
117
|
+
risk: RiskLevel
|
|
118
|
+
expected_range: tuple[float, float] # (low%, high%) improvement
|
|
119
|
+
reasoning: str
|
|
120
|
+
priority_score: float # internal score for ranking
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@dataclass(frozen=True)
|
|
124
|
+
class AnalysisResult:
|
|
125
|
+
experiments: list[Experiment]
|
|
126
|
+
stats_by_category: dict[Category, CategoryStats]
|
|
127
|
+
suggestions: list[Suggestion]
|
|
128
|
+
strategy_label: str
|
|
129
|
+
baseline_bpb: Optional[float]
|
|
130
|
+
best_bpb: Optional[float]
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# ---------------------------------------------------------------------------
|
|
134
|
+
# Known good directions — the knowledge base
|
|
135
|
+
# ---------------------------------------------------------------------------
|
|
136
|
+
|
|
137
|
+
@dataclass(frozen=True)
|
|
138
|
+
class KnownDirection:
|
|
139
|
+
category: Category
|
|
140
|
+
title: str
|
|
141
|
+
description: str
|
|
142
|
+
risk: RiskLevel
|
|
143
|
+
expected_range: tuple[float, float]
|
|
144
|
+
keywords: tuple[str, ...] # used to check if already tried
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
KNOWN_DIRECTIONS: tuple[KnownDirection, ...] = (
|
|
148
|
+
# Architecture
|
|
149
|
+
KnownDirection(
|
|
150
|
+
category=Category.ARCHITECTURE,
|
|
151
|
+
title="Adjust depth/width ratio",
|
|
152
|
+
description=(
|
|
153
|
+
"Current ASPECT_RATIO=64 gives dim=512 at depth=8. Try ASPECT_RATIO=96 "
|
|
154
|
+
"for a wider but same-depth model, or DEPTH=12 with ASPECT_RATIO=48 "
|
|
155
|
+
"for deeper but narrower. Width tends to help more on short training runs."
|
|
156
|
+
),
|
|
157
|
+
risk=RiskLevel.MEDIUM,
|
|
158
|
+
expected_range=(0.3, 0.8),
|
|
159
|
+
keywords=("aspect_ratio", "depth", "width", "model_dim"),
|
|
160
|
+
),
|
|
161
|
+
KnownDirection(
|
|
162
|
+
category=Category.ARCHITECTURE,
|
|
163
|
+
title="Modify GQA head configuration",
|
|
164
|
+
description=(
|
|
165
|
+
"Adjust HEAD_DIM or number of KV heads. Fewer KV heads (more aggressive "
|
|
166
|
+
"GQA) reduces memory and may allow a larger model. Try HEAD_DIM=64 for "
|
|
167
|
+
"more heads or HEAD_DIM=256 for fewer."
|
|
168
|
+
),
|
|
169
|
+
risk=RiskLevel.MEDIUM,
|
|
170
|
+
expected_range=(0.2, 0.6),
|
|
171
|
+
keywords=("head_dim", "gqa", "kv_head", "num_head", "query_head"),
|
|
172
|
+
),
|
|
173
|
+
KnownDirection(
|
|
174
|
+
category=Category.ARCHITECTURE,
|
|
175
|
+
title="Change sliding window pattern",
|
|
176
|
+
description=(
|
|
177
|
+
"Current WINDOW_PATTERN='SSSL'. Try 'SSLL' for more full-attention layers, "
|
|
178
|
+
"or 'SLSL' for alternating. More full-attention layers capture longer "
|
|
179
|
+
"dependencies at the cost of memory."
|
|
180
|
+
),
|
|
181
|
+
risk=RiskLevel.LOW,
|
|
182
|
+
expected_range=(0.1, 0.4),
|
|
183
|
+
keywords=("window_pattern", "sliding window", "window", "attention pattern"),
|
|
184
|
+
),
|
|
185
|
+
KnownDirection(
|
|
186
|
+
category=Category.ARCHITECTURE,
|
|
187
|
+
title="Adjust MLP expansion ratio",
|
|
188
|
+
description=(
|
|
189
|
+
"Default MLP is 4x expansion. Try 3x to free parameters for more layers, "
|
|
190
|
+
"or 8/3x (~2.67x) which is the SwiGLU-optimal ratio. The freed parameters "
|
|
191
|
+
"can go toward extra depth."
|
|
192
|
+
),
|
|
193
|
+
risk=RiskLevel.MEDIUM,
|
|
194
|
+
expected_range=(0.2, 0.5),
|
|
195
|
+
keywords=("mlp", "expansion", "ffn", "feed_forward", "intermediate_size"),
|
|
196
|
+
),
|
|
197
|
+
# Hyperparams
|
|
198
|
+
KnownDirection(
|
|
199
|
+
category=Category.HYPERPARAMS,
|
|
200
|
+
title="Tune matrix learning rate",
|
|
201
|
+
description=(
|
|
202
|
+
"MATRIX_LR has the most leverage on Muon's behavior. Try MATRIX_LR=0.06 "
|
|
203
|
+
"(50% increase) or MATRIX_LR=0.02 (50% decrease). Muon is sensitive to "
|
|
204
|
+
"this — small changes can have outsized impact."
|
|
205
|
+
),
|
|
206
|
+
risk=RiskLevel.LOW,
|
|
207
|
+
expected_range=(0.3, 0.5),
|
|
208
|
+
keywords=("matrix_lr", "learning_rate", "muon_lr"),
|
|
209
|
+
),
|
|
210
|
+
KnownDirection(
|
|
211
|
+
category=Category.HYPERPARAMS,
|
|
212
|
+
title="Tune embedding learning rates",
|
|
213
|
+
description=(
|
|
214
|
+
"EMBEDDING_LR=0.6 and UNEMBEDDING_LR=0.004 are quite asymmetric. Try "
|
|
215
|
+
"reducing EMBEDDING_LR to 0.3 or increasing UNEMBEDDING_LR to 0.008. "
|
|
216
|
+
"The large gap suggests room for tuning."
|
|
217
|
+
),
|
|
218
|
+
risk=RiskLevel.LOW,
|
|
219
|
+
expected_range=(0.1, 0.3),
|
|
220
|
+
keywords=("embedding_lr", "unembedding_lr", "embed"),
|
|
221
|
+
),
|
|
222
|
+
KnownDirection(
|
|
223
|
+
category=Category.HYPERPARAMS,
|
|
224
|
+
title="Add warmup schedule",
|
|
225
|
+
description=(
|
|
226
|
+
"Currently WARMUP_RATIO=0.0 (no warmup). Try WARMUP_RATIO=0.05 for "
|
|
227
|
+
"5% warmup — helps with training stability early on, especially with "
|
|
228
|
+
"large learning rates."
|
|
229
|
+
),
|
|
230
|
+
risk=RiskLevel.LOW,
|
|
231
|
+
expected_range=(0.1, 0.3),
|
|
232
|
+
keywords=("warmup", "warmup_ratio"),
|
|
233
|
+
),
|
|
234
|
+
KnownDirection(
|
|
235
|
+
category=Category.HYPERPARAMS,
|
|
236
|
+
title="Adjust warmdown schedule",
|
|
237
|
+
description=(
|
|
238
|
+
"Currently WARMDOWN_RATIO=0.5. Try 0.3 for less cooldown (more time at "
|
|
239
|
+
"peak LR) or 0.7 for more gradual decay. With a 5-minute budget, the "
|
|
240
|
+
"warmdown schedule significantly affects final performance."
|
|
241
|
+
),
|
|
242
|
+
risk=RiskLevel.LOW,
|
|
243
|
+
expected_range=(0.1, 0.4),
|
|
244
|
+
keywords=("warmdown", "warmdown_ratio", "cooldown", "final_lr"),
|
|
245
|
+
),
|
|
246
|
+
KnownDirection(
|
|
247
|
+
category=Category.HYPERPARAMS,
|
|
248
|
+
title="Increase batch size",
|
|
249
|
+
description=(
|
|
250
|
+
"Current TOTAL_BATCH_SIZE=2**19 (~524K tokens). Try 2**20 (~1M tokens) "
|
|
251
|
+
"for smoother gradients. Larger batches can help with short training runs "
|
|
252
|
+
"by reducing gradient noise. May need to reduce model size to fit."
|
|
253
|
+
),
|
|
254
|
+
risk=RiskLevel.MEDIUM,
|
|
255
|
+
expected_range=(0.2, 0.5),
|
|
256
|
+
keywords=("batch_size", "total_batch", "device_batch"),
|
|
257
|
+
),
|
|
258
|
+
# Optimizer
|
|
259
|
+
KnownDirection(
|
|
260
|
+
category=Category.OPTIMIZER,
|
|
261
|
+
title="Tune Muon orthogonalization steps",
|
|
262
|
+
description=(
|
|
263
|
+
"Muon uses Newton-Schulz iterations for polar decomposition (typically "
|
|
264
|
+
"ns_steps=5). Try ns_steps=6 for better orthogonalization quality, or "
|
|
265
|
+
"ns_steps=4 to save compute and allow a slightly larger model."
|
|
266
|
+
),
|
|
267
|
+
risk=RiskLevel.LOW,
|
|
268
|
+
expected_range=(0.1, 0.3),
|
|
269
|
+
keywords=("ns_steps", "newton", "schulz", "orthogonal", "polar"),
|
|
270
|
+
),
|
|
271
|
+
KnownDirection(
|
|
272
|
+
category=Category.OPTIMIZER,
|
|
273
|
+
title="Adjust Adam betas",
|
|
274
|
+
description=(
|
|
275
|
+
"Current ADAM_BETAS=(0.8, 0.95). Try (0.9, 0.95) for more momentum "
|
|
276
|
+
"(standard Adam default) or (0.8, 0.99) for longer gradient memory. "
|
|
277
|
+
"Beta1=0.8 is already aggressive — going higher may stabilize training."
|
|
278
|
+
),
|
|
279
|
+
risk=RiskLevel.LOW,
|
|
280
|
+
expected_range=(0.1, 0.3),
|
|
281
|
+
keywords=("adam_beta", "beta1", "beta2", "momentum"),
|
|
282
|
+
),
|
|
283
|
+
KnownDirection(
|
|
284
|
+
category=Category.OPTIMIZER,
|
|
285
|
+
title="Tune weight decay",
|
|
286
|
+
description=(
|
|
287
|
+
"Current WEIGHT_DECAY=0.2. Try 0.1 for less regularization (may help "
|
|
288
|
+
"with short training runs where overfitting is not an issue) or 0.3 "
|
|
289
|
+
"for stronger regularization."
|
|
290
|
+
),
|
|
291
|
+
risk=RiskLevel.LOW,
|
|
292
|
+
expected_range=(0.1, 0.3),
|
|
293
|
+
keywords=("weight_decay",),
|
|
294
|
+
),
|
|
295
|
+
# Regularization
|
|
296
|
+
KnownDirection(
|
|
297
|
+
category=Category.REGULARIZATION,
|
|
298
|
+
title="Add z-loss regularization",
|
|
299
|
+
description=(
|
|
300
|
+
"Add a small penalty on the log of the softmax partition function. "
|
|
301
|
+
"This stabilizes training and prevents logit drift. Typical coefficient "
|
|
302
|
+
"is 1e-4. Used in PaLM and other large LMs."
|
|
303
|
+
),
|
|
304
|
+
risk=RiskLevel.LOW,
|
|
305
|
+
expected_range=(0.1, 0.3),
|
|
306
|
+
keywords=("z_loss", "z-loss", "partition", "logit_reg"),
|
|
307
|
+
),
|
|
308
|
+
KnownDirection(
|
|
309
|
+
category=Category.REGULARIZATION,
|
|
310
|
+
title="Adjust softcap value",
|
|
311
|
+
description=(
|
|
312
|
+
"The current logit softcap prevents extreme values. Try increasing it "
|
|
313
|
+
"(less capping, more expressivity) or decreasing it (more regularization). "
|
|
314
|
+
"Small changes can affect training dynamics significantly."
|
|
315
|
+
),
|
|
316
|
+
risk=RiskLevel.LOW,
|
|
317
|
+
expected_range=(0.1, 0.3),
|
|
318
|
+
keywords=("softcap",),
|
|
319
|
+
),
|
|
320
|
+
# Activation
|
|
321
|
+
KnownDirection(
|
|
322
|
+
category=Category.ACTIVATION,
|
|
323
|
+
title="Try SwiGLU activation",
|
|
324
|
+
description=(
|
|
325
|
+
"Replace relu().square() MLP with SwiGLU (gate * silu(x)). Common in "
|
|
326
|
+
"modern LLMs (LLaMA, Gemma). Requires adjusting MLP dimensions to 8/3 "
|
|
327
|
+
"ratio instead of 4x to keep param count similar. Typically gives a "
|
|
328
|
+
"meaningful quality boost."
|
|
329
|
+
),
|
|
330
|
+
risk=RiskLevel.MEDIUM,
|
|
331
|
+
expected_range=(0.5, 1.0),
|
|
332
|
+
keywords=("swiglu", "silu", "glu", "gated"),
|
|
333
|
+
),
|
|
334
|
+
KnownDirection(
|
|
335
|
+
category=Category.ACTIVATION,
|
|
336
|
+
title="Try GELU activation",
|
|
337
|
+
description=(
|
|
338
|
+
"Replace ReSquared (relu^2) with GELU. GELU is the standard GPT "
|
|
339
|
+
"activation. ReSquared may underperform GELU in some regimes, though "
|
|
340
|
+
"it has theoretical advantages for feature learning."
|
|
341
|
+
),
|
|
342
|
+
risk=RiskLevel.MEDIUM,
|
|
343
|
+
expected_range=(0.3, 0.7),
|
|
344
|
+
keywords=("gelu",),
|
|
345
|
+
),
|
|
346
|
+
# Embedding
|
|
347
|
+
KnownDirection(
|
|
348
|
+
category=Category.EMBEDDING,
|
|
349
|
+
title="Adjust value embedding frequency",
|
|
350
|
+
description=(
|
|
351
|
+
"Value embeddings (ResFormer) currently apply at alternating layers. "
|
|
352
|
+
"Try applying at every layer (more capacity but more parameters) or "
|
|
353
|
+
"every 3rd layer (fewer parameters, may allow larger model)."
|
|
354
|
+
),
|
|
355
|
+
risk=RiskLevel.LOW,
|
|
356
|
+
expected_range=(0.1, 0.3),
|
|
357
|
+
keywords=("value_embed", "resformer", "embedding_freq", "embed_layer"),
|
|
358
|
+
),
|
|
359
|
+
KnownDirection(
|
|
360
|
+
category=Category.EMBEDDING,
|
|
361
|
+
title="Tune RoPE base frequency",
|
|
362
|
+
description=(
|
|
363
|
+
"Adjust the RoPE base frequency (theta). Higher theta extends effective "
|
|
364
|
+
"context but may hurt short-range performance. Lower theta sharpens "
|
|
365
|
+
"attention within the local window."
|
|
366
|
+
),
|
|
367
|
+
risk=RiskLevel.LOW,
|
|
368
|
+
expected_range=(0.1, 0.3),
|
|
369
|
+
keywords=("rope", "theta", "rotary", "base_freq"),
|
|
370
|
+
),
|
|
371
|
+
# Efficiency
|
|
372
|
+
KnownDirection(
|
|
373
|
+
category=Category.EFFICIENCY,
|
|
374
|
+
title="Optimize memory for larger model",
|
|
375
|
+
description=(
|
|
376
|
+
"Use gradient checkpointing, activation recomputation, or mixed precision "
|
|
377
|
+
"to free VRAM. The saved memory can be reinvested into a larger model "
|
|
378
|
+
"(more layers or wider). Even a small model size increase can improve "
|
|
379
|
+
"val_bpb substantially."
|
|
380
|
+
),
|
|
381
|
+
risk=RiskLevel.MEDIUM,
|
|
382
|
+
expected_range=(0.3, 0.8),
|
|
383
|
+
keywords=("checkpoint", "recompute", "memory", "vram", "mixed_precision"),
|
|
384
|
+
),
|
|
385
|
+
KnownDirection(
|
|
386
|
+
category=Category.EFFICIENCY,
|
|
387
|
+
title="Reduce compilation overhead",
|
|
388
|
+
description=(
|
|
389
|
+
"If torch.compile takes significant time, try compiling fewer components "
|
|
390
|
+
"or using mode='reduce-overhead'. The saved startup time becomes training "
|
|
391
|
+
"time within the 5-minute budget."
|
|
392
|
+
),
|
|
393
|
+
risk=RiskLevel.LOW,
|
|
394
|
+
expected_range=(0.1, 0.3),
|
|
395
|
+
keywords=("compile", "compilation", "torch.compile", "startup"),
|
|
396
|
+
),
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
# ---------------------------------------------------------------------------
|
|
401
|
+
# Category classification
|
|
402
|
+
# ---------------------------------------------------------------------------
|
|
403
|
+
|
|
404
|
+
# Keywords for classifying experiment descriptions and diffs into categories.
|
|
405
|
+
# Checked in priority order — first match wins.
|
|
406
|
+
_CATEGORY_KEYWORDS: tuple[tuple[Category, tuple[str, ...]], ...] = (
|
|
407
|
+
(Category.ACTIVATION, (
|
|
408
|
+
"activation", "relu", "gelu", "silu", "swiglu", "squared",
|
|
409
|
+
"resquared", "relu^2", "glu",
|
|
410
|
+
)),
|
|
411
|
+
(Category.OPTIMIZER, (
|
|
412
|
+
"optimizer", "muon", "adam", "sgd", "momentum", "ns_step",
|
|
413
|
+
"newton", "schulz", "polar", "orthogonal",
|
|
414
|
+
)),
|
|
415
|
+
(Category.REGULARIZATION, (
|
|
416
|
+
"dropout", "z-loss", "z_loss", "softcap", "regulariz",
|
|
417
|
+
"weight_decay", "weight decay",
|
|
418
|
+
)),
|
|
419
|
+
(Category.EMBEDDING, (
|
|
420
|
+
"embedding", "embed", "token embed", "resformer", "rope",
|
|
421
|
+
"rotary", "vocab", "tokeniz", "unembedding",
|
|
422
|
+
)),
|
|
423
|
+
(Category.EFFICIENCY, (
|
|
424
|
+
"memory", "vram", "compile", "checkpoint", "recompute",
|
|
425
|
+
"mixed precision", "fp16", "bf16", "flash", "fused",
|
|
426
|
+
)),
|
|
427
|
+
(Category.HYPERPARAMS, (
|
|
428
|
+
"lr", "learning rate", "batch_size", "batch size", "warmup",
|
|
429
|
+
"warmdown", "cooldown", "schedule", "final_lr",
|
|
430
|
+
)),
|
|
431
|
+
(Category.ARCHITECTURE, (
|
|
432
|
+
"layer", "depth", "width", "head", "attention", "mlp",
|
|
433
|
+
"ffn", "aspect_ratio", "model_dim", "gqa", "window",
|
|
434
|
+
"transformer", "block", "residual",
|
|
435
|
+
)),
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def classify_experiment(description: str, diff_text: str) -> Category:
|
|
440
|
+
"""Classify an experiment into a category based on its description and diff."""
|
|
441
|
+
combined = (description + " " + diff_text).lower()
|
|
442
|
+
for category, keywords in _CATEGORY_KEYWORDS:
|
|
443
|
+
for keyword in keywords:
|
|
444
|
+
# Use word boundary matching for short keywords to avoid false positives
|
|
445
|
+
# e.g. "lr" matching "clearly", "head" matching "overhead"
|
|
446
|
+
if len(keyword) <= 3:
|
|
447
|
+
if re.search(r'(?:^|[\s_])' + re.escape(keyword) + r'(?:[\s_=,.]|$)', combined):
|
|
448
|
+
return category
|
|
449
|
+
else:
|
|
450
|
+
if keyword in combined:
|
|
451
|
+
return category
|
|
452
|
+
return Category.OTHER
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
# ---------------------------------------------------------------------------
|
|
456
|
+
# Parsing
|
|
457
|
+
# ---------------------------------------------------------------------------
|
|
458
|
+
|
|
459
|
+
@dataclass(frozen=True)
|
|
460
|
+
class ParseResult:
|
|
461
|
+
ok: bool
|
|
462
|
+
experiments: list[Experiment]
|
|
463
|
+
error: str = ""
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def parse_results_tsv(path: Path) -> ParseResult:
|
|
467
|
+
"""Parse a results.tsv file into a list of Experiment records."""
|
|
468
|
+
if not path.exists():
|
|
469
|
+
return ParseResult(ok=False, experiments=[], error=f"File not found: {path}")
|
|
470
|
+
|
|
471
|
+
text = path.read_text().strip()
|
|
472
|
+
if not text:
|
|
473
|
+
return ParseResult(ok=False, experiments=[], error=f"Empty file: {path}")
|
|
474
|
+
|
|
475
|
+
lines = text.splitlines()
|
|
476
|
+
if len(lines) < 2:
|
|
477
|
+
return ParseResult(
|
|
478
|
+
ok=False, experiments=[],
|
|
479
|
+
error="Only header row found — no experiments yet",
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
reader = csv.DictReader(lines, delimiter="\t")
|
|
483
|
+
required_fields = {"commit", "val_bpb", "status", "description"}
|
|
484
|
+
if reader.fieldnames is None or not required_fields.issubset(set(reader.fieldnames)):
|
|
485
|
+
missing = required_fields - set(reader.fieldnames or [])
|
|
486
|
+
return ParseResult(
|
|
487
|
+
ok=False, experiments=[],
|
|
488
|
+
error=f"Missing columns in results.tsv: {missing}",
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
experiments: list[Experiment] = []
|
|
492
|
+
for row in reader:
|
|
493
|
+
try:
|
|
494
|
+
status = Status(row["status"].strip().lower())
|
|
495
|
+
except ValueError:
|
|
496
|
+
status = Status.DISCARD
|
|
497
|
+
|
|
498
|
+
try:
|
|
499
|
+
val_bpb = float(row["val_bpb"])
|
|
500
|
+
except (ValueError, KeyError):
|
|
501
|
+
val_bpb = 0.0
|
|
502
|
+
|
|
503
|
+
try:
|
|
504
|
+
memory_gb = float(row.get("memory_gb", "0.0"))
|
|
505
|
+
except ValueError:
|
|
506
|
+
memory_gb = 0.0
|
|
507
|
+
|
|
508
|
+
experiments.append(Experiment(
|
|
509
|
+
commit=row["commit"].strip(),
|
|
510
|
+
val_bpb=val_bpb,
|
|
511
|
+
memory_gb=memory_gb,
|
|
512
|
+
status=status,
|
|
513
|
+
description=row["description"].strip(),
|
|
514
|
+
))
|
|
515
|
+
|
|
516
|
+
return ParseResult(ok=True, experiments=experiments)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
# ---------------------------------------------------------------------------
|
|
520
|
+
# Git integration
|
|
521
|
+
# ---------------------------------------------------------------------------
|
|
522
|
+
|
|
523
|
+
def get_git_diff(commit: str, repo_dir: Path) -> str:
|
|
524
|
+
"""Get the diff for a given commit hash using git."""
|
|
525
|
+
try:
|
|
526
|
+
result = subprocess.run(
|
|
527
|
+
["git", "diff", f"{commit}~1", commit, "--", "train.py"],
|
|
528
|
+
capture_output=True,
|
|
529
|
+
text=True,
|
|
530
|
+
cwd=str(repo_dir),
|
|
531
|
+
timeout=10,
|
|
532
|
+
)
|
|
533
|
+
return result.stdout if result.returncode == 0 else ""
|
|
534
|
+
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
|
535
|
+
return ""
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
_git_repo_warning_shown = False
|
|
539
|
+
|
|
540
|
+
def enrich_experiments_with_git(
|
|
541
|
+
experiments: list[Experiment],
|
|
542
|
+
repo_dir: Path,
|
|
543
|
+
) -> list[Experiment]:
|
|
544
|
+
"""Add git diff text and category classification to each experiment."""
|
|
545
|
+
global _git_repo_warning_shown
|
|
546
|
+
|
|
547
|
+
# Check if we're in a git repo
|
|
548
|
+
if not _git_repo_warning_shown:
|
|
549
|
+
try:
|
|
550
|
+
result = subprocess.run(
|
|
551
|
+
["git", "rev-parse", "--git-dir"],
|
|
552
|
+
capture_output=True, text=True,
|
|
553
|
+
cwd=str(repo_dir), timeout=5,
|
|
554
|
+
)
|
|
555
|
+
if result.returncode != 0:
|
|
556
|
+
click.echo(
|
|
557
|
+
"Note: not in a git repo, category classification based on descriptions only",
|
|
558
|
+
err=True,
|
|
559
|
+
)
|
|
560
|
+
_git_repo_warning_shown = True
|
|
561
|
+
except (FileNotFoundError, OSError, subprocess.TimeoutExpired):
|
|
562
|
+
click.echo(
|
|
563
|
+
"Note: git not available, category classification based on descriptions only",
|
|
564
|
+
err=True,
|
|
565
|
+
)
|
|
566
|
+
_git_repo_warning_shown = True
|
|
567
|
+
|
|
568
|
+
enriched: list[Experiment] = []
|
|
569
|
+
for exp in experiments:
|
|
570
|
+
diff_text = get_git_diff(exp.commit, repo_dir) if exp.commit else ""
|
|
571
|
+
category = classify_experiment(exp.description, diff_text)
|
|
572
|
+
enriched.append(Experiment(
|
|
573
|
+
commit=exp.commit,
|
|
574
|
+
val_bpb=exp.val_bpb,
|
|
575
|
+
memory_gb=exp.memory_gb,
|
|
576
|
+
status=exp.status,
|
|
577
|
+
description=exp.description,
|
|
578
|
+
category=category,
|
|
579
|
+
diff_text=diff_text,
|
|
580
|
+
))
|
|
581
|
+
return enriched
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
# ---------------------------------------------------------------------------
|
|
585
|
+
# Analysis
|
|
586
|
+
# ---------------------------------------------------------------------------
|
|
587
|
+
|
|
588
|
+
def compute_category_stats(
|
|
589
|
+
experiments: list[Experiment],
|
|
590
|
+
baseline_bpb: Optional[float],
|
|
591
|
+
) -> dict[Category, CategoryStats]:
|
|
592
|
+
"""Compute success rates and average improvements by category."""
|
|
593
|
+
stats: dict[Category, CategoryStats] = {}
|
|
594
|
+
by_cat: dict[Category, list[Experiment]] = {}
|
|
595
|
+
|
|
596
|
+
# Skip the baseline (first experiment)
|
|
597
|
+
for exp in experiments[1:]:
|
|
598
|
+
by_cat.setdefault(exp.category, []).append(exp)
|
|
599
|
+
|
|
600
|
+
for category in Category:
|
|
601
|
+
cat_exps = by_cat.get(category, [])
|
|
602
|
+
keeps = [e for e in cat_exps if e.status == Status.KEEP]
|
|
603
|
+
discards = [e for e in cat_exps if e.status == Status.DISCARD]
|
|
604
|
+
crashes = [e for e in cat_exps if e.status == Status.CRASH]
|
|
605
|
+
|
|
606
|
+
avg_improvement = 0.0
|
|
607
|
+
if keeps and baseline_bpb and baseline_bpb > 0:
|
|
608
|
+
improvements = [
|
|
609
|
+
((baseline_bpb - e.val_bpb) / baseline_bpb) * 100.0
|
|
610
|
+
for e in keeps
|
|
611
|
+
if e.val_bpb > 0 # only valid bpb values for improvement calc
|
|
612
|
+
]
|
|
613
|
+
avg_improvement = sum(improvements) / len(improvements) if improvements else 0.0
|
|
614
|
+
|
|
615
|
+
stats[category] = CategoryStats(
|
|
616
|
+
category=category,
|
|
617
|
+
total=len(cat_exps),
|
|
618
|
+
keeps=len(keeps),
|
|
619
|
+
discards=len(discards),
|
|
620
|
+
crashes=len(crashes),
|
|
621
|
+
avg_improvement_pct=avg_improvement,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
return stats
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def _direction_already_tried(
|
|
628
|
+
direction: KnownDirection,
|
|
629
|
+
experiments: list[Experiment],
|
|
630
|
+
) -> bool:
|
|
631
|
+
"""Check if a known direction has already been tried based on keyword matching."""
|
|
632
|
+
for exp in experiments:
|
|
633
|
+
combined = (exp.description + " " + exp.diff_text).lower().replace("_", " ")
|
|
634
|
+
normalized_keywords = [kw.replace("_", " ") for kw in direction.keywords]
|
|
635
|
+
matches = sum(1 for kw in normalized_keywords if kw in combined)
|
|
636
|
+
# Require at least 2 keyword matches (or 1 if the direction only has 1 keyword)
|
|
637
|
+
threshold = min(2, len(direction.keywords))
|
|
638
|
+
if matches >= threshold:
|
|
639
|
+
return True
|
|
640
|
+
return False
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
def _compute_priority_score(
|
|
644
|
+
direction: KnownDirection,
|
|
645
|
+
stats: dict[Category, CategoryStats],
|
|
646
|
+
strategy: Strategy,
|
|
647
|
+
total_experiments: int,
|
|
648
|
+
) -> float:
|
|
649
|
+
"""Compute a priority score for a suggestion based on strategy and history."""
|
|
650
|
+
cat_stats = stats.get(direction.category, CategoryStats(
|
|
651
|
+
category=direction.category, total=0, keeps=0,
|
|
652
|
+
discards=0, crashes=0, avg_improvement_pct=0.0,
|
|
653
|
+
))
|
|
654
|
+
|
|
655
|
+
# Base score from expected value (midpoint of expected range)
|
|
656
|
+
expected_midpoint = (direction.expected_range[0] + direction.expected_range[1]) / 2.0
|
|
657
|
+
score = expected_midpoint
|
|
658
|
+
|
|
659
|
+
# Risk adjustment
|
|
660
|
+
risk_multiplier = {RiskLevel.LOW: 1.2, RiskLevel.MEDIUM: 1.0, RiskLevel.HIGH: 0.7}
|
|
661
|
+
score *= risk_multiplier[direction.risk]
|
|
662
|
+
|
|
663
|
+
# Strategy adjustments
|
|
664
|
+
category_tried = cat_stats.total > 0
|
|
665
|
+
category_has_success = cat_stats.keeps > 0
|
|
666
|
+
|
|
667
|
+
if strategy == Strategy.EXPLOIT:
|
|
668
|
+
if category_has_success:
|
|
669
|
+
# Boost categories that have worked before
|
|
670
|
+
score *= 1.5 + (cat_stats.success_rate_pct / 100.0)
|
|
671
|
+
elif category_tried:
|
|
672
|
+
# Penalize categories that have been tried and failed
|
|
673
|
+
score *= 0.3
|
|
674
|
+
else:
|
|
675
|
+
# Untried categories get moderate penalty in exploit mode
|
|
676
|
+
score *= 0.5
|
|
677
|
+
elif strategy == Strategy.EXPLORE:
|
|
678
|
+
if not category_tried:
|
|
679
|
+
# Boost untried categories
|
|
680
|
+
score *= 2.0
|
|
681
|
+
elif category_tried and not category_has_success:
|
|
682
|
+
# Even failed categories get a second look in explore mode
|
|
683
|
+
score *= 0.8
|
|
684
|
+
else:
|
|
685
|
+
# Already-successful categories get slight penalty (explore new ground)
|
|
686
|
+
score *= 0.7
|
|
687
|
+
else:
|
|
688
|
+
# AUTO: balance based on experiment count
|
|
689
|
+
exploration_weight = max(0.2, 1.0 - (total_experiments / 20.0))
|
|
690
|
+
exploitation_weight = 1.0 - exploration_weight
|
|
691
|
+
|
|
692
|
+
if not category_tried:
|
|
693
|
+
score *= 1.0 + exploration_weight
|
|
694
|
+
elif category_has_success:
|
|
695
|
+
success_bonus = cat_stats.success_rate_pct / 100.0
|
|
696
|
+
score *= 1.0 + (exploitation_weight * success_bonus)
|
|
697
|
+
else:
|
|
698
|
+
score *= 0.5
|
|
699
|
+
|
|
700
|
+
return score
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
def generate_suggestions(
|
|
704
|
+
experiments: list[Experiment],
|
|
705
|
+
stats: dict[Category, CategoryStats],
|
|
706
|
+
strategy: Strategy,
|
|
707
|
+
num_suggestions: int,
|
|
708
|
+
) -> list[Suggestion]:
|
|
709
|
+
"""Generate ranked experiment suggestions based on history and strategy."""
|
|
710
|
+
total_experiments = sum(s.total for s in stats.values())
|
|
711
|
+
|
|
712
|
+
scored: list[tuple[KnownDirection, float, SuggestionKind]] = []
|
|
713
|
+
for direction in KNOWN_DIRECTIONS:
|
|
714
|
+
if _direction_already_tried(direction, experiments):
|
|
715
|
+
continue
|
|
716
|
+
|
|
717
|
+
priority = _compute_priority_score(direction, stats, strategy, total_experiments)
|
|
718
|
+
|
|
719
|
+
# Determine suggestion kind
|
|
720
|
+
cat_stats = stats.get(direction.category)
|
|
721
|
+
if cat_stats and cat_stats.keeps > 0:
|
|
722
|
+
kind = SuggestionKind.EXPLOIT
|
|
723
|
+
else:
|
|
724
|
+
kind = SuggestionKind.EXPLORE
|
|
725
|
+
|
|
726
|
+
scored.append((direction, priority, kind))
|
|
727
|
+
|
|
728
|
+
# Sort by priority descending
|
|
729
|
+
scored.sort(key=lambda x: x[1], reverse=True)
|
|
730
|
+
|
|
731
|
+
suggestions: list[Suggestion] = []
|
|
732
|
+
for rank, (direction, priority, kind) in enumerate(scored[:num_suggestions], start=1):
|
|
733
|
+
suggestions.append(Suggestion(
|
|
734
|
+
rank=rank,
|
|
735
|
+
kind=kind,
|
|
736
|
+
title=direction.title,
|
|
737
|
+
category=direction.category,
|
|
738
|
+
risk=direction.risk,
|
|
739
|
+
expected_range=direction.expected_range,
|
|
740
|
+
reasoning=direction.description,
|
|
741
|
+
priority_score=priority,
|
|
742
|
+
))
|
|
743
|
+
|
|
744
|
+
return suggestions
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def _resolve_strategy_label(strategy: Strategy, total_experiments: int) -> str:
|
|
748
|
+
"""Generate a human-readable label for the chosen strategy."""
|
|
749
|
+
if strategy == Strategy.EXPLORE:
|
|
750
|
+
return "explore (favoring new directions)"
|
|
751
|
+
if strategy == Strategy.EXPLOIT:
|
|
752
|
+
return "exploit (favoring what works)"
|
|
753
|
+
|
|
754
|
+
# Auto mode: describe the balance
|
|
755
|
+
if total_experiments <= 3:
|
|
756
|
+
return f"auto (balanced — {total_experiments} experiments so far, favoring exploration)"
|
|
757
|
+
elif total_experiments <= 10:
|
|
758
|
+
return f"auto (balanced — {total_experiments} experiments so far, mixed strategy)"
|
|
759
|
+
else:
|
|
760
|
+
return f"auto (balanced — {total_experiments} experiments so far, favoring exploitation)"
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
def analyze(
|
|
764
|
+
results_path: Path,
|
|
765
|
+
repo_dir: Path,
|
|
766
|
+
strategy: Strategy,
|
|
767
|
+
num_suggestions: int,
|
|
768
|
+
) -> AnalysisResult:
|
|
769
|
+
"""Run the full analysis pipeline."""
|
|
770
|
+
parsed = parse_results_tsv(results_path)
|
|
771
|
+
if not parsed.ok:
|
|
772
|
+
# File missing or malformed — still generate suggestions with empty history
|
|
773
|
+
empty_stats = {c: CategoryStats(c, 0, 0, 0, 0, 0.0) for c in Category}
|
|
774
|
+
strategy_label = _resolve_strategy_label(strategy, 0)
|
|
775
|
+
suggestions = generate_suggestions([], empty_stats, strategy, num_suggestions)
|
|
776
|
+
return AnalysisResult(
|
|
777
|
+
experiments=[],
|
|
778
|
+
stats_by_category=empty_stats,
|
|
779
|
+
suggestions=suggestions,
|
|
780
|
+
strategy_label=f"{strategy_label} (note: {parsed.error})",
|
|
781
|
+
baseline_bpb=None,
|
|
782
|
+
best_bpb=None,
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
experiments = enrich_experiments_with_git(parsed.experiments, repo_dir)
|
|
786
|
+
|
|
787
|
+
# Determine baseline (first experiment, typically status=keep)
|
|
788
|
+
baseline_bpb: Optional[float] = None
|
|
789
|
+
best_bpb: Optional[float] = None
|
|
790
|
+
if experiments:
|
|
791
|
+
baseline_bpb = experiments[0].val_bpb if experiments[0].val_bpb > 0 else None
|
|
792
|
+
kept_bpbs = [e.val_bpb for e in experiments if e.status == Status.KEEP and e.val_bpb > 0]
|
|
793
|
+
best_bpb = min(kept_bpbs) if kept_bpbs else baseline_bpb
|
|
794
|
+
|
|
795
|
+
stats = compute_category_stats(experiments, baseline_bpb)
|
|
796
|
+
total_experiments = sum(s.total for s in stats.values())
|
|
797
|
+
strategy_label = _resolve_strategy_label(strategy, total_experiments)
|
|
798
|
+
|
|
799
|
+
suggestions = generate_suggestions(experiments, stats, strategy, num_suggestions)
|
|
800
|
+
|
|
801
|
+
return AnalysisResult(
|
|
802
|
+
experiments=experiments,
|
|
803
|
+
stats_by_category=stats,
|
|
804
|
+
suggestions=suggestions,
|
|
805
|
+
strategy_label=strategy_label,
|
|
806
|
+
baseline_bpb=baseline_bpb,
|
|
807
|
+
best_bpb=best_bpb,
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
|
|
811
|
+
# ---------------------------------------------------------------------------
|
|
812
|
+
# Output formatting
|
|
813
|
+
# ---------------------------------------------------------------------------
|
|
814
|
+
|
|
815
|
+
def format_quiet(result: AnalysisResult) -> str:
|
|
816
|
+
"""Minimal output: numbered titles only."""
|
|
817
|
+
lines: list[str] = []
|
|
818
|
+
for s in result.suggestions:
|
|
819
|
+
lines.append(f"{s.rank}. [{s.kind.value}] {s.title}")
|
|
820
|
+
return "\n".join(lines) if lines else "No suggestions available."
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
def format_text(result: AnalysisResult, cfg: OutputConfig) -> str:
|
|
824
|
+
"""Format the analysis result as human-readable text with optional color."""
|
|
825
|
+
lines: list[str] = []
|
|
826
|
+
lines.append(f"== {cfg.styled('autosteer', fg='cyan', bold=True)} suggestions ==")
|
|
827
|
+
lines.append(f"Strategy: {result.strategy_label}")
|
|
828
|
+
|
|
829
|
+
if result.baseline_bpb is not None:
|
|
830
|
+
lines.append(f"Baseline val_bpb: {result.baseline_bpb:.6f}")
|
|
831
|
+
if result.best_bpb is not None and result.best_bpb != result.baseline_bpb:
|
|
832
|
+
lines.append(f"Best val_bpb: {result.best_bpb:.6f}")
|
|
833
|
+
|
|
834
|
+
# Category stats
|
|
835
|
+
lines.append("")
|
|
836
|
+
lines.append(cfg.styled("Based on history:", dim=True))
|
|
837
|
+
has_any_experiments = False
|
|
838
|
+
for category in Category:
|
|
839
|
+
stats = result.stats_by_category.get(category)
|
|
840
|
+
if stats is None:
|
|
841
|
+
continue
|
|
842
|
+
if stats.total == 0:
|
|
843
|
+
continue
|
|
844
|
+
has_any_experiments = True
|
|
845
|
+
cat_label = cfg.styled(category.value, fg="blue")
|
|
846
|
+
success_str = f"{stats.success_rate_pct:.0f}% success"
|
|
847
|
+
if stats.avg_improvement_pct > 0:
|
|
848
|
+
success_str += f", avg +{stats.avg_improvement_pct:.2f}%"
|
|
849
|
+
if cfg.color:
|
|
850
|
+
lines.append(
|
|
851
|
+
f" {cat_label}: {stats.total} tried, "
|
|
852
|
+
f"{stats.keeps} kept ({success_str})"
|
|
853
|
+
)
|
|
854
|
+
else:
|
|
855
|
+
lines.append(
|
|
856
|
+
f" {category.value + ':':18s} {stats.total} tried, "
|
|
857
|
+
f"{stats.keeps} kept ({success_str})"
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
# Show untried categories
|
|
861
|
+
untried = [
|
|
862
|
+
c for c in Category
|
|
863
|
+
if result.stats_by_category.get(c) is not None
|
|
864
|
+
and result.stats_by_category[c].total == 0
|
|
865
|
+
]
|
|
866
|
+
if untried:
|
|
867
|
+
for cat in untried:
|
|
868
|
+
if cfg.color:
|
|
869
|
+
lines.append(f" {cfg.styled(cat.value, fg='blue')}: 0 tried")
|
|
870
|
+
else:
|
|
871
|
+
lines.append(f" {cat.value + ':':18s} 0 tried")
|
|
872
|
+
|
|
873
|
+
if not has_any_experiments and not untried:
|
|
874
|
+
lines.append(" (no experiments recorded yet)")
|
|
875
|
+
|
|
876
|
+
# Suggestions
|
|
877
|
+
lines.append("")
|
|
878
|
+
if not result.suggestions:
|
|
879
|
+
lines.append("No suggestions available — all known directions have been tried.")
|
|
880
|
+
lines.append("Consider exploring novel ideas outside the standard playbook.")
|
|
881
|
+
else:
|
|
882
|
+
lines.append(cfg.styled("Suggestions (ranked by expected value):", dim=True))
|
|
883
|
+
lines.append("")
|
|
884
|
+
|
|
885
|
+
risk_colors = {"low": "green", "medium": "yellow", "high": "red"}
|
|
886
|
+
|
|
887
|
+
for s in result.suggestions:
|
|
888
|
+
kind_color = "cyan" if s.kind == SuggestionKind.EXPLORE else "green"
|
|
889
|
+
kind_badge = cfg.styled(f"[{s.kind.value}]", fg=kind_color)
|
|
890
|
+
rank_str = cfg.styled(f"{s.rank}.", bold=True)
|
|
891
|
+
lines.append(f"{rank_str} {kind_badge} {s.title}")
|
|
892
|
+
|
|
893
|
+
risk_color = risk_colors.get(s.risk.value, "white")
|
|
894
|
+
cat_str = cfg.styled(s.category.value, fg="blue")
|
|
895
|
+
risk_str = cfg.styled(s.risk.value, fg=risk_color)
|
|
896
|
+
lines.append(
|
|
897
|
+
f" Category: {cat_str} | "
|
|
898
|
+
f"Risk: {risk_str} | "
|
|
899
|
+
f"Expected: +{s.expected_range[0]:.1f}-{s.expected_range[1]:.1f}%"
|
|
900
|
+
)
|
|
901
|
+
wrapped = _wrap_text(s.reasoning, width=72, indent=" ")
|
|
902
|
+
lines.append(wrapped)
|
|
903
|
+
lines.append("")
|
|
904
|
+
|
|
905
|
+
return "\n".join(lines)
|
|
906
|
+
|
|
907
|
+
|
|
908
|
+
def _wrap_text(text: str, width: int, indent: str) -> str:
|
|
909
|
+
"""Word-wrap with indent prefix for all lines."""
|
|
910
|
+
return textwrap.fill(text, width=width, initial_indent=indent, subsequent_indent=indent)
|
|
911
|
+
|
|
912
|
+
|
|
913
|
+
def format_json(result: AnalysisResult) -> str:
|
|
914
|
+
"""Format the analysis result as JSON."""
|
|
915
|
+
data = {
|
|
916
|
+
"strategy": result.strategy_label,
|
|
917
|
+
"baseline_bpb": result.baseline_bpb,
|
|
918
|
+
"best_bpb": result.best_bpb,
|
|
919
|
+
"category_stats": {
|
|
920
|
+
cat.value: {
|
|
921
|
+
"total": stats.total,
|
|
922
|
+
"keeps": stats.keeps,
|
|
923
|
+
"discards": stats.discards,
|
|
924
|
+
"crashes": stats.crashes,
|
|
925
|
+
"success_rate_pct": round(stats.success_rate_pct, 1),
|
|
926
|
+
"avg_improvement_pct": round(stats.avg_improvement_pct, 2),
|
|
927
|
+
}
|
|
928
|
+
for cat, stats in result.stats_by_category.items()
|
|
929
|
+
},
|
|
930
|
+
"suggestions": [
|
|
931
|
+
{
|
|
932
|
+
"rank": s.rank,
|
|
933
|
+
"kind": s.kind.value,
|
|
934
|
+
"title": s.title,
|
|
935
|
+
"category": s.category.value,
|
|
936
|
+
"risk": s.risk.value,
|
|
937
|
+
"expected_range_pct": [s.expected_range[0], s.expected_range[1]],
|
|
938
|
+
"reasoning": s.reasoning,
|
|
939
|
+
"priority_score": round(s.priority_score, 3),
|
|
940
|
+
}
|
|
941
|
+
for s in result.suggestions
|
|
942
|
+
],
|
|
943
|
+
}
|
|
944
|
+
|
|
945
|
+
return json.dumps(data, indent=2)
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
# ---------------------------------------------------------------------------
|
|
949
|
+
# CLI
|
|
950
|
+
# ---------------------------------------------------------------------------
|
|
951
|
+
|
|
952
|
+
@click.command(epilog="Exit codes: 0 = success, 1 = file error")
|
|
953
|
+
@click.version_option(version="1.0.0", prog_name="autosteer")
|
|
954
|
+
@click.option(
|
|
955
|
+
"--results", "results_path",
|
|
956
|
+
default="results.tsv",
|
|
957
|
+
type=click.Path(),
|
|
958
|
+
help="Path to results.tsv file (default: results.tsv in current dir)",
|
|
959
|
+
)
|
|
960
|
+
@click.option(
|
|
961
|
+
"--repo-dir", "repo_dir",
|
|
962
|
+
default=".",
|
|
963
|
+
type=click.Path(exists=True, file_okay=False),
|
|
964
|
+
help="Path to the autoresearch repo (default: current dir)",
|
|
965
|
+
)
|
|
966
|
+
@click.option(
|
|
967
|
+
"--num-suggestions", "num_suggestions",
|
|
968
|
+
default=5,
|
|
969
|
+
type=click.IntRange(min=1, max=20),
|
|
970
|
+
help="Number of suggestions to generate (default: 5)",
|
|
971
|
+
)
|
|
972
|
+
@click.option(
|
|
973
|
+
"--strategy",
|
|
974
|
+
default="auto",
|
|
975
|
+
type=click.Choice(["auto", "explore", "exploit"], case_sensitive=False),
|
|
976
|
+
help="Strategy: auto (balanced), explore (new directions), exploit (what works)",
|
|
977
|
+
)
|
|
978
|
+
@click.option(
|
|
979
|
+
"--format", "output_format",
|
|
980
|
+
default="text",
|
|
981
|
+
type=click.Choice(["text", "json"], case_sensitive=False),
|
|
982
|
+
help="Output format (default: text)",
|
|
983
|
+
)
|
|
984
|
+
@click.option(
|
|
985
|
+
"--no-color", "no_color",
|
|
986
|
+
is_flag=True,
|
|
987
|
+
default=False,
|
|
988
|
+
help="Disable colored output",
|
|
989
|
+
)
|
|
990
|
+
@click.option(
|
|
991
|
+
"--quiet", "-q",
|
|
992
|
+
is_flag=True,
|
|
993
|
+
default=False,
|
|
994
|
+
help="Minimal output (one line per suggestion)",
|
|
995
|
+
)
|
|
996
|
+
def cli(
|
|
997
|
+
results_path: str,
|
|
998
|
+
repo_dir: str,
|
|
999
|
+
num_suggestions: int,
|
|
1000
|
+
strategy: str,
|
|
1001
|
+
output_format: str,
|
|
1002
|
+
no_color: bool,
|
|
1003
|
+
quiet: bool,
|
|
1004
|
+
) -> None:
|
|
1005
|
+
"""Analyze autoresearch experiment history and suggest next steps."""
|
|
1006
|
+
cfg = OutputConfig(color=not no_color and sys.stdout.isatty(), quiet=quiet)
|
|
1007
|
+
|
|
1008
|
+
resolved_results = Path(results_path)
|
|
1009
|
+
if not resolved_results.is_absolute():
|
|
1010
|
+
resolved_results = Path(repo_dir) / resolved_results
|
|
1011
|
+
|
|
1012
|
+
result = analyze(
|
|
1013
|
+
results_path=resolved_results,
|
|
1014
|
+
repo_dir=Path(repo_dir),
|
|
1015
|
+
strategy=Strategy(strategy.lower()),
|
|
1016
|
+
num_suggestions=num_suggestions,
|
|
1017
|
+
)
|
|
1018
|
+
|
|
1019
|
+
if output_format == "json":
|
|
1020
|
+
click.echo(format_json(result))
|
|
1021
|
+
elif cfg.quiet:
|
|
1022
|
+
click.echo(format_quiet(result))
|
|
1023
|
+
else:
|
|
1024
|
+
click.echo(format_text(result, cfg))
|
|
1025
|
+
|
|
1026
|
+
# Exit with error code if no experiments found and it was due to a file issue
|
|
1027
|
+
if not result.experiments and "not found" in result.strategy_label.lower():
|
|
1028
|
+
sys.exit(1)
|
|
1029
|
+
|
|
1030
|
+
|
|
1031
|
+
def main() -> None:
|
|
1032
|
+
"""Entry point for the auto-steer CLI."""
|
|
1033
|
+
cli()
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
if __name__ == "__main__":
|
|
1037
|
+
main()
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "autosteer"
|
|
3
|
+
version = "1.0.0"
|
|
4
|
+
description = "Research direction generator for autoresearch — analyzes experiment history and suggests next steps"
|
|
5
|
+
requires-python = ">=3.10"
|
|
6
|
+
license = "MIT"
|
|
7
|
+
readme = "README.md"
|
|
8
|
+
keywords = ["autoresearch", "karpathy", "gpt", "pretraining", "experiment-suggestions"]
|
|
9
|
+
classifiers = [
|
|
10
|
+
"Development Status :: 4 - Beta",
|
|
11
|
+
"Environment :: Console",
|
|
12
|
+
"Intended Audience :: Science/Research",
|
|
13
|
+
"License :: OSI Approved :: MIT License",
|
|
14
|
+
"Programming Language :: Python :: 3",
|
|
15
|
+
"Programming Language :: Python :: 3.10",
|
|
16
|
+
"Programming Language :: Python :: 3.11",
|
|
17
|
+
"Programming Language :: Python :: 3.12",
|
|
18
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
19
|
+
]
|
|
20
|
+
dependencies = ["click>=8.0"]
|
|
21
|
+
|
|
22
|
+
[project.scripts]
|
|
23
|
+
autosteer = "auto_steer:main"
|
|
24
|
+
|
|
25
|
+
[project.urls]
|
|
26
|
+
Homepage = "https://github.com/dean0x/autolab"
|
|
27
|
+
Repository = "https://github.com/dean0x/autolab"
|
|
28
|
+
Issues = "https://github.com/dean0x/autolab/issues"
|
|
29
|
+
|
|
30
|
+
[build-system]
|
|
31
|
+
requires = ["hatchling"]
|
|
32
|
+
build-backend = "hatchling.build"
|
|
33
|
+
|
|
34
|
+
[tool.hatch.build.targets.wheel]
|
|
35
|
+
packages = ["auto_steer.py"]
|