claude-turing 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. package/.claude-plugin/plugin.json +34 -0
  2. package/LICENSE +21 -0
  3. package/README.md +457 -0
  4. package/agents/ml-evaluator.md +43 -0
  5. package/agents/ml-researcher.md +74 -0
  6. package/bin/cli.js +46 -0
  7. package/bin/turing-init.sh +57 -0
  8. package/commands/brief.md +83 -0
  9. package/commands/compare.md +24 -0
  10. package/commands/design.md +97 -0
  11. package/commands/init.md +123 -0
  12. package/commands/logbook.md +51 -0
  13. package/commands/mode.md +43 -0
  14. package/commands/poster.md +89 -0
  15. package/commands/preflight.md +75 -0
  16. package/commands/report.md +97 -0
  17. package/commands/rules/loop-protocol.md +91 -0
  18. package/commands/status.md +24 -0
  19. package/commands/suggest.md +95 -0
  20. package/commands/sweep.md +45 -0
  21. package/commands/train.md +66 -0
  22. package/commands/try.md +63 -0
  23. package/commands/turing.md +54 -0
  24. package/commands/validate.md +34 -0
  25. package/config/defaults.yaml +45 -0
  26. package/config/experiment_archetypes.yaml +127 -0
  27. package/config/lifecycle.toml +31 -0
  28. package/config/novelty_aliases.yaml +107 -0
  29. package/config/relationships.toml +125 -0
  30. package/config/state.toml +24 -0
  31. package/config/task_taxonomy.yaml +110 -0
  32. package/config/taxonomy.toml +37 -0
  33. package/package.json +54 -0
  34. package/src/claude-md.js +55 -0
  35. package/src/install.js +107 -0
  36. package/src/paths.js +20 -0
  37. package/src/postinstall.js +22 -0
  38. package/src/verify.js +109 -0
  39. package/templates/MEMORY.md +36 -0
  40. package/templates/README.md +93 -0
  41. package/templates/__pycache__/evaluate.cpython-314.pyc +0 -0
  42. package/templates/__pycache__/prepare.cpython-314.pyc +0 -0
  43. package/templates/config.yaml +48 -0
  44. package/templates/evaluate.py +237 -0
  45. package/templates/features/__init__.py +0 -0
  46. package/templates/features/__pycache__/__init__.cpython-314.pyc +0 -0
  47. package/templates/features/__pycache__/featurizers.cpython-314.pyc +0 -0
  48. package/templates/features/featurizers.py +138 -0
  49. package/templates/prepare.py +171 -0
  50. package/templates/program.md +216 -0
  51. package/templates/pyproject.toml +8 -0
  52. package/templates/requirements.txt +8 -0
  53. package/templates/scripts/__init__.py +0 -0
  54. package/templates/scripts/__pycache__/__init__.cpython-314.pyc +0 -0
  55. package/templates/scripts/__pycache__/check_convergence.cpython-314.pyc +0 -0
  56. package/templates/scripts/__pycache__/classify_task.cpython-314.pyc +0 -0
  57. package/templates/scripts/__pycache__/critique_hypothesis.cpython-314.pyc +0 -0
  58. package/templates/scripts/__pycache__/experiment_index.cpython-314.pyc +0 -0
  59. package/templates/scripts/__pycache__/generate_brief.cpython-314.pyc +0 -0
  60. package/templates/scripts/__pycache__/generate_logbook.cpython-314.pyc +0 -0
  61. package/templates/scripts/__pycache__/log_experiment.cpython-314.pyc +0 -0
  62. package/templates/scripts/__pycache__/manage_hypotheses.cpython-314.pyc +0 -0
  63. package/templates/scripts/__pycache__/novelty_guard.cpython-314.pyc +0 -0
  64. package/templates/scripts/__pycache__/parse_metrics.cpython-314.pyc +0 -0
  65. package/templates/scripts/__pycache__/scaffold.cpython-314.pyc +0 -0
  66. package/templates/scripts/__pycache__/show_experiment_tree.cpython-314.pyc +0 -0
  67. package/templates/scripts/__pycache__/show_families.cpython-314.pyc +0 -0
  68. package/templates/scripts/__pycache__/statistical_compare.cpython-314.pyc +0 -0
  69. package/templates/scripts/__pycache__/suggest_next.cpython-314.pyc +0 -0
  70. package/templates/scripts/__pycache__/sweep.cpython-314.pyc +0 -0
  71. package/templates/scripts/__pycache__/synthesize_decision.cpython-314.pyc +0 -0
  72. package/templates/scripts/__pycache__/turing_io.cpython-314.pyc +0 -0
  73. package/templates/scripts/__pycache__/update_state.cpython-314.pyc +0 -0
  74. package/templates/scripts/__pycache__/verify_placeholders.cpython-314.pyc +0 -0
  75. package/templates/scripts/check_convergence.py +230 -0
  76. package/templates/scripts/compare_runs.py +124 -0
  77. package/templates/scripts/critique_hypothesis.py +350 -0
  78. package/templates/scripts/experiment_index.py +288 -0
  79. package/templates/scripts/generate_brief.py +389 -0
  80. package/templates/scripts/generate_logbook.py +423 -0
  81. package/templates/scripts/log_experiment.py +243 -0
  82. package/templates/scripts/manage_hypotheses.py +543 -0
  83. package/templates/scripts/novelty_guard.py +343 -0
  84. package/templates/scripts/parse_metrics.py +139 -0
  85. package/templates/scripts/post-train-hook.sh +74 -0
  86. package/templates/scripts/preflight.py +549 -0
  87. package/templates/scripts/scaffold.py +409 -0
  88. package/templates/scripts/show_environment.py +92 -0
  89. package/templates/scripts/show_experiment_tree.py +144 -0
  90. package/templates/scripts/show_families.py +133 -0
  91. package/templates/scripts/show_metrics.py +157 -0
  92. package/templates/scripts/statistical_compare.py +259 -0
  93. package/templates/scripts/stop-hook.sh +34 -0
  94. package/templates/scripts/suggest_next.py +301 -0
  95. package/templates/scripts/sweep.py +276 -0
  96. package/templates/scripts/synthesize_decision.py +300 -0
  97. package/templates/scripts/turing_io.py +76 -0
  98. package/templates/scripts/update_state.py +296 -0
  99. package/templates/scripts/validate_stability.py +167 -0
  100. package/templates/scripts/verify_placeholders.py +119 -0
  101. package/templates/sweep_config.yaml +14 -0
  102. package/templates/tests/__init__.py +0 -0
  103. package/templates/tests/conftest.py +91 -0
  104. package/templates/train.py +240 -0
@@ -0,0 +1,237 @@
1
+ """Evaluation harness for the {{PROJECT_NAME}} ML pipeline.
2
+
3
+ HIDDEN — MEASUREMENT APPARATUS.
4
+
5
+ This file is hidden from the autoresearch agent. The agent cannot
6
+ read, modify, or reference this file. This prevents metric gaming,
7
+ seed exploitation, and evaluation function reverse-engineering.
8
+
9
+ The platform runs this file automatically. The agent knows only:
10
+ - The primary metric name (from config.yaml)
11
+ - Whether higher or lower is better (from config.yaml)
12
+ - The metric value (from parsed run.log output)
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import json
19
+ import sys
20
+ from pathlib import Path
21
+
22
+ import numpy as np
23
+ from sklearn.metrics import (
24
+ accuracy_score,
25
+ f1_score,
26
+ mean_absolute_error,
27
+ mean_squared_error,
28
+ precision_score,
29
+ recall_score,
30
+ roc_auc_score,
31
+ )
32
+
33
+
34
+ def evaluate_model(
35
+ predictions: np.ndarray,
36
+ ground_truth: np.ndarray,
37
+ config: dict | None = None,
38
+ ) -> dict:
39
+ """Compute evaluation metrics from predictions vs ground truth.
40
+
41
+ Customize this function for your specific ML task. Add or remove
42
+ metrics as needed. The autoresearch agent reads these via format_metrics().
43
+
44
+ Args:
45
+ predictions: Model predictions (numpy array).
46
+ ground_truth: Ground truth values (numpy array).
47
+ config: Optional config dict for metric selection.
48
+
49
+ Returns:
50
+ Dict with metric names and values.
51
+ """
52
+ if len(predictions) != len(ground_truth):
53
+ raise ValueError(
54
+ f"Length mismatch: {len(predictions)} predictions vs "
55
+ f"{len(ground_truth)} ground truth"
56
+ )
57
+
58
+ # Determine which metrics to compute from config
59
+ eval_cfg = config.get("evaluation", {}) if config else {}
60
+ metric_names = eval_cfg.get("metrics", ["accuracy", "f1_weighted"])
61
+
62
+ results = {}
63
+
64
+ for metric_name in metric_names:
65
+ if metric_name == "accuracy":
66
+ results["accuracy"] = round(float(accuracy_score(ground_truth, predictions)), 4)
67
+ elif metric_name == "f1_weighted":
68
+ results["f1_weighted"] = round(float(f1_score(ground_truth, predictions, average="weighted")), 4)
69
+ elif metric_name == "f1_macro":
70
+ results["f1_macro"] = round(float(f1_score(ground_truth, predictions, average="macro")), 4)
71
+ elif metric_name == "f1_micro":
72
+ results["f1_micro"] = round(float(f1_score(ground_truth, predictions, average="micro")), 4)
73
+ elif metric_name == "precision":
74
+ results["precision"] = round(float(precision_score(ground_truth, predictions, average="weighted")), 4)
75
+ elif metric_name == "recall":
76
+ results["recall"] = round(float(recall_score(ground_truth, predictions, average="weighted")), 4)
77
+ elif metric_name == "mae":
78
+ results["mae"] = round(float(mean_absolute_error(ground_truth, predictions)), 4)
79
+ elif metric_name == "mse":
80
+ results["mse"] = round(float(mean_squared_error(ground_truth, predictions)), 4)
81
+ elif metric_name == "rmse":
82
+ results["rmse"] = round(float(np.sqrt(mean_squared_error(ground_truth, predictions))), 4)
83
+
84
+ return results
85
+
86
+
87
+ def evaluate_detailed(
88
+ predictions: np.ndarray,
89
+ ground_truth: np.ndarray,
90
+ config: dict | None = None,
91
+ ) -> dict:
92
+ """Compute detailed evaluation metrics including per-class breakdown.
93
+
94
+ Extends evaluate_model with per-class precision/recall/F1 and a
95
+ confusion matrix. The agent uses this to understand WHERE the model
96
+ fails, not just that it fails.
97
+
98
+ Args:
99
+ predictions: Model predictions (numpy array).
100
+ ground_truth: Ground truth values (numpy array).
101
+ config: Optional config dict for metric selection.
102
+
103
+ Returns:
104
+ Dict with 'aggregate' (same as evaluate_model), 'per_class'
105
+ (dict of class -> {precision, recall, f1, support}), and
106
+ 'confusion_matrix' (dict representation).
107
+ """
108
+ aggregate = evaluate_model(predictions, ground_truth, config)
109
+
110
+ # Per-class breakdown
111
+ classes = sorted(set(ground_truth.tolist()))
112
+ per_class = {}
113
+
114
+ for cls in classes:
115
+ cls_mask = ground_truth == cls
116
+ n_support = int(cls_mask.sum())
117
+ cls_preds = predictions[cls_mask]
118
+
119
+ tp = int((cls_preds == cls).sum())
120
+ precision_denom = int((predictions == cls).sum())
121
+ cls_precision = round(tp / precision_denom, 4) if precision_denom > 0 else 0.0
122
+ cls_recall = round(tp / n_support, 4) if n_support > 0 else 0.0
123
+
124
+ if cls_precision + cls_recall > 0:
125
+ cls_f1 = round(2 * cls_precision * cls_recall / (cls_precision + cls_recall), 4)
126
+ else:
127
+ cls_f1 = 0.0
128
+
129
+ per_class[str(cls)] = {
130
+ "precision": cls_precision,
131
+ "recall": cls_recall,
132
+ "f1": cls_f1,
133
+ "support": n_support,
134
+ }
135
+
136
+ # Confusion matrix as dict
137
+ confusion = {}
138
+ for true_cls in classes:
139
+ row = {}
140
+ for pred_cls in classes:
141
+ row[str(pred_cls)] = int(((ground_truth == true_cls) & (predictions == pred_cls)).sum())
142
+ confusion[str(true_cls)] = row
143
+
144
+ return {
145
+ "aggregate": aggregate,
146
+ "per_class": per_class,
147
+ "confusion_matrix": confusion,
148
+ }
149
+
150
+
151
+ def format_metrics(metrics: dict) -> str:
152
+ """Format metrics in a parseable delimited format.
153
+
154
+ Output format (for the autoresearch agent to parse):
155
+ ---
156
+ metric_name: value
157
+ ...
158
+ ---
159
+
160
+ The agent reads metrics by grepping between --- delimiters.
161
+
162
+ Args:
163
+ metrics: Dict with metric names and values.
164
+
165
+ Returns:
166
+ Formatted string.
167
+ """
168
+ # Separate known metadata keys from actual metrics
169
+ metadata_keys = {"model_type", "train_seconds"}
170
+ metric_keys = [k for k in metrics if k not in metadata_keys]
171
+ all_keys = metric_keys + [k for k in metadata_keys if k in metrics]
172
+
173
+ lines = ["---"]
174
+ for key in all_keys:
175
+ padding = " " * max(1, 15 - len(key))
176
+ lines.append(f"{key}:{padding}{metrics[key]}")
177
+ lines.append("---")
178
+ return "\n".join(lines)
179
+
180
+
181
+ def validate_training_behavior(config: dict | None = None) -> tuple[bool, str]:
182
+ """Validate that real training work was performed.
183
+
184
+ HIDDEN — this function runs automatically. The agent cannot see it.
185
+
186
+ Checks:
187
+ 1. Training took a minimum amount of time (prevents task avoidance)
188
+ 2. Model artifact has non-trivial size (prevents empty model saves)
189
+ 3. Predictions have diversity (prevents constant-prediction shortcuts)
190
+
191
+ Returns:
192
+ (passed, message) tuple.
193
+ """
194
+ constraints = (config or {}).get("constraints", {})
195
+ min_train_time = constraints.get("min_train_time", 5)
196
+ min_model_size = constraints.get("min_model_size_bytes", 100)
197
+
198
+ meta_path = Path("train_metadata.json")
199
+ if not meta_path.exists():
200
+ return True, "no metadata — skipping behavioral checks"
201
+
202
+ with open(meta_path) as f:
203
+ meta = json.load(f)
204
+
205
+ train_time = meta.get("train_time_sec", 0)
206
+ if train_time < min_train_time:
207
+ return False, f"PROBE FAIL: train_time={train_time:.1f}s < minimum {min_train_time}s — training may have been skipped"
208
+
209
+ model_size = meta.get("model_size_bytes", 0)
210
+ if model_size < min_model_size:
211
+ return False, f"PROBE FAIL: model_size={model_size} bytes < minimum {min_model_size} — model may be empty"
212
+
213
+ pred_unique = meta.get("predictions_unique", 0)
214
+ if pred_unique <= 1:
215
+ return False, f"PROBE FAIL: predictions_unique={pred_unique} — model may predict a constant value"
216
+
217
+ return True, f"behavioral_ok: train_time={train_time:.1f}s, model_size={model_size}, pred_diversity={pred_unique}"
218
+
219
+
220
+ if __name__ == "__main__":
221
+ parser = argparse.ArgumentParser(
222
+ description="Evaluate {{PROJECT_NAME}} model predictions"
223
+ )
224
+ parser.add_argument("predictions", help="Path to predictions JSONL file")
225
+ parser.add_argument("ground_truth", help="Path to ground truth JSONL file")
226
+ args = parser.parse_args()
227
+
228
+ with open(args.predictions) as f:
229
+ preds = [json.loads(line) for line in f if line.strip()]
230
+ with open(args.ground_truth) as f:
231
+ truth = [json.loads(line) for line in f if line.strip()]
232
+
233
+ pred_values = np.array([p.get("prediction") for p in preds])
234
+ truth_values = np.array([t.get("label") for t in truth])
235
+
236
+ result = evaluate_model(pred_values, truth_values)
237
+ print(format_metrics(result))
File without changes
@@ -0,0 +1,138 @@
1
+ """Feature engineering strategies for the {{PROJECT_NAME}} ML pipeline.
2
+
3
+ READ-ONLY — INFRASTRUCTURE.
4
+
5
+ The autoresearch agent does not modify this file directly. Instead, it
6
+ modifies how train.py *uses* the featurizers — composing them differently,
7
+ selecting different column subsets, or adding preprocessing in train.py.
8
+
9
+ Provides pluggable featurizers following a scikit-learn-like fit/transform
10
+ interface. The CompositeFeaturizer chains multiple featurizers, concatenating
11
+ their output columns.
12
+
13
+ Exports:
14
+ - BaseFeaturizer: Abstract base class.
15
+ - NumericFeaturizer: Passes through numeric columns.
16
+ - CategoricalFeaturizer: One-hot encodes categorical columns.
17
+ - CompositeFeaturizer: Chains multiple featurizers.
18
+ - get_default_featurizer: Returns the standard composite featurizer.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from abc import ABC, abstractmethod
24
+
25
+ import pandas as pd
26
+
27
+
28
+ class BaseFeaturizer(ABC):
29
+ """Abstract base class for feature extraction."""
30
+
31
+ @abstractmethod
32
+ def fit(self, df: pd.DataFrame) -> "BaseFeaturizer":
33
+ """Fit the featurizer to training data.
34
+
35
+ Args:
36
+ df: Training DataFrame.
37
+
38
+ Returns:
39
+ self
40
+ """
41
+
42
+ @abstractmethod
43
+ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
44
+ """Transform data into feature DataFrame.
45
+
46
+ Args:
47
+ df: Input DataFrame.
48
+
49
+ Returns:
50
+ DataFrame with extracted features (numeric columns only).
51
+ """
52
+
53
+ def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
54
+ """Fit and transform in one step."""
55
+ return self.fit(df).transform(df)
56
+
57
+
58
+ class NumericFeaturizer(BaseFeaturizer):
59
+ """Passes through numeric columns as features.
60
+
61
+ Customize the column list for your dataset.
62
+ """
63
+
64
+ def __init__(self, columns: list[str] | None = None) -> None:
65
+ self.columns = columns
66
+ self._fitted_columns: list[str] = []
67
+
68
+ def fit(self, df: pd.DataFrame) -> "NumericFeaturizer":
69
+ if self.columns:
70
+ self._fitted_columns = [c for c in self.columns if c in df.columns]
71
+ else:
72
+ self._fitted_columns = df.select_dtypes(include=["number"]).columns.tolist()
73
+ return self
74
+
75
+ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
76
+ return df[self._fitted_columns].copy()
77
+
78
+
79
+ class CategoricalFeaturizer(BaseFeaturizer):
80
+ """One-hot encodes categorical columns.
81
+
82
+ Customize the column list for your dataset.
83
+ """
84
+
85
+ def __init__(self, columns: list[str] | None = None) -> None:
86
+ self.columns = columns
87
+ self._categories: dict[str, list[str]] = {}
88
+
89
+ def fit(self, df: pd.DataFrame) -> "CategoricalFeaturizer":
90
+ cols = self.columns or df.select_dtypes(include=["object", "category"]).columns.tolist()
91
+ for col in cols:
92
+ if col in df.columns:
93
+ self._categories[col] = sorted(df[col].unique().tolist())
94
+ return self
95
+
96
+ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
97
+ features = pd.DataFrame(index=df.index)
98
+ for col, categories in self._categories.items():
99
+ if col in df.columns:
100
+ for cat in categories:
101
+ features[f"{col}_{cat}"] = (df[col] == cat).astype(int)
102
+ return features
103
+
104
+
105
+ class CompositeFeaturizer(BaseFeaturizer):
106
+ """Chains multiple featurizers, concatenating their output columns."""
107
+
108
+ def __init__(self, featurizers: list[BaseFeaturizer]) -> None:
109
+ self.featurizers = featurizers
110
+
111
+ def fit(self, df: pd.DataFrame) -> "CompositeFeaturizer":
112
+ for f in self.featurizers:
113
+ f.fit(df)
114
+ return self
115
+
116
+ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
117
+ parts = [f.transform(df) for f in self.featurizers]
118
+ return pd.concat(parts, axis=1)
119
+
120
+ def __repr__(self) -> str:
121
+ names = [type(f).__name__ for f in self.featurizers]
122
+ return f"CompositeFeaturizer({names})"
123
+
124
+
125
+ def get_default_featurizer() -> CompositeFeaturizer:
126
+ """Return the standard composite featurizer.
127
+
128
+ Customize this function for your dataset. Add or remove featurizers
129
+ to match your feature engineering needs.
130
+
131
+ The autoresearch agent calls this from train.py. To experiment with
132
+ different feature sets, the agent modifies how train.py calls this
133
+ function or composes featurizers differently.
134
+ """
135
+ return CompositeFeaturizer([
136
+ NumericFeaturizer(),
137
+ CategoricalFeaturizer(),
138
+ ])
@@ -0,0 +1,171 @@
1
+ """Data preparation module for the {{PROJECT_NAME}} ML pipeline.
2
+
3
+ READ-ONLY — MEASUREMENT APPARATUS.
4
+
5
+ This file is part of the immutable evaluation infrastructure. The autoresearch
6
+ agent MUST NOT modify this file under any circumstances. Consistent data
7
+ preparation across experiments ensures that observed metric changes reflect
8
+ genuine model improvements, not data handling artifacts.
9
+
10
+ Provides:
11
+ - load_config: Load YAML experiment configuration.
12
+ - load_data: Load training data into a DataFrame.
13
+ - create_splits: Stratified train/val/test split.
14
+ - load_splits: Load pre-created split files.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ from pathlib import Path
21
+
22
+ import pandas as pd
23
+ import yaml
24
+
25
+
26
+ def load_config(path: str = "config.yaml") -> dict:
27
+ """Load YAML experiment configuration.
28
+
29
+ Args:
30
+ path: Path to the YAML config file.
31
+
32
+ Returns:
33
+ Configuration dictionary.
34
+ """
35
+ with open(path) as f:
36
+ return yaml.safe_load(f)
37
+
38
+
39
+ def load_data(path: str) -> pd.DataFrame:
40
+ """Load training data into a DataFrame.
41
+
42
+ Supports JSONL (.jsonl) and CSV (.csv) formats.
43
+
44
+ Args:
45
+ path: Path to the data file.
46
+
47
+ Returns:
48
+ DataFrame with training data.
49
+
50
+ Raises:
51
+ FileNotFoundError: If path does not exist.
52
+ ValueError: If file format is unsupported.
53
+ """
54
+ p = Path(path)
55
+ if not p.exists():
56
+ raise FileNotFoundError(f"Data file not found: {path}")
57
+
58
+ if p.suffix == ".jsonl":
59
+ records = []
60
+ with open(path) as f:
61
+ for line in f:
62
+ line = line.strip()
63
+ if line:
64
+ records.append(json.loads(line))
65
+ if not records:
66
+ return pd.DataFrame()
67
+ return pd.DataFrame(records)
68
+ elif p.suffix == ".csv":
69
+ return pd.read_csv(path)
70
+ else:
71
+ raise ValueError(
72
+ f"Unsupported file format: {p.suffix}. Use .jsonl or .csv"
73
+ )
74
+
75
+
76
+ def create_splits(
77
+ data_path: str,
78
+ output_dir: str,
79
+ target_column: str = "label",
80
+ test_size: float = 0.15,
81
+ val_size: float = 0.15,
82
+ random_state: int = 42,
83
+ ) -> dict[str, Path]:
84
+ """Create stratified train/val/test splits from training data.
85
+
86
+ Stratifies by target_column to preserve label distribution.
87
+
88
+ Args:
89
+ data_path: Path to the source data file.
90
+ output_dir: Directory to write train.jsonl, val.jsonl, test.jsonl.
91
+ target_column: Column to stratify on.
92
+ test_size: Fraction of data for test set.
93
+ val_size: Fraction of data for validation set.
94
+ random_state: Random seed for reproducibility.
95
+
96
+ Returns:
97
+ Dict mapping split name to output file path.
98
+ """
99
+ from sklearn.model_selection import train_test_split
100
+
101
+ df = load_data(data_path)
102
+ if df.empty:
103
+ raise ValueError(f"No data found in {data_path}")
104
+
105
+ out = Path(output_dir)
106
+ out.mkdir(parents=True, exist_ok=True)
107
+
108
+ # First split: separate test set
109
+ stratify_col = df[target_column] if target_column in df.columns else None
110
+ train_val, test = train_test_split(
111
+ df,
112
+ test_size=test_size,
113
+ random_state=random_state,
114
+ stratify=stratify_col,
115
+ )
116
+
117
+ # Second split: separate val from train
118
+ val_relative = val_size / (1.0 - test_size)
119
+ stratify_col_tv = train_val[target_column] if target_column in train_val.columns else None
120
+ train, val = train_test_split(
121
+ train_val,
122
+ test_size=val_relative,
123
+ random_state=random_state,
124
+ stratify=stratify_col_tv,
125
+ )
126
+
127
+ paths = {}
128
+ for name, split_df in [("train", train), ("val", val), ("test", test)]:
129
+ path = out / f"{name}.jsonl"
130
+ with open(path, "w") as f:
131
+ for _, row in split_df.iterrows():
132
+ f.write(json.dumps(row.to_dict()) + "\n")
133
+ paths[name] = path
134
+
135
+ return paths
136
+
137
+
138
+ def load_splits(splits_dir: str) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
139
+ """Load pre-created train/val/test splits.
140
+
141
+ Args:
142
+ splits_dir: Directory containing train.jsonl, val.jsonl, test.jsonl.
143
+
144
+ Returns:
145
+ Tuple of (train_df, val_df, test_df).
146
+
147
+ Raises:
148
+ FileNotFoundError: If any split file is missing.
149
+ """
150
+ splits_path = Path(splits_dir)
151
+ train = load_data(str(splits_path / "train.jsonl"))
152
+ val = load_data(str(splits_path / "val.jsonl"))
153
+ test = load_data(str(splits_path / "test.jsonl"))
154
+ return train, val, test
155
+
156
+
157
+ if __name__ == "__main__":
158
+ config = load_config()
159
+ data_cfg = config["data"]
160
+ print(f"Creating splits from {data_cfg['source']}...")
161
+ paths = create_splits(
162
+ data_path=data_cfg["source"],
163
+ output_dir=data_cfg["splits_dir"],
164
+ target_column=data_cfg.get("target_column", "label"),
165
+ test_size=data_cfg["split_ratios"]["test"],
166
+ val_size=data_cfg["split_ratios"]["val"],
167
+ random_state=data_cfg["random_state"],
168
+ )
169
+ for name, path in paths.items():
170
+ print(f" {name}: {path}")
171
+ print("Done.")