@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,17 @@
1
+ """
2
+ Compatibility module exposing neutral RLAIF environment names.
3
+ """
4
+
5
+ from .babylon_env import (
6
+ RLAIFEnv,
7
+ RLAIFEnvConfig,
8
+ BabylonRLAIFEnv,
9
+ BabylonEnvConfig,
10
+ )
11
+
12
+ __all__ = [
13
+ "RLAIFEnv",
14
+ "RLAIFEnvConfig",
15
+ "BabylonRLAIFEnv",
16
+ "BabylonEnvConfig",
17
+ ]
@@ -0,0 +1,502 @@
1
+ """
2
+ Babylon Fast Rollout Generator
3
+
4
+ Generates high-quality rollouts at maximum speed for RL training.
5
+ Captures the COMPLETE agent tick including all thinking, planning, and execution.
6
+
7
+ A complete agent tick consists of:
8
+ 1. Environment Observation - What the agent sees
9
+ 2. Thinking/Reasoning - Internal deliberation
10
+ 3. Planning - What actions to take
11
+ 4. Action Execution - The actual action
12
+ 5. Feedback - Result and reward
13
+
14
+ We need to capture ALL of this for training.
15
+ """
16
+
17
+ import asyncio
18
+ import json
19
+ import logging
20
+ from dataclasses import dataclass, field
21
+ from typing import Callable, Protocol
22
+ import time
23
+
24
+ from ..models import (
25
+ BabylonTrajectory,
26
+ LLMCall,
27
+ Action,
28
+ EnvironmentState,
29
+ )
30
+ from .quality_utils import (
31
+ calculate_trajectory_quality_score,
32
+ build_trajectory_from_ticks,
33
+ state_to_observation,
34
+ state_to_env_state,
35
+ calculate_detailed_tick_quality
36
+ )
37
+
38
+ from .rewards import TrajectoryRewardInputs, composite_reward, calculate_risk_reward
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ @dataclass
44
+ class AgentTickData:
45
+ """
46
+ Complete data for a single agent tick.
47
+
48
+ This captures EVERYTHING the agent does in one tick:
49
+ - All observations received
50
+ - All LLM calls made (thinking, planning, action)
51
+ - The final action taken
52
+ - Feedback received
53
+ """
54
+ tick_number: int
55
+ timestamp: int
56
+
57
+ # Environment observation
58
+ observation: dict
59
+ environment_state: EnvironmentState
60
+
61
+ # All LLM calls during this tick
62
+ llm_calls: list[LLMCall] = field(default_factory=list)
63
+
64
+ # The reasoning chain (concatenated thinking)
65
+ reasoning_chain: str = ""
66
+
67
+ # Final action
68
+ action: Action | None = None
69
+
70
+ # Feedback from environment
71
+ feedback: dict = field(default_factory=dict)
72
+ reward: float = 0.0
73
+
74
+ def get_full_context(self) -> str:
75
+ """Get the complete context string for this tick"""
76
+ parts = []
77
+
78
+ # Observation
79
+ parts.append(f"=== OBSERVATION (Tick {self.tick_number}) ===")
80
+ parts.append(json.dumps(self.observation, indent=2))
81
+
82
+ # All LLM calls in order
83
+ for i, call in enumerate(self.llm_calls, 1):
84
+ parts.append(f"\n=== LLM CALL {i} ({call.purpose}) ===")
85
+ parts.append(f"System: {call.system_prompt}")
86
+ parts.append(f"User: {call.user_prompt}")
87
+ parts.append(f"Response: {call.response}")
88
+ if call.reasoning:
89
+ parts.append(f"Reasoning: {call.reasoning}")
90
+
91
+ # Action
92
+ if self.action:
93
+ parts.append("\n=== ACTION ===")
94
+ parts.append(f"Type: {self.action.action_type}")
95
+ parts.append(f"Parameters: {json.dumps(self.action.parameters)}")
96
+ if self.action.reasoning:
97
+ parts.append(f"Reasoning: {self.action.reasoning}")
98
+
99
+ # Feedback
100
+ if self.feedback:
101
+ parts.append("\n=== FEEDBACK ===")
102
+ parts.append(json.dumps(self.feedback, indent=2))
103
+ parts.append(f"Reward: {self.reward}")
104
+
105
+ return "\n".join(parts)
106
+
107
+
108
+ @dataclass
109
+ class RolloutConfig:
110
+ """Configuration for rollout generation"""
111
+
112
+ # Speed settings
113
+ fast_forward: bool = True
114
+ parallel_agents: int = 4
115
+ max_ticks_per_agent: int = 100
116
+
117
+ # Quality settings
118
+ min_llm_calls_per_tick: int = 1
119
+ require_action: bool = True
120
+
121
+ # Database settings
122
+ database_url: str = ""
123
+
124
+
125
+ @dataclass
126
+ class RolloutResult:
127
+ """Result of a rollout generation run"""
128
+ agent_id: str
129
+ trajectory_id: str
130
+ ticks_completed: int
131
+ total_duration_ms: int
132
+ avg_tick_duration_ms: float
133
+ total_llm_calls: int
134
+ total_reward: float
135
+ final_pnl: float
136
+ quality_score: float # 0-1 based on completeness
137
+ trajectory: BabylonTrajectory | None = None
138
+
139
+
140
+ class AgentRunner(Protocol):
141
+ """Protocol for agent implementations - consistent with FastSimulator"""
142
+
143
+ async def run_tick(
144
+ self,
145
+ agent_id: str,
146
+ observation: dict,
147
+ env_state: EnvironmentState,
148
+ ) -> AgentTickData:
149
+ """Run a single tick and return complete tick data"""
150
+ ...
151
+
152
+
153
+ class FastRolloutGenerator:
154
+ """
155
+ Generates rollouts at maximum speed while maintaining quality.
156
+
157
+ Key features:
158
+ - Fast-forward mode skips all waiting
159
+ - Parallel agent execution
160
+ - Complete tick capture (observation → thinking → action → feedback)
161
+ - Quality validation
162
+ """
163
+
164
+ def __init__(self, config: RolloutConfig):
165
+ self.config = config
166
+ self.rollouts_generated = 0
167
+ self.total_ticks = 0
168
+ self.start_time: float | None = None
169
+
170
+ async def generate_rollout(
171
+ self,
172
+ agent: AgentRunner,
173
+ agent_id: str,
174
+ simulation, # SimulationEngine instance
175
+ ) -> RolloutResult:
176
+ """
177
+ Generate a single rollout from an agent running through simulation.
178
+
179
+ Args:
180
+ agent: Agent implementation
181
+ agent_id: Unique agent identifier
182
+ simulation: Simulation engine providing environment
183
+
184
+ Returns:
185
+ RolloutResult with complete trajectory data
186
+ """
187
+ start_time = time.time()
188
+ tick_durations: list[float] = []
189
+ all_ticks: list[AgentTickData] = []
190
+ total_llm_calls = 0
191
+ total_reward = 0.0
192
+
193
+ trajectory_id = f"rollout-{agent_id}-{int(start_time * 1000)}"
194
+
195
+ logger.info(f"Starting rollout generation for agent {agent_id}")
196
+
197
+ # Run through simulation
198
+ tick_number = 0
199
+ while not simulation.isComplete() and tick_number < self.config.max_ticks_per_agent:
200
+ tick_start = time.time()
201
+
202
+ # Get observation from simulation
203
+ game_state = simulation.getGameState()
204
+ observation = state_to_observation(game_state)
205
+ env_state = state_to_env_state(game_state, agent_id)
206
+
207
+ # Agent processes tick (captures all LLM calls)
208
+ tick_data = await agent.run_tick(agent_id, observation, env_state)
209
+ tick_data.tick_number = tick_number
210
+ tick_data.timestamp = int(time.time() * 1000)
211
+
212
+ # Execute action in simulation if provided
213
+ if tick_data.action and tick_data.action.action_type != "wait":
214
+ result = await simulation.performAction(
215
+ tick_data.action.action_type,
216
+ tick_data.action.parameters,
217
+ )
218
+ tick_data.feedback = result
219
+ tick_data.action.success = result.get("success", False)
220
+ tick_data.action.result = result.get("result")
221
+ tick_data.action.error = result.get("error")
222
+
223
+ # Calculate reward for this tick using The Judge logic
224
+ tick_data.reward = self._calculate_tick_reward(
225
+ tick_data, env_state)
226
+ total_reward += tick_data.reward
227
+
228
+ # Validate tick quality
229
+ if not self._validate_tick_quality(tick_data):
230
+ logger.warning(f"Tick {tick_number} failed quality check")
231
+
232
+ # Store tick
233
+ all_ticks.append(tick_data)
234
+ total_llm_calls += len(tick_data.llm_calls)
235
+
236
+ # Track timing
237
+ tick_duration = time.time() - tick_start
238
+ tick_durations.append(tick_duration)
239
+
240
+ # Advance simulation (no artificial delay in fast-forward mode)
241
+ simulation.advanceTick()
242
+ tick_number += 1
243
+
244
+ # Log progress periodically
245
+ if tick_number % 50 == 0:
246
+ avg_tick = sum(tick_durations[-50:]) / \
247
+ min(50, len(tick_durations))
248
+ logger.info(
249
+ f"Tick {tick_number}: avg {avg_tick*1000:.1f}ms/tick")
250
+
251
+ total_duration_ms = int((time.time() - start_time) * 1000)
252
+ avg_tick_duration = sum(tick_durations) / \
253
+ len(tick_durations) if tick_durations else 0
254
+
255
+ # Build trajectory from ticks
256
+ trajectory = build_trajectory_from_ticks(
257
+ trajectory_id=trajectory_id,
258
+ agent_id=agent_id,
259
+ ticks=all_ticks,
260
+ min_steps=1,
261
+ )
262
+
263
+ # Calculate quality score
264
+ quality_score = calculate_trajectory_quality_score(all_ticks)
265
+
266
+ result = RolloutResult(
267
+ agent_id=agent_id,
268
+ trajectory_id=trajectory_id,
269
+ ticks_completed=tick_number,
270
+ total_duration_ms=total_duration_ms,
271
+ avg_tick_duration_ms=avg_tick_duration * 1000,
272
+ total_llm_calls=total_llm_calls,
273
+ total_reward=total_reward,
274
+ final_pnl=trajectory.final_pnl if trajectory else 0.0,
275
+ quality_score=quality_score,
276
+ trajectory=trajectory,
277
+ )
278
+
279
+ self.rollouts_generated += 1
280
+ self.total_ticks += tick_number
281
+
282
+ logger.info(
283
+ f"Rollout complete: {tick_number} ticks in {total_duration_ms}ms "
284
+ f"({avg_tick_duration*1000:.1f}ms/tick), quality={quality_score:.2f}"
285
+ )
286
+
287
+ return result
288
+
289
+ async def generate_parallel_rollouts(
290
+ self,
291
+ agents: list[tuple[AgentRunner, str]], # (agent, agent_id)
292
+ simulation_factory: Callable,
293
+ ) -> list[RolloutResult]:
294
+ """
295
+ Generate multiple rollouts in parallel.
296
+
297
+ Args:
298
+ agents: List of (agent, agent_id) tuples
299
+ simulation_factory: Factory to create simulation instances
300
+
301
+ Returns:
302
+ List of RolloutResults
303
+ """
304
+ self.start_time = time.time()
305
+
306
+ logger.info(
307
+ f"Starting parallel rollout generation for {len(agents)} agents")
308
+
309
+ # Create tasks for each agent
310
+ tasks = []
311
+ for agent, agent_id in agents:
312
+ simulation = simulation_factory()
313
+ simulation.initialize()
314
+ tasks.append(self.generate_rollout(agent, agent_id, simulation))
315
+
316
+ # Run all in parallel
317
+ results = await asyncio.gather(*tasks, return_exceptions=True)
318
+
319
+ # Filter out errors
320
+ valid_results = []
321
+ for r in results:
322
+ if isinstance(r, Exception):
323
+ logger.error(f"Rollout failed: {r}")
324
+ else:
325
+ valid_results.append(r)
326
+
327
+ total_time = time.time() - self.start_time
328
+ logger.info(
329
+ f"Parallel rollout generation complete: "
330
+ f"{len(valid_results)}/{len(agents)} succeeded, "
331
+ f"{self.total_ticks} total ticks in {total_time:.1f}s "
332
+ f"({self.total_ticks/total_time:.1f} ticks/s)"
333
+ )
334
+
335
+ return valid_results
336
+
337
+ def _calculate_tick_reward(
338
+ self,
339
+ tick_data: AgentTickData,
340
+ env_state: EnvironmentState,
341
+ ) -> float:
342
+ """
343
+ Calculate reward for a single tick.
344
+
345
+ Combines:
346
+ 1. Financial Performance (PnL)
347
+ 2. Format Compliance (XML validation)
348
+ 3. Reasoning Alignment (Financial Literacy)
349
+ 4. Risk Management (Exposure penalties)
350
+ """
351
+ # 1. Quality Scores (Format & Reasoning)
352
+ fmt_score, rsn_score = calculate_detailed_tick_quality(
353
+ tick_data.llm_calls,
354
+ tick_data.action,
355
+ tick_data.feedback
356
+ )
357
+
358
+ # 2. Risk Calculation
359
+ # Exposure proxy: active positions / max reasonable positions (e.g. 10)
360
+ # Or ideally use a dedicated exposure field if available in env_state
361
+ exposure = min(1.0, env_state.open_positions / 10.0)
362
+
363
+ action_type = "wait"
364
+ if tick_data.action:
365
+ action_type = tick_data.action.action_type
366
+
367
+ risk_penalty_count = 0
368
+ if calculate_risk_reward(exposure, action_type) < 0:
369
+ risk_penalty_count = 1
370
+
371
+ # 3. Financials (PnL for this tick)
372
+ pnl = tick_data.feedback.get("pnl_delta", 0.0)
373
+
374
+ # Build Inputs for Composite Reward
375
+ inputs = TrajectoryRewardInputs(
376
+ final_pnl=pnl,
377
+ starting_balance=env_state.agent_balance,
378
+ end_balance=env_state.agent_balance + pnl,
379
+ format_score=fmt_score,
380
+ reasoning_score=rsn_score,
381
+ risky_actions_count=risk_penalty_count,
382
+ total_actions=1 if tick_data.action else 0,
383
+ successful_actions=1 if tick_data.action and tick_data.action.success else 0
384
+ )
385
+
386
+ return composite_reward(inputs)
387
+
388
+ def _validate_tick_quality(self, tick_data: AgentTickData) -> bool:
389
+ """Validate that tick data meets quality requirements"""
390
+ # Must have LLM calls if configured
391
+ if self.config.min_llm_calls_per_tick > 0:
392
+ if len(tick_data.llm_calls) < self.config.min_llm_calls_per_tick:
393
+ return False
394
+
395
+ # Must have action if configured
396
+ if self.config.require_action and tick_data.action is None:
397
+ return False
398
+
399
+ # LLM calls must have non-empty responses
400
+ for call in tick_data.llm_calls:
401
+ if not call.response or len(call.response.strip()) == 0:
402
+ return False
403
+
404
+ return True
405
+
406
+
407
+ class RolloutQualityValidator:
408
+ """
409
+ Validates that rollouts meet quality standards for training.
410
+ """
411
+
412
+ @staticmethod
413
+ def validate_rollout(result: RolloutResult) -> tuple[bool, list[str]]:
414
+ """
415
+ Validate a rollout result.
416
+
417
+ Returns:
418
+ (is_valid, list of issues)
419
+ """
420
+ issues = []
421
+
422
+ # Must have trajectory
423
+ if result.trajectory is None:
424
+ issues.append("No trajectory data")
425
+ return False, issues
426
+
427
+ # Minimum ticks
428
+ if result.ticks_completed < 5:
429
+ issues.append(f"Too few ticks: {result.ticks_completed} < 5")
430
+
431
+ # Must have LLM calls
432
+ if result.total_llm_calls < result.ticks_completed:
433
+ issues.append(
434
+ f"Low LLM call rate: {result.total_llm_calls} calls for {result.ticks_completed} ticks")
435
+
436
+ # Quality score threshold
437
+ if result.quality_score < 0.5:
438
+ issues.append(
439
+ f"Quality score too low: {result.quality_score:.2f} < 0.5")
440
+
441
+ # Check trajectory steps
442
+ traj = result.trajectory
443
+ for i, step in enumerate(traj.steps):
444
+ # Each step should have LLM calls
445
+ if not step.llm_calls:
446
+ issues.append(f"Step {i} has no LLM calls")
447
+
448
+ # Each LLM call should have content
449
+ for j, call in enumerate(step.llm_calls):
450
+ if not call.user_prompt or not call.response:
451
+ issues.append(
452
+ f"Step {i}, call {j}: missing prompt or response")
453
+ if not call.system_prompt:
454
+ issues.append(f"Step {i}, call {j}: missing system prompt")
455
+
456
+ is_valid = len(issues) == 0
457
+ return is_valid, issues
458
+
459
+ @staticmethod
460
+ def print_quality_report(results: list[RolloutResult]) -> None:
461
+ """Print a quality report for a batch of rollouts"""
462
+ print("\n" + "=" * 60)
463
+ print(" ROLLOUT QUALITY REPORT")
464
+ print("=" * 60)
465
+
466
+ total = len(results)
467
+ valid_count = 0
468
+ total_ticks = 0
469
+ total_llm_calls = 0
470
+ total_quality = 0.0
471
+ all_issues: list[str] = []
472
+
473
+ for result in results:
474
+ is_valid, issues = RolloutQualityValidator.validate_rollout(result)
475
+ if is_valid:
476
+ valid_count += 1
477
+ all_issues.extend(issues)
478
+
479
+ total_ticks += result.ticks_completed
480
+ total_llm_calls += result.total_llm_calls
481
+ total_quality += result.quality_score
482
+
483
+ print(
484
+ f"\nValid rollouts: {valid_count}/{total} ({valid_count/total*100:.1f}%)")
485
+ print(f"Total ticks: {total_ticks}")
486
+ print(f"Total LLM calls: {total_llm_calls}")
487
+ print(f"Average quality score: {total_quality/total:.2f}")
488
+ print(f"LLM calls per tick: {total_llm_calls/total_ticks:.1f}")
489
+
490
+ if all_issues:
491
+ print(f"\nIssues found ({len(all_issues)} total):")
492
+ # Group and count issues
493
+ issue_counts: dict[str, int] = {}
494
+ for issue in all_issues:
495
+ # Normalize issue text
496
+ key = issue.split(":")[0] if ":" in issue else issue
497
+ issue_counts[key] = issue_counts.get(key, 0) + 1
498
+
499
+ for issue, count in sorted(issue_counts.items(), key=lambda x: -x[1])[:10]:
500
+ print(f" - {issue}: {count} occurrences")
501
+
502
+ print("=" * 60 + "\n")