claude-turing 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. package/.claude-plugin/plugin.json +34 -0
  2. package/LICENSE +21 -0
  3. package/README.md +457 -0
  4. package/agents/ml-evaluator.md +43 -0
  5. package/agents/ml-researcher.md +74 -0
  6. package/bin/cli.js +46 -0
  7. package/bin/turing-init.sh +57 -0
  8. package/commands/brief.md +83 -0
  9. package/commands/compare.md +24 -0
  10. package/commands/design.md +97 -0
  11. package/commands/init.md +123 -0
  12. package/commands/logbook.md +51 -0
  13. package/commands/mode.md +43 -0
  14. package/commands/poster.md +89 -0
  15. package/commands/preflight.md +75 -0
  16. package/commands/report.md +97 -0
  17. package/commands/rules/loop-protocol.md +91 -0
  18. package/commands/status.md +24 -0
  19. package/commands/suggest.md +95 -0
  20. package/commands/sweep.md +45 -0
  21. package/commands/train.md +66 -0
  22. package/commands/try.md +63 -0
  23. package/commands/turing.md +54 -0
  24. package/commands/validate.md +34 -0
  25. package/config/defaults.yaml +45 -0
  26. package/config/experiment_archetypes.yaml +127 -0
  27. package/config/lifecycle.toml +31 -0
  28. package/config/novelty_aliases.yaml +107 -0
  29. package/config/relationships.toml +125 -0
  30. package/config/state.toml +24 -0
  31. package/config/task_taxonomy.yaml +110 -0
  32. package/config/taxonomy.toml +37 -0
  33. package/package.json +54 -0
  34. package/src/claude-md.js +55 -0
  35. package/src/install.js +107 -0
  36. package/src/paths.js +20 -0
  37. package/src/postinstall.js +22 -0
  38. package/src/verify.js +109 -0
  39. package/templates/MEMORY.md +36 -0
  40. package/templates/README.md +93 -0
  41. package/templates/__pycache__/evaluate.cpython-314.pyc +0 -0
  42. package/templates/__pycache__/prepare.cpython-314.pyc +0 -0
  43. package/templates/config.yaml +48 -0
  44. package/templates/evaluate.py +237 -0
  45. package/templates/features/__init__.py +0 -0
  46. package/templates/features/__pycache__/__init__.cpython-314.pyc +0 -0
  47. package/templates/features/__pycache__/featurizers.cpython-314.pyc +0 -0
  48. package/templates/features/featurizers.py +138 -0
  49. package/templates/prepare.py +171 -0
  50. package/templates/program.md +216 -0
  51. package/templates/pyproject.toml +8 -0
  52. package/templates/requirements.txt +8 -0
  53. package/templates/scripts/__init__.py +0 -0
  54. package/templates/scripts/__pycache__/__init__.cpython-314.pyc +0 -0
  55. package/templates/scripts/__pycache__/check_convergence.cpython-314.pyc +0 -0
  56. package/templates/scripts/__pycache__/classify_task.cpython-314.pyc +0 -0
  57. package/templates/scripts/__pycache__/critique_hypothesis.cpython-314.pyc +0 -0
  58. package/templates/scripts/__pycache__/experiment_index.cpython-314.pyc +0 -0
  59. package/templates/scripts/__pycache__/generate_brief.cpython-314.pyc +0 -0
  60. package/templates/scripts/__pycache__/generate_logbook.cpython-314.pyc +0 -0
  61. package/templates/scripts/__pycache__/log_experiment.cpython-314.pyc +0 -0
  62. package/templates/scripts/__pycache__/manage_hypotheses.cpython-314.pyc +0 -0
  63. package/templates/scripts/__pycache__/novelty_guard.cpython-314.pyc +0 -0
  64. package/templates/scripts/__pycache__/parse_metrics.cpython-314.pyc +0 -0
  65. package/templates/scripts/__pycache__/scaffold.cpython-314.pyc +0 -0
  66. package/templates/scripts/__pycache__/show_experiment_tree.cpython-314.pyc +0 -0
  67. package/templates/scripts/__pycache__/show_families.cpython-314.pyc +0 -0
  68. package/templates/scripts/__pycache__/statistical_compare.cpython-314.pyc +0 -0
  69. package/templates/scripts/__pycache__/suggest_next.cpython-314.pyc +0 -0
  70. package/templates/scripts/__pycache__/sweep.cpython-314.pyc +0 -0
  71. package/templates/scripts/__pycache__/synthesize_decision.cpython-314.pyc +0 -0
  72. package/templates/scripts/__pycache__/turing_io.cpython-314.pyc +0 -0
  73. package/templates/scripts/__pycache__/update_state.cpython-314.pyc +0 -0
  74. package/templates/scripts/__pycache__/verify_placeholders.cpython-314.pyc +0 -0
  75. package/templates/scripts/check_convergence.py +230 -0
  76. package/templates/scripts/compare_runs.py +124 -0
  77. package/templates/scripts/critique_hypothesis.py +350 -0
  78. package/templates/scripts/experiment_index.py +288 -0
  79. package/templates/scripts/generate_brief.py +389 -0
  80. package/templates/scripts/generate_logbook.py +423 -0
  81. package/templates/scripts/log_experiment.py +243 -0
  82. package/templates/scripts/manage_hypotheses.py +543 -0
  83. package/templates/scripts/novelty_guard.py +343 -0
  84. package/templates/scripts/parse_metrics.py +139 -0
  85. package/templates/scripts/post-train-hook.sh +74 -0
  86. package/templates/scripts/preflight.py +549 -0
  87. package/templates/scripts/scaffold.py +409 -0
  88. package/templates/scripts/show_environment.py +92 -0
  89. package/templates/scripts/show_experiment_tree.py +144 -0
  90. package/templates/scripts/show_families.py +133 -0
  91. package/templates/scripts/show_metrics.py +157 -0
  92. package/templates/scripts/statistical_compare.py +259 -0
  93. package/templates/scripts/stop-hook.sh +34 -0
  94. package/templates/scripts/suggest_next.py +301 -0
  95. package/templates/scripts/sweep.py +276 -0
  96. package/templates/scripts/synthesize_decision.py +300 -0
  97. package/templates/scripts/turing_io.py +76 -0
  98. package/templates/scripts/update_state.py +296 -0
  99. package/templates/scripts/validate_stability.py +167 -0
  100. package/templates/scripts/verify_placeholders.py +119 -0
  101. package/templates/sweep_config.yaml +14 -0
  102. package/templates/tests/__init__.py +0 -0
  103. package/templates/tests/conftest.py +91 -0
  104. package/templates/train.py +240 -0
@@ -0,0 +1,167 @@
1
+ """Stability validation for experiment metrics.
2
+
3
+ Runs the training pipeline N times and computes coefficient of variation.
4
+ If CV > 5%, the metric is too noisy for reliable single-run comparison.
5
+
6
+ Usage:
7
+ python scripts/validate_stability.py # Check stability (5 runs)
8
+ python scripts/validate_stability.py --runs 10 # Custom run count
9
+ python scripts/validate_stability.py --auto # Auto-fix if unstable
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import json
16
+ import math
17
+ import re
18
+ import statistics
19
+ import subprocess
20
+ import sys
21
+ from pathlib import Path
22
+
23
+ import yaml
24
+
25
+ from scripts.turing_io import load_config
26
+
27
+ DEFAULT_RUNS = 5
28
+ CV_THRESHOLD = 5.0 # Percent
29
+
30
+
31
+ def run_experiment(seed: int) -> float | None:
32
+ """Run a single training pass and extract the primary metric."""
33
+ config = load_config()
34
+ eval_cfg = config.get("evaluation", {})
35
+ primary_metric = eval_cfg.get("primary_metric", "accuracy")
36
+
37
+ result = subprocess.run(
38
+ ["python", "train.py", "--seed", str(seed)],
39
+ capture_output=True,
40
+ text=True,
41
+ timeout=600,
42
+ )
43
+
44
+ # Parse metric from stdout (between --- delimiters)
45
+ output = result.stdout
46
+ in_block = False
47
+ for line in output.splitlines():
48
+ if line.strip() == "---":
49
+ in_block = not in_block
50
+ continue
51
+ if in_block:
52
+ match = re.match(rf"^{re.escape(primary_metric)}:\s*(.+)$", line.strip())
53
+ if match:
54
+ try:
55
+ return float(match.group(1))
56
+ except ValueError:
57
+ pass
58
+ return None
59
+
60
+
61
+ def check_stability(n_runs: int = DEFAULT_RUNS) -> dict:
62
+ """Run N experiments and compute stability metrics."""
63
+ config = load_config()
64
+ eval_cfg = config.get("evaluation", {})
65
+ primary_metric = eval_cfg.get("primary_metric", "accuracy")
66
+
67
+ print(f"Running {n_runs} stability validation passes...")
68
+ print(f"Primary metric: {primary_metric}")
69
+ print()
70
+
71
+ results = []
72
+ for i in range(n_runs):
73
+ seed = 42 + i # Deterministic but different seeds
74
+ print(f" Run {i + 1}/{n_runs} (seed={seed})...", end=" ", flush=True)
75
+ value = run_experiment(seed)
76
+ if value is not None:
77
+ results.append(value)
78
+ print(f"{primary_metric}={value:.4f}")
79
+ else:
80
+ print("FAILED (no metric)")
81
+
82
+ if len(results) < 2:
83
+ return {
84
+ "stable": False,
85
+ "reason": f"Only {len(results)} successful runs out of {n_runs}",
86
+ "results": results,
87
+ "recommendation": "Fix pipeline errors before validating stability",
88
+ }
89
+
90
+ mean = statistics.mean(results)
91
+ stdev = statistics.stdev(results)
92
+ cv = (stdev / abs(mean) * 100) if mean != 0 else float("inf")
93
+
94
+ stable = cv < CV_THRESHOLD
95
+
96
+ return {
97
+ "stable": stable,
98
+ "n_runs": len(results),
99
+ "mean": round(mean, 4),
100
+ "stdev": round(stdev, 4),
101
+ "cv_percent": round(cv, 2),
102
+ "min": round(min(results), 4),
103
+ "max": round(max(results), 4),
104
+ "results": results,
105
+ "recommendation": (
106
+ "Metric is stable — single-run evaluation is reliable."
107
+ if stable
108
+ else f"CV={cv:.1f}% exceeds {CV_THRESHOLD}% threshold. "
109
+ f"Recommend setting evaluation.n_runs: 3 in config.yaml "
110
+ f"to use median of 3 runs per experiment."
111
+ ),
112
+ }
113
+
114
+
115
+ def auto_fix_config() -> bool:
116
+ """Set evaluation.n_runs: 3 in config.yaml if not already set."""
117
+ config_path = Path("config.yaml")
118
+ if not config_path.exists():
119
+ return False
120
+
121
+ with open(config_path) as f:
122
+ config = yaml.safe_load(f) or {}
123
+
124
+ eval_cfg = config.setdefault("evaluation", {})
125
+ if eval_cfg.get("n_runs", 1) >= 3:
126
+ return False # Already configured
127
+
128
+ eval_cfg["n_runs"] = 3
129
+ with open(config_path, "w") as f:
130
+ yaml.dump(config, f, default_flow_style=False, sort_keys=False)
131
+
132
+ return True
133
+
134
+
135
+ def main() -> None:
136
+ parser = argparse.ArgumentParser(description="Validate experiment metric stability")
137
+ parser.add_argument("--runs", type=int, default=DEFAULT_RUNS, help=f"Number of runs (default: {DEFAULT_RUNS})")
138
+ parser.add_argument("--auto", action="store_true", help="Auto-configure multi-run if unstable")
139
+ args = parser.parse_args()
140
+
141
+ result = check_stability(args.runs)
142
+
143
+ print()
144
+ print("=" * 50)
145
+ print(f" Stability: {'STABLE' if result['stable'] else 'UNSTABLE'}")
146
+ if "mean" in result:
147
+ print(f" Mean: {result['mean']}")
148
+ print(f" Stdev: {result['stdev']}")
149
+ print(f" CV: {result['cv_percent']}%")
150
+ print(f" Range: [{result['min']}, {result['max']}]")
151
+ print(f" Runs: {result.get('n_runs', 0)}")
152
+ print()
153
+ print(f" {result['recommendation']}")
154
+ print("=" * 50)
155
+
156
+ if args.auto and not result["stable"]:
157
+ if auto_fix_config():
158
+ print()
159
+ print(" AUTO-FIX: Set evaluation.n_runs: 3 in config.yaml")
160
+ print(" Each experiment will now run 3 times and report the median.")
161
+ else:
162
+ print()
163
+ print(" evaluation.n_runs already >= 3, no changes made.")
164
+
165
+
166
+ if __name__ == "__main__":
167
+ main()
@@ -0,0 +1,119 @@
1
+ #!/usr/bin/env python3
2
+ """Post-scaffolding placeholder verification.
3
+
4
+ Scans all scaffolded files for unreplaced {{PLACEHOLDER}} markers and
5
+ reports any that remain. Returns exit code 1 if any are found.
6
+
7
+ This converts silent-wrong-behavior failures into loud-and-immediate errors.
8
+ A missed placeholder like {{TARGET_METRIC}} produces valid Python that silently
9
+ does the wrong thing — this script catches it before the first experiment runs.
10
+
11
+ Usage:
12
+ python scripts/verify_placeholders.py [directory]
13
+
14
+ directory: path to scan (default: current directory)
15
+
16
+ Exit codes:
17
+ 0 = all placeholders replaced
18
+ 1 = unreplaced placeholders found
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import re
24
+ import sys
25
+ from pathlib import Path
26
+
27
+ # Files to scan for placeholders (relative to project root)
28
+ SCANNABLE_EXTENSIONS = {".py", ".yaml", ".yml", ".md", ".toml", ".sh", ".txt"}
29
+
30
+ # Known placeholder pattern
31
+ PLACEHOLDER_RE = re.compile(r"\{\{([A-Z_]+)\}\}")
32
+
33
+ # Known valid placeholders (from config/defaults.yaml)
34
+ KNOWN_PLACEHOLDERS = {
35
+ "PROJECT_NAME",
36
+ "TARGET_METRIC",
37
+ "TASK_DESCRIPTION",
38
+ "ML_DIR",
39
+ "DATA_SOURCE",
40
+ "METRIC_DIRECTION",
41
+ }
42
+
43
+
44
+ def scan_file(path: Path) -> list[tuple[int, str, str]]:
45
+ """Scan a single file for unreplaced placeholders.
46
+
47
+ Returns list of (line_number, placeholder_name, line_content) tuples.
48
+ """
49
+ findings: list[tuple[int, str, str]] = []
50
+ try:
51
+ text = path.read_text(encoding="utf-8")
52
+ except (UnicodeDecodeError, PermissionError):
53
+ return findings
54
+
55
+ for i, line in enumerate(text.splitlines(), start=1):
56
+ for match in PLACEHOLDER_RE.finditer(line):
57
+ placeholder = match.group(1)
58
+ if placeholder in KNOWN_PLACEHOLDERS:
59
+ findings.append((i, placeholder, line.strip()))
60
+
61
+ return findings
62
+
63
+
64
+ def scan_directory(root: Path) -> dict[Path, list[tuple[int, str, str]]]:
65
+ """Scan all files in directory for unreplaced placeholders.
66
+
67
+ Returns dict mapping file paths to their findings.
68
+ """
69
+ results: dict[Path, list[tuple[int, str, str]]] = {}
70
+
71
+ for path in sorted(root.rglob("*")):
72
+ if not path.is_file():
73
+ continue
74
+ if path.suffix not in SCANNABLE_EXTENSIONS:
75
+ continue
76
+ # Skip venv, node_modules, git
77
+ parts = path.parts
78
+ if any(p in (".venv", "node_modules", ".git", "__pycache__") for p in parts):
79
+ continue
80
+
81
+ findings = scan_file(path)
82
+ if findings:
83
+ results[path] = findings
84
+
85
+ return results
86
+
87
+
88
+ def main() -> None:
89
+ """CLI entry point."""
90
+ root = Path(sys.argv[1]) if len(sys.argv) > 1 else Path(".")
91
+
92
+ if not root.is_dir():
93
+ print(f"Error: {root} is not a directory", file=sys.stderr)
94
+ sys.exit(1)
95
+
96
+ results = scan_directory(root)
97
+
98
+ if not results:
99
+ print("All placeholders replaced successfully.")
100
+ sys.exit(0)
101
+
102
+ # Report findings
103
+ total = sum(len(findings) for findings in results.values())
104
+ print(f"Found {total} unreplaced placeholder(s) in {len(results)} file(s):\n")
105
+
106
+ for path, findings in results.items():
107
+ rel_path = path.relative_to(root) if path.is_relative_to(root) else path
108
+ print(f" {rel_path}:")
109
+ for line_num, placeholder, line_content in findings:
110
+ print(f" line {line_num}: {{{{{placeholder}}}}} — {line_content[:80]}")
111
+ print()
112
+
113
+ print("Fix: replace all {{PLACEHOLDER}} markers with project-specific values.")
114
+ print("See config.yaml and program.md for placeholder descriptions.")
115
+ sys.exit(1)
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()
@@ -0,0 +1,14 @@
1
+ # Hyperparameter sweep configuration for {{PROJECT_NAME}}.
2
+ # Each key under 'sweep' maps to a config.yaml dotted path.
3
+ # Values are lists -- sweep.py generates the cartesian product.
4
+
5
+ sweep:
6
+ model.hyperparams.n_estimators: [50, 100, 200]
7
+ model.hyperparams.max_depth: [3, 4, 6, 8]
8
+ model.hyperparams.learning_rate: [0.01, 0.05, 0.1]
9
+
10
+ # Base config to overlay sweep values onto
11
+ base_config: "config.yaml"
12
+
13
+ # Output queue file
14
+ output: "experiments/queue.yaml"
File without changes
@@ -0,0 +1,91 @@
1
+ """Shared test fixtures for the {{PROJECT_NAME}} ML pipeline tests.
2
+
3
+ These fixtures provide deterministic sample data for testing the
4
+ pipeline components. Customize with records matching your actual
5
+ data schema — the autoresearch agent uses these when running tests.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ from pathlib import Path
12
+
13
+ import pytest
14
+
15
+
16
+ @pytest.fixture
17
+ def sample_training_data() -> list[dict]:
18
+ """Sample training records for testing.
19
+
20
+ Replace these with records matching your data schema.
21
+ Include a representative mix of all label/target values.
22
+ """
23
+ # TODO: Replace with your actual data schema
24
+ return [
25
+ {"feature_1": 1.0, "feature_2": "category_a", "label": "positive"},
26
+ {"feature_1": 2.0, "feature_2": "category_b", "label": "positive"},
27
+ {"feature_1": 3.0, "feature_2": "category_a", "label": "negative"},
28
+ {"feature_1": 4.0, "feature_2": "category_b", "label": "negative"},
29
+ {"feature_1": 5.0, "feature_2": "category_a", "label": "positive"},
30
+ {"feature_1": 6.0, "feature_2": "category_b", "label": "negative"},
31
+ {"feature_1": 7.0, "feature_2": "category_a", "label": "positive"},
32
+ {"feature_1": 8.0, "feature_2": "category_b", "label": "negative"},
33
+ {"feature_1": 9.0, "feature_2": "category_a", "label": "positive"},
34
+ {"feature_1": 10.0, "feature_2": "category_b", "label": "negative"},
35
+ ]
36
+
37
+
38
+ @pytest.fixture
39
+ def sample_jsonl_file(tmp_path: Path, sample_training_data: list[dict]) -> Path:
40
+ """Write sample training data to a temporary JSONL file."""
41
+ path = tmp_path / "test_data.jsonl"
42
+ with open(path, "w") as f:
43
+ for record in sample_training_data:
44
+ f.write(json.dumps(record) + "\n")
45
+ return path
46
+
47
+
48
+ @pytest.fixture
49
+ def sample_config(tmp_path: Path) -> Path:
50
+ """Write a minimal config.yaml for testing."""
51
+ import yaml
52
+
53
+ config = {
54
+ "data": {
55
+ "source": str(tmp_path / "test_data.jsonl"),
56
+ "splits_dir": str(tmp_path / "splits"),
57
+ "target_column": "label",
58
+ "split_ratios": {"train": 0.70, "val": 0.15, "test": 0.15},
59
+ "random_state": 42,
60
+ },
61
+ "evaluation": {
62
+ "primary_metric": "accuracy",
63
+ "metrics": ["accuracy", "f1_weighted"],
64
+ "lower_is_better": False,
65
+ },
66
+ "convergence": {
67
+ "patience": 3,
68
+ "improvement_threshold": 0.005,
69
+ },
70
+ "model": {
71
+ "type": "xgboost",
72
+ "hyperparams": {
73
+ "n_estimators": 10,
74
+ "max_depth": 3,
75
+ "verbosity": 0,
76
+ },
77
+ },
78
+ "output": {
79
+ "models_dir": str(tmp_path / "models"),
80
+ "best_model_dir": str(tmp_path / "models" / "best"),
81
+ "archive_dir": str(tmp_path / "models" / "archive"),
82
+ "experiment_log": str(tmp_path / "experiments" / "log.jsonl"),
83
+ "results_tsv": str(tmp_path / "experiments" / "results.tsv"),
84
+ },
85
+ }
86
+
87
+ config_path = tmp_path / "config.yaml"
88
+ with open(config_path, "w") as f:
89
+ yaml.dump(config, f)
90
+
91
+ return config_path
@@ -0,0 +1,240 @@
1
+ # AGENT-EDITABLE — HYPOTHESIS SPACE.
2
+ #
3
+ # This is the ONLY code file the autoresearch agent should modify.
4
+ # Training code for {{PROJECT_NAME}}.
5
+ # Default implementation: XGBoost classifier.
6
+ #
7
+ # The agent iteratively modifies this file to test hypotheses about
8
+ # model architecture, hyperparameters, and feature usage. The evaluation
9
+ # harness (evaluate.py) remains immutable — ensuring all experiments
10
+ # are measured by the same yardstick.
11
+ #
12
+ # See program.md for the full experiment loop protocol.
13
+
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ import hashlib
18
+ import json
19
+ import os
20
+ import platform
21
+ import random
22
+ import sys
23
+ import time
24
+ from pathlib import Path
25
+
26
+ import joblib
27
+ import numpy as np
28
+ import pandas as pd
29
+ from xgboost import XGBClassifier
30
+
31
+ from evaluate import evaluate_model, format_metrics
32
+ from features.featurizers import get_default_featurizer
33
+ from prepare import load_config, load_splits
34
+
35
+
36
+ def capture_environment(seed: int) -> dict:
37
+ """Capture the full runtime environment for reproducibility.
38
+
39
+ Records everything needed to reproduce this exact experiment:
40
+ python version, package versions, hardware, OS, and all random seeds.
41
+ """
42
+ env = {
43
+ "python_version": platform.python_version(),
44
+ "platform": platform.platform(),
45
+ "machine": platform.machine(),
46
+ "os": platform.system(),
47
+ "seeds": {
48
+ "random_state": seed,
49
+ "PYTHONHASHSEED": os.environ.get("PYTHONHASHSEED", "not set"),
50
+ },
51
+ "packages": {},
52
+ }
53
+
54
+ # Key package versions
55
+ for pkg_name in ["xgboost", "lightgbm", "sklearn", "numpy", "pandas", "joblib", "scipy"]:
56
+ try:
57
+ mod = __import__(pkg_name)
58
+ env["packages"][pkg_name] = getattr(mod, "__version__", "unknown")
59
+ except ImportError:
60
+ pass
61
+
62
+ # GPU info (if torch is available)
63
+ try:
64
+ import torch
65
+ env["packages"]["torch"] = torch.__version__
66
+ if torch.cuda.is_available():
67
+ env["gpu"] = {
68
+ "name": torch.cuda.get_device_name(0),
69
+ "cuda_version": torch.version.cuda,
70
+ "device_count": torch.cuda.device_count(),
71
+ }
72
+ env["seeds"]["cuda_deterministic"] = True
73
+ except ImportError:
74
+ pass
75
+
76
+ # Config file hash for drift detection
77
+ config_path = Path("config.yaml")
78
+ if config_path.exists():
79
+ env["config_hash"] = hashlib.sha256(config_path.read_bytes()).hexdigest()[:16]
80
+
81
+ return env
82
+
83
+
84
+ def pin_all_seeds(seed: int) -> None:
85
+ """Pin all random seeds for full reproducibility.
86
+
87
+ Covers: stdlib random, numpy, PYTHONHASHSEED.
88
+ Torch/CUDA seeds are pinned only if torch is available.
89
+ """
90
+ random.seed(seed)
91
+ np.random.seed(seed)
92
+ os.environ["PYTHONHASHSEED"] = str(seed)
93
+
94
+ try:
95
+ import torch
96
+ torch.manual_seed(seed)
97
+ if torch.cuda.is_available():
98
+ torch.cuda.manual_seed_all(seed)
99
+ torch.backends.cudnn.deterministic = True
100
+ torch.backends.cudnn.benchmark = False
101
+ except ImportError:
102
+ pass
103
+
104
+
105
+ def train_model(
106
+ config_path: str = "config.yaml",
107
+ output_dir: str | None = None,
108
+ seed_override: int | None = None,
109
+ ) -> None:
110
+ """Train a model and print metrics.
111
+
112
+ Flow:
113
+ 1. Load config and pin seeds
114
+ 2. Load train/val splits
115
+ 3. Apply featurizer to get feature matrices
116
+ 4. Train model with hyperparams from config
117
+ 5. Predict on val set
118
+ 6. Evaluate using evaluate_model()
119
+ 7. Print formatted metrics
120
+ 8. Save model to models/ directory
121
+
122
+ Customize this function for your specific ML task.
123
+ """
124
+ start_time = time.time()
125
+
126
+ # 1. Load config and pin all seeds
127
+ config = load_config(config_path)
128
+ eval_cfg = config.get("evaluation", {})
129
+ primary_metric = eval_cfg.get("primary_metric", "{{TARGET_METRIC}}")
130
+ random_state = seed_override if seed_override is not None else config["data"].get("random_state", 42)
131
+ pin_all_seeds(random_state)
132
+
133
+ # 2. Load splits
134
+ splits_dir = config["data"]["splits_dir"]
135
+ train_df, val_df, _test_df = load_splits(splits_dir)
136
+
137
+ # 3. Apply featurizer
138
+ featurizer = get_default_featurizer()
139
+ target_column = config["data"].get("target_column", "label")
140
+
141
+ X_train = featurizer.fit_transform(train_df)
142
+ X_val = featurizer.transform(val_df)
143
+
144
+ # Fill NaN values that may arise from feature engineering
145
+ X_train = X_train.fillna(0)
146
+ X_val = X_val.fillna(0)
147
+
148
+ y_train = train_df[target_column].values
149
+ y_val = val_df[target_column].values
150
+
151
+ # 4. Train model (hyperparams from config.yaml)
152
+ model_config = config.get("model", {})
153
+ hyperparams = model_config.get("hyperparams", {})
154
+ model = XGBClassifier(
155
+ **hyperparams,
156
+ random_state=random_state,
157
+ )
158
+ model.fit(X_train, y_train)
159
+
160
+ # 5. Predict on train and val sets
161
+ y_pred_val = model.predict(X_val)
162
+ y_pred_train = model.predict(X_train)
163
+
164
+ # 6. Evaluate on val (primary) and train (gap monitoring)
165
+ metrics = evaluate_model(y_pred_val, y_val, config)
166
+ train_metrics = evaluate_model(y_pred_train, y_train, config)
167
+
168
+ # Compute train/val gap for overfitting detection
169
+ eval_cfg = config.get("evaluation", {})
170
+ primary_metric = eval_cfg.get("primary_metric", "{{TARGET_METRIC}}")
171
+ val_score = metrics.get(primary_metric)
172
+ train_score = train_metrics.get(primary_metric)
173
+ if val_score is not None and train_score is not None:
174
+ metrics["train_" + primary_metric] = train_score
175
+ metrics["overfit_gap"] = round(train_score - val_score, 4)
176
+
177
+ train_seconds = time.time() - start_time
178
+
179
+ # Add metadata for parseable output
180
+ metrics["model_type"] = model_config.get("type", "xgboost")
181
+ metrics["train_seconds"] = round(train_seconds, 1)
182
+
183
+ # 7. Print formatted metrics
184
+ print(format_metrics(metrics))
185
+
186
+ # 8. Save model
187
+ models_dir = output_dir or config["output"]["models_dir"]
188
+ models_path = Path(models_dir)
189
+ models_path.mkdir(parents=True, exist_ok=True)
190
+
191
+ model_file = models_path / "model.joblib"
192
+ joblib.dump(
193
+ {
194
+ "model": model,
195
+ "featurizer": featurizer,
196
+ "config": config,
197
+ },
198
+ model_file,
199
+ )
200
+ print(f"\nModel saved to {model_file}")
201
+
202
+ # 9. Write training metadata for behavioral validation + reproducibility
203
+ metadata = {
204
+ "train_time_sec": round(train_seconds, 1),
205
+ "model_size_bytes": model_file.stat().st_size if model_file.exists() else 0,
206
+ "n_train_samples": len(X_train),
207
+ "n_features": X_train.shape[1] if hasattr(X_train, 'shape') else len(X_train.columns),
208
+ "predictions_unique": int(len(set(y_pred_val.tolist()))),
209
+ "environment": capture_environment(random_state),
210
+ }
211
+ metadata_path = Path("train_metadata.json")
212
+ with open(metadata_path, "w") as f:
213
+ json.dump(metadata, f, indent=2)
214
+
215
+
216
+ def main() -> None:
217
+ """CLI entry point."""
218
+ parser = argparse.ArgumentParser(description="Train {{PROJECT_NAME}} model")
219
+ parser.add_argument(
220
+ "--config",
221
+ default="config.yaml",
222
+ help="Path to config override (default: config.yaml)",
223
+ )
224
+ parser.add_argument(
225
+ "--output-dir",
226
+ default=None,
227
+ help="Model output directory (default: from config)",
228
+ )
229
+ parser.add_argument(
230
+ "--seed",
231
+ type=int,
232
+ default=None,
233
+ help="Override random_state for multi-run statistical comparison",
234
+ )
235
+ args = parser.parse_args()
236
+ train_model(config_path=args.config, output_dir=args.output_dir, seed_override=args.seed)
237
+
238
+
239
+ if __name__ == "__main__":
240
+ main()