@zigrivers/scaffold 3.14.0 → 3.15.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +31 -9
- package/content/knowledge/research/research-architecture.md +385 -0
- package/content/knowledge/research/research-conventions.md +248 -0
- package/content/knowledge/research/research-dev-environment.md +303 -0
- package/content/knowledge/research/research-experiment-loop.md +429 -0
- package/content/knowledge/research/research-experiment-tracking.md +336 -0
- package/content/knowledge/research/research-ml-architecture-search.md +383 -0
- package/content/knowledge/research/research-ml-evaluation.md +407 -0
- package/content/knowledge/research/research-ml-experiment-tracking.md +466 -0
- package/content/knowledge/research/research-ml-training-patterns.md +413 -0
- package/content/knowledge/research/research-observability.md +395 -0
- package/content/knowledge/research/research-overfitting-prevention.md +306 -0
- package/content/knowledge/research/research-project-structure.md +264 -0
- package/content/knowledge/research/research-quant-backtesting.md +326 -0
- package/content/knowledge/research/research-quant-market-data.md +366 -0
- package/content/knowledge/research/research-quant-metrics.md +335 -0
- package/content/knowledge/research/research-quant-requirements.md +223 -0
- package/content/knowledge/research/research-quant-risk.md +469 -0
- package/content/knowledge/research/research-quant-strategy-patterns.md +412 -0
- package/content/knowledge/research/research-requirements.md +201 -0
- package/content/knowledge/research/research-security.md +374 -0
- package/content/knowledge/research/research-sim-compute-management.md +538 -0
- package/content/knowledge/research/research-sim-engine-patterns.md +448 -0
- package/content/knowledge/research/research-sim-parameter-spaces.md +425 -0
- package/content/knowledge/research/research-sim-validation.md +456 -0
- package/content/knowledge/research/research-testing.md +334 -0
- package/content/methodology/research-ml-research.yml +23 -0
- package/content/methodology/research-overlay.yml +65 -0
- package/content/methodology/research-quant-finance.yml +29 -0
- package/content/methodology/research-simulation.yml +23 -0
- package/dist/cli/commands/adopt.d.ts.map +1 -1
- package/dist/cli/commands/adopt.js +22 -1
- package/dist/cli/commands/adopt.js.map +1 -1
- package/dist/cli/commands/adopt.serialization.test.js +41 -0
- package/dist/cli/commands/adopt.serialization.test.js.map +1 -1
- package/dist/cli/commands/init.d.ts +4 -0
- package/dist/cli/commands/init.d.ts.map +1 -1
- package/dist/cli/commands/init.js +32 -2
- package/dist/cli/commands/init.js.map +1 -1
- package/dist/cli/init-flag-families.d.ts +6 -1
- package/dist/cli/init-flag-families.d.ts.map +1 -1
- package/dist/cli/init-flag-families.js +32 -1
- package/dist/cli/init-flag-families.js.map +1 -1
- package/dist/cli/init-flag-families.test.js +47 -0
- package/dist/cli/init-flag-families.test.js.map +1 -1
- package/dist/config/schema.d.ts +272 -16
- package/dist/config/schema.d.ts.map +1 -1
- package/dist/config/schema.js +25 -1
- package/dist/config/schema.js.map +1 -1
- package/dist/config/schema.test.js +103 -3
- package/dist/config/schema.test.js.map +1 -1
- package/dist/core/assembly/overlay-loader.d.ts +12 -0
- package/dist/core/assembly/overlay-loader.d.ts.map +1 -1
- package/dist/core/assembly/overlay-loader.js +30 -0
- package/dist/core/assembly/overlay-loader.js.map +1 -1
- package/dist/core/assembly/overlay-loader.test.js +66 -1
- package/dist/core/assembly/overlay-loader.test.js.map +1 -1
- package/dist/core/assembly/overlay-state-resolver.d.ts.map +1 -1
- package/dist/core/assembly/overlay-state-resolver.js +48 -19
- package/dist/core/assembly/overlay-state-resolver.js.map +1 -1
- package/dist/core/assembly/overlay-state-resolver.test.js +80 -0
- package/dist/core/assembly/overlay-state-resolver.test.js.map +1 -1
- package/dist/e2e/project-type-overlays.test.js +119 -0
- package/dist/e2e/project-type-overlays.test.js.map +1 -1
- package/dist/project/adopt.d.ts.map +1 -1
- package/dist/project/adopt.js +3 -1
- package/dist/project/adopt.js.map +1 -1
- package/dist/project/detectors/disambiguate.js +1 -1
- package/dist/project/detectors/disambiguate.js.map +1 -1
- package/dist/project/detectors/index.d.ts.map +1 -1
- package/dist/project/detectors/index.js +2 -1
- package/dist/project/detectors/index.js.map +1 -1
- package/dist/project/detectors/ml.d.ts.map +1 -1
- package/dist/project/detectors/ml.js +2 -6
- package/dist/project/detectors/ml.js.map +1 -1
- package/dist/project/detectors/research.d.ts +4 -0
- package/dist/project/detectors/research.d.ts.map +1 -0
- package/dist/project/detectors/research.js +141 -0
- package/dist/project/detectors/research.js.map +1 -0
- package/dist/project/detectors/research.test.d.ts +2 -0
- package/dist/project/detectors/research.test.d.ts.map +1 -0
- package/dist/project/detectors/research.test.js +235 -0
- package/dist/project/detectors/research.test.js.map +1 -0
- package/dist/project/detectors/shared-signals.d.ts +3 -0
- package/dist/project/detectors/shared-signals.d.ts.map +1 -0
- package/dist/project/detectors/shared-signals.js +9 -0
- package/dist/project/detectors/shared-signals.js.map +1 -0
- package/dist/project/detectors/types.d.ts +6 -2
- package/dist/project/detectors/types.d.ts.map +1 -1
- package/dist/project/detectors/types.js.map +1 -1
- package/dist/types/config.d.ts +7 -1
- package/dist/types/config.d.ts.map +1 -1
- package/dist/wizard/copy/core.d.ts.map +1 -1
- package/dist/wizard/copy/core.js +4 -0
- package/dist/wizard/copy/core.js.map +1 -1
- package/dist/wizard/copy/index.d.ts.map +1 -1
- package/dist/wizard/copy/index.js +2 -0
- package/dist/wizard/copy/index.js.map +1 -1
- package/dist/wizard/copy/research.d.ts +3 -0
- package/dist/wizard/copy/research.d.ts.map +1 -0
- package/dist/wizard/copy/research.js +27 -0
- package/dist/wizard/copy/research.js.map +1 -0
- package/dist/wizard/copy/types.d.ts +5 -1
- package/dist/wizard/copy/types.d.ts.map +1 -1
- package/dist/wizard/flags.d.ts +7 -1
- package/dist/wizard/flags.d.ts.map +1 -1
- package/dist/wizard/questions.d.ts +4 -2
- package/dist/wizard/questions.d.ts.map +1 -1
- package/dist/wizard/questions.js +27 -1
- package/dist/wizard/questions.js.map +1 -1
- package/dist/wizard/questions.test.js +51 -0
- package/dist/wizard/questions.test.js.map +1 -1
- package/dist/wizard/wizard.d.ts +3 -2
- package/dist/wizard/wizard.d.ts.map +1 -1
- package/dist/wizard/wizard.js +3 -1
- package/dist/wizard/wizard.js.map +1 -1
- package/package.json +1 -1
|
@@ -0,0 +1,466 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: research-ml-experiment-tracking
|
|
3
|
+
description: Lightweight experiment tracking for ML research including W&B and MLflow integration, experiment tagging, parallel coordinate plots, run comparison dashboards, and checkpoint strategies
|
|
4
|
+
topics: [research, ml-research, experiment-tracking, wandb, mlflow, parallel-coordinates, checkpointing, comparison, tagging]
|
|
5
|
+
---
|
|
6
|
+
|
|
7
|
+
ML research experiment tracking prioritizes rapid iteration and comparison over production-grade audit trails. The goal is to answer "which of my 200 runs was best and why?" within seconds, not to maintain a production model registry. This means lightweight setup (local or cloud-hosted, not self-managed infrastructure), aggressive tagging and grouping for filtering, visualization tools that reveal patterns across many runs simultaneously, and checkpoint strategies that keep what matters and discard the rest to avoid filling storage with abandoned experiments.
|
|
8
|
+
|
|
9
|
+
## Summary
|
|
10
|
+
|
|
11
|
+
Use W&B or MLflow for research experiment tracking with minimal ceremony: log hyperparameters, metrics at each step, and final results with a single decorator or context manager. Tag runs by experiment group, hypothesis, and search phase for fast filtering. Use parallel coordinate plots to visualize relationships between hyperparameters and outcomes across hundreds of runs. Build run comparison dashboards that highlight what changed between the best and worst runs. Implement a checkpoint strategy that keeps the top-N checkpoints per experiment group and aggressively discards the rest -- research storage fills fast.
|
|
12
|
+
|
|
13
|
+
## Deep Guidance
|
|
14
|
+
|
|
15
|
+
### W&B for Rapid Research Comparison
|
|
16
|
+
|
|
17
|
+
Weights & Biases excels at research tracking because its UI is designed for comparing hundreds of runs:
|
|
18
|
+
|
|
19
|
+
```python
|
|
20
|
+
# src/tracking/wandb_research.py
|
|
21
|
+
import wandb
|
|
22
|
+
from typing import Any
|
|
23
|
+
from functools import wraps
|
|
24
|
+
|
|
25
|
+
def init_research_run(
|
|
26
|
+
project: str,
|
|
27
|
+
experiment_group: str,
|
|
28
|
+
config: dict[str, Any],
|
|
29
|
+
tags: list[str] | None = None,
|
|
30
|
+
) -> wandb.Run:
|
|
31
|
+
"""Initialize a W&B run with research-oriented metadata."""
|
|
32
|
+
run = wandb.init(
|
|
33
|
+
project=project,
|
|
34
|
+
group=experiment_group, # Groups runs in the UI
|
|
35
|
+
config=config,
|
|
36
|
+
tags=tags or [],
|
|
37
|
+
# Research-specific settings
|
|
38
|
+
save_code=True, # Save the code that produced this run
|
|
39
|
+
notes=f"Experiment group: {experiment_group}",
|
|
40
|
+
)
|
|
41
|
+
# Log system info for reproducibility
|
|
42
|
+
wandb.config.update({
|
|
43
|
+
"git_sha": _get_git_sha(),
|
|
44
|
+
"hostname": _get_hostname(),
|
|
45
|
+
}, allow_val_change=True)
|
|
46
|
+
return run
|
|
47
|
+
|
|
48
|
+
def log_step_metrics(step: int, metrics: dict[str, float]) -> None:
|
|
49
|
+
"""Log metrics at a training step."""
|
|
50
|
+
wandb.log(metrics, step=step)
|
|
51
|
+
|
|
52
|
+
def log_summary_metrics(metrics: dict[str, float]) -> None:
|
|
53
|
+
"""Log final summary metrics (used for run comparison tables)."""
|
|
54
|
+
for key, value in metrics.items():
|
|
55
|
+
wandb.run.summary[key] = value
|
|
56
|
+
|
|
57
|
+
def research_run(project: str, group: str):
|
|
58
|
+
"""Decorator for research training functions."""
|
|
59
|
+
def decorator(fn):
|
|
60
|
+
@wraps(fn)
|
|
61
|
+
def wrapper(config: dict[str, Any], *args, **kwargs):
|
|
62
|
+
run = init_research_run(project, group, config)
|
|
63
|
+
try:
|
|
64
|
+
result = fn(config, *args, **kwargs)
|
|
65
|
+
if isinstance(result, dict):
|
|
66
|
+
log_summary_metrics(result)
|
|
67
|
+
wandb.finish(exit_code=0)
|
|
68
|
+
return result
|
|
69
|
+
except Exception as e:
|
|
70
|
+
wandb.finish(exit_code=1)
|
|
71
|
+
raise
|
|
72
|
+
return wrapper
|
|
73
|
+
return decorator
|
|
74
|
+
|
|
75
|
+
def _get_git_sha() -> str:
|
|
76
|
+
import subprocess
|
|
77
|
+
try:
|
|
78
|
+
return subprocess.check_output(
|
|
79
|
+
["git", "rev-parse", "HEAD"], text=True
|
|
80
|
+
).strip()
|
|
81
|
+
except Exception:
|
|
82
|
+
return "unknown"
|
|
83
|
+
|
|
84
|
+
def _get_hostname() -> str:
|
|
85
|
+
import socket
|
|
86
|
+
return socket.gethostname()
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
### Experiment Tagging and Grouping
|
|
90
|
+
|
|
91
|
+
Effective tagging enables filtering hundreds of runs down to the relevant subset in seconds:
|
|
92
|
+
|
|
93
|
+
```python
|
|
94
|
+
# src/tracking/tagging.py
|
|
95
|
+
from dataclasses import dataclass, field
|
|
96
|
+
from enum import Enum
|
|
97
|
+
|
|
98
|
+
class ExperimentPhase(Enum):
|
|
99
|
+
"""Research phase tags for filtering runs."""
|
|
100
|
+
EXPLORATION = "exploration" # Broad search, many configs
|
|
101
|
+
REFINEMENT = "refinement" # Narrowing in on promising region
|
|
102
|
+
ABLATION = "ablation" # Understanding component contributions
|
|
103
|
+
FINAL = "final" # Final evaluation with multiple seeds
|
|
104
|
+
BASELINE = "baseline" # Baseline method for comparison
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class RunTags:
|
|
108
|
+
"""Structured tags for a research run."""
|
|
109
|
+
hypothesis: str # Which hypothesis this tests (e.g., "H-003")
|
|
110
|
+
phase: ExperimentPhase
|
|
111
|
+
group: str # Experiment group name
|
|
112
|
+
architecture: str # Architecture variant (e.g., "resnet50_modified")
|
|
113
|
+
search_method: str = "" # NAS method if applicable
|
|
114
|
+
custom_tags: list[str] = field(default_factory=list)
|
|
115
|
+
|
|
116
|
+
def to_wandb_tags(self) -> list[str]:
|
|
117
|
+
"""Convert to flat W&B tag list."""
|
|
118
|
+
tags = [
|
|
119
|
+
f"hypothesis:{self.hypothesis}",
|
|
120
|
+
f"phase:{self.phase.value}",
|
|
121
|
+
f"arch:{self.architecture}",
|
|
122
|
+
]
|
|
123
|
+
if self.search_method:
|
|
124
|
+
tags.append(f"search:{self.search_method}")
|
|
125
|
+
tags.extend(self.custom_tags)
|
|
126
|
+
return tags
|
|
127
|
+
|
|
128
|
+
def to_mlflow_tags(self) -> dict[str, str]:
|
|
129
|
+
"""Convert to MLflow tag dict."""
|
|
130
|
+
tags = {
|
|
131
|
+
"hypothesis": self.hypothesis,
|
|
132
|
+
"phase": self.phase.value,
|
|
133
|
+
"group": self.group,
|
|
134
|
+
"architecture": self.architecture,
|
|
135
|
+
}
|
|
136
|
+
if self.search_method:
|
|
137
|
+
tags["search_method"] = self.search_method
|
|
138
|
+
return tags
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# Recommended tag taxonomy for ML research:
|
|
142
|
+
# - hypothesis:<id> -- which research question this run addresses
|
|
143
|
+
# - phase:<exploration|refinement|ablation|final|baseline>
|
|
144
|
+
# - arch:<architecture_name> -- model architecture variant
|
|
145
|
+
# - search:<method> -- NAS/HPO method used
|
|
146
|
+
# - dataset:<name> -- dataset variant or split
|
|
147
|
+
# - scale:<small|medium|large> -- compute scale of the run
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
### Parallel Coordinate Visualization
|
|
151
|
+
|
|
152
|
+
Parallel coordinates reveal which hyperparameter ranges correlate with good performance:
|
|
153
|
+
|
|
154
|
+
```python
|
|
155
|
+
# src/tracking/visualization.py
|
|
156
|
+
import pandas as pd
|
|
157
|
+
from typing import Any
|
|
158
|
+
|
|
159
|
+
def prepare_parallel_coords_data(
|
|
160
|
+
runs: list[dict[str, Any]],
|
|
161
|
+
params: list[str],
|
|
162
|
+
metric: str,
|
|
163
|
+
top_k: int | None = None,
|
|
164
|
+
) -> pd.DataFrame:
|
|
165
|
+
"""Prepare data for parallel coordinate plot.
|
|
166
|
+
|
|
167
|
+
Each row is a run; columns are hyperparameters + the target metric.
|
|
168
|
+
"""
|
|
169
|
+
records = []
|
|
170
|
+
for run in runs:
|
|
171
|
+
record = {param: run["config"].get(param) for param in params}
|
|
172
|
+
record[metric] = run["metrics"].get(metric)
|
|
173
|
+
record["run_id"] = run["run_id"]
|
|
174
|
+
records.append(record)
|
|
175
|
+
|
|
176
|
+
df = pd.DataFrame(records)
|
|
177
|
+
df = df.dropna(subset=[metric])
|
|
178
|
+
|
|
179
|
+
if top_k:
|
|
180
|
+
df = df.nlargest(top_k, metric)
|
|
181
|
+
|
|
182
|
+
return df
|
|
183
|
+
|
|
184
|
+
def wandb_parallel_coords_query(
|
|
185
|
+
project: str,
|
|
186
|
+
group: str,
|
|
187
|
+
params: list[str],
|
|
188
|
+
metric: str,
|
|
189
|
+
) -> str:
|
|
190
|
+
"""Generate W&B API query for parallel coordinates view.
|
|
191
|
+
|
|
192
|
+
Use this to programmatically create a W&B panel.
|
|
193
|
+
"""
|
|
194
|
+
# W&B parallel coordinates are configured in the UI,
|
|
195
|
+
# but we can query the data programmatically
|
|
196
|
+
return f"""
|
|
197
|
+
import wandb
|
|
198
|
+
api = wandb.Api()
|
|
199
|
+
runs = api.runs(
|
|
200
|
+
"{project}",
|
|
201
|
+
filters={{"group": "{group}", "state": "finished"}},
|
|
202
|
+
)
|
|
203
|
+
data = []
|
|
204
|
+
for run in runs:
|
|
205
|
+
record = {{p: run.config.get(p) for p in {params}}}
|
|
206
|
+
record["{metric}"] = run.summary.get("{metric}")
|
|
207
|
+
data.append(record)
|
|
208
|
+
"""
|
|
209
|
+
```
|
|
210
|
+
|
|
211
|
+
### Run Comparison Dashboards
|
|
212
|
+
|
|
213
|
+
Build comparison views that highlight differences between best and worst runs:
|
|
214
|
+
|
|
215
|
+
```python
|
|
216
|
+
# src/tracking/comparison.py
|
|
217
|
+
from typing import Any
|
|
218
|
+
import numpy as np
|
|
219
|
+
|
|
220
|
+
def compare_top_bottom(
|
|
221
|
+
runs: list[dict[str, Any]],
|
|
222
|
+
metric: str,
|
|
223
|
+
n: int = 5,
|
|
224
|
+
) -> dict[str, Any]:
|
|
225
|
+
"""Compare top-N vs bottom-N runs to find discriminating hyperparameters."""
|
|
226
|
+
sorted_runs = sorted(
|
|
227
|
+
runs, key=lambda r: r["metrics"].get(metric, float("-inf")), reverse=True
|
|
228
|
+
)
|
|
229
|
+
top = sorted_runs[:n]
|
|
230
|
+
bottom = sorted_runs[-n:]
|
|
231
|
+
|
|
232
|
+
# Find parameters that differ most between top and bottom
|
|
233
|
+
all_params = set()
|
|
234
|
+
for run in runs:
|
|
235
|
+
all_params.update(run["config"].keys())
|
|
236
|
+
|
|
237
|
+
discriminators = []
|
|
238
|
+
for param in all_params:
|
|
239
|
+
top_values = [r["config"].get(param) for r in top]
|
|
240
|
+
bottom_values = [r["config"].get(param) for r in bottom]
|
|
241
|
+
|
|
242
|
+
# For numeric params, compute separation
|
|
243
|
+
try:
|
|
244
|
+
top_mean = np.mean([float(v) for v in top_values if v is not None])
|
|
245
|
+
bottom_mean = np.mean([float(v) for v in bottom_values if v is not None])
|
|
246
|
+
separation = abs(top_mean - bottom_mean)
|
|
247
|
+
discriminators.append({
|
|
248
|
+
"param": param,
|
|
249
|
+
"top_mean": top_mean,
|
|
250
|
+
"bottom_mean": bottom_mean,
|
|
251
|
+
"separation": separation,
|
|
252
|
+
"type": "numeric",
|
|
253
|
+
})
|
|
254
|
+
except (TypeError, ValueError):
|
|
255
|
+
# Categorical param -- check if top/bottom have different modes
|
|
256
|
+
from collections import Counter
|
|
257
|
+
top_mode = Counter(top_values).most_common(1)[0][0] if top_values else None
|
|
258
|
+
bottom_mode = Counter(bottom_values).most_common(1)[0][0] if bottom_values else None
|
|
259
|
+
if top_mode != bottom_mode:
|
|
260
|
+
discriminators.append({
|
|
261
|
+
"param": param,
|
|
262
|
+
"top_mode": top_mode,
|
|
263
|
+
"bottom_mode": bottom_mode,
|
|
264
|
+
"type": "categorical",
|
|
265
|
+
})
|
|
266
|
+
|
|
267
|
+
discriminators.sort(
|
|
268
|
+
key=lambda d: d.get("separation", 1.0), reverse=True
|
|
269
|
+
)
|
|
270
|
+
return {
|
|
271
|
+
"top_runs": top,
|
|
272
|
+
"bottom_runs": bottom,
|
|
273
|
+
"discriminating_params": discriminators[:10],
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
def format_run_comparison(run_a: dict, run_b: dict, metric: str) -> str:
|
|
277
|
+
"""Format a human-readable comparison of two runs."""
|
|
278
|
+
lines = [f"{'Parameter':<30} {'Run A':>15} {'Run B':>15} {'Delta':>10}"]
|
|
279
|
+
lines.append("-" * 75)
|
|
280
|
+
|
|
281
|
+
all_params = sorted(set(run_a["config"]) | set(run_b["config"]))
|
|
282
|
+
for param in all_params:
|
|
283
|
+
val_a = run_a["config"].get(param, "---")
|
|
284
|
+
val_b = run_b["config"].get(param, "---")
|
|
285
|
+
if val_a != val_b:
|
|
286
|
+
lines.append(f"{param:<30} {str(val_a):>15} {str(val_b):>15} {'*':>10}")
|
|
287
|
+
|
|
288
|
+
lines.append("")
|
|
289
|
+
lines.append(f"{'Metric':<30} {'Run A':>15} {'Run B':>15} {'Delta':>10}")
|
|
290
|
+
lines.append("-" * 75)
|
|
291
|
+
score_a = run_a["metrics"].get(metric, 0)
|
|
292
|
+
score_b = run_b["metrics"].get(metric, 0)
|
|
293
|
+
lines.append(f"{metric:<30} {score_a:>15.4f} {score_b:>15.4f} {score_b - score_a:>+10.4f}")
|
|
294
|
+
|
|
295
|
+
return "\n".join(lines)
|
|
296
|
+
```
|
|
297
|
+
|
|
298
|
+
### Model Checkpointing for Research
|
|
299
|
+
|
|
300
|
+
Research checkpointing differs from production: keep the best-N, discard everything else aggressively, and support cross-experiment checkpoint reuse:
|
|
301
|
+
|
|
302
|
+
```python
|
|
303
|
+
# src/tracking/research_checkpoints.py
|
|
304
|
+
from pathlib import Path
|
|
305
|
+
from dataclasses import dataclass
|
|
306
|
+
import json
|
|
307
|
+
import shutil
|
|
308
|
+
|
|
309
|
+
@dataclass
|
|
310
|
+
class CheckpointPolicy:
|
|
311
|
+
"""Research checkpoint retention policy."""
|
|
312
|
+
keep_top_n: int = 3 # Per experiment group
|
|
313
|
+
keep_final: bool = True # Always keep the last checkpoint
|
|
314
|
+
max_total_gb: float = 50.0 # Total storage budget
|
|
315
|
+
cleanup_on_abort: bool = True # Delete checkpoints from aborted runs
|
|
316
|
+
|
|
317
|
+
class ResearchCheckpointManager:
|
|
318
|
+
"""Manage checkpoints across many research runs."""
|
|
319
|
+
|
|
320
|
+
def __init__(self, base_dir: str, policy: CheckpointPolicy):
|
|
321
|
+
self.base_dir = Path(base_dir)
|
|
322
|
+
self.policy = policy
|
|
323
|
+
self.base_dir.mkdir(parents=True, exist_ok=True)
|
|
324
|
+
|
|
325
|
+
def register_checkpoint(
|
|
326
|
+
self,
|
|
327
|
+
run_id: str,
|
|
328
|
+
group: str,
|
|
329
|
+
step: int,
|
|
330
|
+
metric_value: float,
|
|
331
|
+
metric_name: str,
|
|
332
|
+
path: Path,
|
|
333
|
+
) -> None:
|
|
334
|
+
"""Register a checkpoint and enforce retention policy."""
|
|
335
|
+
meta = {
|
|
336
|
+
"run_id": run_id,
|
|
337
|
+
"group": group,
|
|
338
|
+
"step": step,
|
|
339
|
+
"metric_name": metric_name,
|
|
340
|
+
"metric_value": metric_value,
|
|
341
|
+
"path": str(path),
|
|
342
|
+
}
|
|
343
|
+
meta_path = self.base_dir / "registry" / group / f"{run_id}_step{step}.json"
|
|
344
|
+
meta_path.parent.mkdir(parents=True, exist_ok=True)
|
|
345
|
+
with open(meta_path, "w") as f:
|
|
346
|
+
json.dump(meta, f, indent=2)
|
|
347
|
+
|
|
348
|
+
self._enforce_policy(group)
|
|
349
|
+
|
|
350
|
+
def _enforce_policy(self, group: str) -> None:
|
|
351
|
+
"""Keep only top-N checkpoints per group."""
|
|
352
|
+
registry_dir = self.base_dir / "registry" / group
|
|
353
|
+
if not registry_dir.exists():
|
|
354
|
+
return
|
|
355
|
+
|
|
356
|
+
entries = []
|
|
357
|
+
for meta_path in registry_dir.glob("*.json"):
|
|
358
|
+
with open(meta_path) as f:
|
|
359
|
+
entries.append((meta_path, json.load(f)))
|
|
360
|
+
|
|
361
|
+
# Sort by metric (descending -- higher is better)
|
|
362
|
+
entries.sort(key=lambda e: e[1]["metric_value"], reverse=True)
|
|
363
|
+
|
|
364
|
+
# Remove checkpoints beyond top-N
|
|
365
|
+
for meta_path, meta in entries[self.policy.keep_top_n:]:
|
|
366
|
+
ckpt_path = Path(meta["path"])
|
|
367
|
+
if ckpt_path.exists():
|
|
368
|
+
ckpt_path.unlink()
|
|
369
|
+
meta_path.unlink()
|
|
370
|
+
|
|
371
|
+
def cleanup_aborted_runs(self, aborted_run_ids: list[str]) -> int:
|
|
372
|
+
"""Remove all checkpoints from aborted runs."""
|
|
373
|
+
if not self.policy.cleanup_on_abort:
|
|
374
|
+
return 0
|
|
375
|
+
removed = 0
|
|
376
|
+
for meta_path in self.base_dir.rglob("*.json"):
|
|
377
|
+
with open(meta_path) as f:
|
|
378
|
+
meta = json.load(f)
|
|
379
|
+
if meta["run_id"] in aborted_run_ids:
|
|
380
|
+
ckpt_path = Path(meta["path"])
|
|
381
|
+
if ckpt_path.exists():
|
|
382
|
+
ckpt_path.unlink()
|
|
383
|
+
removed += 1
|
|
384
|
+
meta_path.unlink()
|
|
385
|
+
return removed
|
|
386
|
+
|
|
387
|
+
def get_storage_usage_gb(self) -> float:
|
|
388
|
+
"""Calculate total checkpoint storage usage."""
|
|
389
|
+
total_bytes = sum(
|
|
390
|
+
f.stat().st_size for f in self.base_dir.rglob("*.pt")
|
|
391
|
+
)
|
|
392
|
+
return total_bytes / (1024**3)
|
|
393
|
+
|
|
394
|
+
def find_best_checkpoint(self, group: str) -> Path | None:
|
|
395
|
+
"""Find the best checkpoint in a group for warm-starting."""
|
|
396
|
+
registry_dir = self.base_dir / "registry" / group
|
|
397
|
+
if not registry_dir.exists():
|
|
398
|
+
return None
|
|
399
|
+
|
|
400
|
+
best_meta = None
|
|
401
|
+
best_value = float("-inf")
|
|
402
|
+
for meta_path in registry_dir.glob("*.json"):
|
|
403
|
+
with open(meta_path) as f:
|
|
404
|
+
meta = json.load(f)
|
|
405
|
+
if meta["metric_value"] > best_value:
|
|
406
|
+
best_value = meta["metric_value"]
|
|
407
|
+
best_meta = meta
|
|
408
|
+
|
|
409
|
+
if best_meta:
|
|
410
|
+
path = Path(best_meta["path"])
|
|
411
|
+
return path if path.exists() else None
|
|
412
|
+
return None
|
|
413
|
+
```
|
|
414
|
+
|
|
415
|
+
### MLflow for Research (Lightweight Setup)
|
|
416
|
+
|
|
417
|
+
MLflow works well for research when configured for minimal overhead:
|
|
418
|
+
|
|
419
|
+
```python
|
|
420
|
+
# src/tracking/mlflow_research.py
|
|
421
|
+
import mlflow
|
|
422
|
+
from typing import Any
|
|
423
|
+
|
|
424
|
+
def setup_mlflow_research(experiment_name: str) -> None:
|
|
425
|
+
"""Configure MLflow for lightweight research tracking."""
|
|
426
|
+
# Use local SQLite -- no server needed for single-researcher projects
|
|
427
|
+
mlflow.set_tracking_uri("sqlite:///mlruns.db")
|
|
428
|
+
mlflow.set_experiment(experiment_name)
|
|
429
|
+
# Enable autologging for common frameworks
|
|
430
|
+
mlflow.autolog(log_models=False) # Skip model logging (saves storage)
|
|
431
|
+
|
|
432
|
+
def log_research_run(
|
|
433
|
+
run_name: str,
|
|
434
|
+
config: dict[str, Any],
|
|
435
|
+
metrics: dict[str, float],
|
|
436
|
+
tags: dict[str, str],
|
|
437
|
+
) -> str:
|
|
438
|
+
"""Log a complete research run with minimal boilerplate."""
|
|
439
|
+
with mlflow.start_run(run_name=run_name) as run:
|
|
440
|
+
# Flatten nested config for MLflow params (max 500 chars per value)
|
|
441
|
+
flat_config = _flatten_config(config)
|
|
442
|
+
mlflow.log_params(flat_config)
|
|
443
|
+
mlflow.log_metrics(metrics)
|
|
444
|
+
for key, value in tags.items():
|
|
445
|
+
mlflow.set_tag(key, value)
|
|
446
|
+
return run.info.run_id
|
|
447
|
+
|
|
448
|
+
def _flatten_config(config: dict, prefix: str = "") -> dict[str, str]:
|
|
449
|
+
"""Flatten nested config dict for MLflow (which requires flat params)."""
|
|
450
|
+
flat = {}
|
|
451
|
+
for key, value in config.items():
|
|
452
|
+
full_key = f"{prefix}.{key}" if prefix else key
|
|
453
|
+
if isinstance(value, dict):
|
|
454
|
+
flat.update(_flatten_config(value, full_key))
|
|
455
|
+
else:
|
|
456
|
+
flat[full_key] = str(value)[:500] # MLflow param limit
|
|
457
|
+
return flat
|
|
458
|
+
```
|
|
459
|
+
|
|
460
|
+
### Best Practices for ML Research Tracking
|
|
461
|
+
|
|
462
|
+
1. **One project per research question**: Do not mix unrelated experiments in one tracking project. Each hypothesis or research direction gets its own project.
|
|
463
|
+
2. **Group by experiment phase**: Tag runs as exploration, refinement, ablation, or final. Filter by phase to avoid noise.
|
|
464
|
+
3. **Log at step granularity during exploration, epoch during final**: Step-level logging during exploration helps diagnose training dynamics; epoch-level during final runs reduces storage.
|
|
465
|
+
4. **Delete failed runs immediately**: Runs that crash, diverge, or hit NaN add noise to comparison views. Mark them as failed and archive or delete.
|
|
466
|
+
5. **Pin your best runs**: Use W&B pinning or MLflow tags to mark the current best result so it is always visible in comparison views.
|