claude-turing 2.2.1 → 2.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/plugin.json +2 -2
- package/README.md +8 -2
- package/commands/diff.md +48 -0
- package/commands/ensemble.md +54 -0
- package/commands/regress.md +53 -0
- package/commands/stitch.md +49 -0
- package/commands/turing.md +12 -0
- package/commands/warm.md +53 -0
- package/commands/watch.md +60 -0
- package/config/watch_alerts.yaml +36 -0
- package/package.json +1 -1
- package/src/install.js +3 -0
- package/src/verify.js +7 -0
- package/templates/scripts/__pycache__/build_ensemble.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/experiment_diff.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/generate_brief.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/pipeline_manager.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/regression_gate.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/scaffold.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/training_monitor.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/warm_start.cpython-314.pyc +0 -0
- package/templates/scripts/build_ensemble.py +696 -0
- package/templates/scripts/experiment_diff.py +703 -0
- package/templates/scripts/generate_brief.py +79 -0
- package/templates/scripts/pipeline_manager.py +457 -0
- package/templates/scripts/regression_gate.py +536 -0
- package/templates/scripts/scaffold.py +12 -0
- package/templates/scripts/training_monitor.py +611 -0
- package/templates/scripts/warm_start.py +493 -0
|
@@ -0,0 +1,611 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Live training monitor for the autoresearch pipeline.
|
|
3
|
+
|
|
4
|
+
Streams metrics during a training run with early-warning alerts:
|
|
5
|
+
loss spikes, gradient explosion, learning rate too aggressive,
|
|
6
|
+
train/val gap widening. Catches problems early instead of at the end.
|
|
7
|
+
|
|
8
|
+
Can tail a run.log file or read completed logs for post-hoc analysis.
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
python scripts/training_monitor.py # Monitor run.log
|
|
12
|
+
python scripts/training_monitor.py --log custom.log # Custom log file
|
|
13
|
+
python scripts/training_monitor.py --alerts # Show only alerts
|
|
14
|
+
python scripts/training_monitor.py --interval 5 # Check every 5 seconds
|
|
15
|
+
python scripts/training_monitor.py --analyze run.log # Post-hoc analysis
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import argparse
|
|
21
|
+
import json
|
|
22
|
+
import math
|
|
23
|
+
import sys
|
|
24
|
+
import time
|
|
25
|
+
from collections import deque
|
|
26
|
+
from datetime import datetime, timezone
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
|
|
29
|
+
import yaml
|
|
30
|
+
|
|
31
|
+
from scripts.turing_io import load_config
|
|
32
|
+
|
|
33
|
+
DEFAULT_LOG_PATH = "run.log"
|
|
34
|
+
DEFAULT_ALERT_CONFIG = "config/watch_alerts.yaml"
|
|
35
|
+
DEFAULT_INTERVAL = 10 # seconds
|
|
36
|
+
ROLLING_WINDOW = 10 # epochs for rolling statistics
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# --- Metric Parsing ---
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def parse_epoch_metrics(line: str) -> dict | None:
|
|
43
|
+
"""Parse a single log line into epoch metrics.
|
|
44
|
+
|
|
45
|
+
Supports formats:
|
|
46
|
+
JSON: {"epoch": 1, "loss": 0.5, "val_loss": 0.6, ...}
|
|
47
|
+
KV: epoch=1 loss=0.5 val_loss=0.6
|
|
48
|
+
CSV: epoch,loss,val_loss\\n1,0.5,0.6
|
|
49
|
+
"""
|
|
50
|
+
line = line.strip()
|
|
51
|
+
if not line:
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
# JSON format
|
|
55
|
+
if line.startswith("{"):
|
|
56
|
+
try:
|
|
57
|
+
data = json.loads(line)
|
|
58
|
+
if "epoch" in data:
|
|
59
|
+
return data
|
|
60
|
+
except json.JSONDecodeError:
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
# Key=value format
|
|
64
|
+
if "epoch=" in line or "epoch:" in line:
|
|
65
|
+
metrics = {}
|
|
66
|
+
# Handle both = and : separators
|
|
67
|
+
parts = line.replace(":", "=").split()
|
|
68
|
+
for part in parts:
|
|
69
|
+
if "=" in part:
|
|
70
|
+
key, _, val = part.partition("=")
|
|
71
|
+
key = key.strip()
|
|
72
|
+
val = val.strip()
|
|
73
|
+
try:
|
|
74
|
+
metrics[key] = int(val) if key == "epoch" else float(val)
|
|
75
|
+
except ValueError:
|
|
76
|
+
metrics[key] = val
|
|
77
|
+
if "epoch" in metrics:
|
|
78
|
+
return metrics
|
|
79
|
+
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def parse_log_file(log_path: str) -> list[dict]:
|
|
84
|
+
"""Parse all epoch metrics from a log file."""
|
|
85
|
+
path = Path(log_path)
|
|
86
|
+
if not path.exists():
|
|
87
|
+
return []
|
|
88
|
+
|
|
89
|
+
metrics = []
|
|
90
|
+
with open(path) as f:
|
|
91
|
+
for line in f:
|
|
92
|
+
parsed = parse_epoch_metrics(line)
|
|
93
|
+
if parsed is not None:
|
|
94
|
+
metrics.append(parsed)
|
|
95
|
+
return metrics
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# --- Rolling Statistics ---
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def compute_rolling_stats(
|
|
102
|
+
history: list[dict],
|
|
103
|
+
metric: str,
|
|
104
|
+
window: int = ROLLING_WINDOW,
|
|
105
|
+
) -> dict:
|
|
106
|
+
"""Compute rolling statistics for a metric.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Dict with mean, std, trend (slope), min, max over the window.
|
|
110
|
+
"""
|
|
111
|
+
values = [
|
|
112
|
+
h.get(metric) for h in history[-window:]
|
|
113
|
+
if h.get(metric) is not None and not (isinstance(h.get(metric), float) and math.isnan(h.get(metric)))
|
|
114
|
+
]
|
|
115
|
+
if not values:
|
|
116
|
+
return {}
|
|
117
|
+
|
|
118
|
+
n = len(values)
|
|
119
|
+
mean = sum(values) / n
|
|
120
|
+
|
|
121
|
+
if n >= 2:
|
|
122
|
+
variance = sum((v - mean) ** 2 for v in values) / (n - 1)
|
|
123
|
+
std = math.sqrt(variance)
|
|
124
|
+
|
|
125
|
+
# Simple linear trend (slope)
|
|
126
|
+
x_mean = (n - 1) / 2
|
|
127
|
+
numerator = sum((i - x_mean) * (v - mean) for i, v in enumerate(values))
|
|
128
|
+
denominator = sum((i - x_mean) ** 2 for i in range(n))
|
|
129
|
+
trend = numerator / denominator if denominator > 0 else 0.0
|
|
130
|
+
else:
|
|
131
|
+
std = 0.0
|
|
132
|
+
trend = 0.0
|
|
133
|
+
|
|
134
|
+
return {
|
|
135
|
+
"mean": mean,
|
|
136
|
+
"std": std,
|
|
137
|
+
"trend": trend,
|
|
138
|
+
"min": min(values),
|
|
139
|
+
"max": max(values),
|
|
140
|
+
"n": n,
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# --- Alert Rules ---
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def load_alert_config(config_path: str = DEFAULT_ALERT_CONFIG) -> dict:
|
|
148
|
+
"""Load alert configuration from YAML."""
|
|
149
|
+
path = Path(config_path)
|
|
150
|
+
if not path.exists():
|
|
151
|
+
return default_alert_config()
|
|
152
|
+
|
|
153
|
+
with open(path) as f:
|
|
154
|
+
data = yaml.safe_load(f)
|
|
155
|
+
return data if isinstance(data, dict) else default_alert_config()
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def default_alert_config() -> dict:
|
|
159
|
+
"""Return default alert rules."""
|
|
160
|
+
return {
|
|
161
|
+
"alerts": {
|
|
162
|
+
"loss_spike": {
|
|
163
|
+
"condition": "loss_spike",
|
|
164
|
+
"multiplier": 3.0,
|
|
165
|
+
"severity": "warning",
|
|
166
|
+
"message": "Loss spike at epoch {epoch}: {value} vs rolling mean {mean:.4f}",
|
|
167
|
+
},
|
|
168
|
+
"nan_detected": {
|
|
169
|
+
"condition": "nan_detected",
|
|
170
|
+
"severity": "critical",
|
|
171
|
+
"action": "pause",
|
|
172
|
+
"message": "NaN detected in {metric} at epoch {epoch}",
|
|
173
|
+
},
|
|
174
|
+
"overfitting_onset": {
|
|
175
|
+
"condition": "overfitting",
|
|
176
|
+
"gap_ratio": 0.5,
|
|
177
|
+
"consecutive": 3,
|
|
178
|
+
"severity": "warning",
|
|
179
|
+
"message": "Overfitting detected — train/val gap widening since epoch {onset}",
|
|
180
|
+
},
|
|
181
|
+
"plateau": {
|
|
182
|
+
"condition": "plateau",
|
|
183
|
+
"min_improvement": 0.001,
|
|
184
|
+
"consecutive": 5,
|
|
185
|
+
"severity": "info",
|
|
186
|
+
"message": "Metric plateaued — consider early stopping or LR reduction",
|
|
187
|
+
},
|
|
188
|
+
},
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def evaluate_alerts(
|
|
193
|
+
current: dict,
|
|
194
|
+
history: list[dict],
|
|
195
|
+
alert_config: dict,
|
|
196
|
+
) -> list[dict]:
|
|
197
|
+
"""Evaluate all alert rules against current state.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
current: Current epoch metrics.
|
|
201
|
+
history: All previous epoch metrics.
|
|
202
|
+
alert_config: Alert configuration dict.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
List of triggered alert dicts.
|
|
206
|
+
"""
|
|
207
|
+
alerts_config = alert_config.get("alerts", {})
|
|
208
|
+
triggered = []
|
|
209
|
+
|
|
210
|
+
for name, rule in alerts_config.items():
|
|
211
|
+
condition = rule.get("condition", name)
|
|
212
|
+
|
|
213
|
+
if condition == "loss_spike":
|
|
214
|
+
alert = _check_loss_spike(current, history, rule)
|
|
215
|
+
elif condition == "nan_detected":
|
|
216
|
+
alert = _check_nan(current, rule)
|
|
217
|
+
elif condition == "overfitting":
|
|
218
|
+
alert = _check_overfitting(current, history, rule)
|
|
219
|
+
elif condition == "plateau":
|
|
220
|
+
alert = _check_plateau(history, rule)
|
|
221
|
+
else:
|
|
222
|
+
continue
|
|
223
|
+
|
|
224
|
+
if alert:
|
|
225
|
+
alert["name"] = name
|
|
226
|
+
alert["severity"] = rule.get("severity", "info")
|
|
227
|
+
alert["action"] = rule.get("action")
|
|
228
|
+
triggered.append(alert)
|
|
229
|
+
|
|
230
|
+
return triggered
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _check_loss_spike(current: dict, history: list[dict], rule: dict) -> dict | None:
|
|
234
|
+
"""Check for sudden loss spikes."""
|
|
235
|
+
loss = current.get("loss") or current.get("train_loss")
|
|
236
|
+
if loss is None:
|
|
237
|
+
return None
|
|
238
|
+
|
|
239
|
+
rolling = compute_rolling_stats(history, "loss")
|
|
240
|
+
if not rolling or rolling["n"] < 3:
|
|
241
|
+
# Also try train_loss
|
|
242
|
+
rolling = compute_rolling_stats(history, "train_loss")
|
|
243
|
+
if not rolling or rolling["n"] < 3:
|
|
244
|
+
return None
|
|
245
|
+
|
|
246
|
+
multiplier = rule.get("multiplier", 3.0)
|
|
247
|
+
mean = rolling["mean"]
|
|
248
|
+
|
|
249
|
+
if mean > 0 and loss > multiplier * mean:
|
|
250
|
+
msg = rule.get("message", "Loss spike detected").format(
|
|
251
|
+
epoch=current.get("epoch", "?"),
|
|
252
|
+
value=loss,
|
|
253
|
+
mean=mean,
|
|
254
|
+
)
|
|
255
|
+
return {"message": msg, "epoch": current.get("epoch"), "value": loss, "mean": mean}
|
|
256
|
+
|
|
257
|
+
return None
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _check_nan(current: dict, rule: dict) -> dict | None:
|
|
261
|
+
"""Check for NaN values in any metric."""
|
|
262
|
+
for key, val in current.items():
|
|
263
|
+
if key == "epoch":
|
|
264
|
+
continue
|
|
265
|
+
if isinstance(val, float) and math.isnan(val):
|
|
266
|
+
msg = rule.get("message", "NaN detected").format(
|
|
267
|
+
metric=key,
|
|
268
|
+
epoch=current.get("epoch", "?"),
|
|
269
|
+
)
|
|
270
|
+
return {"message": msg, "epoch": current.get("epoch"), "metric": key}
|
|
271
|
+
return None
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _check_overfitting(current: dict, history: list[dict], rule: dict) -> dict | None:
|
|
275
|
+
"""Check for train/val gap widening."""
|
|
276
|
+
gap_ratio = rule.get("gap_ratio", 0.5)
|
|
277
|
+
consecutive_required = rule.get("consecutive", 3)
|
|
278
|
+
|
|
279
|
+
# Compute train/val gap over recent history
|
|
280
|
+
gaps = []
|
|
281
|
+
for entry in history + [current]:
|
|
282
|
+
train_loss = entry.get("loss") or entry.get("train_loss")
|
|
283
|
+
val_loss = entry.get("val_loss")
|
|
284
|
+
if train_loss is not None and val_loss is not None:
|
|
285
|
+
gaps.append({
|
|
286
|
+
"epoch": entry.get("epoch"),
|
|
287
|
+
"gap": val_loss - train_loss,
|
|
288
|
+
"ratio": train_loss / val_loss if val_loss != 0 else 0,
|
|
289
|
+
})
|
|
290
|
+
|
|
291
|
+
if len(gaps) < consecutive_required + 1:
|
|
292
|
+
return None
|
|
293
|
+
|
|
294
|
+
# Check if gap is widening for N consecutive epochs
|
|
295
|
+
recent = gaps[-consecutive_required:]
|
|
296
|
+
widening = all(
|
|
297
|
+
recent[i]["gap"] > recent[i - 1]["gap"]
|
|
298
|
+
for i in range(1, len(recent))
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if widening and recent[-1]["ratio"] < gap_ratio:
|
|
302
|
+
onset = recent[0]["epoch"]
|
|
303
|
+
msg = rule.get("message", "Overfitting detected").format(
|
|
304
|
+
onset=onset,
|
|
305
|
+
)
|
|
306
|
+
return {"message": msg, "onset": onset, "current_gap": recent[-1]["gap"]}
|
|
307
|
+
|
|
308
|
+
return None
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def _check_plateau(history: list[dict], rule: dict) -> dict | None:
|
|
312
|
+
"""Check for metric plateau."""
|
|
313
|
+
min_improvement = rule.get("min_improvement", 0.001)
|
|
314
|
+
consecutive_required = rule.get("consecutive", 5)
|
|
315
|
+
|
|
316
|
+
if len(history) < consecutive_required:
|
|
317
|
+
return None
|
|
318
|
+
|
|
319
|
+
# Check val_loss or accuracy for plateau
|
|
320
|
+
for metric in ("val_loss", "val_accuracy", "accuracy", "loss"):
|
|
321
|
+
values = [h.get(metric) for h in history[-consecutive_required:] if h.get(metric) is not None]
|
|
322
|
+
if len(values) < consecutive_required:
|
|
323
|
+
continue
|
|
324
|
+
|
|
325
|
+
improvements = [abs(values[i] - values[i - 1]) for i in range(1, len(values))]
|
|
326
|
+
if all(imp < min_improvement for imp in improvements):
|
|
327
|
+
msg = rule.get("message", "Metric plateaued")
|
|
328
|
+
return {"message": msg, "metric": metric, "n_flat_epochs": len(values)}
|
|
329
|
+
|
|
330
|
+
return None
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
# --- Dashboard Formatting ---
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def format_dashboard_line(current: dict, rolling_loss: dict, alerts: list[dict]) -> str:
|
|
337
|
+
"""Format a compact single-line dashboard.
|
|
338
|
+
|
|
339
|
+
Example: Epoch 23/100 | loss: 0.342 ↓ | acc: 0.865 ↑ | gap: 0.018 | ⚠ plateau (5 epochs)
|
|
340
|
+
"""
|
|
341
|
+
epoch = current.get("epoch", "?")
|
|
342
|
+
total_epochs = current.get("total_epochs") or current.get("n_epochs", "?")
|
|
343
|
+
|
|
344
|
+
parts = [f"Epoch {epoch}/{total_epochs}"]
|
|
345
|
+
|
|
346
|
+
# Loss with trend arrow
|
|
347
|
+
loss = current.get("loss") or current.get("train_loss")
|
|
348
|
+
if loss is not None:
|
|
349
|
+
arrow = ""
|
|
350
|
+
if rolling_loss and rolling_loss.get("trend") is not None:
|
|
351
|
+
arrow = " ↓" if rolling_loss["trend"] < 0 else " ↑" if rolling_loss["trend"] > 0 else ""
|
|
352
|
+
if math.isnan(loss):
|
|
353
|
+
parts.append("loss: NaN")
|
|
354
|
+
else:
|
|
355
|
+
parts.append(f"loss: {loss:.4f}{arrow}")
|
|
356
|
+
|
|
357
|
+
# Accuracy/val metric
|
|
358
|
+
for metric in ("accuracy", "val_accuracy", "val_loss"):
|
|
359
|
+
val = current.get(metric)
|
|
360
|
+
if val is not None and not (isinstance(val, float) and math.isnan(val)):
|
|
361
|
+
parts.append(f"{metric}: {val:.4f}")
|
|
362
|
+
|
|
363
|
+
# Train/val gap
|
|
364
|
+
train_loss = current.get("loss") or current.get("train_loss")
|
|
365
|
+
val_loss = current.get("val_loss")
|
|
366
|
+
if train_loss is not None and val_loss is not None:
|
|
367
|
+
if not (math.isnan(train_loss) or math.isnan(val_loss)):
|
|
368
|
+
gap = val_loss - train_loss
|
|
369
|
+
parts.append(f"gap: {gap:.4f}")
|
|
370
|
+
|
|
371
|
+
# Alert indicators
|
|
372
|
+
for alert in alerts:
|
|
373
|
+
severity = alert.get("severity", "info")
|
|
374
|
+
name = alert.get("name", "alert")
|
|
375
|
+
if severity == "critical":
|
|
376
|
+
parts.append(f"CRITICAL: {name}")
|
|
377
|
+
elif severity == "warning":
|
|
378
|
+
parts.append(f"WARNING: {name}")
|
|
379
|
+
else:
|
|
380
|
+
parts.append(f"info: {name}")
|
|
381
|
+
|
|
382
|
+
return " | ".join(parts)
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
# --- Analysis Report ---
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def analyze_training_log(
|
|
389
|
+
log_path: str,
|
|
390
|
+
alert_config_path: str = DEFAULT_ALERT_CONFIG,
|
|
391
|
+
config_path: str = "config.yaml",
|
|
392
|
+
) -> dict:
|
|
393
|
+
"""Analyze a completed training log for issues.
|
|
394
|
+
|
|
395
|
+
Returns a structured report with all alerts that would have
|
|
396
|
+
been triggered during training.
|
|
397
|
+
"""
|
|
398
|
+
metrics = parse_log_file(log_path)
|
|
399
|
+
if not metrics:
|
|
400
|
+
return {"error": f"No metrics found in {log_path}", "log_path": log_path}
|
|
401
|
+
|
|
402
|
+
alert_config = load_alert_config(alert_config_path)
|
|
403
|
+
|
|
404
|
+
all_alerts = []
|
|
405
|
+
for i, current in enumerate(metrics):
|
|
406
|
+
history = metrics[:i]
|
|
407
|
+
alerts = evaluate_alerts(current, history, alert_config)
|
|
408
|
+
all_alerts.extend(alerts)
|
|
409
|
+
|
|
410
|
+
# Compute overall statistics
|
|
411
|
+
loss_values = [m.get("loss") or m.get("train_loss") for m in metrics]
|
|
412
|
+
loss_values = [v for v in loss_values if v is not None and not math.isnan(v)]
|
|
413
|
+
|
|
414
|
+
val_loss_values = [m.get("val_loss") for m in metrics]
|
|
415
|
+
val_loss_values = [v for v in val_loss_values if v is not None and not math.isnan(v)]
|
|
416
|
+
|
|
417
|
+
report = {
|
|
418
|
+
"log_path": log_path,
|
|
419
|
+
"analyzed_at": datetime.now(timezone.utc).isoformat(),
|
|
420
|
+
"total_epochs": len(metrics),
|
|
421
|
+
"alerts": all_alerts,
|
|
422
|
+
"alert_summary": {
|
|
423
|
+
"total": len(all_alerts),
|
|
424
|
+
"critical": len([a for a in all_alerts if a.get("severity") == "critical"]),
|
|
425
|
+
"warning": len([a for a in all_alerts if a.get("severity") == "warning"]),
|
|
426
|
+
"info": len([a for a in all_alerts if a.get("severity") == "info"]),
|
|
427
|
+
},
|
|
428
|
+
"training_stats": {},
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
if loss_values:
|
|
432
|
+
report["training_stats"]["final_loss"] = loss_values[-1]
|
|
433
|
+
report["training_stats"]["min_loss"] = min(loss_values)
|
|
434
|
+
report["training_stats"]["loss_reduction"] = loss_values[0] - loss_values[-1] if len(loss_values) > 1 else 0
|
|
435
|
+
|
|
436
|
+
if val_loss_values:
|
|
437
|
+
report["training_stats"]["final_val_loss"] = val_loss_values[-1]
|
|
438
|
+
report["training_stats"]["min_val_loss"] = min(val_loss_values)
|
|
439
|
+
|
|
440
|
+
if loss_values and val_loss_values:
|
|
441
|
+
report["training_stats"]["final_gap"] = val_loss_values[-1] - loss_values[-1]
|
|
442
|
+
|
|
443
|
+
return report
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def format_analysis_report(report: dict) -> str:
|
|
447
|
+
"""Format analysis report as markdown."""
|
|
448
|
+
if "error" in report:
|
|
449
|
+
return f"ERROR: {report['error']}"
|
|
450
|
+
|
|
451
|
+
lines = [
|
|
452
|
+
"# Training Log Analysis",
|
|
453
|
+
"",
|
|
454
|
+
f"*Analyzed {report.get('analyzed_at', 'N/A')[:19]}*",
|
|
455
|
+
f"*Log: {report.get('log_path', 'N/A')}*",
|
|
456
|
+
"",
|
|
457
|
+
f"## Summary",
|
|
458
|
+
"",
|
|
459
|
+
f"- **Total epochs:** {report.get('total_epochs', 0)}",
|
|
460
|
+
]
|
|
461
|
+
|
|
462
|
+
stats = report.get("training_stats", {})
|
|
463
|
+
if stats.get("final_loss") is not None:
|
|
464
|
+
lines.append(f"- **Final loss:** {stats['final_loss']:.4f} (min: {stats.get('min_loss', 0):.4f})")
|
|
465
|
+
if stats.get("final_val_loss") is not None:
|
|
466
|
+
lines.append(f"- **Final val_loss:** {stats['final_val_loss']:.4f}")
|
|
467
|
+
if stats.get("final_gap") is not None:
|
|
468
|
+
lines.append(f"- **Train/val gap:** {stats['final_gap']:.4f}")
|
|
469
|
+
|
|
470
|
+
summary = report.get("alert_summary", {})
|
|
471
|
+
lines.extend([
|
|
472
|
+
"",
|
|
473
|
+
"## Alerts",
|
|
474
|
+
"",
|
|
475
|
+
f"- **Critical:** {summary.get('critical', 0)}",
|
|
476
|
+
f"- **Warning:** {summary.get('warning', 0)}",
|
|
477
|
+
f"- **Info:** {summary.get('info', 0)}",
|
|
478
|
+
])
|
|
479
|
+
|
|
480
|
+
alerts = report.get("alerts", [])
|
|
481
|
+
if alerts:
|
|
482
|
+
lines.extend(["", "### Details", ""])
|
|
483
|
+
for alert in alerts:
|
|
484
|
+
sev = alert.get("severity", "info").upper()
|
|
485
|
+
lines.append(f"- **[{sev}]** {alert.get('message', 'N/A')}")
|
|
486
|
+
else:
|
|
487
|
+
lines.extend(["", "No issues detected during training."])
|
|
488
|
+
|
|
489
|
+
return "\n".join(lines)
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def save_analysis_report(report: dict, output_dir: str = "experiments/monitors") -> Path:
|
|
493
|
+
"""Save analysis report to YAML."""
|
|
494
|
+
out_path = Path(output_dir)
|
|
495
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
496
|
+
|
|
497
|
+
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
|
498
|
+
filepath = out_path / f"analysis-{timestamp}.yaml"
|
|
499
|
+
|
|
500
|
+
with open(filepath, "w") as f:
|
|
501
|
+
yaml.dump(report, f, default_flow_style=False, sort_keys=False)
|
|
502
|
+
|
|
503
|
+
return filepath
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def main() -> None:
|
|
507
|
+
"""CLI entry point."""
|
|
508
|
+
parser = argparse.ArgumentParser(
|
|
509
|
+
description="Live training monitor with early-warning alerts",
|
|
510
|
+
)
|
|
511
|
+
parser.add_argument(
|
|
512
|
+
"--log", default=DEFAULT_LOG_PATH,
|
|
513
|
+
help=f"Path to training log file (default: {DEFAULT_LOG_PATH})",
|
|
514
|
+
)
|
|
515
|
+
parser.add_argument(
|
|
516
|
+
"--alerts-config", default=DEFAULT_ALERT_CONFIG,
|
|
517
|
+
help=f"Path to alert config YAML (default: {DEFAULT_ALERT_CONFIG})",
|
|
518
|
+
)
|
|
519
|
+
parser.add_argument(
|
|
520
|
+
"--config", default="config.yaml",
|
|
521
|
+
help="Path to config.yaml",
|
|
522
|
+
)
|
|
523
|
+
parser.add_argument(
|
|
524
|
+
"--interval", type=int, default=DEFAULT_INTERVAL,
|
|
525
|
+
help=f"Check interval in seconds (default: {DEFAULT_INTERVAL})",
|
|
526
|
+
)
|
|
527
|
+
parser.add_argument(
|
|
528
|
+
"--alerts", action="store_true",
|
|
529
|
+
help="Show only alerts, suppress normal output",
|
|
530
|
+
)
|
|
531
|
+
parser.add_argument(
|
|
532
|
+
"--analyze", metavar="LOG_FILE",
|
|
533
|
+
help="Post-hoc analysis of a completed training log",
|
|
534
|
+
)
|
|
535
|
+
parser.add_argument(
|
|
536
|
+
"--json", action="store_true",
|
|
537
|
+
help="Output raw JSON instead of formatted report",
|
|
538
|
+
)
|
|
539
|
+
args = parser.parse_args()
|
|
540
|
+
|
|
541
|
+
if args.analyze:
|
|
542
|
+
# Post-hoc analysis mode
|
|
543
|
+
report = analyze_training_log(
|
|
544
|
+
args.analyze,
|
|
545
|
+
alert_config_path=args.alerts_config,
|
|
546
|
+
config_path=args.config,
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
if "error" not in report:
|
|
550
|
+
filepath = save_analysis_report(report)
|
|
551
|
+
print(f"Saved to {filepath}", file=sys.stderr)
|
|
552
|
+
|
|
553
|
+
if args.json:
|
|
554
|
+
print(json.dumps(report, indent=2, default=str))
|
|
555
|
+
else:
|
|
556
|
+
print(format_analysis_report(report))
|
|
557
|
+
|
|
558
|
+
if report.get("alert_summary", {}).get("critical", 0) > 0:
|
|
559
|
+
sys.exit(1)
|
|
560
|
+
return
|
|
561
|
+
|
|
562
|
+
# Live monitoring mode
|
|
563
|
+
log_path = Path(args.log)
|
|
564
|
+
alert_config = load_alert_config(args.alerts_config)
|
|
565
|
+
|
|
566
|
+
print(f"Monitoring {log_path} (interval: {args.interval}s)", file=sys.stderr)
|
|
567
|
+
print("Press Ctrl+C to stop.", file=sys.stderr)
|
|
568
|
+
print(file=sys.stderr)
|
|
569
|
+
|
|
570
|
+
history: list[dict] = []
|
|
571
|
+
last_line_count = 0
|
|
572
|
+
|
|
573
|
+
try:
|
|
574
|
+
while True:
|
|
575
|
+
if not log_path.exists():
|
|
576
|
+
time.sleep(args.interval)
|
|
577
|
+
continue
|
|
578
|
+
|
|
579
|
+
with open(log_path) as f:
|
|
580
|
+
lines = f.readlines()
|
|
581
|
+
|
|
582
|
+
new_lines = lines[last_line_count:]
|
|
583
|
+
last_line_count = len(lines)
|
|
584
|
+
|
|
585
|
+
for line in new_lines:
|
|
586
|
+
parsed = parse_epoch_metrics(line)
|
|
587
|
+
if parsed is None:
|
|
588
|
+
continue
|
|
589
|
+
|
|
590
|
+
alerts = evaluate_alerts(parsed, history, alert_config)
|
|
591
|
+
rolling = compute_rolling_stats(history, "loss")
|
|
592
|
+
|
|
593
|
+
if not args.alerts or alerts:
|
|
594
|
+
dashboard = format_dashboard_line(parsed, rolling, alerts)
|
|
595
|
+
print(dashboard)
|
|
596
|
+
|
|
597
|
+
# Handle critical alerts with pause action
|
|
598
|
+
for alert in alerts:
|
|
599
|
+
if alert.get("action") == "pause":
|
|
600
|
+
print(f"\nCRITICAL: {alert['message']}", file=sys.stderr)
|
|
601
|
+
print("Training should be paused.", file=sys.stderr)
|
|
602
|
+
|
|
603
|
+
history.append(parsed)
|
|
604
|
+
|
|
605
|
+
time.sleep(args.interval)
|
|
606
|
+
except KeyboardInterrupt:
|
|
607
|
+
print(f"\nMonitoring stopped. {len(history)} epochs observed.", file=sys.stderr)
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
if __name__ == "__main__":
|
|
611
|
+
main()
|