@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,528 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Babylon Local Training Script - Unified Mac (MLX) + GTX (CUDA) Support
4
+
5
+ This script provides training using REAL data from the database OR local JSON files.
6
+ Only trajectories with actual LLM calls are used.
7
+
8
+ Supports:
9
+ - Apple Silicon (MLX) - LoRA fine-tuning
10
+ - NVIDIA GPU (PyTorch/CUDA) - Full or LoRA fine-tuning
11
+ - CPU fallback (slow but works)
12
+
13
+ Usage:
14
+ # Mac with MLX from Postgres Database
15
+ python scripts/train_local.py --backend mlx --model mlx-community/Qwen2.5-1.5B-Instruct-4bit
16
+
17
+ # Mac with MLX from local JSON files
18
+ python scripts/train_local.py --backend mlx --model mlx-community/Qwen2.5-1.5B-Instruct-4bit --source-dir ../engine/training-data-output/trajectories
19
+
20
+ # GTX/CUDA machine from Postgres Database
21
+ python scripts/train_local.py --backend cuda --model Qwen/Qwen2.5-1.5B-Instruct
22
+
23
+ # GTX/CUDA machine from local JSON files
24
+ python scripts/train_local.py --backend cuda --model Qwen/Qwen2.5-1.5B-Instruct --source-dir ../engine/training-data-output/trajectories
25
+
26
+ Small model recommendations for consumer hardware:
27
+ Mac M1/M2 (8GB): mlx-community/Qwen2.5-0.5B-Instruct-4bit
28
+ Mac M1/M2 (16GB): mlx-community/Qwen2.5-1.5B-Instruct-4bit
29
+ GTX 3060 (12GB): Qwen/Qwen2.5-1.5B-Instruct
30
+ GTX 3080 (10GB): Qwen/Qwen2.5-1.5B-Instruct
31
+ GTX 4090 (24GB): Qwen/Qwen2.5-3B-Instruct
32
+ """
33
+
34
+ import os
35
+ import sys
36
+ from pathlib import Path
37
+
38
+ # Add src to path
39
+ sys.path.insert(0, str(Path(__file__).parent.parent))
40
+
41
+
42
+
43
+ import argparse
44
+ import asyncio
45
+ import json
46
+ import logging
47
+ from datetime import datetime, timezone
48
+ from typing import Literal, List
49
+ from dotenv import load_dotenv
50
+
51
+ from src.models import BabylonTrajectory
52
+ from src.data_bridge.reader import JsonTrajectoryReader, PostgresTrajectoryReader, validate_llm_calls
53
+
54
+ # Load environment
55
+ env_path = Path(__file__).parent.parent.parent.parent.parent / ".env"
56
+ if env_path.exists():
57
+ load_dotenv(env_path)
58
+
59
+ logging.basicConfig(
60
+ level=logging.INFO,
61
+ format='%(asctime)s [%(levelname)s] %(message)s'
62
+ )
63
+ logger = logging.getLogger(__name__)
64
+
65
+
66
+ # =============================================================================
67
+ # Backend Detection
68
+ # =============================================================================
69
+
70
+ def detect_backend() -> Literal["mlx", "cuda", "cpu"]:
71
+ """Auto-detect the best available backend."""
72
+ # Check for MLX (Apple Silicon)
73
+ try:
74
+ import mlx.core # type: ignore
75
+ logger.info("MLX backend available (Apple Silicon)")
76
+ return "mlx"
77
+ except ImportError:
78
+ pass
79
+
80
+ # Check for CUDA
81
+ try:
82
+ import torch
83
+ if torch.cuda.is_available():
84
+ logger.info(
85
+ f"CUDA backend available: {torch.cuda.get_device_name(0)}")
86
+ return "cuda"
87
+ except ImportError:
88
+ pass
89
+
90
+ logger.warning("No GPU backend available, falling back to CPU (slow)")
91
+ return "cpu"
92
+
93
+
94
+ # =============================================================================
95
+ # Data Loading
96
+ # =============================================================================
97
+
98
+ async def load_postgres_training_data(
99
+ database_url: str,
100
+ min_actions: int,
101
+ lookback_hours: int,
102
+ max_trajectories: int,
103
+ ) -> List[BabylonTrajectory]:
104
+ """Load REAL training data from the database and parse into Pydantic models."""
105
+ logger.info("Loading real training data from database...")
106
+
107
+ trajectories: List[BabylonTrajectory] = []
108
+
109
+ try:
110
+ async with PostgresTrajectoryReader(database_url) as reader:
111
+ windows = await reader.get_window_ids(lookback_hours=lookback_hours)
112
+ if not windows:
113
+ raise ValueError(
114
+ "No trajectory windows found in database. Generate data first.")
115
+
116
+ logger.info(f"Found {len(windows)} trajectory windows")
117
+
118
+ for window_id in windows:
119
+ if len(trajectories) >= max_trajectories:
120
+ break
121
+
122
+ window_trajectories = await reader.get_trajectories_by_window(
123
+ window_id, min_actions=min_actions, validate=True
124
+ )
125
+ for traj_row in window_trajectories:
126
+ try:
127
+ steps = json.loads(traj_row.steps_json)
128
+ # Convert TrajectoryRow object to a dict for Pydantic validation
129
+ traj_data = {
130
+ "id": traj_row.trajectory_id,
131
+ "trajectory_id": traj_row.trajectory_id,
132
+ "agent_id": traj_row.agent_id,
133
+ "window_id": traj_row.window_id,
134
+ "steps": steps,
135
+ "total_reward": traj_row.total_reward,
136
+ "episode_length": traj_row.episode_length,
137
+ "final_status": traj_row.final_status,
138
+ "final_pnl": traj_row.final_pnl,
139
+ "trades_executed": traj_row.trades_executed,
140
+ "archetype": traj_row.archetype,
141
+ }
142
+ traj_model = BabylonTrajectory.model_validate(
143
+ traj_data)
144
+ trajectories.append(traj_model)
145
+ except Exception as e:
146
+ logger.warning(
147
+ f"Skipping DB trajectory {traj_row.trajectory_id} due to parsing error: {e}")
148
+
149
+ except Exception as e:
150
+ logger.error(f"Failed to load from database: {e}")
151
+ logger.error(
152
+ "Please ensure the database is running and DATABASE_URL is correct.")
153
+ sys.exit(1)
154
+
155
+ if len(trajectories) < 10:
156
+ raise ValueError(
157
+ f"Insufficient training data: only {len(trajectories)} valid trajectories found.")
158
+
159
+ logger.info(f"Loaded {len(trajectories)} real trajectories from DB")
160
+ return trajectories
161
+
162
+
163
+ def load_json_training_data(source_dir: str, max_trajectories: int) -> List[BabylonTrajectory]:
164
+ """Loads training data from a directory of JSON files."""
165
+ logger.info(f"Loading training data from local directory: {source_dir}")
166
+ try:
167
+ reader = JsonTrajectoryReader(source_dir)
168
+ all_trajectories: List[BabylonTrajectory] = []
169
+ for window_id in reader.get_window_ids():
170
+ if len(all_trajectories) >= max_trajectories:
171
+ break
172
+ for traj_data in reader.get_trajectories_by_window(window_id):
173
+ try:
174
+ # Handle the nested `trajectory` key and `stepsJson` string format
175
+ # from the TypeScript simulation engine.
176
+ if 'trajectory' in traj_data:
177
+ traj_data = traj_data['trajectory']
178
+ if 'stepsJson' in traj_data and isinstance(traj_data['stepsJson'], str):
179
+ traj_data['steps'] = json.loads(traj_data['stepsJson'])
180
+
181
+ is_valid, issues = validate_llm_calls(
182
+ traj_data.get('steps', []))
183
+ if not is_valid:
184
+ logger.debug(
185
+ f"Skipping invalid JSON trajectory {traj_data.get('trajectoryId')}: {issues}")
186
+ continue
187
+
188
+ # Ensure 'id' field is present for Pydantic model validation
189
+ if 'id' not in traj_data:
190
+ traj_data['id'] = traj_data.get(
191
+ 'trajectory_id', 'id_missing')
192
+
193
+ all_trajectories.append(
194
+ BabylonTrajectory.model_validate(traj_data))
195
+ except Exception as e:
196
+ logger.warning(
197
+ f"Skipping invalid JSON trajectory {traj_data.get('trajectoryId')}: {e}")
198
+
199
+ if len(all_trajectories) == 0:
200
+ raise ValueError(
201
+ "Insufficient training data: 0 valid trajectories were loaded. Check validation logs with DEBUG level.")
202
+ elif len(all_trajectories) < 10:
203
+ logger.warning(
204
+ f"Low training data: only {len(all_trajectories)} valid trajectories found.")
205
+
206
+ logger.info(
207
+ f"Loaded {len(all_trajectories)} valid trajectories from JSON files.")
208
+ return all_trajectories
209
+ except (FileNotFoundError, ValueError) as e:
210
+ logger.error(f"Error loading JSON data: {e}")
211
+ sys.exit(1)
212
+
213
+
214
+ def trajectories_to_training_samples(trajectories: List[BabylonTrajectory]) -> list[dict]:
215
+ """
216
+ Convert a list of BabylonTrajectory objects to the training sample format.
217
+
218
+ Each LLM call within a trajectory is extracted into a separate sample
219
+ containing a list of messages (system, user, assistant).
220
+ """
221
+ samples = []
222
+ for traj in trajectories:
223
+ for step in traj.steps:
224
+ if not step.llm_calls:
225
+ continue
226
+ for llm_call in step.llm_calls:
227
+ # Basic quality filter for the LLM call
228
+ if not llm_call.response or len(llm_call.response) < 20:
229
+ continue
230
+
231
+ messages = []
232
+ if llm_call.system_prompt:
233
+ messages.append(
234
+ {"role": "system", "content": llm_call.system_prompt})
235
+ if llm_call.user_prompt:
236
+ messages.append(
237
+ {"role": "user", "content": llm_call.user_prompt})
238
+ messages.append(
239
+ {"role": "assistant", "content": llm_call.response})
240
+
241
+ if len(messages) >= 2:
242
+ samples.append({"messages": messages})
243
+
244
+ logger.info(
245
+ f"Converted {len(trajectories)} trajectories to {len(samples)} training samples")
246
+ return samples
247
+
248
+
249
+ # =============================================================================
250
+ # Training Backends
251
+ # =============================================================================
252
+
253
+ def train_mlx(
254
+ samples: list[dict], model_name: str, output_dir: str,
255
+ num_iters: int, batch_size: int, learning_rate: float
256
+ ) -> str:
257
+ """Train using MLX LoRA on Apple Silicon."""
258
+ import subprocess
259
+ import random
260
+
261
+ logger.info("=" * 60 + "\nMLX LORA TRAINING\n" + "=" * 60)
262
+ data_dir = os.path.join(output_dir, "training_data")
263
+ os.makedirs(data_dir, exist_ok=True)
264
+
265
+ random.shuffle(samples)
266
+ split_idx = int(len(samples) * 0.9)
267
+ train_samples, valid_samples = samples[:split_idx], samples[split_idx:]
268
+
269
+ with open(os.path.join(data_dir, "train.jsonl"), 'w') as f:
270
+ for s in train_samples:
271
+ f.write(json.dumps(s) + "\n")
272
+ with open(os.path.join(data_dir, "valid.jsonl"), 'w') as f:
273
+ for s in valid_samples:
274
+ f.write(json.dumps(s) + "\n")
275
+
276
+ adapter_path = os.path.join(output_dir, "adapters")
277
+ import mlx_lm # type: ignore
278
+ cmd = [
279
+ sys.executable, "-m", "mlx_lm", "lora", "--model", model_name, "--train",
280
+ "--data", data_dir, "--adapter-path", adapter_path, "--batch-size", str(
281
+ batch_size),
282
+ "--iters", str(num_iters), "--learning-rate", str(learning_rate),
283
+ "--steps-per-report", "10", "--steps-per-eval", "25", "--val-batches", "5",
284
+ "--max-seq-length", "1024", "--num-layers", "8", "--mask-prompt",
285
+ ]
286
+ logger.info(f"Command: {' '.join(cmd)}")
287
+ subprocess.run(cmd, check=True)
288
+ return adapter_path
289
+
290
+
291
+ def train_cuda(
292
+ samples: list[dict], model_name: str, output_dir: str,
293
+ epochs: int, batch_size: int, learning_rate: float, use_lora: bool
294
+ ) -> str:
295
+ """Train using PyTorch/CUDA on NVIDIA GPU."""
296
+ import torch
297
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
298
+ from datasets import Dataset
299
+
300
+ logger.info("=" * 60 + "\nCUDA/PYTORCH TRAINING\n" + "=" * 60)
301
+ logger.info(
302
+ f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
303
+
304
+ tokenizer = AutoTokenizer.from_pretrained(
305
+ model_name, trust_remote_code=True)
306
+ if tokenizer.pad_token is None:
307
+ tokenizer.pad_token = tokenizer.eos_token
308
+
309
+ formatted = [{"text": tokenizer.apply_chat_template(
310
+ s['messages'], tokenize=False, add_generation_prompt=False)} for s in samples if s.get("messages")]
311
+ dataset = Dataset.from_list(formatted)
312
+
313
+ def tokenize_fn(examples):
314
+ # Using a shorter sequence length to prevent CUDA out-of-memory errors
315
+ # on consumer GPUs. The memory usage scales quadratically with this value.
316
+ return tokenizer(
317
+ examples["text"],
318
+ truncation=True,
319
+ max_length=1024, # Reduced from 2048 to fit in ~12GB VRAM
320
+ padding="max_length",
321
+ )
322
+
323
+ tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
324
+
325
+ model = AutoModelForCausalLM.from_pretrained(
326
+ model_name, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto")
327
+
328
+ if use_lora:
329
+ from peft import LoraConfig, get_peft_model, TaskType
330
+ lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=16, lora_alpha=32,
331
+ lora_dropout=0.1, target_modules=["q_proj", "v_proj", "k_proj", "o_proj"])
332
+ model = get_peft_model(model, lora_config)
333
+ model.print_trainable_parameters()
334
+
335
+ # Optimized training arguments for consumer GPUs (~12GB VRAM)
336
+ training_args = TrainingArguments(
337
+ output_dir=output_dir,
338
+ num_train_epochs=epochs,
339
+ # Smallest possible batch size to save memory
340
+ per_device_train_batch_size=1,
341
+ gradient_accumulation_steps=8, # Compensate for small batch size
342
+ learning_rate=learning_rate,
343
+ warmup_steps=100,
344
+ logging_steps=10,
345
+ save_steps=500,
346
+ save_total_limit=2,
347
+ fp16=True,
348
+ report_to="none",
349
+ remove_unused_columns=False
350
+ )
351
+
352
+ trainer = Trainer(model=model, args=training_args, train_dataset=tokenized,
353
+ data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False))
354
+
355
+ trainer.train()
356
+ trainer.save_model(output_dir)
357
+ return output_dir
358
+
359
+
360
+ def train_cpu(samples: list[dict], model_name: str, output_dir: str, epochs: int, batch_size: int, learning_rate: float) -> str:
361
+ """Train using CPU (slow fallback)."""
362
+ logger.warning("=" * 60 + "\nCPU TRAINING (VERY SLOW)\n" + "=" * 60)
363
+ # Using the CUDA function is fine here, as transformers will default to CPU if no GPU is found.
364
+ # We force a smaller model to make it feasible.
365
+ return train_cuda(samples, "Qwen/Qwen2.5-0.5B-Instruct", output_dir, epochs, batch_size, learning_rate, use_lora=False)
366
+
367
+ # =============================================================================
368
+ # Validation
369
+ # =============================================================================
370
+
371
+
372
+ def validate_trained_model(model_path: str, backend: Literal["mlx", "cuda", "cpu"], base_model: str | None = None) -> bool:
373
+ """Validate the trained model by generating a test response."""
374
+ logger.info("=" * 60 + "\nVALIDATING TRAINED MODEL\n" + "=" * 60)
375
+ test_prompt = """You are a trading agent in Babylon prediction markets.
376
+
377
+ Current State:
378
+ - Balance: $10,000
379
+ - P&L: $250
380
+ - Positions: 2 open
381
+
382
+ Market Update:
383
+ - BTC prediction market at 68% probability
384
+ - Recent news: Fed announces rate cut consideration
385
+
386
+ Analyze this market update and explain your trading decision."""
387
+
388
+ try:
389
+ if backend == "mlx":
390
+ from mlx_lm import load, generate # type: ignore
391
+ model, tokenizer = load(base_model, adapter_path=model_path)
392
+ messages = [{"role": "user", "content": test_prompt}]
393
+ prompt = tokenizer.apply_chat_template(
394
+ messages, tokenize=False, add_generation_prompt=True)
395
+ response = generate(model, tokenizer, prompt=prompt,
396
+ max_tokens=200, verbose=False)
397
+ else:
398
+ import torch
399
+ from transformers import AutoModelForCausalLM, AutoTokenizer
400
+ tokenizer = AutoTokenizer.from_pretrained(
401
+ model_path, trust_remote_code=True)
402
+ model = AutoModelForCausalLM.from_pretrained(
403
+ model_path,
404
+ torch_dtype=torch.float16 if backend == "cuda" else torch.float32,
405
+ device_map="auto" if backend == "cuda" else None,
406
+ trust_remote_code=True,
407
+ )
408
+ messages = [{"role": "user", "content": test_prompt}]
409
+ prompt = tokenizer.apply_chat_template(
410
+ messages, tokenize=False, add_generation_prompt=True)
411
+ inputs = tokenizer(prompt, return_tensors="pt")
412
+ if backend == "cuda":
413
+ inputs = {k: v.cuda() for k, v in inputs.items()}
414
+ outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.7,
415
+ do_sample=True, pad_token_id=tokenizer.eos_token_id)
416
+ response = tokenizer.decode(
417
+ outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
418
+
419
+ logger.info("Test Response:\n" + "-" * 40 +
420
+ f"\n{response[:500]}...\n" + "-" * 40)
421
+
422
+ if len(response) < 50:
423
+ logger.error("Response too short - model may not be working")
424
+ return False
425
+
426
+ logger.info("✅ Model validation passed!")
427
+ return True
428
+
429
+ except Exception as e:
430
+ logger.error(f"Model validation failed: {e}", exc_info=True)
431
+ return False
432
+
433
+ # =============================================================================
434
+ # Main
435
+ # =============================================================================
436
+
437
+
438
+ async def main_async(args):
439
+ """Main async training function."""
440
+ backend = args.backend or detect_backend()
441
+ model_name = args.model or (
442
+ "mlx-community/Qwen2.5-1.5B-Instruct-4bit" if backend == "mlx" else "Qwen/Qwen2.5-1.5B-Instruct")
443
+ logger.info(f"Using backend: {backend}, Model: {model_name}")
444
+ os.makedirs(args.output, exist_ok=True)
445
+
446
+ try:
447
+ # Main logic to select data source based on CLI arguments
448
+ if args.source_dir:
449
+ trajectories = load_json_training_data(
450
+ args.source_dir, args.max_trajectories)
451
+ else:
452
+ database_url = args.database_url or os.getenv("DATABASE_URL")
453
+ if not database_url:
454
+ logger.error(
455
+ "DATABASE_URL not set and --source-dir not provided. Exiting.")
456
+ return 1
457
+ trajectories = await load_postgres_training_data(database_url, args.min_actions, args.lookback_hours, args.max_trajectories)
458
+ except (ValueError, FileNotFoundError) as e:
459
+ logger.error(f"Failed to load data: {e}")
460
+ return 1
461
+
462
+ samples = trajectories_to_training_samples(trajectories)
463
+ if len(samples) < 10:
464
+ logger.error(
465
+ f"Not enough valid training samples found: {len(samples)}")
466
+ return 1
467
+
468
+ model_path, base_model = "", None
469
+ try:
470
+ if backend == "mlx":
471
+ model_path, base_model = train_mlx(
472
+ samples, model_name, args.output, args.iters, args.batch_size, args.lr), model_name
473
+ elif backend == "cuda":
474
+ model_path = train_cuda(
475
+ samples, model_name, args.output, args.epochs, args.batch_size, args.lr, args.lora)
476
+ else: # cpu
477
+ model_path = train_cpu(
478
+ samples, model_name, args.output, args.epochs, args.batch_size, args.lr)
479
+ except Exception as e:
480
+ logger.error(f"Training process failed: {e}", exc_info=True)
481
+ return 1
482
+
483
+ if args.validate and model_path:
484
+ validate_trained_model(model_path, backend, base_model)
485
+
486
+ logger.info("\n" + "="*60 + "\nTRAINING COMPLETE\n" +
487
+ f" Model/adapter saved to: {model_path}\n" + "="*60)
488
+ return 0
489
+
490
+
491
+ def main():
492
+ parser = argparse.ArgumentParser(
493
+ description="Babylon Local Training", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
494
+
495
+ parser.add_argument(
496
+ "--source-dir", help="Directory with local JSON trajectory files for offline training.")
497
+ parser.add_argument(
498
+ "--database-url", help="Database URL (used if --source-dir is not provided).")
499
+ parser.add_argument("--backend", choices=["mlx", "cuda", "cpu"],
500
+ help="Training backend (auto-detected if not specified)")
501
+ parser.add_argument(
502
+ "--model", help="Model to train (default depends on backend)")
503
+ parser.add_argument("--min-actions", type=int, default=3,
504
+ help="Minimum actions per trajectory (DB source)")
505
+ parser.add_argument("--lookback-hours", type=int, default=168,
506
+ help="Hours to look back for trajectories (DB source)")
507
+ parser.add_argument("--max-trajectories", type=int,
508
+ default=500, help="Maximum trajectories to load")
509
+ parser.add_argument(
510
+ "--output", default="./trained_models/local", help="Output directory")
511
+ parser.add_argument("--iters", type=int, default=100,
512
+ help="Training iterations (MLX)")
513
+ parser.add_argument("--epochs", type=int, default=3,
514
+ help="Training epochs (CUDA/CPU)")
515
+ parser.add_argument("--batch-size", type=int, default=2,
516
+ help="Batch size (Note: CUDA uses a fixed batch size of 1 for memory optimization)")
517
+ parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
518
+ parser.add_argument("--lora", action=argparse.BooleanOptionalAction,
519
+ default=True, help="Use LoRA (CUDA only)")
520
+ parser.add_argument("--validate", action=argparse.BooleanOptionalAction,
521
+ default=True, help="Validate trained model")
522
+
523
+ args = parser.parse_args()
524
+ return asyncio.run(main_async(args))
525
+
526
+
527
+ if __name__ == "__main__":
528
+ sys.exit(main())
@@ -0,0 +1,20 @@
1
+ """Setup file for ElizaOS RL training with Atropos."""
2
+
3
+ from setuptools import setup, find_packages
4
+
5
+ setup(
6
+ name="elizaos-training",
7
+ version="1.0.0",
8
+ packages=find_packages(where="src"),
9
+ package_dir={"": "src"},
10
+ python_requires=">=3.10",
11
+ install_requires=[
12
+ "atroposlib>=0.3.0",
13
+ "asyncpg>=0.29.0",
14
+ "python-dotenv>=1.0.0",
15
+ "pydantic>=2.0.0",
16
+ "openai>=1.0.0",
17
+ "torch>=2.1.0",
18
+ "transformers>=4.36.0",
19
+ ],
20
+ )