@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Merge Trajectories from Multiple Workers
|
|
4
|
+
|
|
5
|
+
Combines trajectory files from parallel generation workers into a single
|
|
6
|
+
output directory, handling deduplication and validation.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python scripts/merge_trajectories.py ./training-data-output
|
|
10
|
+
python scripts/merge_trajectories.py ./training-data-output --output ./merged
|
|
11
|
+
python scripts/merge_trajectories.py ./training-data-output --validate
|
|
12
|
+
|
|
13
|
+
Requirements:
|
|
14
|
+
- Trajectory JSON files from generate_dataset.sh
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import hashlib
|
|
19
|
+
import json
|
|
20
|
+
import logging
|
|
21
|
+
import os
|
|
22
|
+
import shutil
|
|
23
|
+
import sys
|
|
24
|
+
from dataclasses import dataclass
|
|
25
|
+
from datetime import datetime
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Dict, List, Set
|
|
28
|
+
|
|
29
|
+
logging.basicConfig(
|
|
30
|
+
level=logging.INFO,
|
|
31
|
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
|
32
|
+
)
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class MergeStats:
|
|
38
|
+
"""Statistics for merge operation"""
|
|
39
|
+
total_files_found: int = 0
|
|
40
|
+
valid_trajectories: int = 0
|
|
41
|
+
duplicate_trajectories: int = 0
|
|
42
|
+
invalid_trajectories: int = 0
|
|
43
|
+
merged_trajectories: int = 0
|
|
44
|
+
archetypes: Dict[str, int] = None
|
|
45
|
+
|
|
46
|
+
def __post_init__(self):
|
|
47
|
+
if self.archetypes is None:
|
|
48
|
+
self.archetypes = {}
|
|
49
|
+
|
|
50
|
+
def record_archetype(self, archetype: str):
|
|
51
|
+
self.archetypes[archetype] = self.archetypes.get(archetype, 0) + 1
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def generate_content_hash(data: Dict) -> str:
|
|
55
|
+
"""Generate a hash of trajectory content for deduplication"""
|
|
56
|
+
# Use trajectory ID if available
|
|
57
|
+
if "trajectoryId" in data:
|
|
58
|
+
return data["trajectoryId"]
|
|
59
|
+
|
|
60
|
+
# Otherwise hash the content
|
|
61
|
+
content = json.dumps(data, sort_keys=True)
|
|
62
|
+
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def validate_trajectory(data: Dict) -> tuple[bool, List[str]]:
|
|
66
|
+
"""
|
|
67
|
+
Validate trajectory data structure.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
(is_valid, list_of_issues)
|
|
71
|
+
"""
|
|
72
|
+
issues = []
|
|
73
|
+
|
|
74
|
+
# Handle wrapped format
|
|
75
|
+
if "trajectory" in data:
|
|
76
|
+
data = data["trajectory"]
|
|
77
|
+
|
|
78
|
+
# Check required fields
|
|
79
|
+
required = ["trajectoryId", "agentId"]
|
|
80
|
+
for field in required:
|
|
81
|
+
if not data.get(field):
|
|
82
|
+
issues.append(f"Missing field: {field}")
|
|
83
|
+
|
|
84
|
+
# Check steps
|
|
85
|
+
steps = data.get("stepsJson", "[]")
|
|
86
|
+
if isinstance(steps, str):
|
|
87
|
+
try:
|
|
88
|
+
steps = json.loads(steps)
|
|
89
|
+
except json.JSONDecodeError:
|
|
90
|
+
issues.append("Invalid stepsJson")
|
|
91
|
+
steps = []
|
|
92
|
+
|
|
93
|
+
if len(steps) == 0:
|
|
94
|
+
issues.append("No steps in trajectory")
|
|
95
|
+
|
|
96
|
+
return len(issues) == 0, issues
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def extract_archetype(data: Dict) -> str:
|
|
100
|
+
"""Extract archetype from trajectory data"""
|
|
101
|
+
if "trajectory" in data:
|
|
102
|
+
data = data["trajectory"]
|
|
103
|
+
|
|
104
|
+
archetype = data.get("archetype")
|
|
105
|
+
if archetype and archetype != "default":
|
|
106
|
+
return archetype
|
|
107
|
+
|
|
108
|
+
# Try to extract from steps
|
|
109
|
+
steps = data.get("stepsJson", "[]")
|
|
110
|
+
if isinstance(steps, str):
|
|
111
|
+
try:
|
|
112
|
+
steps = json.loads(steps)
|
|
113
|
+
except json.JSONDecodeError:
|
|
114
|
+
return "default"
|
|
115
|
+
|
|
116
|
+
for step in steps:
|
|
117
|
+
action = step.get("action", {})
|
|
118
|
+
params = action.get("parameters", {})
|
|
119
|
+
if params.get("archetype"):
|
|
120
|
+
return params["archetype"]
|
|
121
|
+
|
|
122
|
+
return "default"
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def find_trajectory_files(source_dir: Path) -> List[Path]:
|
|
126
|
+
"""Find all trajectory JSON files in source directory"""
|
|
127
|
+
files = []
|
|
128
|
+
|
|
129
|
+
# Check for batch_N directories
|
|
130
|
+
batch_dirs = list(source_dir.glob("batch_*/trajectories"))
|
|
131
|
+
|
|
132
|
+
if batch_dirs:
|
|
133
|
+
for batch_dir in batch_dirs:
|
|
134
|
+
files.extend(batch_dir.glob("*.json"))
|
|
135
|
+
else:
|
|
136
|
+
# Check for direct trajectories directory
|
|
137
|
+
traj_dir = source_dir / "trajectories"
|
|
138
|
+
if traj_dir.exists():
|
|
139
|
+
files.extend(traj_dir.glob("*.json"))
|
|
140
|
+
else:
|
|
141
|
+
# Check source dir itself
|
|
142
|
+
files.extend(source_dir.glob("*.json"))
|
|
143
|
+
|
|
144
|
+
return sorted(files)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def merge_trajectories(
|
|
148
|
+
source_dir: Path,
|
|
149
|
+
output_dir: Path,
|
|
150
|
+
validate: bool = True,
|
|
151
|
+
dry_run: bool = False,
|
|
152
|
+
) -> MergeStats:
|
|
153
|
+
"""
|
|
154
|
+
Merge trajectories from multiple workers.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
source_dir: Directory containing batch_N subdirectories
|
|
158
|
+
output_dir: Output directory for merged trajectories
|
|
159
|
+
validate: Whether to validate trajectories before merging
|
|
160
|
+
dry_run: If True, don't actually copy files
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
Merge statistics
|
|
164
|
+
"""
|
|
165
|
+
stats = MergeStats()
|
|
166
|
+
seen_hashes: Set[str] = set()
|
|
167
|
+
|
|
168
|
+
# Find all trajectory files
|
|
169
|
+
files = find_trajectory_files(source_dir)
|
|
170
|
+
stats.total_files_found = len(files)
|
|
171
|
+
|
|
172
|
+
if stats.total_files_found == 0:
|
|
173
|
+
logger.warning(f"No trajectory files found in {source_dir}")
|
|
174
|
+
return stats
|
|
175
|
+
|
|
176
|
+
logger.info(f"Found {stats.total_files_found} trajectory files")
|
|
177
|
+
|
|
178
|
+
# Create output directory
|
|
179
|
+
output_traj_dir = output_dir / "trajectories"
|
|
180
|
+
if not dry_run:
|
|
181
|
+
output_traj_dir.mkdir(parents=True, exist_ok=True)
|
|
182
|
+
|
|
183
|
+
# Process each file
|
|
184
|
+
for file_path in files:
|
|
185
|
+
try:
|
|
186
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
187
|
+
data = json.load(f)
|
|
188
|
+
except json.JSONDecodeError as e:
|
|
189
|
+
logger.warning(f"Invalid JSON in {file_path}: {e}")
|
|
190
|
+
stats.invalid_trajectories += 1
|
|
191
|
+
continue
|
|
192
|
+
|
|
193
|
+
# Validate if requested
|
|
194
|
+
if validate:
|
|
195
|
+
is_valid, issues = validate_trajectory(data)
|
|
196
|
+
if not is_valid:
|
|
197
|
+
logger.debug(f"Invalid trajectory {file_path}: {issues}")
|
|
198
|
+
stats.invalid_trajectories += 1
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
stats.valid_trajectories += 1
|
|
202
|
+
|
|
203
|
+
# Check for duplicates
|
|
204
|
+
content_hash = generate_content_hash(data)
|
|
205
|
+
if content_hash in seen_hashes:
|
|
206
|
+
stats.duplicate_trajectories += 1
|
|
207
|
+
continue
|
|
208
|
+
seen_hashes.add(content_hash)
|
|
209
|
+
|
|
210
|
+
# Record archetype
|
|
211
|
+
archetype = extract_archetype(data)
|
|
212
|
+
stats.record_archetype(archetype)
|
|
213
|
+
|
|
214
|
+
# Copy to output
|
|
215
|
+
if not dry_run:
|
|
216
|
+
output_file = output_traj_dir / file_path.name
|
|
217
|
+
|
|
218
|
+
# Handle name collisions
|
|
219
|
+
if output_file.exists():
|
|
220
|
+
base = file_path.stem
|
|
221
|
+
suffix = file_path.suffix
|
|
222
|
+
counter = 1
|
|
223
|
+
while output_file.exists():
|
|
224
|
+
output_file = output_traj_dir / f"{base}_{counter}{suffix}"
|
|
225
|
+
counter += 1
|
|
226
|
+
|
|
227
|
+
shutil.copy2(file_path, output_file)
|
|
228
|
+
|
|
229
|
+
stats.merged_trajectories += 1
|
|
230
|
+
|
|
231
|
+
return stats
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def main():
|
|
235
|
+
parser = argparse.ArgumentParser(
|
|
236
|
+
description="Merge trajectories from multiple generation workers"
|
|
237
|
+
)
|
|
238
|
+
parser.add_argument(
|
|
239
|
+
"source_dir",
|
|
240
|
+
type=Path,
|
|
241
|
+
help="Source directory containing batch_N subdirectories"
|
|
242
|
+
)
|
|
243
|
+
parser.add_argument(
|
|
244
|
+
"--output", "-o",
|
|
245
|
+
type=Path,
|
|
246
|
+
default=None,
|
|
247
|
+
help="Output directory (default: source_dir/merged)"
|
|
248
|
+
)
|
|
249
|
+
parser.add_argument(
|
|
250
|
+
"--validate",
|
|
251
|
+
action="store_true",
|
|
252
|
+
help="Validate trajectories before merging"
|
|
253
|
+
)
|
|
254
|
+
parser.add_argument(
|
|
255
|
+
"--dry-run",
|
|
256
|
+
action="store_true",
|
|
257
|
+
help="Show what would be done without copying files"
|
|
258
|
+
)
|
|
259
|
+
parser.add_argument(
|
|
260
|
+
"--verbose", "-v",
|
|
261
|
+
action="store_true",
|
|
262
|
+
help="Verbose output"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
args = parser.parse_args()
|
|
266
|
+
|
|
267
|
+
if args.verbose:
|
|
268
|
+
logging.getLogger().setLevel(logging.DEBUG)
|
|
269
|
+
|
|
270
|
+
if not args.source_dir.exists():
|
|
271
|
+
logger.error(f"Source directory not found: {args.source_dir}")
|
|
272
|
+
sys.exit(1)
|
|
273
|
+
|
|
274
|
+
output_dir = args.output or (args.source_dir / "merged")
|
|
275
|
+
|
|
276
|
+
if args.dry_run:
|
|
277
|
+
logger.info("DRY RUN MODE - No files will be copied")
|
|
278
|
+
|
|
279
|
+
logger.info(f"Merging trajectories from {args.source_dir}")
|
|
280
|
+
logger.info(f"Output directory: {output_dir}")
|
|
281
|
+
|
|
282
|
+
stats = merge_trajectories(
|
|
283
|
+
source_dir=args.source_dir,
|
|
284
|
+
output_dir=output_dir,
|
|
285
|
+
validate=args.validate,
|
|
286
|
+
dry_run=args.dry_run,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Print summary
|
|
290
|
+
print("\n" + "=" * 50)
|
|
291
|
+
print("MERGE SUMMARY")
|
|
292
|
+
print("=" * 50)
|
|
293
|
+
print(f"Total files found: {stats.total_files_found}")
|
|
294
|
+
print(f"Valid trajectories: {stats.valid_trajectories}")
|
|
295
|
+
print(f"Invalid trajectories: {stats.invalid_trajectories}")
|
|
296
|
+
print(f"Duplicate trajectories: {stats.duplicate_trajectories}")
|
|
297
|
+
print(f"Merged trajectories: {stats.merged_trajectories}")
|
|
298
|
+
|
|
299
|
+
if stats.archetypes:
|
|
300
|
+
print("\nArchetype distribution:")
|
|
301
|
+
for archetype, count in sorted(stats.archetypes.items()):
|
|
302
|
+
pct = (count / stats.merged_trajectories * 100) if stats.merged_trajectories > 0 else 0
|
|
303
|
+
print(f" {archetype}: {count} ({pct:.1f}%)")
|
|
304
|
+
|
|
305
|
+
if not args.dry_run:
|
|
306
|
+
print(f"\nMerged trajectories saved to: {output_dir / 'trajectories'}")
|
|
307
|
+
|
|
308
|
+
if stats.invalid_trajectories > stats.total_files_found * 0.1:
|
|
309
|
+
logger.warning("More than 10% of trajectories are invalid")
|
|
310
|
+
sys.exit(1)
|
|
311
|
+
|
|
312
|
+
sys.exit(0)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
if __name__ == "__main__":
|
|
316
|
+
main()
|
|
317
|
+
|
|
318
|
+
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
A/B Test Runner for Babylon Agent Training
|
|
4
|
+
|
|
5
|
+
Compares a trained model against a baseline model using standardized evaluation scenarios.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
# Compare trained model against baseline
|
|
9
|
+
python scripts/run_ab_test.py \
|
|
10
|
+
--model-a Qwen/Qwen3-30B \
|
|
11
|
+
--model-b ./trained_models/final_model \
|
|
12
|
+
--archetypes trader degen
|
|
13
|
+
|
|
14
|
+
# Quick test with fewer scenarios
|
|
15
|
+
python scripts/run_ab_test.py \
|
|
16
|
+
--model-a Qwen/Qwen2.5-0.5B-Instruct \
|
|
17
|
+
--model-b ./trained_models/final_model \
|
|
18
|
+
--num-runs 1
|
|
19
|
+
|
|
20
|
+
# Full test suite
|
|
21
|
+
python scripts/run_ab_test.py \
|
|
22
|
+
--model-a Qwen/Qwen3-30B \
|
|
23
|
+
--model-b ./trained_models/final_model \
|
|
24
|
+
--num-runs 5 \
|
|
25
|
+
--output-dir ./ab_results
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
import argparse
|
|
29
|
+
import asyncio
|
|
30
|
+
import logging
|
|
31
|
+
import sys
|
|
32
|
+
from pathlib import Path
|
|
33
|
+
|
|
34
|
+
# Add src to path for imports
|
|
35
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
36
|
+
|
|
37
|
+
from src.training.ab_testing import ABTestRunner, EVAL_SCENARIOS
|
|
38
|
+
|
|
39
|
+
logging.basicConfig(
|
|
40
|
+
level=logging.INFO,
|
|
41
|
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
42
|
+
)
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
async def main():
|
|
47
|
+
parser = argparse.ArgumentParser(
|
|
48
|
+
description="Run A/B test comparing two models",
|
|
49
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
50
|
+
epilog=__doc__,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
"--model-a",
|
|
55
|
+
required=True,
|
|
56
|
+
help="Baseline model (path or HuggingFace name)",
|
|
57
|
+
)
|
|
58
|
+
parser.add_argument(
|
|
59
|
+
"--model-b",
|
|
60
|
+
required=True,
|
|
61
|
+
help="Trained model to compare (path or HuggingFace name)",
|
|
62
|
+
)
|
|
63
|
+
parser.add_argument(
|
|
64
|
+
"--archetypes",
|
|
65
|
+
nargs="+",
|
|
66
|
+
choices=list(EVAL_SCENARIOS.keys()),
|
|
67
|
+
help="Archetypes to test (default: all)",
|
|
68
|
+
)
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"--num-runs",
|
|
71
|
+
type=int,
|
|
72
|
+
default=3,
|
|
73
|
+
help="Number of runs per scenario for statistical significance (default: 3)",
|
|
74
|
+
)
|
|
75
|
+
parser.add_argument(
|
|
76
|
+
"--vllm-url",
|
|
77
|
+
default="http://localhost:9001/v1",
|
|
78
|
+
help="vLLM server URL (default: http://localhost:9001/v1)",
|
|
79
|
+
)
|
|
80
|
+
parser.add_argument(
|
|
81
|
+
"--output-dir",
|
|
82
|
+
default="./ab_test_results",
|
|
83
|
+
help="Directory to save results (default: ./ab_test_results)",
|
|
84
|
+
)
|
|
85
|
+
parser.add_argument(
|
|
86
|
+
"--verbose", "-v",
|
|
87
|
+
action="store_true",
|
|
88
|
+
help="Enable verbose logging",
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
args = parser.parse_args()
|
|
92
|
+
|
|
93
|
+
if args.verbose:
|
|
94
|
+
logging.getLogger().setLevel(logging.DEBUG)
|
|
95
|
+
|
|
96
|
+
# Filter scenarios by archetype if specified
|
|
97
|
+
scenarios = EVAL_SCENARIOS
|
|
98
|
+
if args.archetypes:
|
|
99
|
+
scenarios = {k: v for k, v in EVAL_SCENARIOS.items() if k in args.archetypes}
|
|
100
|
+
logger.info(f"Testing archetypes: {', '.join(args.archetypes)}")
|
|
101
|
+
else:
|
|
102
|
+
logger.info(f"Testing all archetypes: {', '.join(EVAL_SCENARIOS.keys())}")
|
|
103
|
+
|
|
104
|
+
logger.info(f"Model A (baseline): {args.model_a}")
|
|
105
|
+
logger.info(f"Model B (trained): {args.model_b}")
|
|
106
|
+
logger.info(f"Runs per scenario: {args.num_runs}")
|
|
107
|
+
|
|
108
|
+
# Create runner
|
|
109
|
+
runner = ABTestRunner(
|
|
110
|
+
model_a=args.model_a,
|
|
111
|
+
model_b=args.model_b,
|
|
112
|
+
scenarios=scenarios,
|
|
113
|
+
vllm_url=args.vllm_url,
|
|
114
|
+
num_runs_per_scenario=args.num_runs,
|
|
115
|
+
output_dir=args.output_dir,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Run tests
|
|
119
|
+
logger.info("Starting A/B test...")
|
|
120
|
+
result = await runner.run()
|
|
121
|
+
|
|
122
|
+
# Print summary
|
|
123
|
+
print()
|
|
124
|
+
print(result.summary())
|
|
125
|
+
|
|
126
|
+
# Return exit code based on result
|
|
127
|
+
if result.model_b_wins > result.model_a_wins:
|
|
128
|
+
logger.info("Trained model outperforms baseline!")
|
|
129
|
+
return 0
|
|
130
|
+
elif result.model_a_wins > result.model_b_wins:
|
|
131
|
+
logger.warning("Baseline model outperforms trained model")
|
|
132
|
+
return 1
|
|
133
|
+
else:
|
|
134
|
+
logger.info("Models performed equally")
|
|
135
|
+
return 0
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
if __name__ == "__main__":
|
|
139
|
+
exit_code = asyncio.run(main())
|
|
140
|
+
sys.exit(exit_code)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
|