@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,522 @@
1
+ """
2
+ Babylon Hybrid Environment for GRPO Training
3
+
4
+ Combines offline (database) and online (simulation bridge) rollouts.
5
+ This provides the best of both worlds:
6
+ - Offline: Large, diverse dataset from historical trajectories
7
+ - Online: Fresh rollouts from current policy interacting with simulation
8
+
9
+ Usage:
10
+ make train-hybrid # 80% offline, 20% online by default
11
+
12
+ # Or with custom ratio
13
+ python scripts/run_training.py --mode hybrid --hybrid-online-ratio 0.3
14
+
15
+ The online ratio determines what fraction of rollouts come from the
16
+ simulation bridge vs the database.
17
+ """
18
+
19
+ import asyncio
20
+ import copy
21
+ import logging
22
+ import os
23
+ import random
24
+ from typing import Any, Dict, List, Optional, Tuple
25
+
26
+ from pydantic import Field
27
+
28
+ from atroposlib.envs.base import APIServerConfig, BaseEnv, ScoredDataGroup
29
+
30
+ from .babylon_env import BabylonEnvConfig, BabylonRLAIFEnv
31
+ from .online_env import BabylonOnlineEnv, BabylonOnlineEnvConfig, Scenario
32
+ from .simulation_bridge import SimulationBridge
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class BabylonHybridEnvConfig(BabylonOnlineEnvConfig):
38
+ """
39
+ Configuration for hybrid environment.
40
+
41
+ Inherits from BabylonOnlineEnvConfig and adds offline ratio control.
42
+ """
43
+
44
+ online_ratio: float = Field(
45
+ default=0.2,
46
+ description="Ratio of rollouts from online simulation (0.0 = all offline, 1.0 = all online)"
47
+ )
48
+
49
+ # Database settings for offline mode (same as BabylonEnvConfig)
50
+ db_url: Optional[str] = Field(
51
+ default=None,
52
+ description="PostgreSQL connection URL for offline trajectories"
53
+ )
54
+ trajectory_window_size: int = Field(
55
+ default=1000,
56
+ description="Number of trajectories to cache in memory"
57
+ )
58
+ min_trajectories: int = Field(
59
+ default=10,
60
+ description="Minimum trajectories required to start offline training"
61
+ )
62
+
63
+
64
+ class BabylonHybridEnv(BaseEnv):
65
+ """
66
+ Hybrid environment that mixes offline and online rollouts.
67
+
68
+ Architecture:
69
+ - Maintains both an offline trajectory cache and online bridge connection
70
+ - For each get_next_item() call, randomly selects offline vs online
71
+ - Collects trajectories using the appropriate mode
72
+ - Scores and returns consistent ScoredDataGroup format
73
+
74
+ Benefits:
75
+ - Stability from large offline dataset
76
+ - Adaptability from on-policy online rollouts
77
+ - Smooth transition from offline to online training
78
+ """
79
+
80
+ name = "babylon_hybrid_env"
81
+
82
+ def __init__(
83
+ self,
84
+ config: BabylonHybridEnvConfig,
85
+ server_configs: List[APIServerConfig],
86
+ slurm: bool = False,
87
+ testing: bool = False,
88
+ ):
89
+ super().__init__(config, server_configs, slurm, testing)
90
+ self.config: BabylonHybridEnvConfig = config
91
+ self._server_configs = server_configs
92
+
93
+ # Offline components (from BabylonRLAIFEnv)
94
+ self.db_pool = None
95
+ self.trajectory_cache: List[Dict] = []
96
+ self.current_cache_idx: int = 0
97
+
98
+ # Online components (from BabylonOnlineEnv)
99
+ self.simulation_bridge: Optional[SimulationBridge] = None
100
+ self.scenario_pool = None
101
+ self._bridge_npc_index: int = 0
102
+
103
+ # Hybrid control
104
+ self.online_ratio = config.online_ratio
105
+ self.iter = 0
106
+ self.online_count = 0
107
+ self.offline_count = 0
108
+
109
+ # Tokenizer (set in setup)
110
+ self.tokenizer = None
111
+
112
+ logger.info(f"HybridEnv initialized with online_ratio={self.online_ratio:.0%}")
113
+
114
+ @classmethod
115
+ def config_init(cls) -> Tuple[BabylonHybridEnvConfig, List[APIServerConfig]]:
116
+ """Create default config"""
117
+ env_config = BabylonHybridEnvConfig(
118
+ tokenizer_name="Qwen/Qwen2.5-3B-Instruct",
119
+ rollout_server_url="http://localhost:8000",
120
+ total_steps=1000,
121
+ batch_size=16,
122
+ online_ratio=float(os.getenv("HYBRID_ONLINE_RATIO", "0.2")),
123
+ use_simulation_bridge=True,
124
+ simulation_bridge_url=os.getenv("SIMULATION_BRIDGE_URL", "http://localhost:3001"),
125
+ db_url=os.getenv("DATABASE_URL"),
126
+ )
127
+
128
+ server_configs = [
129
+ APIServerConfig(
130
+ model_name="Qwen/Qwen2.5-3B-Instruct",
131
+ base_url="http://localhost:9001/v1",
132
+ )
133
+ ]
134
+
135
+ return env_config, server_configs
136
+
137
+ async def setup(self):
138
+ """Initialize both offline and online components"""
139
+ from transformers import AutoTokenizer
140
+
141
+ logger.info("Setting up hybrid environment...")
142
+
143
+ # Load tokenizer
144
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name)
145
+
146
+ # Setup offline component (database)
147
+ if self.config.db_url:
148
+ await self._setup_offline()
149
+ else:
150
+ logger.warning("No DATABASE_URL set, hybrid will only use online rollouts")
151
+ self.online_ratio = 1.0
152
+
153
+ # Setup online component (simulation bridge)
154
+ if self.config.use_simulation_bridge:
155
+ await self._setup_online()
156
+ else:
157
+ logger.warning("Simulation bridge disabled, hybrid will only use offline rollouts")
158
+ self.online_ratio = 0.0
159
+
160
+ logger.info(f"Hybrid setup complete: online_ratio={self.online_ratio:.0%}, "
161
+ f"offline_trajectories={len(self.trajectory_cache)}, "
162
+ f"bridge_npcs={len(self.simulation_bridge.npc_ids) if self.simulation_bridge else 0}")
163
+
164
+ async def _setup_offline(self):
165
+ """Setup database connection and load trajectories"""
166
+ import asyncpg
167
+
168
+ logger.info("Connecting to database for offline trajectories...")
169
+
170
+ self.db_pool = await asyncpg.create_pool(
171
+ self.config.db_url,
172
+ min_size=2,
173
+ max_size=10,
174
+ )
175
+
176
+ # Load initial trajectory window
177
+ await self._load_trajectory_window()
178
+
179
+ if len(self.trajectory_cache) < self.config.min_trajectories:
180
+ logger.warning(f"Only {len(self.trajectory_cache)} trajectories in DB, "
181
+ f"need {self.config.min_trajectories}")
182
+
183
+ async def _load_trajectory_window(self):
184
+ """Load a window of trajectories from database"""
185
+ if not self.db_pool:
186
+ return
187
+
188
+ async with self.db_pool.acquire() as conn:
189
+ # Load trajectories with reasoning
190
+ rows = await conn.fetch("""
191
+ SELECT
192
+ id, archetype, scenario_context, model_response,
193
+ reasoning, metrics, created_at
194
+ FROM trajectories
195
+ WHERE model_response IS NOT NULL
196
+ ORDER BY created_at DESC
197
+ LIMIT $1
198
+ """, self.config.trajectory_window_size)
199
+
200
+ self.trajectory_cache = [dict(row) for row in rows]
201
+ self.current_cache_idx = 0
202
+
203
+ logger.info(f"Loaded {len(self.trajectory_cache)} trajectories from database")
204
+
205
+ async def _setup_online(self):
206
+ """Setup simulation bridge connection"""
207
+ logger.info(f"Connecting to simulation bridge at {self.config.simulation_bridge_url}...")
208
+
209
+ self.simulation_bridge = SimulationBridge(
210
+ base_url=self.config.simulation_bridge_url,
211
+ )
212
+ await self.simulation_bridge.__aenter__()
213
+
214
+ # Initialize with archetypes
215
+ archetypes = list(self.config.archetype_distribution.keys())
216
+ await self.simulation_bridge.initialize(
217
+ num_npcs=self.config.bridge_num_npcs,
218
+ archetypes=archetypes,
219
+ )
220
+
221
+ logger.info(f"Simulation bridge connected with {len(self.simulation_bridge.npc_ids)} NPCs")
222
+
223
+ async def get_next_item(self) -> Tuple[Any, str]:
224
+ """
225
+ Get next item for training.
226
+
227
+ Randomly decides between offline and online based on online_ratio.
228
+ """
229
+ self.iter += 1
230
+
231
+ # Decide online vs offline based on ratio
232
+ use_online = random.random() < self.online_ratio
233
+
234
+ # If online selected but not available, fall back to offline
235
+ if use_online and (not self.simulation_bridge or not self.simulation_bridge.is_initialized):
236
+ use_online = False
237
+
238
+ # If offline selected but no trajectories, use online
239
+ if not use_online and len(self.trajectory_cache) == 0:
240
+ use_online = True
241
+
242
+ if use_online:
243
+ self.online_count += 1
244
+ return await self._get_online_item()
245
+ else:
246
+ self.offline_count += 1
247
+ return self._get_offline_item()
248
+
249
+ async def _get_online_item(self) -> Tuple["PoolScenario", str]:
250
+ """Get a scenario from simulation bridge"""
251
+ from .scenario_pool import Scenario as PoolScenario, PortfolioState
252
+
253
+ npc_ids = self.simulation_bridge.npc_ids
254
+ npc_id = npc_ids[self._bridge_npc_index % len(npc_ids)]
255
+ self._bridge_npc_index += 1
256
+
257
+ bridge_scenario = await self.simulation_bridge.get_scenario(npc_id)
258
+ archetype = bridge_scenario.archetype
259
+
260
+ # Convert to Scenario format used by scoring
261
+ scenario = PoolScenario(
262
+ id=f"bridge-{npc_id}-{self.iter}",
263
+ source="production",
264
+ archetype_focus=archetype,
265
+ difficulty="medium",
266
+ portfolio=PortfolioState(
267
+ balance=bridge_scenario.balance,
268
+ positions=[],
269
+ ),
270
+ )
271
+
272
+ # Add market data from bridge
273
+ for m in bridge_scenario.market_state.prediction_markets:
274
+ scenario.add_market({
275
+ "id": m.id,
276
+ "question": m.question,
277
+ "yesPrice": m.yes_price,
278
+ "noPrice": m.no_price,
279
+ })
280
+
281
+ for m in bridge_scenario.market_state.perp_markets:
282
+ scenario.add_perpetual({
283
+ "ticker": m.ticker,
284
+ "markPrice": m.current_price,
285
+ "change24h": m.change_percent_24h,
286
+ })
287
+
288
+ # Store bridge scenario for action execution
289
+ scenario.metadata["bridge_scenario"] = bridge_scenario
290
+ scenario.metadata["npc_id"] = npc_id
291
+ scenario.metadata["mode"] = "online"
292
+
293
+ return (scenario, archetype)
294
+
295
+ def _get_offline_item(self) -> Tuple[Dict, str]:
296
+ """Get a trajectory from cached database trajectories"""
297
+ if not self.trajectory_cache:
298
+ raise RuntimeError("No trajectories in cache")
299
+
300
+ # Round-robin through cache
301
+ traj = self.trajectory_cache[self.current_cache_idx]
302
+ self.current_cache_idx = (self.current_cache_idx + 1) % len(self.trajectory_cache)
303
+
304
+ archetype = traj.get("archetype", "trader")
305
+
306
+ # Add source metadata
307
+ traj_copy = copy.deepcopy(traj)
308
+ traj_copy["source"] = "offline"
309
+
310
+ return (traj_copy, archetype)
311
+
312
+ async def collect_trajectories(self, item: Tuple[Any, str]) -> Tuple[Optional[ScoredDataGroup], List]:
313
+ """
314
+ Collect and score trajectories.
315
+
316
+ Delegates to appropriate handler based on item source.
317
+ """
318
+ data, archetype = item
319
+
320
+ # Check if it's a Scenario (online) or Dict (offline)
321
+ if hasattr(data, "metadata") and data.metadata.get("source") == "online":
322
+ return await self._collect_online(data, archetype)
323
+ else:
324
+ return await self._collect_offline(data, archetype)
325
+
326
+ async def _collect_online(self, scenario: "Scenario", archetype: str) -> Tuple[Optional[ScoredDataGroup], List]:
327
+ """Collect online rollouts via simulation bridge"""
328
+ from .online_env import build_trading_system_prompt, build_observation_prompt
329
+ from .quality_scorer import score_response
330
+ from .format_validator import validate_response_format
331
+
332
+ # Build messages
333
+ system_prompt = build_trading_system_prompt(archetype)
334
+ user_prompt = build_observation_prompt(scenario)
335
+
336
+ messages = [
337
+ {"role": "system", "content": system_prompt},
338
+ {"role": "user", "content": user_prompt},
339
+ ]
340
+
341
+ # Generate completions using managed_server
342
+ async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
343
+ chat_completions = await managed.chat_completion(
344
+ messages=messages,
345
+ n=self.config.group_size,
346
+ max_tokens=self.config.max_response_tokens,
347
+ temperature=self.config.temperature,
348
+ )
349
+
350
+ state = managed.get_state()
351
+ nodes = state["nodes"]
352
+
353
+ if not nodes or len(nodes) < 2:
354
+ logger.warning("Insufficient nodes from managed_server")
355
+ return None, []
356
+
357
+ # Process and score completions
358
+ rollout_data = []
359
+ for i, choice in enumerate(chat_completions.choices):
360
+ if i >= len(nodes):
361
+ break
362
+
363
+ node = nodes[i]
364
+ response_content = choice.message.content or ""
365
+
366
+ # Score the response
367
+ quality = score_response(
368
+ response=response_content,
369
+ archetype=archetype,
370
+ execute_action=False,
371
+ )
372
+
373
+ format_result = validate_response_format(response_content)
374
+
375
+ # Calculate final score
376
+ base_score = quality.combined_format_score * 0.4 + quality.reasoning_score * 0.3
377
+ action_bonus = 0.3 if format_result.is_valid else 0.0
378
+ final_score = base_score + action_bonus
379
+
380
+ rollout_data.append({
381
+ "tokens": node.tokens,
382
+ "masks": node.masked_tokens,
383
+ "score": final_score,
384
+ })
385
+
386
+ # Center scores
387
+ scores = [r["score"] for r in rollout_data]
388
+ mean_score = sum(scores) / len(scores)
389
+
390
+ # Build ScoredDataGroup
391
+ scored_group = ScoredDataGroup(
392
+ tokens=[r["tokens"] for r in rollout_data],
393
+ masks=[r["masks"] for r in rollout_data],
394
+ scores=[s - mean_score for s in scores],
395
+ )
396
+
397
+ return scored_group, []
398
+
399
+ async def _collect_offline(self, traj: Dict, archetype: str) -> Tuple[Optional[ScoredDataGroup], List]:
400
+ """Collect offline rollouts from database trajectory"""
401
+ from .rewards import archetype_composite_reward, BehaviorMetrics
402
+ from .quality_scorer import score_response
403
+ from .format_validator import validate_response_format
404
+ from .tokenization_utils import tokenize_for_trainer
405
+
406
+ # Build messages from trajectory
407
+ scenario_context = traj.get("scenario_context", {})
408
+ model_response = traj.get("model_response", "")
409
+
410
+ if not model_response:
411
+ return None, []
412
+
413
+ # Build chat messages
414
+ messages = [
415
+ {"role": "system", "content": f"You are a {archetype} trading agent."},
416
+ {"role": "user", "content": str(scenario_context)},
417
+ {"role": "assistant", "content": model_response},
418
+ ]
419
+
420
+ # Get vLLM URL for generation
421
+ vllm_base_url = self._server_configs[0].base_url if self._server_configs else "http://localhost:9001/v1"
422
+ model_name = self.config.tokenizer_name
423
+
424
+ # Generate N completions for the same prompt
425
+ import aiohttp
426
+
427
+ prompt_messages = messages[:-1] # Exclude assistant response
428
+
429
+ async with aiohttp.ClientSession() as session:
430
+ async with session.post(
431
+ f"{vllm_base_url}/chat/completions",
432
+ json={
433
+ "model": model_name,
434
+ "messages": prompt_messages,
435
+ "max_tokens": 512,
436
+ "n": self.config.group_size,
437
+ "temperature": 0.7,
438
+ },
439
+ ) as resp:
440
+ if resp.status != 200:
441
+ logger.warning(f"vLLM request failed: {resp.status}")
442
+ return None, []
443
+ result = await resp.json()
444
+
445
+ choices = result.get("choices", [])
446
+ if len(choices) < 2:
447
+ return None, []
448
+
449
+ # Score each completion
450
+ rollout_data = []
451
+ for choice in choices:
452
+ response_content = choice.get("message", {}).get("content", "")
453
+
454
+ # Build full messages
455
+ full_messages = copy.deepcopy(prompt_messages)
456
+ full_messages.append({"role": "assistant", "content": response_content})
457
+
458
+ # Tokenize with proper masking
459
+ token_result = tokenize_for_trainer(
460
+ self.tokenizer,
461
+ full_messages,
462
+ train_on_all_assistant_turns=True,
463
+ )
464
+
465
+ # Score
466
+ quality = score_response(
467
+ response=response_content,
468
+ archetype=archetype,
469
+ execute_action=False,
470
+ )
471
+
472
+ format_result = validate_response_format(response_content)
473
+
474
+ base_score = quality.combined_format_score * 0.4 + quality.reasoning_score * 0.3
475
+ action_bonus = 0.3 if format_result.is_valid else 0.0
476
+ final_score = base_score + action_bonus
477
+
478
+ rollout_data.append({
479
+ "tokens": token_result["input_ids"],
480
+ "masks": token_result["masks"],
481
+ "score": final_score,
482
+ })
483
+
484
+ # Center scores and add small noise to prevent identical scores
485
+ scores = [r["score"] + random.uniform(-0.01, 0.01) for r in rollout_data]
486
+ mean_score = sum(scores) / len(scores)
487
+
488
+ scored_group = ScoredDataGroup(
489
+ tokens=[r["tokens"] for r in rollout_data],
490
+ masks=[r["masks"] for r in rollout_data],
491
+ scores=[s - mean_score for s in scores],
492
+ )
493
+
494
+ return scored_group, []
495
+
496
+ async def cleanup(self):
497
+ """Clean up resources"""
498
+ if self.simulation_bridge:
499
+ logger.info("Cleaning up simulation bridge...")
500
+ await self.simulation_bridge.reset()
501
+ await self.simulation_bridge.__aexit__(None, None, None)
502
+ self.simulation_bridge = None
503
+
504
+ if self.db_pool:
505
+ logger.info("Closing database pool...")
506
+ await self.db_pool.close()
507
+ self.db_pool = None
508
+
509
+ logger.info(f"Hybrid stats: online={self.online_count}, offline={self.offline_count}")
510
+
511
+ async def evaluate(self):
512
+ """Periodic evaluation logging"""
513
+ total = self.online_count + self.offline_count
514
+ if total > 0:
515
+ actual_online_ratio = self.online_count / total
516
+ logger.info(f"Hybrid stats: total={total}, online={self.online_count} ({actual_online_ratio:.1%}), "
517
+ f"offline={self.offline_count} ({1-actual_online_ratio:.1%})")
518
+
519
+
520
+ if __name__ == "__main__":
521
+ BabylonHybridEnv.cli()
522
+