@elizaos/training 2.0.0-alpha.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. package/Dockerfile +75 -0
  2. package/Makefile +374 -0
  3. package/README.md +346 -0
  4. package/config/rubrics.json +137 -0
  5. package/data/.gitkeep +0 -0
  6. package/data/degen/.gitkeep +2 -0
  7. package/data/trader/.gitkeep +2 -0
  8. package/docker-compose.test.yml +57 -0
  9. package/package.json +58 -0
  10. package/python/config/babylon_atropos.yaml +90 -0
  11. package/python/config/profiles/12gb.json +11 -0
  12. package/python/config/profiles/16gb.json +10 -0
  13. package/python/config/profiles/24gb.json +10 -0
  14. package/python/config/profiles/48gb.json +10 -0
  15. package/python/config/profiles/cpu.json +11 -0
  16. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  17. package/python/config/profiles/l40-2gpu.json +22 -0
  18. package/python/config/profiles/l40-4gpu.json +21 -0
  19. package/python/config/profiles/l40.json +17 -0
  20. package/python/config/tinker_training.yaml +143 -0
  21. package/python/curriculum_state.json +165 -0
  22. package/python/env.template +86 -0
  23. package/python/env.training.template +46 -0
  24. package/python/pyproject.toml +41 -0
  25. package/python/requirements-ci.txt +31 -0
  26. package/python/requirements.txt +87 -0
  27. package/python/scripts/__init__.py +4 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/run_ab_test.py +143 -0
  36. package/python/scripts/run_full_pipeline.py +544 -0
  37. package/python/scripts/run_tinker_training.py +192 -0
  38. package/python/scripts/run_training.py +914 -0
  39. package/python/scripts/test_judge.py +155 -0
  40. package/python/scripts/test_pipeline.py +356 -0
  41. package/python/scripts/test_trained_model.py +380 -0
  42. package/python/scripts/train_local.py +528 -0
  43. package/python/setup.py +20 -0
  44. package/python/src/__init__.py +190 -0
  45. package/python/src/data_bridge/__init__.py +24 -0
  46. package/python/src/data_bridge/converter.py +435 -0
  47. package/python/src/data_bridge/reader.py +393 -0
  48. package/python/src/models.py +283 -0
  49. package/python/src/training/__init__.py +605 -0
  50. package/python/src/training/ab_testing.py +404 -0
  51. package/python/src/training/action_executor.py +621 -0
  52. package/python/src/training/archetype_trainer.py +347 -0
  53. package/python/src/training/atropos_trainer.py +980 -0
  54. package/python/src/training/babylon_env.py +1254 -0
  55. package/python/src/training/error_recovery.py +647 -0
  56. package/python/src/training/evaluation.py +856 -0
  57. package/python/src/training/fast_simulator.py +880 -0
  58. package/python/src/training/format_validator.py +584 -0
  59. package/python/src/training/hybrid_env.py +522 -0
  60. package/python/src/training/kl_controller.py +628 -0
  61. package/python/src/training/multi_prompt_dataset.py +883 -0
  62. package/python/src/training/multi_turn.py +656 -0
  63. package/python/src/training/online_env.py +1084 -0
  64. package/python/src/training/quality_scorer.py +391 -0
  65. package/python/src/training/quality_utils.py +633 -0
  66. package/python/src/training/rewards.py +1344 -0
  67. package/python/src/training/rlaif_env.py +17 -0
  68. package/python/src/training/rollout_generator.py +502 -0
  69. package/python/src/training/rubric_loader.py +198 -0
  70. package/python/src/training/scenario_pool.py +1072 -0
  71. package/python/src/training/schemas.py +481 -0
  72. package/python/src/training/service_manager.py +552 -0
  73. package/python/src/training/simulation_bridge.py +535 -0
  74. package/python/src/training/tick_reward_attribution.py +399 -0
  75. package/python/src/training/tinker_client.py +575 -0
  76. package/python/src/training/tinker_trainer.py +646 -0
  77. package/python/src/training/tokenization_utils.py +402 -0
  78. package/python/tests/e2e/__init__.py +13 -0
  79. package/python/tests/e2e/conftest.py +258 -0
  80. package/python/tests/e2e/test_full_pipeline.py +643 -0
  81. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  82. package/python/tests/integration/__init__.py +12 -0
  83. package/python/tests/integration/conftest.py +383 -0
  84. package/python/tests/integration/test_db_integration.py +649 -0
  85. package/python/tests/integration/test_json_mode_integration.py +554 -0
  86. package/python/tests/test_action_executor.py +594 -0
  87. package/python/tests/test_archetype_scoring.py +1027 -0
  88. package/python/tests/test_atropos_integration.py +360 -0
  89. package/python/tests/test_evaluation.py +727 -0
  90. package/python/tests/test_format_validator.py +486 -0
  91. package/python/tests/test_kl_controller.py +432 -0
  92. package/python/tests/test_lr_scheduler.py +579 -0
  93. package/python/tests/test_multi_turn.py +590 -0
  94. package/python/tests/test_online_env.py +519 -0
  95. package/python/tests/test_quality_scorer.py +474 -0
  96. package/python/tests/test_scenario_pool.py +735 -0
  97. package/python/tests/test_service_manager.py +585 -0
  98. package/python/tests/test_simulation_rollout.py +581 -0
  99. package/python/tests/test_tokenization_utils.py +501 -0
  100. package/python/tests/test_training_orchestrator.py +497 -0
  101. package/python/tests/test_training_output_structure.py +661 -0
  102. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  103. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  104. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  105. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  106. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  107. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  108. package/scripts/assess-training-data.ts +422 -0
  109. package/scripts/e2e-training-test.ts +550 -0
  110. package/scripts/export-rubrics.ts +64 -0
  111. package/scripts/generate-research-report.ts +1523 -0
  112. package/scripts/generate_dataset.sh +173 -0
  113. package/scripts/json-mode-benchmark.ts +399 -0
  114. package/scripts/real-archetype-benchmark.ts +210 -0
  115. package/scripts/run-baseline-comparison.ts +116 -0
  116. package/scripts/run-full-pipeline.ts +272 -0
  117. package/scripts/runpod_setup.sh +137 -0
  118. package/scripts/runpod_validate.sh +147 -0
  119. package/scripts/test-model-in-game.ts +955 -0
  120. package/scripts/test-scoring.ts +73 -0
  121. package/scripts/test-trained-model.ts +209 -0
  122. package/scripts/train-and-test.ts +824 -0
  123. package/scripts/verify-final.ts +118 -0
  124. package/src/adapter.ts +516 -0
  125. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  126. package/src/archetypes/derive-archetype.ts +249 -0
  127. package/src/archetypes/index.ts +22 -0
  128. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  129. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  130. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  131. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  132. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  133. package/src/benchmark/BenchmarkRunner.ts +685 -0
  134. package/src/benchmark/BenchmarkValidator.ts +206 -0
  135. package/src/benchmark/FastEvalRunner.ts +225 -0
  136. package/src/benchmark/MetricsValidator.ts +165 -0
  137. package/src/benchmark/MetricsVisualizer.ts +909 -0
  138. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  139. package/src/benchmark/ModelRegistry.ts +158 -0
  140. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  141. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  142. package/src/benchmark/SimulationEngine.ts +832 -0
  143. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  144. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  145. package/src/benchmark/index.ts +89 -0
  146. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  147. package/src/benchmark/simulation-types.ts +78 -0
  148. package/src/dependencies.ts +439 -0
  149. package/src/generation/TrajectoryGenerator.ts +387 -0
  150. package/src/generation/index.ts +12 -0
  151. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  152. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  153. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  154. package/src/huggingface/index.ts +27 -0
  155. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  156. package/src/index.ts +102 -0
  157. package/src/init-training.ts +53 -0
  158. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  159. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  160. package/src/metrics/index.ts +8 -0
  161. package/src/metrics/types.ts +200 -0
  162. package/src/rubrics/__tests__/index.test.ts +184 -0
  163. package/src/rubrics/ass-kisser.ts +85 -0
  164. package/src/rubrics/degen.ts +80 -0
  165. package/src/rubrics/goody-twoshoes.ts +84 -0
  166. package/src/rubrics/index.ts +236 -0
  167. package/src/rubrics/information-trader.ts +84 -0
  168. package/src/rubrics/infosec.ts +101 -0
  169. package/src/rubrics/liar.ts +104 -0
  170. package/src/rubrics/perps-trader.ts +87 -0
  171. package/src/rubrics/researcher.ts +81 -0
  172. package/src/rubrics/scammer.ts +82 -0
  173. package/src/rubrics/social-butterfly.ts +73 -0
  174. package/src/rubrics/super-predictor.ts +97 -0
  175. package/src/rubrics/trader.ts +67 -0
  176. package/src/scoring/ArchetypeScoringService.ts +486 -0
  177. package/src/scoring/JudgePromptBuilder.ts +556 -0
  178. package/src/scoring/LLMJudgeCache.ts +401 -0
  179. package/src/scoring/index.ts +9 -0
  180. package/src/training/AutomationPipeline.ts +916 -0
  181. package/src/training/BenchmarkService.ts +518 -0
  182. package/src/training/ConfigValidator.ts +220 -0
  183. package/src/training/MarketOutcomesTracker.ts +187 -0
  184. package/src/training/ModelDeployer.ts +186 -0
  185. package/src/training/ModelFetcher.ts +76 -0
  186. package/src/training/ModelSelectionService.ts +341 -0
  187. package/src/training/ModelUsageVerifier.ts +160 -0
  188. package/src/training/MultiModelOrchestrator.ts +580 -0
  189. package/src/training/RLModelConfig.ts +407 -0
  190. package/src/training/RewardBackpropagationService.ts +149 -0
  191. package/src/training/RulerScoringService.ts +666 -0
  192. package/src/training/TrainingMonitor.ts +166 -0
  193. package/src/training/TrajectoryRecorder.ts +399 -0
  194. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  195. package/src/training/index.ts +100 -0
  196. package/src/training/logRLConfig.ts +34 -0
  197. package/src/training/pipeline.ts +129 -0
  198. package/src/training/storage/ModelStorageService.ts +279 -0
  199. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  200. package/src/training/storage/index.ts +17 -0
  201. package/src/training/types.ts +207 -0
  202. package/src/training/window-utils.ts +138 -0
  203. package/src/utils/index.ts +101 -0
  204. package/src/utils/logger.ts +59 -0
  205. package/src/utils/snowflake.ts +17 -0
  206. package/src/utils/synthetic-detector.ts +111 -0
  207. package/tsconfig.json +20 -0
@@ -0,0 +1,318 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Merge Trajectories from Multiple Workers
4
+
5
+ Combines trajectory files from parallel generation workers into a single
6
+ output directory, handling deduplication and validation.
7
+
8
+ Usage:
9
+ python scripts/merge_trajectories.py ./training-data-output
10
+ python scripts/merge_trajectories.py ./training-data-output --output ./merged
11
+ python scripts/merge_trajectories.py ./training-data-output --validate
12
+
13
+ Requirements:
14
+ - Trajectory JSON files from generate_dataset.sh
15
+ """
16
+
17
+ import argparse
18
+ import hashlib
19
+ import json
20
+ import logging
21
+ import os
22
+ import shutil
23
+ import sys
24
+ from dataclasses import dataclass
25
+ from datetime import datetime
26
+ from pathlib import Path
27
+ from typing import Dict, List, Set
28
+
29
+ logging.basicConfig(
30
+ level=logging.INFO,
31
+ format="%(asctime)s - %(levelname)s - %(message)s"
32
+ )
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class MergeStats:
38
+ """Statistics for merge operation"""
39
+ total_files_found: int = 0
40
+ valid_trajectories: int = 0
41
+ duplicate_trajectories: int = 0
42
+ invalid_trajectories: int = 0
43
+ merged_trajectories: int = 0
44
+ archetypes: Dict[str, int] = None
45
+
46
+ def __post_init__(self):
47
+ if self.archetypes is None:
48
+ self.archetypes = {}
49
+
50
+ def record_archetype(self, archetype: str):
51
+ self.archetypes[archetype] = self.archetypes.get(archetype, 0) + 1
52
+
53
+
54
+ def generate_content_hash(data: Dict) -> str:
55
+ """Generate a hash of trajectory content for deduplication"""
56
+ # Use trajectory ID if available
57
+ if "trajectoryId" in data:
58
+ return data["trajectoryId"]
59
+
60
+ # Otherwise hash the content
61
+ content = json.dumps(data, sort_keys=True)
62
+ return hashlib.sha256(content.encode()).hexdigest()[:16]
63
+
64
+
65
+ def validate_trajectory(data: Dict) -> tuple[bool, List[str]]:
66
+ """
67
+ Validate trajectory data structure.
68
+
69
+ Returns:
70
+ (is_valid, list_of_issues)
71
+ """
72
+ issues = []
73
+
74
+ # Handle wrapped format
75
+ if "trajectory" in data:
76
+ data = data["trajectory"]
77
+
78
+ # Check required fields
79
+ required = ["trajectoryId", "agentId"]
80
+ for field in required:
81
+ if not data.get(field):
82
+ issues.append(f"Missing field: {field}")
83
+
84
+ # Check steps
85
+ steps = data.get("stepsJson", "[]")
86
+ if isinstance(steps, str):
87
+ try:
88
+ steps = json.loads(steps)
89
+ except json.JSONDecodeError:
90
+ issues.append("Invalid stepsJson")
91
+ steps = []
92
+
93
+ if len(steps) == 0:
94
+ issues.append("No steps in trajectory")
95
+
96
+ return len(issues) == 0, issues
97
+
98
+
99
+ def extract_archetype(data: Dict) -> str:
100
+ """Extract archetype from trajectory data"""
101
+ if "trajectory" in data:
102
+ data = data["trajectory"]
103
+
104
+ archetype = data.get("archetype")
105
+ if archetype and archetype != "default":
106
+ return archetype
107
+
108
+ # Try to extract from steps
109
+ steps = data.get("stepsJson", "[]")
110
+ if isinstance(steps, str):
111
+ try:
112
+ steps = json.loads(steps)
113
+ except json.JSONDecodeError:
114
+ return "default"
115
+
116
+ for step in steps:
117
+ action = step.get("action", {})
118
+ params = action.get("parameters", {})
119
+ if params.get("archetype"):
120
+ return params["archetype"]
121
+
122
+ return "default"
123
+
124
+
125
+ def find_trajectory_files(source_dir: Path) -> List[Path]:
126
+ """Find all trajectory JSON files in source directory"""
127
+ files = []
128
+
129
+ # Check for batch_N directories
130
+ batch_dirs = list(source_dir.glob("batch_*/trajectories"))
131
+
132
+ if batch_dirs:
133
+ for batch_dir in batch_dirs:
134
+ files.extend(batch_dir.glob("*.json"))
135
+ else:
136
+ # Check for direct trajectories directory
137
+ traj_dir = source_dir / "trajectories"
138
+ if traj_dir.exists():
139
+ files.extend(traj_dir.glob("*.json"))
140
+ else:
141
+ # Check source dir itself
142
+ files.extend(source_dir.glob("*.json"))
143
+
144
+ return sorted(files)
145
+
146
+
147
+ def merge_trajectories(
148
+ source_dir: Path,
149
+ output_dir: Path,
150
+ validate: bool = True,
151
+ dry_run: bool = False,
152
+ ) -> MergeStats:
153
+ """
154
+ Merge trajectories from multiple workers.
155
+
156
+ Args:
157
+ source_dir: Directory containing batch_N subdirectories
158
+ output_dir: Output directory for merged trajectories
159
+ validate: Whether to validate trajectories before merging
160
+ dry_run: If True, don't actually copy files
161
+
162
+ Returns:
163
+ Merge statistics
164
+ """
165
+ stats = MergeStats()
166
+ seen_hashes: Set[str] = set()
167
+
168
+ # Find all trajectory files
169
+ files = find_trajectory_files(source_dir)
170
+ stats.total_files_found = len(files)
171
+
172
+ if stats.total_files_found == 0:
173
+ logger.warning(f"No trajectory files found in {source_dir}")
174
+ return stats
175
+
176
+ logger.info(f"Found {stats.total_files_found} trajectory files")
177
+
178
+ # Create output directory
179
+ output_traj_dir = output_dir / "trajectories"
180
+ if not dry_run:
181
+ output_traj_dir.mkdir(parents=True, exist_ok=True)
182
+
183
+ # Process each file
184
+ for file_path in files:
185
+ try:
186
+ with open(file_path, "r", encoding="utf-8") as f:
187
+ data = json.load(f)
188
+ except json.JSONDecodeError as e:
189
+ logger.warning(f"Invalid JSON in {file_path}: {e}")
190
+ stats.invalid_trajectories += 1
191
+ continue
192
+
193
+ # Validate if requested
194
+ if validate:
195
+ is_valid, issues = validate_trajectory(data)
196
+ if not is_valid:
197
+ logger.debug(f"Invalid trajectory {file_path}: {issues}")
198
+ stats.invalid_trajectories += 1
199
+ continue
200
+
201
+ stats.valid_trajectories += 1
202
+
203
+ # Check for duplicates
204
+ content_hash = generate_content_hash(data)
205
+ if content_hash in seen_hashes:
206
+ stats.duplicate_trajectories += 1
207
+ continue
208
+ seen_hashes.add(content_hash)
209
+
210
+ # Record archetype
211
+ archetype = extract_archetype(data)
212
+ stats.record_archetype(archetype)
213
+
214
+ # Copy to output
215
+ if not dry_run:
216
+ output_file = output_traj_dir / file_path.name
217
+
218
+ # Handle name collisions
219
+ if output_file.exists():
220
+ base = file_path.stem
221
+ suffix = file_path.suffix
222
+ counter = 1
223
+ while output_file.exists():
224
+ output_file = output_traj_dir / f"{base}_{counter}{suffix}"
225
+ counter += 1
226
+
227
+ shutil.copy2(file_path, output_file)
228
+
229
+ stats.merged_trajectories += 1
230
+
231
+ return stats
232
+
233
+
234
+ def main():
235
+ parser = argparse.ArgumentParser(
236
+ description="Merge trajectories from multiple generation workers"
237
+ )
238
+ parser.add_argument(
239
+ "source_dir",
240
+ type=Path,
241
+ help="Source directory containing batch_N subdirectories"
242
+ )
243
+ parser.add_argument(
244
+ "--output", "-o",
245
+ type=Path,
246
+ default=None,
247
+ help="Output directory (default: source_dir/merged)"
248
+ )
249
+ parser.add_argument(
250
+ "--validate",
251
+ action="store_true",
252
+ help="Validate trajectories before merging"
253
+ )
254
+ parser.add_argument(
255
+ "--dry-run",
256
+ action="store_true",
257
+ help="Show what would be done without copying files"
258
+ )
259
+ parser.add_argument(
260
+ "--verbose", "-v",
261
+ action="store_true",
262
+ help="Verbose output"
263
+ )
264
+
265
+ args = parser.parse_args()
266
+
267
+ if args.verbose:
268
+ logging.getLogger().setLevel(logging.DEBUG)
269
+
270
+ if not args.source_dir.exists():
271
+ logger.error(f"Source directory not found: {args.source_dir}")
272
+ sys.exit(1)
273
+
274
+ output_dir = args.output or (args.source_dir / "merged")
275
+
276
+ if args.dry_run:
277
+ logger.info("DRY RUN MODE - No files will be copied")
278
+
279
+ logger.info(f"Merging trajectories from {args.source_dir}")
280
+ logger.info(f"Output directory: {output_dir}")
281
+
282
+ stats = merge_trajectories(
283
+ source_dir=args.source_dir,
284
+ output_dir=output_dir,
285
+ validate=args.validate,
286
+ dry_run=args.dry_run,
287
+ )
288
+
289
+ # Print summary
290
+ print("\n" + "=" * 50)
291
+ print("MERGE SUMMARY")
292
+ print("=" * 50)
293
+ print(f"Total files found: {stats.total_files_found}")
294
+ print(f"Valid trajectories: {stats.valid_trajectories}")
295
+ print(f"Invalid trajectories: {stats.invalid_trajectories}")
296
+ print(f"Duplicate trajectories: {stats.duplicate_trajectories}")
297
+ print(f"Merged trajectories: {stats.merged_trajectories}")
298
+
299
+ if stats.archetypes:
300
+ print("\nArchetype distribution:")
301
+ for archetype, count in sorted(stats.archetypes.items()):
302
+ pct = (count / stats.merged_trajectories * 100) if stats.merged_trajectories > 0 else 0
303
+ print(f" {archetype}: {count} ({pct:.1f}%)")
304
+
305
+ if not args.dry_run:
306
+ print(f"\nMerged trajectories saved to: {output_dir / 'trajectories'}")
307
+
308
+ if stats.invalid_trajectories > stats.total_files_found * 0.1:
309
+ logger.warning("More than 10% of trajectories are invalid")
310
+ sys.exit(1)
311
+
312
+ sys.exit(0)
313
+
314
+
315
+ if __name__ == "__main__":
316
+ main()
317
+
318
+
@@ -0,0 +1,143 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ A/B Test Runner for Babylon Agent Training
4
+
5
+ Compares a trained model against a baseline model using standardized evaluation scenarios.
6
+
7
+ Usage:
8
+ # Compare trained model against baseline
9
+ python scripts/run_ab_test.py \
10
+ --model-a Qwen/Qwen3-30B \
11
+ --model-b ./trained_models/final_model \
12
+ --archetypes trader degen
13
+
14
+ # Quick test with fewer scenarios
15
+ python scripts/run_ab_test.py \
16
+ --model-a Qwen/Qwen2.5-0.5B-Instruct \
17
+ --model-b ./trained_models/final_model \
18
+ --num-runs 1
19
+
20
+ # Full test suite
21
+ python scripts/run_ab_test.py \
22
+ --model-a Qwen/Qwen3-30B \
23
+ --model-b ./trained_models/final_model \
24
+ --num-runs 5 \
25
+ --output-dir ./ab_results
26
+ """
27
+
28
+ import argparse
29
+ import asyncio
30
+ import logging
31
+ import sys
32
+ from pathlib import Path
33
+
34
+ # Add src to path for imports
35
+ sys.path.insert(0, str(Path(__file__).parent.parent))
36
+
37
+ from src.training.ab_testing import ABTestRunner, EVAL_SCENARIOS
38
+
39
+ logging.basicConfig(
40
+ level=logging.INFO,
41
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
42
+ )
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ async def main():
47
+ parser = argparse.ArgumentParser(
48
+ description="Run A/B test comparing two models",
49
+ formatter_class=argparse.RawDescriptionHelpFormatter,
50
+ epilog=__doc__,
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--model-a",
55
+ required=True,
56
+ help="Baseline model (path or HuggingFace name)",
57
+ )
58
+ parser.add_argument(
59
+ "--model-b",
60
+ required=True,
61
+ help="Trained model to compare (path or HuggingFace name)",
62
+ )
63
+ parser.add_argument(
64
+ "--archetypes",
65
+ nargs="+",
66
+ choices=list(EVAL_SCENARIOS.keys()),
67
+ help="Archetypes to test (default: all)",
68
+ )
69
+ parser.add_argument(
70
+ "--num-runs",
71
+ type=int,
72
+ default=3,
73
+ help="Number of runs per scenario for statistical significance (default: 3)",
74
+ )
75
+ parser.add_argument(
76
+ "--vllm-url",
77
+ default="http://localhost:9001/v1",
78
+ help="vLLM server URL (default: http://localhost:9001/v1)",
79
+ )
80
+ parser.add_argument(
81
+ "--output-dir",
82
+ default="./ab_test_results",
83
+ help="Directory to save results (default: ./ab_test_results)",
84
+ )
85
+ parser.add_argument(
86
+ "--verbose", "-v",
87
+ action="store_true",
88
+ help="Enable verbose logging",
89
+ )
90
+
91
+ args = parser.parse_args()
92
+
93
+ if args.verbose:
94
+ logging.getLogger().setLevel(logging.DEBUG)
95
+
96
+ # Filter scenarios by archetype if specified
97
+ scenarios = EVAL_SCENARIOS
98
+ if args.archetypes:
99
+ scenarios = {k: v for k, v in EVAL_SCENARIOS.items() if k in args.archetypes}
100
+ logger.info(f"Testing archetypes: {', '.join(args.archetypes)}")
101
+ else:
102
+ logger.info(f"Testing all archetypes: {', '.join(EVAL_SCENARIOS.keys())}")
103
+
104
+ logger.info(f"Model A (baseline): {args.model_a}")
105
+ logger.info(f"Model B (trained): {args.model_b}")
106
+ logger.info(f"Runs per scenario: {args.num_runs}")
107
+
108
+ # Create runner
109
+ runner = ABTestRunner(
110
+ model_a=args.model_a,
111
+ model_b=args.model_b,
112
+ scenarios=scenarios,
113
+ vllm_url=args.vllm_url,
114
+ num_runs_per_scenario=args.num_runs,
115
+ output_dir=args.output_dir,
116
+ )
117
+
118
+ # Run tests
119
+ logger.info("Starting A/B test...")
120
+ result = await runner.run()
121
+
122
+ # Print summary
123
+ print()
124
+ print(result.summary())
125
+
126
+ # Return exit code based on result
127
+ if result.model_b_wins > result.model_a_wins:
128
+ logger.info("Trained model outperforms baseline!")
129
+ return 0
130
+ elif result.model_a_wins > result.model_b_wins:
131
+ logger.warning("Baseline model outperforms trained model")
132
+ return 1
133
+ else:
134
+ logger.info("Models performed equally")
135
+ return 0
136
+
137
+
138
+ if __name__ == "__main__":
139
+ exit_code = asyncio.run(main())
140
+ sys.exit(exit_code)
141
+
142
+
143
+