@elizaos/training 2.0.0-alpha.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. package/Dockerfile +75 -0
  2. package/Makefile +374 -0
  3. package/README.md +346 -0
  4. package/config/rubrics.json +137 -0
  5. package/data/.gitkeep +0 -0
  6. package/data/degen/.gitkeep +2 -0
  7. package/data/trader/.gitkeep +2 -0
  8. package/docker-compose.test.yml +57 -0
  9. package/package.json +58 -0
  10. package/python/config/babylon_atropos.yaml +90 -0
  11. package/python/config/profiles/12gb.json +11 -0
  12. package/python/config/profiles/16gb.json +10 -0
  13. package/python/config/profiles/24gb.json +10 -0
  14. package/python/config/profiles/48gb.json +10 -0
  15. package/python/config/profiles/cpu.json +11 -0
  16. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  17. package/python/config/profiles/l40-2gpu.json +22 -0
  18. package/python/config/profiles/l40-4gpu.json +21 -0
  19. package/python/config/profiles/l40.json +17 -0
  20. package/python/config/tinker_training.yaml +143 -0
  21. package/python/curriculum_state.json +165 -0
  22. package/python/env.template +86 -0
  23. package/python/env.training.template +46 -0
  24. package/python/pyproject.toml +41 -0
  25. package/python/requirements-ci.txt +31 -0
  26. package/python/requirements.txt +87 -0
  27. package/python/scripts/__init__.py +4 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/run_ab_test.py +143 -0
  36. package/python/scripts/run_full_pipeline.py +544 -0
  37. package/python/scripts/run_tinker_training.py +192 -0
  38. package/python/scripts/run_training.py +914 -0
  39. package/python/scripts/test_judge.py +155 -0
  40. package/python/scripts/test_pipeline.py +356 -0
  41. package/python/scripts/test_trained_model.py +380 -0
  42. package/python/scripts/train_local.py +528 -0
  43. package/python/setup.py +20 -0
  44. package/python/src/__init__.py +190 -0
  45. package/python/src/data_bridge/__init__.py +24 -0
  46. package/python/src/data_bridge/converter.py +435 -0
  47. package/python/src/data_bridge/reader.py +393 -0
  48. package/python/src/models.py +283 -0
  49. package/python/src/training/__init__.py +605 -0
  50. package/python/src/training/ab_testing.py +404 -0
  51. package/python/src/training/action_executor.py +621 -0
  52. package/python/src/training/archetype_trainer.py +347 -0
  53. package/python/src/training/atropos_trainer.py +980 -0
  54. package/python/src/training/babylon_env.py +1254 -0
  55. package/python/src/training/error_recovery.py +647 -0
  56. package/python/src/training/evaluation.py +856 -0
  57. package/python/src/training/fast_simulator.py +880 -0
  58. package/python/src/training/format_validator.py +584 -0
  59. package/python/src/training/hybrid_env.py +522 -0
  60. package/python/src/training/kl_controller.py +628 -0
  61. package/python/src/training/multi_prompt_dataset.py +883 -0
  62. package/python/src/training/multi_turn.py +656 -0
  63. package/python/src/training/online_env.py +1084 -0
  64. package/python/src/training/quality_scorer.py +391 -0
  65. package/python/src/training/quality_utils.py +633 -0
  66. package/python/src/training/rewards.py +1344 -0
  67. package/python/src/training/rlaif_env.py +17 -0
  68. package/python/src/training/rollout_generator.py +502 -0
  69. package/python/src/training/rubric_loader.py +198 -0
  70. package/python/src/training/scenario_pool.py +1072 -0
  71. package/python/src/training/schemas.py +481 -0
  72. package/python/src/training/service_manager.py +552 -0
  73. package/python/src/training/simulation_bridge.py +535 -0
  74. package/python/src/training/tick_reward_attribution.py +399 -0
  75. package/python/src/training/tinker_client.py +575 -0
  76. package/python/src/training/tinker_trainer.py +646 -0
  77. package/python/src/training/tokenization_utils.py +402 -0
  78. package/python/tests/e2e/__init__.py +13 -0
  79. package/python/tests/e2e/conftest.py +258 -0
  80. package/python/tests/e2e/test_full_pipeline.py +643 -0
  81. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  82. package/python/tests/integration/__init__.py +12 -0
  83. package/python/tests/integration/conftest.py +383 -0
  84. package/python/tests/integration/test_db_integration.py +649 -0
  85. package/python/tests/integration/test_json_mode_integration.py +554 -0
  86. package/python/tests/test_action_executor.py +594 -0
  87. package/python/tests/test_archetype_scoring.py +1027 -0
  88. package/python/tests/test_atropos_integration.py +360 -0
  89. package/python/tests/test_evaluation.py +727 -0
  90. package/python/tests/test_format_validator.py +486 -0
  91. package/python/tests/test_kl_controller.py +432 -0
  92. package/python/tests/test_lr_scheduler.py +579 -0
  93. package/python/tests/test_multi_turn.py +590 -0
  94. package/python/tests/test_online_env.py +519 -0
  95. package/python/tests/test_quality_scorer.py +474 -0
  96. package/python/tests/test_scenario_pool.py +735 -0
  97. package/python/tests/test_service_manager.py +585 -0
  98. package/python/tests/test_simulation_rollout.py +581 -0
  99. package/python/tests/test_tokenization_utils.py +501 -0
  100. package/python/tests/test_training_orchestrator.py +497 -0
  101. package/python/tests/test_training_output_structure.py +661 -0
  102. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  103. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  104. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  105. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  106. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  107. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  108. package/scripts/assess-training-data.ts +422 -0
  109. package/scripts/e2e-training-test.ts +550 -0
  110. package/scripts/export-rubrics.ts +64 -0
  111. package/scripts/generate-research-report.ts +1523 -0
  112. package/scripts/generate_dataset.sh +173 -0
  113. package/scripts/json-mode-benchmark.ts +399 -0
  114. package/scripts/real-archetype-benchmark.ts +210 -0
  115. package/scripts/run-baseline-comparison.ts +116 -0
  116. package/scripts/run-full-pipeline.ts +272 -0
  117. package/scripts/runpod_setup.sh +137 -0
  118. package/scripts/runpod_validate.sh +147 -0
  119. package/scripts/test-model-in-game.ts +955 -0
  120. package/scripts/test-scoring.ts +73 -0
  121. package/scripts/test-trained-model.ts +209 -0
  122. package/scripts/train-and-test.ts +824 -0
  123. package/scripts/verify-final.ts +118 -0
  124. package/src/adapter.ts +516 -0
  125. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  126. package/src/archetypes/derive-archetype.ts +249 -0
  127. package/src/archetypes/index.ts +22 -0
  128. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  129. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  130. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  131. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  132. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  133. package/src/benchmark/BenchmarkRunner.ts +685 -0
  134. package/src/benchmark/BenchmarkValidator.ts +206 -0
  135. package/src/benchmark/FastEvalRunner.ts +225 -0
  136. package/src/benchmark/MetricsValidator.ts +165 -0
  137. package/src/benchmark/MetricsVisualizer.ts +909 -0
  138. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  139. package/src/benchmark/ModelRegistry.ts +158 -0
  140. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  141. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  142. package/src/benchmark/SimulationEngine.ts +832 -0
  143. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  144. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  145. package/src/benchmark/index.ts +89 -0
  146. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  147. package/src/benchmark/simulation-types.ts +78 -0
  148. package/src/dependencies.ts +439 -0
  149. package/src/generation/TrajectoryGenerator.ts +387 -0
  150. package/src/generation/index.ts +12 -0
  151. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  152. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  153. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  154. package/src/huggingface/index.ts +27 -0
  155. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  156. package/src/index.ts +102 -0
  157. package/src/init-training.ts +53 -0
  158. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  159. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  160. package/src/metrics/index.ts +8 -0
  161. package/src/metrics/types.ts +200 -0
  162. package/src/rubrics/__tests__/index.test.ts +184 -0
  163. package/src/rubrics/ass-kisser.ts +85 -0
  164. package/src/rubrics/degen.ts +80 -0
  165. package/src/rubrics/goody-twoshoes.ts +84 -0
  166. package/src/rubrics/index.ts +236 -0
  167. package/src/rubrics/information-trader.ts +84 -0
  168. package/src/rubrics/infosec.ts +101 -0
  169. package/src/rubrics/liar.ts +104 -0
  170. package/src/rubrics/perps-trader.ts +87 -0
  171. package/src/rubrics/researcher.ts +81 -0
  172. package/src/rubrics/scammer.ts +82 -0
  173. package/src/rubrics/social-butterfly.ts +73 -0
  174. package/src/rubrics/super-predictor.ts +97 -0
  175. package/src/rubrics/trader.ts +67 -0
  176. package/src/scoring/ArchetypeScoringService.ts +486 -0
  177. package/src/scoring/JudgePromptBuilder.ts +556 -0
  178. package/src/scoring/LLMJudgeCache.ts +401 -0
  179. package/src/scoring/index.ts +9 -0
  180. package/src/training/AutomationPipeline.ts +916 -0
  181. package/src/training/BenchmarkService.ts +518 -0
  182. package/src/training/ConfigValidator.ts +220 -0
  183. package/src/training/MarketOutcomesTracker.ts +187 -0
  184. package/src/training/ModelDeployer.ts +186 -0
  185. package/src/training/ModelFetcher.ts +76 -0
  186. package/src/training/ModelSelectionService.ts +341 -0
  187. package/src/training/ModelUsageVerifier.ts +160 -0
  188. package/src/training/MultiModelOrchestrator.ts +580 -0
  189. package/src/training/RLModelConfig.ts +407 -0
  190. package/src/training/RewardBackpropagationService.ts +149 -0
  191. package/src/training/RulerScoringService.ts +666 -0
  192. package/src/training/TrainingMonitor.ts +166 -0
  193. package/src/training/TrajectoryRecorder.ts +399 -0
  194. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  195. package/src/training/index.ts +100 -0
  196. package/src/training/logRLConfig.ts +34 -0
  197. package/src/training/pipeline.ts +129 -0
  198. package/src/training/storage/ModelStorageService.ts +279 -0
  199. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  200. package/src/training/storage/index.ts +17 -0
  201. package/src/training/types.ts +207 -0
  202. package/src/training/window-utils.ts +138 -0
  203. package/src/utils/index.ts +101 -0
  204. package/src/utils/logger.ts +59 -0
  205. package/src/utils/snowflake.ts +17 -0
  206. package/src/utils/synthetic-detector.ts +111 -0
  207. package/tsconfig.json +20 -0
@@ -0,0 +1,41 @@
1
+ [project]
2
+ name = "elizaos-training"
3
+ version = "0.2.0"
4
+ description = "RL training pipeline for ElizaOS agents using Atropos"
5
+ authors = [
6
+ {name = "ElizaOS Contributors"}
7
+ ]
8
+ requires-python = ">=3.10"
9
+ dependencies = [
10
+ "atroposlib>=0.1.0",
11
+ "asyncpg>=0.29.0",
12
+ "python-dotenv>=1.0.0",
13
+ "pydantic>=2.0.0",
14
+ "numpy>=1.24.0",
15
+ "openai>=1.0.0",
16
+ "litellm>=1.0.0",
17
+ "tabulate>=0.9.0",
18
+ "vllm>=0.4.0",
19
+ "transformers>=4.40.0",
20
+ ]
21
+
22
+ [project.optional-dependencies]
23
+ dev = [
24
+ "pytest>=7.0.0",
25
+ "pytest-asyncio>=0.21.0",
26
+ "black>=23.0.0",
27
+ "ruff>=0.1.0",
28
+ ]
29
+
30
+ [project.scripts]
31
+ train-mmo = "scripts.train_mmo:main"
32
+ check-windows = "scripts.check_windows:main"
33
+ run-migrations = "scripts.run_migrations:main"
34
+
35
+ [tool.ruff]
36
+ line-length = 100
37
+ target-version = "py310"
38
+
39
+ [tool.black]
40
+ line-length = 100
41
+ target-version = ["py310"]
@@ -0,0 +1,31 @@
1
+ # Lightweight dependencies for CI test runs (avoids GPU / vLLM installs)
2
+ atroposlib>=0.3.0
3
+
4
+ # Database (integration tests)
5
+ asyncpg>=0.29.0
6
+ psycopg2-binary>=2.9.9
7
+
8
+ # HTTP/API
9
+ httpx>=0.26.0
10
+ aiohttp>=3.9.0
11
+ requests>=2.31.0
12
+ openai>=1.0.0
13
+
14
+ # Config / typing
15
+ python-dotenv>=1.0.0
16
+ pydantic>=2.5.0
17
+ pyyaml>=6.0.1
18
+
19
+ # Testing
20
+ pytest>=7.4.0
21
+ pytest-asyncio>=0.21.0
22
+
23
+ # Utilities used by the training pipeline
24
+ wandb>=0.16.0
25
+ tqdm>=4.66.0
26
+ psutil>=5.9.0
27
+ numpy>=1.24.0
28
+ tenacity>=8.2.0
29
+ rich>=13.0.0
30
+ jsonlines>=4.0.0
31
+
@@ -0,0 +1,87 @@
1
+ # Babylon RL Training - Atropos + Tinker Framework
2
+ # Supports: Tinker Cloud (RECOMMENDED), CUDA (GPU), MLX (Apple Silicon), CPU
3
+
4
+ # ===========================================
5
+ # Tinker API (RECOMMENDED - Cloud Training)
6
+ # ===========================================
7
+ # Tinker provides cloud-based training without local GPU requirements
8
+ # Get API key from: https://tinker-docs.thinkingmachines.ai/
9
+ tinker>=0.1.0
10
+
11
+ # ===========================================
12
+ # Core Atropos Framework (for environments)
13
+ # ===========================================
14
+ atroposlib>=0.3.0
15
+
16
+ # ===========================================
17
+ # Database
18
+ # ===========================================
19
+ asyncpg>=0.29.0
20
+ psycopg2-binary>=2.9.9
21
+
22
+ # ===========================================
23
+ # HTTP/API
24
+ # ===========================================
25
+ httpx>=0.26.0
26
+ aiohttp>=3.9.0
27
+ requests>=2.31.0
28
+
29
+ # ===========================================
30
+ # OpenAI-compatible client (for RLAIF judge)
31
+ # ===========================================
32
+ openai>=1.0.0
33
+
34
+ # ===========================================
35
+ # Configuration
36
+ # ===========================================
37
+ pyyaml>=6.0.1
38
+ python-dotenv>=1.0.0
39
+ pydantic>=2.5.0
40
+ pydantic-cli>=3.0.0
41
+
42
+ # ===========================================
43
+ # Testing
44
+ # ===========================================
45
+ pytest>=7.4.0
46
+ pytest-asyncio>=0.21.0
47
+
48
+ # ===========================================
49
+ # Experiment Tracking (Optional)
50
+ # ===========================================
51
+ wandb>=0.16.0 # W&B integration, falls back to offline mode if no API key
52
+
53
+ # ===========================================
54
+ # Utilities
55
+ # ===========================================
56
+ tqdm>=4.66.0
57
+ psutil>=5.9.0
58
+ numpy>=1.24.0
59
+ tenacity>=8.2.0
60
+ rich>=13.0.0
61
+ jsonlines>=4.0.0
62
+
63
+ # ============================================
64
+ # OPTIONAL: Local Training (GPU/CPU)
65
+ # ============================================
66
+ # Only needed if NOT using Tinker for training
67
+ # Uncomment if you need local fallback:
68
+ #
69
+ # torch>=2.1.0
70
+ # transformers>=4.36.0
71
+ # peft>=0.8.0
72
+ # vllm>=0.3.0
73
+ # accelerate>=1.12.0
74
+
75
+ # ============================================
76
+ # OPTIONAL: MLX Backend (Apple Silicon only)
77
+ # ============================================
78
+ # Install on Mac with Apple Silicon:
79
+ # pip install mlx mlx-lm
80
+ #
81
+ # For fine-tuning support:
82
+ # pip install mlx-lm[finetuning]
83
+ #
84
+ # Recommended models for MLX:
85
+ # - mlx-community/Qwen2.5-3B-Instruct-4bit
86
+ # - mlx-community/Qwen2.5-7B-Instruct-4bit
87
+ # - mlx-community/Qwen3-4B-4bit
@@ -0,0 +1,4 @@
1
+ """Training scripts for Babylon RL system"""
2
+
3
+
4
+
@@ -0,0 +1,412 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Import JSON Trajectories to PostgreSQL
4
+
5
+ Bridges JSON-mode generated trajectories to the database for integration testing
6
+ and production seeding. This enables testing the database pipeline with data
7
+ generated by the TypeScript simulation engine.
8
+
9
+ Usage:
10
+ # Import from default training-data-output directory
11
+ python scripts/import_json_trajectories.py
12
+
13
+ # Import from custom directory
14
+ python scripts/import_json_trajectories.py --source ./my-trajectories
15
+
16
+ # Dry run (validate without inserting)
17
+ python scripts/import_json_trajectories.py --dry-run
18
+
19
+ # Verbose output
20
+ python scripts/import_json_trajectories.py --verbose
21
+
22
+ Environment:
23
+ DATABASE_URL: PostgreSQL connection string (required)
24
+
25
+ Requirements:
26
+ - psycopg2
27
+ - JSON trajectory files generated by generate-training-data.ts
28
+ """
29
+
30
+ import argparse
31
+ import json
32
+ import logging
33
+ import os
34
+ import sys
35
+ from dataclasses import dataclass
36
+ from datetime import datetime
37
+ from pathlib import Path
38
+ from typing import Dict, List, Optional
39
+ import hashlib
40
+
41
+ # Add src to path
42
+ sys.path.insert(0, str(Path(__file__).parent.parent))
43
+
44
+ from src.data_bridge.reader import JsonTrajectoryReader, validate_llm_calls
45
+ from src.training.rubric_loader import normalize_archetype, has_custom_rubric
46
+
47
+ logging.basicConfig(
48
+ level=logging.INFO,
49
+ format="%(asctime)s - %(levelname)s - %(message)s"
50
+ )
51
+ logger = logging.getLogger(__name__)
52
+
53
+
54
+ @dataclass
55
+ class ImportStats:
56
+ """Statistics for import operation."""
57
+ total_files: int = 0
58
+ valid_trajectories: int = 0
59
+ invalid_trajectories: int = 0
60
+ inserted: int = 0
61
+ skipped_existing: int = 0
62
+ failed: int = 0
63
+ archetypes_seen: Dict[str, int] = None
64
+
65
+ def __post_init__(self):
66
+ if self.archetypes_seen is None:
67
+ self.archetypes_seen = {}
68
+
69
+ def record_archetype(self, archetype: str):
70
+ normalized = normalize_archetype(archetype)
71
+ self.archetypes_seen[normalized] = self.archetypes_seen.get(normalized, 0) + 1
72
+
73
+
74
+ def get_db_connection():
75
+ """Get database connection from environment."""
76
+ database_url = os.environ.get("DATABASE_URL")
77
+ if not database_url:
78
+ raise ValueError(
79
+ "DATABASE_URL environment variable not set. "
80
+ "Example: postgresql://babylon:password@localhost:5433/babylon"
81
+ )
82
+
83
+ try:
84
+ import psycopg2
85
+ except ImportError:
86
+ raise ImportError(
87
+ "psycopg2 is required for database import. "
88
+ "Install with: pip install psycopg2-binary"
89
+ )
90
+
91
+ return psycopg2.connect(database_url)
92
+
93
+
94
+ def generate_snowflake_id() -> str:
95
+ """Generate a unique ID similar to TypeScript snowflake."""
96
+ timestamp = int(datetime.now().timestamp() * 1000)
97
+ random_part = hashlib.sha256(os.urandom(8)).hexdigest()[:8]
98
+ return f"{timestamp}{random_part}"
99
+
100
+
101
+ def extract_archetype_from_trajectory(traj_data: Dict) -> str:
102
+ """
103
+ Extract archetype from trajectory data.
104
+
105
+ Priority:
106
+ 1. trajectory.archetype field
107
+ 2. First step's action.parameters.archetype
108
+ 3. Default to 'trader'
109
+ """
110
+ # Check trajectory-level archetype
111
+ archetype = traj_data.get("archetype")
112
+ if archetype and archetype != "default":
113
+ return normalize_archetype(archetype)
114
+
115
+ # Check steps for archetype in action parameters
116
+ steps_json = traj_data.get("stepsJson", "[]")
117
+ steps = json.loads(steps_json) if isinstance(steps_json, str) else steps_json
118
+
119
+ for step in steps:
120
+ action = step.get("action", {})
121
+ params = action.get("parameters", {})
122
+ step_archetype = params.get("archetype")
123
+ if step_archetype:
124
+ return normalize_archetype(step_archetype)
125
+
126
+ return "trader"
127
+
128
+
129
+ def validate_trajectory(traj_data: Dict) -> tuple[bool, List[str]]:
130
+ """
131
+ Validate a trajectory for import.
132
+
133
+ Returns:
134
+ (is_valid, list_of_issues)
135
+ """
136
+ issues = []
137
+
138
+ # Required fields
139
+ required = ["trajectoryId", "agentId", "windowId"]
140
+ for field in required:
141
+ if not traj_data.get(field):
142
+ issues.append(f"Missing required field: {field}")
143
+
144
+ # Steps validation
145
+ steps_json = traj_data.get("stepsJson", "[]")
146
+ try:
147
+ steps = json.loads(steps_json) if isinstance(steps_json, str) else steps_json
148
+ if len(steps) == 0:
149
+ issues.append("No steps in trajectory")
150
+ else:
151
+ is_valid_llm, llm_issues = validate_llm_calls(steps)
152
+ if not is_valid_llm:
153
+ issues.extend(llm_issues)
154
+ except json.JSONDecodeError as e:
155
+ issues.append(f"Invalid stepsJson: {e}")
156
+
157
+ # Archetype validation
158
+ archetype = extract_archetype_from_trajectory(traj_data)
159
+ if archetype != "default" and not has_custom_rubric(archetype):
160
+ issues.append(f"Unknown archetype: {archetype}")
161
+
162
+ return len(issues) == 0, issues
163
+
164
+
165
+ def check_trajectory_exists(conn, trajectory_id: str) -> bool:
166
+ """Check if trajectory already exists in database."""
167
+ cur = conn.cursor()
168
+ cur.execute(
169
+ 'SELECT 1 FROM trajectories WHERE "trajectoryId" = %s LIMIT 1',
170
+ (trajectory_id,)
171
+ )
172
+ exists = cur.fetchone() is not None
173
+ cur.close()
174
+ return exists
175
+
176
+
177
+ def insert_trajectory(conn, traj_data: Dict) -> bool:
178
+ """
179
+ Insert a trajectory into the database.
180
+
181
+ Returns:
182
+ True if inserted, False if failed
183
+ """
184
+ trajectory_id = traj_data.get("trajectoryId")
185
+ archetype = extract_archetype_from_trajectory(traj_data)
186
+
187
+ # Parse timestamps
188
+ start_time = traj_data.get("startTime")
189
+ if isinstance(start_time, str):
190
+ start_time = datetime.fromisoformat(start_time.replace("Z", "+00:00"))
191
+ elif start_time is None:
192
+ start_time = datetime.now()
193
+
194
+ end_time = traj_data.get("endTime")
195
+ if isinstance(end_time, str):
196
+ end_time = datetime.fromisoformat(end_time.replace("Z", "+00:00"))
197
+ elif end_time is None:
198
+ end_time = datetime.now()
199
+
200
+ # Prepare values
201
+ values = (
202
+ traj_data.get("id", generate_snowflake_id()),
203
+ trajectory_id,
204
+ traj_data.get("agentId"),
205
+ archetype,
206
+ traj_data.get("windowId"),
207
+ traj_data.get("scenarioId"),
208
+ traj_data.get("stepsJson", "[]"),
209
+ json.dumps(traj_data.get("rewardComponents", {})),
210
+ json.dumps(traj_data.get("metrics", {})),
211
+ json.dumps(traj_data.get("metadata", {})),
212
+ float(traj_data.get("finalPnL", 0)),
213
+ int(traj_data.get("episodeLength", 0)),
214
+ float(traj_data.get("totalReward", 0)),
215
+ traj_data.get("finalStatus", "completed"),
216
+ float(traj_data.get("finalBalance")) if traj_data.get("finalBalance") else None,
217
+ int(traj_data.get("tradesExecuted", 0)),
218
+ int(traj_data.get("postsCreated", 0)),
219
+ float(traj_data.get("aiJudgeReward")) if traj_data.get("aiJudgeReward") else None,
220
+ traj_data.get("isTrainingData", True),
221
+ traj_data.get("isEvaluation", False),
222
+ traj_data.get("usedInTraining", False),
223
+ start_time,
224
+ end_time,
225
+ int(traj_data.get("durationMs", 0)),
226
+ 1, # windowHours
227
+ datetime.now(),
228
+ datetime.now(),
229
+ )
230
+
231
+ cur = conn.cursor()
232
+ cur.execute(
233
+ '''
234
+ INSERT INTO trajectories (
235
+ "id", "trajectoryId", "agentId", "archetype", "windowId",
236
+ "scenarioId", "stepsJson", "rewardComponentsJson", "metricsJson",
237
+ "metadataJson", "finalPnL", "episodeLength", "totalReward",
238
+ "finalStatus", "finalBalance", "tradesExecuted", "postsCreated",
239
+ "aiJudgeReward", "isTrainingData", "isEvaluation", "usedInTraining",
240
+ "startTime", "endTime", "durationMs", "windowHours",
241
+ "createdAt", "updatedAt"
242
+ ) VALUES (
243
+ %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
244
+ )
245
+ ''',
246
+ values
247
+ )
248
+ conn.commit()
249
+ cur.close()
250
+ return True
251
+
252
+
253
+ def import_trajectories(
254
+ source_dir: Path,
255
+ dry_run: bool = False,
256
+ verbose: bool = False,
257
+ ) -> ImportStats:
258
+ """
259
+ Import all trajectories from source directory to database.
260
+
261
+ Args:
262
+ source_dir: Directory containing trajectory JSON files
263
+ dry_run: If True, validate but don't insert
264
+ verbose: If True, log each trajectory
265
+
266
+ Returns:
267
+ Import statistics
268
+ """
269
+ stats = ImportStats()
270
+
271
+ # Find trajectory files
272
+ traj_dir = source_dir / "trajectories"
273
+ if not traj_dir.exists():
274
+ logger.warning(f"Trajectories directory not found: {traj_dir}")
275
+ traj_dir = source_dir
276
+
277
+ json_files = list(traj_dir.glob("*.json"))
278
+ stats.total_files = len(json_files)
279
+
280
+ if stats.total_files == 0:
281
+ logger.warning(f"No JSON files found in {traj_dir}")
282
+ return stats
283
+
284
+ logger.info(f"Found {stats.total_files} trajectory files in {traj_dir}")
285
+
286
+ # Get database connection (skip if dry run)
287
+ conn = None
288
+ if not dry_run:
289
+ conn = get_db_connection()
290
+ logger.info("Connected to database")
291
+
292
+ # Process each file
293
+ for file_path in json_files:
294
+ try:
295
+ with open(file_path, "r", encoding="utf-8") as f:
296
+ data = json.load(f)
297
+
298
+ # Handle wrapped format (trajectory key) vs direct format
299
+ traj_data = data.get("trajectory", data)
300
+ trajectory_id = traj_data.get("trajectoryId", file_path.stem)
301
+
302
+ # Validate
303
+ is_valid, issues = validate_trajectory(traj_data)
304
+ if not is_valid:
305
+ stats.invalid_trajectories += 1
306
+ if verbose:
307
+ logger.warning(f"Invalid trajectory {trajectory_id}: {issues}")
308
+ continue
309
+
310
+ stats.valid_trajectories += 1
311
+ archetype = extract_archetype_from_trajectory(traj_data)
312
+ stats.record_archetype(archetype)
313
+
314
+ if verbose:
315
+ logger.info(f"Validated: {trajectory_id} (archetype: {archetype})")
316
+
317
+ if dry_run:
318
+ continue
319
+
320
+ # Check if exists
321
+ if check_trajectory_exists(conn, trajectory_id):
322
+ stats.skipped_existing += 1
323
+ if verbose:
324
+ logger.info(f"Skipped existing: {trajectory_id}")
325
+ continue
326
+
327
+ # Insert
328
+ try:
329
+ insert_trajectory(conn, traj_data)
330
+ stats.inserted += 1
331
+ if verbose:
332
+ logger.info(f"Inserted: {trajectory_id}")
333
+ except Exception as e:
334
+ stats.failed += 1
335
+ logger.error(f"Failed to insert {trajectory_id}: {e}")
336
+
337
+ except json.JSONDecodeError as e:
338
+ stats.invalid_trajectories += 1
339
+ logger.warning(f"Invalid JSON in {file_path}: {e}")
340
+ except Exception as e:
341
+ stats.failed += 1
342
+ logger.error(f"Error processing {file_path}: {e}")
343
+
344
+ if conn:
345
+ conn.close()
346
+
347
+ return stats
348
+
349
+
350
+ def main():
351
+ parser = argparse.ArgumentParser(
352
+ description="Import JSON trajectories to PostgreSQL database"
353
+ )
354
+ parser.add_argument(
355
+ "--source",
356
+ type=Path,
357
+ default=Path("./training-data-output"),
358
+ help="Source directory containing trajectory JSON files"
359
+ )
360
+ parser.add_argument(
361
+ "--dry-run",
362
+ action="store_true",
363
+ help="Validate trajectories without inserting to database"
364
+ )
365
+ parser.add_argument(
366
+ "--verbose",
367
+ action="store_true",
368
+ help="Log each trajectory processed"
369
+ )
370
+
371
+ args = parser.parse_args()
372
+
373
+ if not args.source.exists():
374
+ logger.error(f"Source directory not found: {args.source}")
375
+ sys.exit(1)
376
+
377
+ if args.dry_run:
378
+ logger.info("DRY RUN MODE - No database modifications")
379
+
380
+ stats = import_trajectories(
381
+ source_dir=args.source,
382
+ dry_run=args.dry_run,
383
+ verbose=args.verbose,
384
+ )
385
+
386
+ # Print summary
387
+ print("\n" + "=" * 50)
388
+ print("IMPORT SUMMARY")
389
+ print("=" * 50)
390
+ print(f"Total files: {stats.total_files}")
391
+ print(f"Valid trajectories: {stats.valid_trajectories}")
392
+ print(f"Invalid trajectories: {stats.invalid_trajectories}")
393
+
394
+ if not args.dry_run:
395
+ print(f"Inserted: {stats.inserted}")
396
+ print(f"Skipped (existing): {stats.skipped_existing}")
397
+ print(f"Failed: {stats.failed}")
398
+
399
+ if stats.archetypes_seen:
400
+ print("\nArchetypes found:")
401
+ for archetype, count in sorted(stats.archetypes_seen.items()):
402
+ print(f" - {archetype}: {count}")
403
+
404
+ if stats.failed > 0:
405
+ sys.exit(1)
406
+
407
+ sys.exit(0)
408
+
409
+
410
+ if __name__ == "__main__":
411
+ main()
412
+
@@ -0,0 +1,63 @@
1
+ # Local Fine-Tuning Pipeline
2
+
3
+ This directory contains scripts to train RL adapters from Babylon simulation logs.
4
+
5
+ ## Workflow
6
+
7
+ 1. **Generate Data:** Run `bun packages/engine/examples/generate-training-data.ts`
8
+ 2. **Score & Format:** Run `python ingest_and_score.py`
9
+ 3. **Train:** Run `python train_from_csv.py`
10
+ 4. **Test:** Run `python test_adapter.py`
11
+
12
+ ## Quick Start
13
+
14
+ If you do not have a local Postgres database, Atropos server, or vLLM instance running, you can use the **Offline Pipeline**. This generates data to JSON files and uses direct PyTorch/HuggingFace libraries for training.
15
+
16
+ ### Prerequisites
17
+
18
+ 1. `GROQ_API_KEY` or `OPENAI_API_KEY` set in environment.
19
+ 2. Python dependencies: `pip install torch transformers peft pandas datasets trl`
20
+
21
+ ### Step 1: Generate Data (TypeScript)
22
+
23
+ Runs the game simulation in-memory and dumps "Observation -> Action" logs to JSON.
24
+
25
+ ```bash
26
+ # Runs 24 simulated hours
27
+ bun packages/engine/examples/generate-training-data.ts
28
+ ```
29
+
30
+ _Output:_ `training-data-output/trajectories/*.json`
31
+
32
+ ### Step 2: Process & Score (Python)
33
+
34
+ Converts raw JSON logs into a scored CSV dataset (System/User/Assistant format).
35
+
36
+ ```bash
37
+ cd packages/training/python/scripts/local-finetune
38
+ python ingest_and_score.py
39
+ ```
40
+
41
+ _Output:_ `packages/training/data/scored_trajectories.csv`
42
+
43
+ ### Step 3: Train Model (Python)
44
+
45
+ Fine-tunes a base model (Qwen2.5-0.5B by default) on your scored data using LoRA.
46
+
47
+ ```bash
48
+ python train_from_csv.py --output ./my-model-v1
49
+ ```
50
+
51
+ ### Step 4: Test Inference
52
+
53
+ Interactively chat with your new LoRA adapter to verify behavior.
54
+
55
+ ```bash
56
+ python test_adapter.py
57
+ ```
58
+
59
+ ---
60
+
61
+ ## 🏗️ Production Architecture (Tinker/Atropos)
62
+
63
+ _For the full cloud-based pipeline involving Postgres, GRPO, and Tinker compute, refer to `scripts/run_full_pipeline.py`._