claude-turing 3.0.0 → 3.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,503 @@
1
+ #!/usr/bin/env python3
2
+ """Pre-training sanity checks for the autoresearch pipeline.
3
+
4
+ Runs a battery of fast checks before committing to a full training run:
5
+ initial loss validation, single-batch overfit, gradient flow, output
6
+ validation, data pipeline check, and config consistency.
7
+
8
+ Usage:
9
+ python scripts/sanity_checks.py
10
+ python scripts/sanity_checks.py --quick
11
+ python scripts/sanity_checks.py --verbose
12
+ python scripts/sanity_checks.py --json
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import json
19
+ import math
20
+ import sys
21
+ from datetime import datetime, timezone
22
+ from pathlib import Path
23
+
24
+ import numpy as np
25
+ import yaml
26
+
27
+ from scripts.turing_io import load_config
28
+
29
+ DEFAULT_OVERFIT_STEPS = 50
30
+ DEFAULT_OVERFIT_THRESHOLD = 0.1 # Loss should drop below this fraction of initial
31
+
32
+
33
+ # --- Individual Checks ---
34
+
35
+
36
+ def check_initial_loss(
37
+ initial_loss: float,
38
+ num_classes: int | None = None,
39
+ task_type: str = "classification",
40
+ ) -> dict:
41
+ """Check if initial loss matches theoretical expectation.
42
+
43
+ For classification with cross-entropy: expected = -log(1/num_classes).
44
+ """
45
+ check = {
46
+ "check": "initial_loss",
47
+ "severity": "high",
48
+ "initial_loss": round(initial_loss, 4),
49
+ }
50
+
51
+ if math.isnan(initial_loss) or math.isinf(initial_loss):
52
+ check["status"] = "fail"
53
+ check["reason"] = f"Initial loss is {initial_loss} — model is broken before training starts"
54
+ return check
55
+
56
+ if task_type == "classification" and num_classes and num_classes > 1:
57
+ expected = -math.log(1.0 / num_classes)
58
+ ratio = initial_loss / expected if expected > 0 else float("inf")
59
+ check["expected_loss"] = round(expected, 4)
60
+ check["ratio"] = round(ratio, 2)
61
+
62
+ if ratio > 3.0:
63
+ check["status"] = "fail"
64
+ check["reason"] = f"Initial loss {initial_loss:.4f} is {ratio:.1f}x expected ({expected:.4f}) — likely misconfigured loss function"
65
+ elif ratio > 2.0:
66
+ check["status"] = "warn"
67
+ check["reason"] = f"Initial loss {initial_loss:.4f} is {ratio:.1f}x expected ({expected:.4f}) — investigate"
68
+ else:
69
+ check["status"] = "pass"
70
+ check["reason"] = f"Initial loss {initial_loss:.4f} matches expected {expected:.4f} (ratio: {ratio:.2f})"
71
+ else:
72
+ # For regression, just check it's finite and reasonable
73
+ if initial_loss < 0:
74
+ check["status"] = "warn"
75
+ check["reason"] = f"Negative initial loss ({initial_loss:.4f}) — unusual for most loss functions"
76
+ else:
77
+ check["status"] = "pass"
78
+ check["reason"] = f"Initial loss {initial_loss:.4f} is finite and non-negative"
79
+
80
+ return check
81
+
82
+
83
+ def check_single_batch_overfit(
84
+ loss_history: list[float],
85
+ threshold: float = DEFAULT_OVERFIT_THRESHOLD,
86
+ ) -> dict:
87
+ """Check if model can overfit a single batch.
88
+
89
+ If loss doesn't approach zero after N steps, something is broken.
90
+ """
91
+ check = {
92
+ "check": "single_batch_overfit",
93
+ "severity": "critical",
94
+ }
95
+
96
+ if not loss_history:
97
+ check["status"] = "skip"
98
+ check["reason"] = "No loss history provided"
99
+ return check
100
+
101
+ initial = loss_history[0]
102
+ final = loss_history[-1]
103
+ n_steps = len(loss_history)
104
+
105
+ if initial <= 0:
106
+ check["status"] = "warn"
107
+ check["reason"] = f"Initial loss is {initial:.4f} (non-positive), cannot assess overfit"
108
+ return check
109
+
110
+ reduction = 1 - (final / initial)
111
+
112
+ check["initial_loss"] = round(initial, 4)
113
+ check["final_loss"] = round(final, 4)
114
+ check["n_steps"] = n_steps
115
+ check["reduction"] = round(reduction, 4)
116
+
117
+ if any(math.isnan(l) for l in loss_history):
118
+ check["status"] = "fail"
119
+ check["reason"] = f"NaN in loss during overfit test — numerical instability"
120
+ return check
121
+
122
+ if reduction > 0.9:
123
+ check["status"] = "pass"
124
+ check["reason"] = f"Loss reduced by {reduction:.0%} in {n_steps} steps — model can memorize"
125
+ elif reduction > 0.5:
126
+ check["status"] = "warn"
127
+ check["reason"] = f"Loss reduced by only {reduction:.0%} — model is learning but slowly. Check learning rate."
128
+ else:
129
+ check["status"] = "fail"
130
+ check["reason"] = f"Loss stuck (reduced only {reduction:.0%} in {n_steps} steps) — model cannot memorize 1 batch. Check: architecture, learning rate, loss function"
131
+
132
+ return check
133
+
134
+
135
+ def check_gradient_flow(
136
+ gradient_stats: list[dict],
137
+ ) -> dict:
138
+ """Check that gradients are non-zero and non-exploding for every parameter.
139
+
140
+ Args:
141
+ gradient_stats: List of {name, mean, max, min, std} per parameter group.
142
+ """
143
+ check = {
144
+ "check": "gradient_flow",
145
+ "severity": "high",
146
+ }
147
+
148
+ if not gradient_stats:
149
+ check["status"] = "skip"
150
+ check["reason"] = "No gradient statistics provided"
151
+ return check
152
+
153
+ dead_layers = []
154
+ exploding_layers = []
155
+ total = len(gradient_stats)
156
+
157
+ mean_grad = np.mean([abs(g.get("mean", 0)) for g in gradient_stats])
158
+
159
+ for g in gradient_stats:
160
+ name = g.get("name", "?")
161
+ grad_mean = abs(g.get("mean", 0))
162
+ grad_max = abs(g.get("max", 0))
163
+
164
+ if grad_mean == 0 and grad_max == 0:
165
+ dead_layers.append(name)
166
+ elif mean_grad > 0 and grad_max > 100 * mean_grad:
167
+ exploding_layers.append(name)
168
+
169
+ check["total_params"] = total
170
+ check["dead_layers"] = dead_layers
171
+ check["exploding_layers"] = exploding_layers
172
+
173
+ if dead_layers and exploding_layers:
174
+ check["status"] = "fail"
175
+ check["reason"] = f"{len(dead_layers)} dead layer(s) and {len(exploding_layers)} exploding layer(s)"
176
+ elif dead_layers:
177
+ check["status"] = "warn"
178
+ check["reason"] = f"{len(dead_layers)} dead layer(s) with zero gradients: {', '.join(dead_layers[:3])}"
179
+ elif exploding_layers:
180
+ check["status"] = "warn"
181
+ check["reason"] = f"{len(exploding_layers)} layer(s) with exploding gradients: {', '.join(exploding_layers[:3])}"
182
+ else:
183
+ check["status"] = "pass"
184
+ check["reason"] = f"All {total} parameter groups have non-zero, stable gradients"
185
+
186
+ return check
187
+
188
+
189
+ def check_output_validation(
190
+ outputs: np.ndarray | list,
191
+ task_type: str = "classification",
192
+ ) -> dict:
193
+ """Check that model outputs are valid (non-NaN, non-constant, reasonable range)."""
194
+ check = {
195
+ "check": "output_validation",
196
+ "severity": "high",
197
+ }
198
+
199
+ arr = np.asarray(outputs, dtype=float)
200
+
201
+ if arr.size == 0:
202
+ check["status"] = "skip"
203
+ check["reason"] = "No outputs to validate"
204
+ return check
205
+
206
+ has_nan = bool(np.any(np.isnan(arr)))
207
+ has_inf = bool(np.any(np.isinf(arr)))
208
+ is_constant = bool(np.std(arr) == 0)
209
+ out_min = float(np.nanmin(arr))
210
+ out_max = float(np.nanmax(arr))
211
+
212
+ check["range"] = [round(out_min, 4), round(out_max, 4)]
213
+ check["has_nan"] = has_nan
214
+ check["has_inf"] = has_inf
215
+ check["is_constant"] = is_constant
216
+
217
+ issues = []
218
+ if has_nan:
219
+ issues.append("NaN values in outputs")
220
+ if has_inf:
221
+ issues.append("Inf values in outputs")
222
+ if is_constant:
223
+ issues.append("All outputs identical (constant predictions)")
224
+ if abs(out_max) > 100 and task_type == "classification":
225
+ issues.append(f"Extreme output range [{out_min:.1f}, {out_max:.1f}] — consider clamping")
226
+
227
+ if has_nan or has_inf:
228
+ check["status"] = "fail"
229
+ check["reason"] = "; ".join(issues)
230
+ elif is_constant:
231
+ check["status"] = "fail"
232
+ check["reason"] = "Constant predictions — model is not differentiating inputs"
233
+ elif issues:
234
+ check["status"] = "warn"
235
+ check["reason"] = "; ".join(issues)
236
+ else:
237
+ check["status"] = "pass"
238
+ check["reason"] = f"Outputs valid: range [{out_min:.4f}, {out_max:.4f}], no NaN/Inf"
239
+
240
+ return check
241
+
242
+
243
+ def check_data_pipeline(
244
+ batch_shapes: dict | None = None,
245
+ has_nan: bool = False,
246
+ has_inf: bool = False,
247
+ loads_ok: bool = True,
248
+ ) -> dict:
249
+ """Check that the data pipeline produces valid batches."""
250
+ check = {
251
+ "check": "data_pipeline",
252
+ "severity": "critical",
253
+ }
254
+
255
+ if not loads_ok:
256
+ check["status"] = "fail"
257
+ check["reason"] = "Data pipeline failed to load first batch"
258
+ return check
259
+
260
+ issues = []
261
+ if has_nan:
262
+ issues.append("NaN values in input data")
263
+ if has_inf:
264
+ issues.append("Inf values in input data")
265
+
266
+ if batch_shapes:
267
+ check["shapes"] = batch_shapes
268
+
269
+ if issues:
270
+ check["status"] = "fail"
271
+ check["reason"] = "; ".join(issues)
272
+ elif batch_shapes:
273
+ shapes_str = ", ".join(f"{k}: {v}" for k, v in batch_shapes.items())
274
+ check["status"] = "pass"
275
+ check["reason"] = f"Batch loads, shapes correct ({shapes_str})"
276
+ else:
277
+ check["status"] = "pass"
278
+ check["reason"] = "Data pipeline functional"
279
+
280
+ return check
281
+
282
+
283
+ def check_config_consistency(config: dict) -> dict:
284
+ """Check that config values are in reasonable ranges."""
285
+ check = {
286
+ "check": "config_consistency",
287
+ "severity": "medium",
288
+ }
289
+
290
+ issues = []
291
+ hyperparams = config.get("model", {}).get("hyperparams", {})
292
+
293
+ lr = hyperparams.get("learning_rate", hyperparams.get("lr"))
294
+ if lr is not None:
295
+ if lr > 1.0:
296
+ issues.append(f"Learning rate {lr} > 1.0 — unusually high")
297
+ elif lr < 1e-8:
298
+ issues.append(f"Learning rate {lr} < 1e-8 — effectively zero")
299
+
300
+ batch_size = hyperparams.get("batch_size")
301
+ if batch_size is not None:
302
+ if batch_size < 1:
303
+ issues.append(f"Batch size {batch_size} < 1 — invalid")
304
+ elif batch_size > 100000:
305
+ issues.append(f"Batch size {batch_size} > 100K — unusually large")
306
+
307
+ n_estimators = hyperparams.get("n_estimators")
308
+ if n_estimators is not None and n_estimators < 1:
309
+ issues.append(f"n_estimators {n_estimators} < 1 — invalid")
310
+
311
+ if issues:
312
+ check["status"] = "warn"
313
+ check["reason"] = "; ".join(issues)
314
+ check["issues"] = issues
315
+ else:
316
+ check["status"] = "pass"
317
+ check["reason"] = "Config values in reasonable ranges"
318
+
319
+ return check
320
+
321
+
322
+ # --- Full Sanity Check ---
323
+
324
+
325
+ def run_sanity_checks(
326
+ config_path: str = "config.yaml",
327
+ quick: bool = False,
328
+ initial_loss: float | None = None,
329
+ loss_history: list[float] | None = None,
330
+ gradient_stats: list[dict] | None = None,
331
+ outputs: list | None = None,
332
+ batch_shapes: dict | None = None,
333
+ data_has_nan: bool = False,
334
+ data_has_inf: bool = False,
335
+ data_loads_ok: bool = True,
336
+ num_classes: int | None = None,
337
+ ) -> dict:
338
+ """Run all sanity checks and produce a report.
339
+
340
+ In CLI mode, most inputs come from running a quick training probe.
341
+ In test mode, values are provided directly.
342
+ """
343
+ config = load_config(config_path)
344
+ task_type = config.get("task", {}).get("type", "classification")
345
+
346
+ checks = []
347
+
348
+ # Data pipeline
349
+ checks.append(check_data_pipeline(batch_shapes, data_has_nan, data_has_inf, data_loads_ok))
350
+
351
+ # Initial loss
352
+ if initial_loss is not None:
353
+ checks.append(check_initial_loss(initial_loss, num_classes, task_type))
354
+
355
+ # Gradient flow
356
+ if gradient_stats is not None:
357
+ checks.append(check_gradient_flow(gradient_stats))
358
+
359
+ # Single-batch overfit (skip in quick mode)
360
+ if not quick and loss_history is not None:
361
+ checks.append(check_single_batch_overfit(loss_history))
362
+
363
+ # Output validation
364
+ if outputs is not None:
365
+ checks.append(check_output_validation(outputs, task_type))
366
+
367
+ # Config consistency
368
+ checks.append(check_config_consistency(config))
369
+
370
+ # Compute verdict
371
+ n_pass = sum(1 for c in checks if c["status"] == "pass")
372
+ n_fail = sum(1 for c in checks if c["status"] == "fail")
373
+ n_warn = sum(1 for c in checks if c["status"] == "warn")
374
+ n_skip = sum(1 for c in checks if c["status"] == "skip")
375
+
376
+ if n_fail > 0:
377
+ verdict = "fail"
378
+ elif n_warn > 2:
379
+ verdict = "warn"
380
+ elif n_warn > 0:
381
+ verdict = "pass_with_warnings"
382
+ else:
383
+ verdict = "pass"
384
+
385
+ return {
386
+ "checked_at": datetime.now(timezone.utc).isoformat(),
387
+ "quick_mode": quick,
388
+ "task_type": task_type,
389
+ "checks": checks,
390
+ "score": {
391
+ "pass": n_pass,
392
+ "fail": n_fail,
393
+ "warn": n_warn,
394
+ "skip": n_skip,
395
+ "total": len(checks),
396
+ },
397
+ "verdict": verdict,
398
+ }
399
+
400
+
401
+ # --- Report Formatting ---
402
+
403
+
404
+ def save_sanity_report(report: dict, output_dir: str = "experiments/sanity") -> Path:
405
+ """Save sanity report to YAML."""
406
+ out_path = Path(output_dir)
407
+ out_path.mkdir(parents=True, exist_ok=True)
408
+
409
+ timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
410
+ filepath = out_path / f"sanity-{timestamp}.yaml"
411
+
412
+ with open(filepath, "w") as f:
413
+ yaml.dump(report, f, default_flow_style=False, sort_keys=False)
414
+
415
+ return filepath
416
+
417
+
418
+ def format_sanity_report(report: dict) -> str:
419
+ """Format sanity report as markdown."""
420
+ if "error" in report:
421
+ return f"ERROR: {report['error']}"
422
+
423
+ verdict = report.get("verdict", "?")
424
+ score = report.get("score", {})
425
+ quick = report.get("quick_mode", False)
426
+
427
+ verdict_labels = {
428
+ "pass": "PASS — Safe to proceed with training",
429
+ "pass_with_warnings": "PASS (with warnings) — Review before training",
430
+ "warn": "WARNINGS — Multiple issues detected",
431
+ "fail": "FAIL — Do not proceed to full training",
432
+ }
433
+
434
+ lines = [
435
+ "# Sanity Check Report",
436
+ "",
437
+ f"*Checked {report.get('checked_at', 'N/A')[:19]}*",
438
+ f"*Mode: {'quick' if quick else 'full'}*",
439
+ "",
440
+ f"**{verdict_labels.get(verdict, verdict.upper())}**",
441
+ "",
442
+ ]
443
+
444
+ status_markers = {"pass": "PASS", "fail": "FAIL", "warn": "WARN", "skip": "SKIP"}
445
+
446
+ for c in report.get("checks", []):
447
+ status = c.get("status", "?")
448
+ marker = status_markers.get(status, status.upper())
449
+ lines.append(f"- **[{marker}]** {c.get('check', '?')}: {c.get('reason', 'N/A')}")
450
+
451
+ lines.extend([
452
+ "",
453
+ f"**Score:** {score.get('pass', 0)}/{score.get('total', 0)} pass, "
454
+ f"{score.get('fail', 0)} fail, {score.get('warn', 0)} warn",
455
+ ])
456
+
457
+ return "\n".join(lines)
458
+
459
+
460
+ def main() -> None:
461
+ """CLI entry point."""
462
+ parser = argparse.ArgumentParser(
463
+ description="Pre-training sanity checks",
464
+ )
465
+ parser.add_argument(
466
+ "--config", default="config.yaml",
467
+ help="Path to config.yaml",
468
+ )
469
+ parser.add_argument(
470
+ "--quick", action="store_true",
471
+ help="Quick mode: skip single-batch overfit test",
472
+ )
473
+ parser.add_argument(
474
+ "--verbose", action="store_true",
475
+ help="Show detailed check output",
476
+ )
477
+ parser.add_argument(
478
+ "--json", action="store_true",
479
+ help="Output raw JSON instead of formatted report",
480
+ )
481
+ args = parser.parse_args()
482
+
483
+ # In CLI mode, we'd run actual probes. For now, report with config check only.
484
+ report = run_sanity_checks(
485
+ config_path=args.config,
486
+ quick=args.quick,
487
+ )
488
+
489
+ if "error" not in report:
490
+ filepath = save_sanity_report(report)
491
+ print(f"Saved to {filepath}", file=sys.stderr)
492
+
493
+ if args.json:
494
+ print(json.dumps(report, indent=2, default=str))
495
+ else:
496
+ print(format_sanity_report(report))
497
+
498
+ if report.get("verdict") == "fail":
499
+ sys.exit(1)
500
+
501
+
502
+ if __name__ == "__main__":
503
+ main()
@@ -118,6 +118,9 @@ TEMPLATE_DIRS = {
118
118
  "model_distiller.py",
119
119
  "knowledge_transfer.py",
120
120
  "methodology_audit.py",
121
+ "sanity_checks.py",
122
+ "generate_baselines.py",
123
+ "leakage_detector.py",
121
124
  ],
122
125
  "tests": ["__init__.py", "conftest.py"],
123
126
  }
@@ -148,6 +151,9 @@ DIRECTORIES_TO_CREATE = [
148
151
  "experiments/distillations",
149
152
  "experiments/transfers",
150
153
  "experiments/audits",
154
+ "experiments/sanity",
155
+ "experiments/baselines",
156
+ "experiments/leakage",
151
157
  "experiments/logs",
152
158
  "models/best",
153
159
  "models/archive",