claude-turing 2.2.1 → 2.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,611 @@
1
+ #!/usr/bin/env python3
2
+ """Live training monitor for the autoresearch pipeline.
3
+
4
+ Streams metrics during a training run with early-warning alerts:
5
+ loss spikes, gradient explosion, learning rate too aggressive,
6
+ train/val gap widening. Catches problems early instead of at the end.
7
+
8
+ Can tail a run.log file or read completed logs for post-hoc analysis.
9
+
10
+ Usage:
11
+ python scripts/training_monitor.py # Monitor run.log
12
+ python scripts/training_monitor.py --log custom.log # Custom log file
13
+ python scripts/training_monitor.py --alerts # Show only alerts
14
+ python scripts/training_monitor.py --interval 5 # Check every 5 seconds
15
+ python scripts/training_monitor.py --analyze run.log # Post-hoc analysis
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import json
22
+ import math
23
+ import sys
24
+ import time
25
+ from collections import deque
26
+ from datetime import datetime, timezone
27
+ from pathlib import Path
28
+
29
+ import yaml
30
+
31
+ from scripts.turing_io import load_config
32
+
33
+ DEFAULT_LOG_PATH = "run.log"
34
+ DEFAULT_ALERT_CONFIG = "config/watch_alerts.yaml"
35
+ DEFAULT_INTERVAL = 10 # seconds
36
+ ROLLING_WINDOW = 10 # epochs for rolling statistics
37
+
38
+
39
+ # --- Metric Parsing ---
40
+
41
+
42
+ def parse_epoch_metrics(line: str) -> dict | None:
43
+ """Parse a single log line into epoch metrics.
44
+
45
+ Supports formats:
46
+ JSON: {"epoch": 1, "loss": 0.5, "val_loss": 0.6, ...}
47
+ KV: epoch=1 loss=0.5 val_loss=0.6
48
+ CSV: epoch,loss,val_loss\\n1,0.5,0.6
49
+ """
50
+ line = line.strip()
51
+ if not line:
52
+ return None
53
+
54
+ # JSON format
55
+ if line.startswith("{"):
56
+ try:
57
+ data = json.loads(line)
58
+ if "epoch" in data:
59
+ return data
60
+ except json.JSONDecodeError:
61
+ pass
62
+
63
+ # Key=value format
64
+ if "epoch=" in line or "epoch:" in line:
65
+ metrics = {}
66
+ # Handle both = and : separators
67
+ parts = line.replace(":", "=").split()
68
+ for part in parts:
69
+ if "=" in part:
70
+ key, _, val = part.partition("=")
71
+ key = key.strip()
72
+ val = val.strip()
73
+ try:
74
+ metrics[key] = int(val) if key == "epoch" else float(val)
75
+ except ValueError:
76
+ metrics[key] = val
77
+ if "epoch" in metrics:
78
+ return metrics
79
+
80
+ return None
81
+
82
+
83
+ def parse_log_file(log_path: str) -> list[dict]:
84
+ """Parse all epoch metrics from a log file."""
85
+ path = Path(log_path)
86
+ if not path.exists():
87
+ return []
88
+
89
+ metrics = []
90
+ with open(path) as f:
91
+ for line in f:
92
+ parsed = parse_epoch_metrics(line)
93
+ if parsed is not None:
94
+ metrics.append(parsed)
95
+ return metrics
96
+
97
+
98
+ # --- Rolling Statistics ---
99
+
100
+
101
+ def compute_rolling_stats(
102
+ history: list[dict],
103
+ metric: str,
104
+ window: int = ROLLING_WINDOW,
105
+ ) -> dict:
106
+ """Compute rolling statistics for a metric.
107
+
108
+ Returns:
109
+ Dict with mean, std, trend (slope), min, max over the window.
110
+ """
111
+ values = [
112
+ h.get(metric) for h in history[-window:]
113
+ if h.get(metric) is not None and not (isinstance(h.get(metric), float) and math.isnan(h.get(metric)))
114
+ ]
115
+ if not values:
116
+ return {}
117
+
118
+ n = len(values)
119
+ mean = sum(values) / n
120
+
121
+ if n >= 2:
122
+ variance = sum((v - mean) ** 2 for v in values) / (n - 1)
123
+ std = math.sqrt(variance)
124
+
125
+ # Simple linear trend (slope)
126
+ x_mean = (n - 1) / 2
127
+ numerator = sum((i - x_mean) * (v - mean) for i, v in enumerate(values))
128
+ denominator = sum((i - x_mean) ** 2 for i in range(n))
129
+ trend = numerator / denominator if denominator > 0 else 0.0
130
+ else:
131
+ std = 0.0
132
+ trend = 0.0
133
+
134
+ return {
135
+ "mean": mean,
136
+ "std": std,
137
+ "trend": trend,
138
+ "min": min(values),
139
+ "max": max(values),
140
+ "n": n,
141
+ }
142
+
143
+
144
+ # --- Alert Rules ---
145
+
146
+
147
+ def load_alert_config(config_path: str = DEFAULT_ALERT_CONFIG) -> dict:
148
+ """Load alert configuration from YAML."""
149
+ path = Path(config_path)
150
+ if not path.exists():
151
+ return default_alert_config()
152
+
153
+ with open(path) as f:
154
+ data = yaml.safe_load(f)
155
+ return data if isinstance(data, dict) else default_alert_config()
156
+
157
+
158
+ def default_alert_config() -> dict:
159
+ """Return default alert rules."""
160
+ return {
161
+ "alerts": {
162
+ "loss_spike": {
163
+ "condition": "loss_spike",
164
+ "multiplier": 3.0,
165
+ "severity": "warning",
166
+ "message": "Loss spike at epoch {epoch}: {value} vs rolling mean {mean:.4f}",
167
+ },
168
+ "nan_detected": {
169
+ "condition": "nan_detected",
170
+ "severity": "critical",
171
+ "action": "pause",
172
+ "message": "NaN detected in {metric} at epoch {epoch}",
173
+ },
174
+ "overfitting_onset": {
175
+ "condition": "overfitting",
176
+ "gap_ratio": 0.5,
177
+ "consecutive": 3,
178
+ "severity": "warning",
179
+ "message": "Overfitting detected — train/val gap widening since epoch {onset}",
180
+ },
181
+ "plateau": {
182
+ "condition": "plateau",
183
+ "min_improvement": 0.001,
184
+ "consecutive": 5,
185
+ "severity": "info",
186
+ "message": "Metric plateaued — consider early stopping or LR reduction",
187
+ },
188
+ },
189
+ }
190
+
191
+
192
+ def evaluate_alerts(
193
+ current: dict,
194
+ history: list[dict],
195
+ alert_config: dict,
196
+ ) -> list[dict]:
197
+ """Evaluate all alert rules against current state.
198
+
199
+ Args:
200
+ current: Current epoch metrics.
201
+ history: All previous epoch metrics.
202
+ alert_config: Alert configuration dict.
203
+
204
+ Returns:
205
+ List of triggered alert dicts.
206
+ """
207
+ alerts_config = alert_config.get("alerts", {})
208
+ triggered = []
209
+
210
+ for name, rule in alerts_config.items():
211
+ condition = rule.get("condition", name)
212
+
213
+ if condition == "loss_spike":
214
+ alert = _check_loss_spike(current, history, rule)
215
+ elif condition == "nan_detected":
216
+ alert = _check_nan(current, rule)
217
+ elif condition == "overfitting":
218
+ alert = _check_overfitting(current, history, rule)
219
+ elif condition == "plateau":
220
+ alert = _check_plateau(history, rule)
221
+ else:
222
+ continue
223
+
224
+ if alert:
225
+ alert["name"] = name
226
+ alert["severity"] = rule.get("severity", "info")
227
+ alert["action"] = rule.get("action")
228
+ triggered.append(alert)
229
+
230
+ return triggered
231
+
232
+
233
+ def _check_loss_spike(current: dict, history: list[dict], rule: dict) -> dict | None:
234
+ """Check for sudden loss spikes."""
235
+ loss = current.get("loss") or current.get("train_loss")
236
+ if loss is None:
237
+ return None
238
+
239
+ rolling = compute_rolling_stats(history, "loss")
240
+ if not rolling or rolling["n"] < 3:
241
+ # Also try train_loss
242
+ rolling = compute_rolling_stats(history, "train_loss")
243
+ if not rolling or rolling["n"] < 3:
244
+ return None
245
+
246
+ multiplier = rule.get("multiplier", 3.0)
247
+ mean = rolling["mean"]
248
+
249
+ if mean > 0 and loss > multiplier * mean:
250
+ msg = rule.get("message", "Loss spike detected").format(
251
+ epoch=current.get("epoch", "?"),
252
+ value=loss,
253
+ mean=mean,
254
+ )
255
+ return {"message": msg, "epoch": current.get("epoch"), "value": loss, "mean": mean}
256
+
257
+ return None
258
+
259
+
260
+ def _check_nan(current: dict, rule: dict) -> dict | None:
261
+ """Check for NaN values in any metric."""
262
+ for key, val in current.items():
263
+ if key == "epoch":
264
+ continue
265
+ if isinstance(val, float) and math.isnan(val):
266
+ msg = rule.get("message", "NaN detected").format(
267
+ metric=key,
268
+ epoch=current.get("epoch", "?"),
269
+ )
270
+ return {"message": msg, "epoch": current.get("epoch"), "metric": key}
271
+ return None
272
+
273
+
274
+ def _check_overfitting(current: dict, history: list[dict], rule: dict) -> dict | None:
275
+ """Check for train/val gap widening."""
276
+ gap_ratio = rule.get("gap_ratio", 0.5)
277
+ consecutive_required = rule.get("consecutive", 3)
278
+
279
+ # Compute train/val gap over recent history
280
+ gaps = []
281
+ for entry in history + [current]:
282
+ train_loss = entry.get("loss") or entry.get("train_loss")
283
+ val_loss = entry.get("val_loss")
284
+ if train_loss is not None and val_loss is not None:
285
+ gaps.append({
286
+ "epoch": entry.get("epoch"),
287
+ "gap": val_loss - train_loss,
288
+ "ratio": train_loss / val_loss if val_loss != 0 else 0,
289
+ })
290
+
291
+ if len(gaps) < consecutive_required + 1:
292
+ return None
293
+
294
+ # Check if gap is widening for N consecutive epochs
295
+ recent = gaps[-consecutive_required:]
296
+ widening = all(
297
+ recent[i]["gap"] > recent[i - 1]["gap"]
298
+ for i in range(1, len(recent))
299
+ )
300
+
301
+ if widening and recent[-1]["ratio"] < gap_ratio:
302
+ onset = recent[0]["epoch"]
303
+ msg = rule.get("message", "Overfitting detected").format(
304
+ onset=onset,
305
+ )
306
+ return {"message": msg, "onset": onset, "current_gap": recent[-1]["gap"]}
307
+
308
+ return None
309
+
310
+
311
+ def _check_plateau(history: list[dict], rule: dict) -> dict | None:
312
+ """Check for metric plateau."""
313
+ min_improvement = rule.get("min_improvement", 0.001)
314
+ consecutive_required = rule.get("consecutive", 5)
315
+
316
+ if len(history) < consecutive_required:
317
+ return None
318
+
319
+ # Check val_loss or accuracy for plateau
320
+ for metric in ("val_loss", "val_accuracy", "accuracy", "loss"):
321
+ values = [h.get(metric) for h in history[-consecutive_required:] if h.get(metric) is not None]
322
+ if len(values) < consecutive_required:
323
+ continue
324
+
325
+ improvements = [abs(values[i] - values[i - 1]) for i in range(1, len(values))]
326
+ if all(imp < min_improvement for imp in improvements):
327
+ msg = rule.get("message", "Metric plateaued")
328
+ return {"message": msg, "metric": metric, "n_flat_epochs": len(values)}
329
+
330
+ return None
331
+
332
+
333
+ # --- Dashboard Formatting ---
334
+
335
+
336
+ def format_dashboard_line(current: dict, rolling_loss: dict, alerts: list[dict]) -> str:
337
+ """Format a compact single-line dashboard.
338
+
339
+ Example: Epoch 23/100 | loss: 0.342 ↓ | acc: 0.865 ↑ | gap: 0.018 | ⚠ plateau (5 epochs)
340
+ """
341
+ epoch = current.get("epoch", "?")
342
+ total_epochs = current.get("total_epochs") or current.get("n_epochs", "?")
343
+
344
+ parts = [f"Epoch {epoch}/{total_epochs}"]
345
+
346
+ # Loss with trend arrow
347
+ loss = current.get("loss") or current.get("train_loss")
348
+ if loss is not None:
349
+ arrow = ""
350
+ if rolling_loss and rolling_loss.get("trend") is not None:
351
+ arrow = " ↓" if rolling_loss["trend"] < 0 else " ↑" if rolling_loss["trend"] > 0 else ""
352
+ if math.isnan(loss):
353
+ parts.append("loss: NaN")
354
+ else:
355
+ parts.append(f"loss: {loss:.4f}{arrow}")
356
+
357
+ # Accuracy/val metric
358
+ for metric in ("accuracy", "val_accuracy", "val_loss"):
359
+ val = current.get(metric)
360
+ if val is not None and not (isinstance(val, float) and math.isnan(val)):
361
+ parts.append(f"{metric}: {val:.4f}")
362
+
363
+ # Train/val gap
364
+ train_loss = current.get("loss") or current.get("train_loss")
365
+ val_loss = current.get("val_loss")
366
+ if train_loss is not None and val_loss is not None:
367
+ if not (math.isnan(train_loss) or math.isnan(val_loss)):
368
+ gap = val_loss - train_loss
369
+ parts.append(f"gap: {gap:.4f}")
370
+
371
+ # Alert indicators
372
+ for alert in alerts:
373
+ severity = alert.get("severity", "info")
374
+ name = alert.get("name", "alert")
375
+ if severity == "critical":
376
+ parts.append(f"CRITICAL: {name}")
377
+ elif severity == "warning":
378
+ parts.append(f"WARNING: {name}")
379
+ else:
380
+ parts.append(f"info: {name}")
381
+
382
+ return " | ".join(parts)
383
+
384
+
385
+ # --- Analysis Report ---
386
+
387
+
388
+ def analyze_training_log(
389
+ log_path: str,
390
+ alert_config_path: str = DEFAULT_ALERT_CONFIG,
391
+ config_path: str = "config.yaml",
392
+ ) -> dict:
393
+ """Analyze a completed training log for issues.
394
+
395
+ Returns a structured report with all alerts that would have
396
+ been triggered during training.
397
+ """
398
+ metrics = parse_log_file(log_path)
399
+ if not metrics:
400
+ return {"error": f"No metrics found in {log_path}", "log_path": log_path}
401
+
402
+ alert_config = load_alert_config(alert_config_path)
403
+
404
+ all_alerts = []
405
+ for i, current in enumerate(metrics):
406
+ history = metrics[:i]
407
+ alerts = evaluate_alerts(current, history, alert_config)
408
+ all_alerts.extend(alerts)
409
+
410
+ # Compute overall statistics
411
+ loss_values = [m.get("loss") or m.get("train_loss") for m in metrics]
412
+ loss_values = [v for v in loss_values if v is not None and not math.isnan(v)]
413
+
414
+ val_loss_values = [m.get("val_loss") for m in metrics]
415
+ val_loss_values = [v for v in val_loss_values if v is not None and not math.isnan(v)]
416
+
417
+ report = {
418
+ "log_path": log_path,
419
+ "analyzed_at": datetime.now(timezone.utc).isoformat(),
420
+ "total_epochs": len(metrics),
421
+ "alerts": all_alerts,
422
+ "alert_summary": {
423
+ "total": len(all_alerts),
424
+ "critical": len([a for a in all_alerts if a.get("severity") == "critical"]),
425
+ "warning": len([a for a in all_alerts if a.get("severity") == "warning"]),
426
+ "info": len([a for a in all_alerts if a.get("severity") == "info"]),
427
+ },
428
+ "training_stats": {},
429
+ }
430
+
431
+ if loss_values:
432
+ report["training_stats"]["final_loss"] = loss_values[-1]
433
+ report["training_stats"]["min_loss"] = min(loss_values)
434
+ report["training_stats"]["loss_reduction"] = loss_values[0] - loss_values[-1] if len(loss_values) > 1 else 0
435
+
436
+ if val_loss_values:
437
+ report["training_stats"]["final_val_loss"] = val_loss_values[-1]
438
+ report["training_stats"]["min_val_loss"] = min(val_loss_values)
439
+
440
+ if loss_values and val_loss_values:
441
+ report["training_stats"]["final_gap"] = val_loss_values[-1] - loss_values[-1]
442
+
443
+ return report
444
+
445
+
446
+ def format_analysis_report(report: dict) -> str:
447
+ """Format analysis report as markdown."""
448
+ if "error" in report:
449
+ return f"ERROR: {report['error']}"
450
+
451
+ lines = [
452
+ "# Training Log Analysis",
453
+ "",
454
+ f"*Analyzed {report.get('analyzed_at', 'N/A')[:19]}*",
455
+ f"*Log: {report.get('log_path', 'N/A')}*",
456
+ "",
457
+ f"## Summary",
458
+ "",
459
+ f"- **Total epochs:** {report.get('total_epochs', 0)}",
460
+ ]
461
+
462
+ stats = report.get("training_stats", {})
463
+ if stats.get("final_loss") is not None:
464
+ lines.append(f"- **Final loss:** {stats['final_loss']:.4f} (min: {stats.get('min_loss', 0):.4f})")
465
+ if stats.get("final_val_loss") is not None:
466
+ lines.append(f"- **Final val_loss:** {stats['final_val_loss']:.4f}")
467
+ if stats.get("final_gap") is not None:
468
+ lines.append(f"- **Train/val gap:** {stats['final_gap']:.4f}")
469
+
470
+ summary = report.get("alert_summary", {})
471
+ lines.extend([
472
+ "",
473
+ "## Alerts",
474
+ "",
475
+ f"- **Critical:** {summary.get('critical', 0)}",
476
+ f"- **Warning:** {summary.get('warning', 0)}",
477
+ f"- **Info:** {summary.get('info', 0)}",
478
+ ])
479
+
480
+ alerts = report.get("alerts", [])
481
+ if alerts:
482
+ lines.extend(["", "### Details", ""])
483
+ for alert in alerts:
484
+ sev = alert.get("severity", "info").upper()
485
+ lines.append(f"- **[{sev}]** {alert.get('message', 'N/A')}")
486
+ else:
487
+ lines.extend(["", "No issues detected during training."])
488
+
489
+ return "\n".join(lines)
490
+
491
+
492
+ def save_analysis_report(report: dict, output_dir: str = "experiments/monitors") -> Path:
493
+ """Save analysis report to YAML."""
494
+ out_path = Path(output_dir)
495
+ out_path.mkdir(parents=True, exist_ok=True)
496
+
497
+ timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
498
+ filepath = out_path / f"analysis-{timestamp}.yaml"
499
+
500
+ with open(filepath, "w") as f:
501
+ yaml.dump(report, f, default_flow_style=False, sort_keys=False)
502
+
503
+ return filepath
504
+
505
+
506
+ def main() -> None:
507
+ """CLI entry point."""
508
+ parser = argparse.ArgumentParser(
509
+ description="Live training monitor with early-warning alerts",
510
+ )
511
+ parser.add_argument(
512
+ "--log", default=DEFAULT_LOG_PATH,
513
+ help=f"Path to training log file (default: {DEFAULT_LOG_PATH})",
514
+ )
515
+ parser.add_argument(
516
+ "--alerts-config", default=DEFAULT_ALERT_CONFIG,
517
+ help=f"Path to alert config YAML (default: {DEFAULT_ALERT_CONFIG})",
518
+ )
519
+ parser.add_argument(
520
+ "--config", default="config.yaml",
521
+ help="Path to config.yaml",
522
+ )
523
+ parser.add_argument(
524
+ "--interval", type=int, default=DEFAULT_INTERVAL,
525
+ help=f"Check interval in seconds (default: {DEFAULT_INTERVAL})",
526
+ )
527
+ parser.add_argument(
528
+ "--alerts", action="store_true",
529
+ help="Show only alerts, suppress normal output",
530
+ )
531
+ parser.add_argument(
532
+ "--analyze", metavar="LOG_FILE",
533
+ help="Post-hoc analysis of a completed training log",
534
+ )
535
+ parser.add_argument(
536
+ "--json", action="store_true",
537
+ help="Output raw JSON instead of formatted report",
538
+ )
539
+ args = parser.parse_args()
540
+
541
+ if args.analyze:
542
+ # Post-hoc analysis mode
543
+ report = analyze_training_log(
544
+ args.analyze,
545
+ alert_config_path=args.alerts_config,
546
+ config_path=args.config,
547
+ )
548
+
549
+ if "error" not in report:
550
+ filepath = save_analysis_report(report)
551
+ print(f"Saved to {filepath}", file=sys.stderr)
552
+
553
+ if args.json:
554
+ print(json.dumps(report, indent=2, default=str))
555
+ else:
556
+ print(format_analysis_report(report))
557
+
558
+ if report.get("alert_summary", {}).get("critical", 0) > 0:
559
+ sys.exit(1)
560
+ return
561
+
562
+ # Live monitoring mode
563
+ log_path = Path(args.log)
564
+ alert_config = load_alert_config(args.alerts_config)
565
+
566
+ print(f"Monitoring {log_path} (interval: {args.interval}s)", file=sys.stderr)
567
+ print("Press Ctrl+C to stop.", file=sys.stderr)
568
+ print(file=sys.stderr)
569
+
570
+ history: list[dict] = []
571
+ last_line_count = 0
572
+
573
+ try:
574
+ while True:
575
+ if not log_path.exists():
576
+ time.sleep(args.interval)
577
+ continue
578
+
579
+ with open(log_path) as f:
580
+ lines = f.readlines()
581
+
582
+ new_lines = lines[last_line_count:]
583
+ last_line_count = len(lines)
584
+
585
+ for line in new_lines:
586
+ parsed = parse_epoch_metrics(line)
587
+ if parsed is None:
588
+ continue
589
+
590
+ alerts = evaluate_alerts(parsed, history, alert_config)
591
+ rolling = compute_rolling_stats(history, "loss")
592
+
593
+ if not args.alerts or alerts:
594
+ dashboard = format_dashboard_line(parsed, rolling, alerts)
595
+ print(dashboard)
596
+
597
+ # Handle critical alerts with pause action
598
+ for alert in alerts:
599
+ if alert.get("action") == "pause":
600
+ print(f"\nCRITICAL: {alert['message']}", file=sys.stderr)
601
+ print("Training should be paused.", file=sys.stderr)
602
+
603
+ history.append(parsed)
604
+
605
+ time.sleep(args.interval)
606
+ except KeyboardInterrupt:
607
+ print(f"\nMonitoring stopped. {len(history)} epochs observed.", file=sys.stderr)
608
+
609
+
610
+ if __name__ == "__main__":
611
+ main()