claude-turing 1.3.0 → 1.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/plugin.json +2 -2
- package/README.md +7 -2
- package/commands/ablate.md +47 -0
- package/commands/checkpoint.md +47 -0
- package/commands/diagnose.md +52 -0
- package/commands/frontier.md +45 -0
- package/commands/profile.md +43 -0
- package/commands/turing.md +10 -0
- package/package.json +1 -1
- package/src/install.js +1 -0
- package/src/verify.js +5 -0
- package/templates/scripts/__pycache__/ablation_study.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/checkpoint_manager.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/diagnose_errors.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/generate_brief.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/pareto_frontier.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/profile_training.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/scaffold.cpython-314.pyc +0 -0
- package/templates/scripts/ablation_study.py +487 -0
- package/templates/scripts/checkpoint_manager.py +449 -0
- package/templates/scripts/diagnose_errors.py +601 -0
- package/templates/scripts/generate_brief.py +74 -1
- package/templates/scripts/pareto_frontier.py +470 -0
- package/templates/scripts/profile_training.py +533 -0
- package/templates/scripts/scaffold.py +11 -0
|
@@ -0,0 +1,449 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Smart checkpoint manager for ML experiments.
|
|
3
|
+
|
|
4
|
+
Manages model checkpoints based on Pareto dominance rather than
|
|
5
|
+
simple "keep last K". Supports listing, Pareto-based pruning,
|
|
6
|
+
checkpoint averaging, and resume from any point.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python scripts/checkpoint_manager.py list # List checkpoints
|
|
10
|
+
python scripts/checkpoint_manager.py prune # Remove dominated
|
|
11
|
+
python scripts/checkpoint_manager.py average [--top 3] # Average top-K
|
|
12
|
+
python scripts/checkpoint_manager.py resume exp-042 # Resume from checkpoint
|
|
13
|
+
python scripts/checkpoint_manager.py stats # Disk usage summary
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import argparse
|
|
19
|
+
import json
|
|
20
|
+
import shutil
|
|
21
|
+
import sys
|
|
22
|
+
from datetime import datetime, timezone
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
|
|
25
|
+
import yaml
|
|
26
|
+
|
|
27
|
+
from scripts.turing_io import load_config, load_experiments
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
DEFAULT_CHECKPOINT_DIR = "experiments/checkpoints"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def scan_checkpoints(checkpoint_dir: str = DEFAULT_CHECKPOINT_DIR) -> list[dict]:
|
|
34
|
+
"""Scan checkpoint directory and return metadata for each checkpoint.
|
|
35
|
+
|
|
36
|
+
Returns list of dicts with path, experiment_id, metrics, size_bytes, created.
|
|
37
|
+
"""
|
|
38
|
+
ckpt_path = Path(checkpoint_dir)
|
|
39
|
+
if not ckpt_path.exists():
|
|
40
|
+
return []
|
|
41
|
+
|
|
42
|
+
checkpoints = []
|
|
43
|
+
for entry in sorted(ckpt_path.iterdir()):
|
|
44
|
+
if entry.is_dir():
|
|
45
|
+
# Look for metadata.yaml or metadata.json in the checkpoint dir
|
|
46
|
+
meta_path = entry / "metadata.yaml"
|
|
47
|
+
meta_json = entry / "metadata.json"
|
|
48
|
+
|
|
49
|
+
metadata = {}
|
|
50
|
+
if meta_path.exists():
|
|
51
|
+
with open(meta_path) as f:
|
|
52
|
+
metadata = yaml.safe_load(f) or {}
|
|
53
|
+
elif meta_json.exists():
|
|
54
|
+
with open(meta_json) as f:
|
|
55
|
+
metadata = json.load(f)
|
|
56
|
+
|
|
57
|
+
# Compute directory size
|
|
58
|
+
size_bytes = sum(f.stat().st_size for f in entry.rglob("*") if f.is_file())
|
|
59
|
+
|
|
60
|
+
checkpoints.append({
|
|
61
|
+
"path": str(entry),
|
|
62
|
+
"name": entry.name,
|
|
63
|
+
"experiment_id": metadata.get("experiment_id", entry.name),
|
|
64
|
+
"metrics": metadata.get("metrics", {}),
|
|
65
|
+
"config": metadata.get("config", {}),
|
|
66
|
+
"size_bytes": size_bytes,
|
|
67
|
+
"size_mb": round(size_bytes / 1024**2, 1),
|
|
68
|
+
"created": metadata.get("timestamp", metadata.get("created_at")),
|
|
69
|
+
})
|
|
70
|
+
elif entry.is_file() and entry.suffix in (".joblib", ".pkl", ".pt", ".pth", ".h5", ".onnx"):
|
|
71
|
+
# Single-file checkpoint
|
|
72
|
+
size_bytes = entry.stat().st_size
|
|
73
|
+
# Try to extract exp ID from filename
|
|
74
|
+
exp_id = entry.stem.replace("-checkpoint", "").replace("_checkpoint", "")
|
|
75
|
+
|
|
76
|
+
checkpoints.append({
|
|
77
|
+
"path": str(entry),
|
|
78
|
+
"name": entry.name,
|
|
79
|
+
"experiment_id": exp_id,
|
|
80
|
+
"metrics": {},
|
|
81
|
+
"config": {},
|
|
82
|
+
"size_bytes": size_bytes,
|
|
83
|
+
"size_mb": round(size_bytes / 1024**2, 1),
|
|
84
|
+
"created": None,
|
|
85
|
+
})
|
|
86
|
+
|
|
87
|
+
return checkpoints
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def enrich_checkpoints_from_log(
|
|
91
|
+
checkpoints: list[dict],
|
|
92
|
+
experiments: list[dict],
|
|
93
|
+
) -> list[dict]:
|
|
94
|
+
"""Enrich checkpoint metadata with metrics from experiment log."""
|
|
95
|
+
exp_map = {e.get("experiment_id"): e for e in experiments}
|
|
96
|
+
|
|
97
|
+
for ckpt in checkpoints:
|
|
98
|
+
exp_id = ckpt["experiment_id"]
|
|
99
|
+
if exp_id in exp_map and not ckpt["metrics"]:
|
|
100
|
+
exp = exp_map[exp_id]
|
|
101
|
+
ckpt["metrics"] = exp.get("metrics", {})
|
|
102
|
+
ckpt["config"] = exp.get("config", {})
|
|
103
|
+
if not ckpt["created"]:
|
|
104
|
+
ckpt["created"] = exp.get("timestamp")
|
|
105
|
+
|
|
106
|
+
return checkpoints
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def compute_pareto_checkpoints(
|
|
110
|
+
checkpoints: list[dict],
|
|
111
|
+
metrics: list[str],
|
|
112
|
+
directions: dict[str, str],
|
|
113
|
+
) -> tuple[list[dict], list[dict]]:
|
|
114
|
+
"""Separate checkpoints into Pareto-optimal and dominated sets.
|
|
115
|
+
|
|
116
|
+
Returns (pareto_optimal, dominated) tuple.
|
|
117
|
+
"""
|
|
118
|
+
if not checkpoints or not metrics:
|
|
119
|
+
return checkpoints, []
|
|
120
|
+
|
|
121
|
+
# Filter to checkpoints that have all requested metrics
|
|
122
|
+
complete = []
|
|
123
|
+
incomplete = []
|
|
124
|
+
for ckpt in checkpoints:
|
|
125
|
+
if all(ckpt["metrics"].get(m) is not None for m in metrics):
|
|
126
|
+
complete.append(ckpt)
|
|
127
|
+
else:
|
|
128
|
+
incomplete.append(ckpt)
|
|
129
|
+
|
|
130
|
+
if not complete:
|
|
131
|
+
return checkpoints, []
|
|
132
|
+
|
|
133
|
+
pareto = []
|
|
134
|
+
dominated = []
|
|
135
|
+
|
|
136
|
+
for i, candidate in enumerate(complete):
|
|
137
|
+
is_dominated = False
|
|
138
|
+
for j, other in enumerate(complete):
|
|
139
|
+
if i == j:
|
|
140
|
+
continue
|
|
141
|
+
all_at_least = True
|
|
142
|
+
strictly_better = False
|
|
143
|
+
for m in metrics:
|
|
144
|
+
c_val = float(candidate["metrics"][m])
|
|
145
|
+
o_val = float(other["metrics"][m])
|
|
146
|
+
direction = directions.get(m, "higher")
|
|
147
|
+
if direction == "higher":
|
|
148
|
+
if o_val < c_val:
|
|
149
|
+
all_at_least = False
|
|
150
|
+
break
|
|
151
|
+
if o_val > c_val:
|
|
152
|
+
strictly_better = True
|
|
153
|
+
else:
|
|
154
|
+
if o_val > c_val:
|
|
155
|
+
all_at_least = False
|
|
156
|
+
break
|
|
157
|
+
if o_val < c_val:
|
|
158
|
+
strictly_better = True
|
|
159
|
+
if all_at_least and strictly_better:
|
|
160
|
+
is_dominated = True
|
|
161
|
+
break
|
|
162
|
+
if is_dominated:
|
|
163
|
+
dominated.append(candidate)
|
|
164
|
+
else:
|
|
165
|
+
pareto.append(candidate)
|
|
166
|
+
|
|
167
|
+
# Incomplete checkpoints are kept (can't determine dominance)
|
|
168
|
+
pareto.extend(incomplete)
|
|
169
|
+
return pareto, dominated
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def prune_dominated(
|
|
173
|
+
checkpoints: list[dict],
|
|
174
|
+
dominated: list[dict],
|
|
175
|
+
keep_latest: bool = True,
|
|
176
|
+
dry_run: bool = False,
|
|
177
|
+
) -> dict:
|
|
178
|
+
"""Remove dominated checkpoints from disk.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
checkpoints: All checkpoints.
|
|
182
|
+
dominated: Dominated checkpoints to remove.
|
|
183
|
+
keep_latest: Always keep the most recent checkpoint (for resume safety).
|
|
184
|
+
dry_run: If True, report what would be pruned without deleting.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Dict with pruned count, bytes saved, and details.
|
|
188
|
+
"""
|
|
189
|
+
# Protect the latest checkpoint
|
|
190
|
+
if keep_latest and checkpoints:
|
|
191
|
+
latest = max(checkpoints, key=lambda c: c.get("created") or "")
|
|
192
|
+
dominated = [d for d in dominated if d["path"] != latest["path"]]
|
|
193
|
+
|
|
194
|
+
pruned = []
|
|
195
|
+
bytes_saved = 0
|
|
196
|
+
|
|
197
|
+
for ckpt in dominated:
|
|
198
|
+
path = Path(ckpt["path"])
|
|
199
|
+
if not path.exists():
|
|
200
|
+
continue
|
|
201
|
+
pruned.append({
|
|
202
|
+
"name": ckpt["name"],
|
|
203
|
+
"experiment_id": ckpt["experiment_id"],
|
|
204
|
+
"size_mb": ckpt["size_mb"],
|
|
205
|
+
})
|
|
206
|
+
bytes_saved += ckpt["size_bytes"]
|
|
207
|
+
if not dry_run:
|
|
208
|
+
if path.is_dir():
|
|
209
|
+
shutil.rmtree(path)
|
|
210
|
+
else:
|
|
211
|
+
path.unlink()
|
|
212
|
+
|
|
213
|
+
return {
|
|
214
|
+
"pruned_count": len(pruned),
|
|
215
|
+
"bytes_saved": bytes_saved,
|
|
216
|
+
"mb_saved": round(bytes_saved / 1024**2, 1),
|
|
217
|
+
"pruned": pruned,
|
|
218
|
+
"dry_run": dry_run,
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def compute_disk_stats(checkpoints: list[dict]) -> dict:
|
|
223
|
+
"""Compute disk usage statistics for checkpoints."""
|
|
224
|
+
if not checkpoints:
|
|
225
|
+
return {"total_count": 0, "total_size_mb": 0, "avg_size_mb": 0}
|
|
226
|
+
|
|
227
|
+
total_bytes = sum(c["size_bytes"] for c in checkpoints)
|
|
228
|
+
return {
|
|
229
|
+
"total_count": len(checkpoints),
|
|
230
|
+
"total_size_mb": round(total_bytes / 1024**2, 1),
|
|
231
|
+
"total_size_gb": round(total_bytes / 1024**3, 2),
|
|
232
|
+
"avg_size_mb": round(total_bytes / len(checkpoints) / 1024**2, 1),
|
|
233
|
+
"largest": max(checkpoints, key=lambda c: c["size_bytes"])["name"] if checkpoints else None,
|
|
234
|
+
"largest_mb": max(c["size_mb"] for c in checkpoints) if checkpoints else 0,
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def format_checkpoint_list(
|
|
239
|
+
checkpoints: list[dict],
|
|
240
|
+
pareto_ids: set[str],
|
|
241
|
+
primary_metric: str,
|
|
242
|
+
) -> str:
|
|
243
|
+
"""Format checkpoint list as markdown table."""
|
|
244
|
+
if not checkpoints:
|
|
245
|
+
return "No checkpoints found."
|
|
246
|
+
|
|
247
|
+
lines = [
|
|
248
|
+
"# Checkpoints",
|
|
249
|
+
"",
|
|
250
|
+
f"| Name | Experiment | {primary_metric} | Size | Pareto |",
|
|
251
|
+
f"|------|------------|{'---' * max(len(primary_metric) // 3, 1)}--|------|--------|",
|
|
252
|
+
]
|
|
253
|
+
|
|
254
|
+
for ckpt in checkpoints:
|
|
255
|
+
metric_val = ckpt["metrics"].get(primary_metric)
|
|
256
|
+
metric_str = f"{metric_val:.4f}" if isinstance(metric_val, (int, float)) else "N/A"
|
|
257
|
+
pareto_marker = "YES" if ckpt["path"] in pareto_ids or ckpt["name"] in pareto_ids else ""
|
|
258
|
+
lines.append(
|
|
259
|
+
f"| {ckpt['name']} | {ckpt['experiment_id']} "
|
|
260
|
+
f"| {metric_str} | {ckpt['size_mb']:.1f} MB | {pareto_marker} |"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
return "\n".join(lines)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def format_prune_report(prune_result: dict, stats_before: dict, stats_after: dict) -> str:
|
|
267
|
+
"""Format pruning report."""
|
|
268
|
+
lines = [
|
|
269
|
+
"# Checkpoint Pruning",
|
|
270
|
+
"",
|
|
271
|
+
f"Before: {stats_before['total_count']} checkpoints, {stats_before.get('total_size_gb', 0):.1f} GB",
|
|
272
|
+
]
|
|
273
|
+
|
|
274
|
+
if prune_result["dry_run"]:
|
|
275
|
+
lines.append(f"Would prune: {prune_result['pruned_count']} dominated checkpoints ({prune_result['mb_saved']:.1f} MB)")
|
|
276
|
+
lines.append("")
|
|
277
|
+
lines.append("*Dry run — no files deleted. Run without --dry-run to prune.*")
|
|
278
|
+
else:
|
|
279
|
+
lines.append(f"Pruned: {prune_result['pruned_count']} dominated checkpoints ({prune_result['mb_saved']:.1f} MB)")
|
|
280
|
+
lines.append(f"After: {stats_after['total_count']} checkpoints, {stats_after.get('total_size_gb', 0):.1f} GB")
|
|
281
|
+
|
|
282
|
+
if prune_result["pruned"]:
|
|
283
|
+
lines.extend(["", "## Removed", ""])
|
|
284
|
+
for p in prune_result["pruned"]:
|
|
285
|
+
lines.append(f"- {p['name']} ({p['experiment_id']}, {p['size_mb']:.1f} MB)")
|
|
286
|
+
|
|
287
|
+
return "\n".join(lines)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def format_stats_report(stats: dict, checkpoints: list[dict]) -> str:
|
|
291
|
+
"""Format disk usage statistics."""
|
|
292
|
+
lines = [
|
|
293
|
+
"# Checkpoint Storage",
|
|
294
|
+
"",
|
|
295
|
+
f"- **Total checkpoints:** {stats['total_count']}",
|
|
296
|
+
f"- **Total size:** {stats.get('total_size_gb', 0):.2f} GB ({stats['total_size_mb']:.1f} MB)",
|
|
297
|
+
f"- **Average size:** {stats['avg_size_mb']:.1f} MB",
|
|
298
|
+
]
|
|
299
|
+
if stats.get("largest"):
|
|
300
|
+
lines.append(f"- **Largest:** {stats['largest']} ({stats['largest_mb']:.1f} MB)")
|
|
301
|
+
|
|
302
|
+
# Size by model type
|
|
303
|
+
by_type: dict[str, list[dict]] = {}
|
|
304
|
+
for ckpt in checkpoints:
|
|
305
|
+
mt = ckpt.get("config", {}).get("model_type", "unknown")
|
|
306
|
+
by_type.setdefault(mt, []).append(ckpt)
|
|
307
|
+
|
|
308
|
+
if len(by_type) > 1:
|
|
309
|
+
lines.extend(["", "## By Model Type", ""])
|
|
310
|
+
for mt, ckpts in sorted(by_type.items()):
|
|
311
|
+
total_mb = sum(c["size_mb"] for c in ckpts)
|
|
312
|
+
lines.append(f"- **{mt}:** {len(ckpts)} checkpoints, {total_mb:.1f} MB")
|
|
313
|
+
|
|
314
|
+
return "\n".join(lines)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def save_checkpoint_report(report: dict, output_dir: str = "experiments/checkpoints") -> Path:
|
|
318
|
+
"""Save checkpoint management report."""
|
|
319
|
+
out_path = Path(output_dir)
|
|
320
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
321
|
+
filepath = out_path / "checkpoint-report.yaml"
|
|
322
|
+
with open(filepath, "w") as f:
|
|
323
|
+
yaml.dump(report, f, default_flow_style=False, sort_keys=False)
|
|
324
|
+
return filepath
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def main() -> None:
|
|
328
|
+
"""CLI entry point."""
|
|
329
|
+
parser = argparse.ArgumentParser(description="Smart checkpoint manager")
|
|
330
|
+
parser.add_argument(
|
|
331
|
+
"action",
|
|
332
|
+
choices=["list", "prune", "average", "resume", "stats"],
|
|
333
|
+
help="Action to perform",
|
|
334
|
+
)
|
|
335
|
+
parser.add_argument("exp_id", nargs="?", default=None, help="Experiment ID (for resume)")
|
|
336
|
+
parser.add_argument("--checkpoint-dir", default=DEFAULT_CHECKPOINT_DIR, help="Checkpoint directory")
|
|
337
|
+
parser.add_argument("--config", default="config.yaml", help="Path to config.yaml")
|
|
338
|
+
parser.add_argument("--log", default="experiments/log.jsonl", help="Path to experiment log")
|
|
339
|
+
parser.add_argument("--top", type=int, default=3, help="Top K for averaging")
|
|
340
|
+
parser.add_argument("--dry-run", action="store_true", help="Don't actually delete (prune)")
|
|
341
|
+
parser.add_argument("--json", action="store_true", help="Output raw JSON")
|
|
342
|
+
args = parser.parse_args()
|
|
343
|
+
|
|
344
|
+
config = load_config(args.config)
|
|
345
|
+
eval_cfg = config.get("evaluation", {})
|
|
346
|
+
primary_metric = eval_cfg.get("primary_metric", "accuracy")
|
|
347
|
+
lower_is_better = eval_cfg.get("lower_is_better", False)
|
|
348
|
+
|
|
349
|
+
experiments = load_experiments(args.log)
|
|
350
|
+
checkpoints = scan_checkpoints(args.checkpoint_dir)
|
|
351
|
+
checkpoints = enrich_checkpoints_from_log(checkpoints, experiments)
|
|
352
|
+
|
|
353
|
+
# Determine metric directions
|
|
354
|
+
lower_metrics = {"train_seconds", "latency", "mse", "rmse", "mae", "loss"}
|
|
355
|
+
metrics_to_check = [primary_metric]
|
|
356
|
+
if "train_seconds" in set().union(*(c["metrics"].keys() for c in checkpoints if c["metrics"])):
|
|
357
|
+
metrics_to_check.append("train_seconds")
|
|
358
|
+
|
|
359
|
+
directions = {}
|
|
360
|
+
for m in metrics_to_check:
|
|
361
|
+
if m == primary_metric:
|
|
362
|
+
directions[m] = "lower" if lower_is_better else "higher"
|
|
363
|
+
elif m in lower_metrics:
|
|
364
|
+
directions[m] = "lower"
|
|
365
|
+
else:
|
|
366
|
+
directions[m] = "higher"
|
|
367
|
+
|
|
368
|
+
pareto, dominated = compute_pareto_checkpoints(checkpoints, metrics_to_check, directions)
|
|
369
|
+
pareto_paths = {c["path"] for c in pareto}
|
|
370
|
+
|
|
371
|
+
if args.action == "list":
|
|
372
|
+
if args.json:
|
|
373
|
+
print(json.dumps(checkpoints, indent=2, default=str))
|
|
374
|
+
else:
|
|
375
|
+
print(format_checkpoint_list(checkpoints, pareto_paths, primary_metric))
|
|
376
|
+
|
|
377
|
+
elif args.action == "stats":
|
|
378
|
+
stats = compute_disk_stats(checkpoints)
|
|
379
|
+
if args.json:
|
|
380
|
+
print(json.dumps(stats, indent=2))
|
|
381
|
+
else:
|
|
382
|
+
print(format_stats_report(stats, checkpoints))
|
|
383
|
+
|
|
384
|
+
elif args.action == "prune":
|
|
385
|
+
stats_before = compute_disk_stats(checkpoints)
|
|
386
|
+
result = prune_dominated(checkpoints, dominated, dry_run=args.dry_run)
|
|
387
|
+
|
|
388
|
+
if not args.dry_run:
|
|
389
|
+
remaining = scan_checkpoints(args.checkpoint_dir)
|
|
390
|
+
stats_after = compute_disk_stats(remaining)
|
|
391
|
+
else:
|
|
392
|
+
stats_after = {
|
|
393
|
+
"total_count": stats_before["total_count"] - result["pruned_count"],
|
|
394
|
+
"total_size_mb": stats_before["total_size_mb"] - result["mb_saved"],
|
|
395
|
+
"total_size_gb": round((stats_before["total_size_mb"] - result["mb_saved"]) / 1024, 2),
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
if args.json:
|
|
399
|
+
print(json.dumps({"prune": result, "before": stats_before, "after": stats_after}, indent=2))
|
|
400
|
+
else:
|
|
401
|
+
print(format_prune_report(result, stats_before, stats_after))
|
|
402
|
+
|
|
403
|
+
elif args.action == "average":
|
|
404
|
+
# Sort by primary metric and take top K
|
|
405
|
+
with_metric = [c for c in checkpoints if primary_metric in c["metrics"]]
|
|
406
|
+
with_metric.sort(
|
|
407
|
+
key=lambda c: c["metrics"][primary_metric],
|
|
408
|
+
reverse=not lower_is_better,
|
|
409
|
+
)
|
|
410
|
+
top_k = with_metric[:args.top]
|
|
411
|
+
|
|
412
|
+
if not top_k:
|
|
413
|
+
print("No checkpoints with metric data found.", file=sys.stderr)
|
|
414
|
+
sys.exit(1)
|
|
415
|
+
|
|
416
|
+
print(f"Top {len(top_k)} checkpoints for averaging:", file=sys.stderr)
|
|
417
|
+
for c in top_k:
|
|
418
|
+
print(f" {c['experiment_id']}: {primary_metric}={c['metrics'][primary_metric]:.4f}", file=sys.stderr)
|
|
419
|
+
print("\nNote: Weight averaging requires model-specific implementation.", file=sys.stderr)
|
|
420
|
+
print("The checkpoint paths are:", file=sys.stderr)
|
|
421
|
+
for c in top_k:
|
|
422
|
+
print(f" {c['path']}", file=sys.stderr)
|
|
423
|
+
|
|
424
|
+
elif args.action == "resume":
|
|
425
|
+
if not args.exp_id:
|
|
426
|
+
print("Specify experiment ID: checkpoint resume <exp-id>", file=sys.stderr)
|
|
427
|
+
sys.exit(1)
|
|
428
|
+
|
|
429
|
+
target = None
|
|
430
|
+
for c in checkpoints:
|
|
431
|
+
if c["experiment_id"] == args.exp_id:
|
|
432
|
+
target = c
|
|
433
|
+
break
|
|
434
|
+
|
|
435
|
+
if not target:
|
|
436
|
+
print(f"No checkpoint found for {args.exp_id}", file=sys.stderr)
|
|
437
|
+
available = [c["experiment_id"] for c in checkpoints]
|
|
438
|
+
if available:
|
|
439
|
+
print(f"Available: {', '.join(available[:10])}", file=sys.stderr)
|
|
440
|
+
sys.exit(1)
|
|
441
|
+
|
|
442
|
+
print(f"Resume checkpoint: {target['path']}", file=sys.stderr)
|
|
443
|
+
print(f"Experiment: {target['experiment_id']}", file=sys.stderr)
|
|
444
|
+
if target["metrics"]:
|
|
445
|
+
print(f"Metrics: {target['metrics']}", file=sys.stderr)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
if __name__ == "__main__":
|
|
449
|
+
main()
|