@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,656 @@
1
+ """
2
+ Multi-Turn Episode Manager for GRPO Training
3
+
4
+ ⚠️ STATUS: NOT YET INTEGRATED
5
+ This module is ready to use but not currently called by babylon_env.py or online_env.py.
6
+ To integrate, see TRAINING_ROADMAP.md Phase 4.
7
+
8
+ Handles multi-turn trading episodes with proper credit assignment
9
+ using Generalized Advantage Estimation (GAE).
10
+
11
+ For trading scenarios:
12
+ - A single action's value depends on future market movements
13
+ - Subsequent actions affect overall trajectory value
14
+ - Final episode outcome determines success
15
+
16
+ This module provides:
17
+ 1. TurnData - Structure for individual turn data
18
+ 2. EpisodeBuffer - Buffer for collecting episode turns
19
+ 3. MultiTurnEpisodeManager - GAE-based advantage computation
20
+ 4. Trajectory utilities for GRPO training
21
+
22
+ Usage:
23
+ manager = MultiTurnEpisodeManager(gamma=0.99, gae_lambda=0.95)
24
+
25
+ # Collect episode turns
26
+ episode = EpisodeBuffer(scenario_id="test-1")
27
+ for turn in range(max_turns):
28
+ turn_data = TurnData(...)
29
+ episode.add_turn(turn_data)
30
+ if done:
31
+ break
32
+
33
+ # Compute advantages
34
+ manager.compute_advantages(episode.turns)
35
+
36
+ # Create training items
37
+ training_items = manager.create_training_items(episode.turns, tokenizer)
38
+ """
39
+
40
+ import logging
41
+ from dataclasses import dataclass, field
42
+ from datetime import datetime, timezone
43
+ from typing import Any, Dict, List, Optional, Tuple
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ # =============================================================================
49
+ # Turn and Episode Data Structures
50
+ # =============================================================================
51
+
52
+
53
+ @dataclass
54
+ class TurnData:
55
+ """
56
+ Data for a single turn in a multi-turn episode.
57
+
58
+ Contains all information needed for training on this turn,
59
+ including computed advantages for GRPO.
60
+ """
61
+ # Turn identification
62
+ turn_number: int
63
+ episode_id: str = ""
64
+
65
+ # State at this turn
66
+ state: Dict[str, Any] = field(default_factory=dict)
67
+ observation: str = ""
68
+
69
+ # Action taken
70
+ action: Dict[str, Any] = field(default_factory=dict)
71
+ action_text: str = ""
72
+ action_type: str = ""
73
+
74
+ # Messages up to this point
75
+ messages: List[Dict[str, str]] = field(default_factory=list)
76
+
77
+ # Reward received after action
78
+ reward: float = 0.0
79
+
80
+ # Quality scores
81
+ format_score: float = 0.0
82
+ reasoning_score: float = 0.0
83
+
84
+ # Episode termination
85
+ done: bool = False
86
+ termination_reason: str = ""
87
+
88
+ # Computed values (filled in by GAE)
89
+ value: float = 0.0
90
+ advantage: float = 0.0
91
+ return_to_go: float = 0.0
92
+
93
+ # Token data for training
94
+ tokens: List[int] = field(default_factory=list)
95
+ masks: List[int] = field(default_factory=list)
96
+ logprobs: List[float] = field(default_factory=list)
97
+
98
+ # Metadata
99
+ timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
100
+
101
+ def to_dict(self) -> Dict:
102
+ """Convert to dictionary for serialization"""
103
+ return {
104
+ "turn_number": self.turn_number,
105
+ "episode_id": self.episode_id,
106
+ "action_type": self.action_type,
107
+ "action_text": self.action_text[:200], # Truncate for logging
108
+ "reward": round(self.reward, 4),
109
+ "format_score": round(self.format_score, 3),
110
+ "reasoning_score": round(self.reasoning_score, 3),
111
+ "done": self.done,
112
+ "value": round(self.value, 4),
113
+ "advantage": round(self.advantage, 4),
114
+ "return_to_go": round(self.return_to_go, 4),
115
+ }
116
+
117
+
118
+ @dataclass
119
+ class EpisodeBuffer:
120
+ """
121
+ Buffer for collecting turns during an episode.
122
+
123
+ Tracks episode-level metrics and provides utilities
124
+ for episode analysis.
125
+ """
126
+ # Episode identification
127
+ episode_id: str
128
+ scenario_id: str = ""
129
+ archetype: str = "trader"
130
+
131
+ # Turns collected
132
+ turns: List[TurnData] = field(default_factory=list)
133
+
134
+ # Episode-level metrics (computed after completion)
135
+ total_reward: float = 0.0
136
+ total_pnl: float = 0.0
137
+ episode_length: int = 0
138
+ completed: bool = False
139
+ success: bool = False
140
+
141
+ # Timing
142
+ start_time: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
143
+ end_time: Optional[datetime] = None
144
+
145
+ def add_turn(self, turn: TurnData) -> None:
146
+ """Add a turn to the buffer"""
147
+ turn.episode_id = self.episode_id
148
+ self.turns.append(turn)
149
+
150
+ if turn.done:
151
+ self._finalize()
152
+
153
+ def _finalize(self) -> None:
154
+ """Finalize episode after completion"""
155
+ self.completed = True
156
+ self.end_time = datetime.now(timezone.utc)
157
+ self.episode_length = len(self.turns)
158
+ self.total_reward = sum(t.reward for t in self.turns)
159
+
160
+ # Check success based on final reward
161
+ if self.turns:
162
+ final_turn = self.turns[-1]
163
+ # Episode is successful if total reward is positive
164
+ self.success = self.total_reward > 0
165
+
166
+ def get_messages(self) -> List[Dict[str, str]]:
167
+ """Get all messages from the episode"""
168
+ if self.turns:
169
+ return self.turns[-1].messages.copy()
170
+ return []
171
+
172
+ def get_trajectory(self) -> List[Tuple[str, float, str]]:
173
+ """Get (action_type, reward, action_text) for each turn"""
174
+ return [
175
+ (t.action_type, t.reward, t.action_text[:100])
176
+ for t in self.turns
177
+ ]
178
+
179
+ def to_dict(self) -> Dict:
180
+ """Convert to dictionary for logging"""
181
+ return {
182
+ "episode_id": self.episode_id,
183
+ "scenario_id": self.scenario_id,
184
+ "archetype": self.archetype,
185
+ "episode_length": self.episode_length,
186
+ "total_reward": round(self.total_reward, 4),
187
+ "total_pnl": round(self.total_pnl, 2),
188
+ "completed": self.completed,
189
+ "success": self.success,
190
+ "turns": [t.to_dict() for t in self.turns],
191
+ }
192
+
193
+
194
+ # =============================================================================
195
+ # Multi-Turn Episode Manager
196
+ # =============================================================================
197
+
198
+
199
+ @dataclass
200
+ class GAEConfig:
201
+ """Configuration for GAE computation"""
202
+ gamma: float = 0.99 # Discount factor
203
+ gae_lambda: float = 0.95 # GAE lambda
204
+ normalize_advantages: bool = True # Normalize across batch
205
+ clip_advantages: bool = True # Clip extreme advantages
206
+ advantage_clip: float = 10.0 # Clip threshold
207
+ use_value_normalization: bool = True # Normalize values
208
+
209
+
210
+ class MultiTurnEpisodeManager:
211
+ """
212
+ Manages multi-turn trading episodes with GAE-based credit assignment.
213
+
214
+ For trading scenarios, a single action's value depends on:
215
+ - Future market movements
216
+ - Subsequent actions
217
+ - Final episode outcome
218
+
219
+ This manager computes proper advantages for each turn to enable
220
+ credit assignment in multi-turn settings.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ gamma: float = 0.99,
226
+ gae_lambda: float = 0.95,
227
+ max_turns: int = 20,
228
+ normalize_advantages: bool = True,
229
+ ):
230
+ self.config = GAEConfig(
231
+ gamma=gamma,
232
+ gae_lambda=gae_lambda,
233
+ normalize_advantages=normalize_advantages,
234
+ )
235
+ self.max_turns = max_turns
236
+
237
+ # Statistics
238
+ self._episodes_processed = 0
239
+ self._total_turns = 0
240
+
241
+ def compute_advantages(self, turns: List[TurnData]) -> None:
242
+ """
243
+ Compute GAE advantages for each turn in an episode.
244
+
245
+ Modifies turns in-place to add:
246
+ - value: Estimated value of the state
247
+ - return_to_go: Cumulative discounted future reward
248
+ - advantage: GAE advantage estimate
249
+
250
+ Args:
251
+ turns: List of turns in chronological order
252
+ """
253
+ if not turns:
254
+ return
255
+
256
+ # Step 1: Compute return-to-go (cumulative discounted reward)
257
+ # Going backwards from the end
258
+ cumulative = 0.0
259
+ for turn in reversed(turns):
260
+ cumulative = turn.reward + self.config.gamma * cumulative
261
+ turn.return_to_go = cumulative
262
+
263
+ # Step 2: Estimate values
264
+ # Simple approach: value = return_to_go (Monte Carlo estimate)
265
+ for turn in turns:
266
+ turn.value = turn.return_to_go
267
+
268
+ # Step 3: Compute GAE advantages
269
+ next_value = 0.0
270
+ gae = 0.0
271
+
272
+ for turn in reversed(turns):
273
+ # TD error
274
+ if turn.done:
275
+ next_value = 0.0
276
+
277
+ delta = turn.reward + self.config.gamma * next_value - turn.value
278
+
279
+ # GAE accumulation
280
+ gae = delta + self.config.gamma * self.config.gae_lambda * gae
281
+ turn.advantage = gae
282
+
283
+ next_value = turn.value
284
+
285
+ # Step 4: Clip extreme advantages
286
+ if self.config.clip_advantages:
287
+ for turn in turns:
288
+ turn.advantage = max(
289
+ -self.config.advantage_clip,
290
+ min(self.config.advantage_clip, turn.advantage),
291
+ )
292
+
293
+ self._episodes_processed += 1
294
+ self._total_turns += len(turns)
295
+
296
+ def compute_batch_advantages(
297
+ self,
298
+ episodes: List[List[TurnData]],
299
+ ) -> None:
300
+ """
301
+ Compute advantages for a batch of episodes.
302
+
303
+ Also normalizes advantages across the batch if configured.
304
+
305
+ Args:
306
+ episodes: List of episodes, each containing turns
307
+ """
308
+ # Compute per-episode advantages
309
+ for episode in episodes:
310
+ self.compute_advantages(episode)
311
+
312
+ # Normalize across batch
313
+ if self.config.normalize_advantages and episodes:
314
+ all_advantages = []
315
+ for episode in episodes:
316
+ for turn in episode:
317
+ all_advantages.append(turn.advantage)
318
+
319
+ if all_advantages:
320
+ mean = sum(all_advantages) / len(all_advantages)
321
+
322
+ if len(all_advantages) > 1:
323
+ variance = sum((a - mean) ** 2 for a in all_advantages) / len(all_advantages)
324
+ std = max(variance ** 0.5, 1e-8)
325
+ else:
326
+ std = 1.0
327
+
328
+ # Normalize
329
+ for episode in episodes:
330
+ for turn in episode:
331
+ turn.advantage = (turn.advantage - mean) / std
332
+
333
+ def create_training_items(
334
+ self,
335
+ turns: List[TurnData],
336
+ tokenizer,
337
+ train_on_all_assistant_turns: bool = True,
338
+ ) -> List[Dict]:
339
+ """
340
+ Create training items from episode turns.
341
+
342
+ Args:
343
+ turns: List of turns with computed advantages
344
+ tokenizer: Tokenizer for token processing
345
+ train_on_all_assistant_turns: Train on all turns or just last
346
+
347
+ Returns:
348
+ List of training items suitable for GRPO
349
+ """
350
+ items = []
351
+
352
+ for turn in turns:
353
+ # Skip if no advantage computed
354
+ if turn.advantage == 0.0 and turn.reward == 0.0:
355
+ continue
356
+
357
+ # If tokens not pre-computed, compute from messages
358
+ if not turn.tokens and turn.messages:
359
+ tokens = self._tokenize_messages(tokenizer, turn.messages)
360
+ turn.tokens = tokens
361
+ turn.masks = self._create_masks(tokenizer, turn.messages, tokens)
362
+
363
+ item = {
364
+ "tokens": turn.tokens,
365
+ "masks": turn.masks,
366
+ "score": turn.advantage, # Use advantage as score for GRPO
367
+ "metadata": {
368
+ "turn": turn.turn_number,
369
+ "episode_id": turn.episode_id,
370
+ "reward": turn.reward,
371
+ "value": turn.value,
372
+ "return_to_go": turn.return_to_go,
373
+ "action_type": turn.action_type,
374
+ },
375
+ }
376
+
377
+ items.append(item)
378
+
379
+ return items
380
+
381
+ def _tokenize_messages(
382
+ self,
383
+ tokenizer,
384
+ messages: List[Dict[str, str]],
385
+ ) -> List[int]:
386
+ """Tokenize chat messages"""
387
+ try:
388
+ tokens = tokenizer.apply_chat_template(
389
+ messages,
390
+ add_generation_prompt=False,
391
+ return_tensors=None,
392
+ )
393
+ return tokens
394
+ except Exception as e:
395
+ logger.warning(f"Tokenization failed: {e}")
396
+ return []
397
+
398
+ def _create_masks(
399
+ self,
400
+ tokenizer,
401
+ messages: List[Dict[str, str]],
402
+ tokens: List[int],
403
+ ) -> List[int]:
404
+ """
405
+ Create training masks for tokens.
406
+
407
+ Masks assistant turns for training.
408
+ """
409
+ if not tokens:
410
+ return []
411
+
412
+ # Simple approach: find last assistant turn and mask it
413
+ masks = [0] * len(tokens)
414
+
415
+ # Find the position where the last assistant response starts
416
+ # This is an approximation; proper implementation would use
417
+ # the tokenizer's chat template structure
418
+
419
+ if not messages:
420
+ return masks
421
+
422
+ # Find last assistant message
423
+ last_assistant_idx = -1
424
+ for i, msg in enumerate(messages):
425
+ if msg.get("role") == "assistant":
426
+ last_assistant_idx = i
427
+
428
+ if last_assistant_idx == -1:
429
+ return masks
430
+
431
+ # Tokenize everything before the last assistant message
432
+ prompt_messages = messages[:last_assistant_idx]
433
+ if prompt_messages:
434
+ prompt_tokens = self._tokenize_messages(tokenizer, prompt_messages)
435
+ prompt_len = len(prompt_tokens)
436
+ else:
437
+ prompt_len = 0
438
+
439
+ # Mark tokens after prompt as trainable
440
+ for i in range(prompt_len, len(tokens)):
441
+ masks[i] = 1
442
+
443
+ return masks
444
+
445
+ def get_stats(self) -> Dict:
446
+ """Get manager statistics"""
447
+ return {
448
+ "episodes_processed": self._episodes_processed,
449
+ "total_turns": self._total_turns,
450
+ "avg_turns_per_episode": (
451
+ self._total_turns / max(1, self._episodes_processed)
452
+ ),
453
+ "gamma": self.config.gamma,
454
+ "gae_lambda": self.config.gae_lambda,
455
+ "max_turns": self.max_turns,
456
+ }
457
+
458
+
459
+ # =============================================================================
460
+ # Reward Shaping Utilities
461
+ # =============================================================================
462
+
463
+
464
+ def shape_trading_rewards(
465
+ turns: List[TurnData],
466
+ format_weight: float = 0.2,
467
+ reasoning_weight: float = 0.1,
468
+ pnl_weight: float = 0.5,
469
+ action_weight: float = 0.2,
470
+ ) -> None:
471
+ """
472
+ Shape rewards for trading episodes.
473
+
474
+ Combines multiple reward signals:
475
+ - Format quality (think tags, action JSON)
476
+ - Reasoning quality
477
+ - PnL changes
478
+ - Action quality (valid, appropriate for situation)
479
+
480
+ Modifies turns in-place.
481
+
482
+ Args:
483
+ turns: Episode turns
484
+ format_weight: Weight for format score
485
+ reasoning_weight: Weight for reasoning score
486
+ pnl_weight: Weight for PnL-based reward
487
+ action_weight: Weight for action quality
488
+ """
489
+ for turn in turns:
490
+ # Start with raw reward (usually PnL-based)
491
+ raw_reward = turn.reward
492
+
493
+ # Add format and reasoning bonuses
494
+ format_bonus = turn.format_score * format_weight
495
+ reasoning_bonus = turn.reasoning_score * reasoning_weight
496
+
497
+ # Compute action quality bonus
498
+ action_bonus = 0.0
499
+ if turn.action_type in ["buy", "sell", "open_perp", "close_perp"]:
500
+ action_bonus = 0.1 * action_weight # Reward for active trading
501
+ elif turn.action_type == "wait":
502
+ action_bonus = 0.05 * action_weight # Smaller reward for waiting
503
+ else:
504
+ action_bonus = -0.1 * action_weight # Penalty for invalid actions
505
+
506
+ # Combine
507
+ turn.reward = (
508
+ raw_reward * pnl_weight +
509
+ format_bonus +
510
+ reasoning_bonus +
511
+ action_bonus
512
+ )
513
+
514
+
515
+ def compute_episode_return(
516
+ turns: List[TurnData],
517
+ gamma: float = 0.99,
518
+ ) -> float:
519
+ """
520
+ Compute total discounted return for an episode.
521
+
522
+ Args:
523
+ turns: Episode turns
524
+ gamma: Discount factor
525
+
526
+ Returns:
527
+ Discounted return
528
+ """
529
+ if not turns:
530
+ return 0.0
531
+
532
+ total = 0.0
533
+ discount = 1.0
534
+
535
+ for turn in turns:
536
+ total += discount * turn.reward
537
+ discount *= gamma
538
+
539
+ return total
540
+
541
+
542
+ def normalize_episode_rewards(
543
+ episodes: List[List[TurnData]],
544
+ ) -> None:
545
+ """
546
+ Normalize rewards across episodes.
547
+
548
+ Useful for reducing variance in training.
549
+
550
+ Args:
551
+ episodes: List of episodes to normalize
552
+ """
553
+ all_rewards = []
554
+ for episode in episodes:
555
+ for turn in episode:
556
+ all_rewards.append(turn.reward)
557
+
558
+ if not all_rewards or len(all_rewards) < 2:
559
+ return
560
+
561
+ mean = sum(all_rewards) / len(all_rewards)
562
+ variance = sum((r - mean) ** 2 for r in all_rewards) / len(all_rewards)
563
+ std = max(variance ** 0.5, 1e-8)
564
+
565
+ for episode in episodes:
566
+ for turn in episode:
567
+ turn.reward = (turn.reward - mean) / std
568
+
569
+
570
+ # =============================================================================
571
+ # Episode Collectors
572
+ # =============================================================================
573
+
574
+
575
+ class EpisodeCollector:
576
+ """
577
+ Utility for collecting episodes with consistent structure.
578
+
579
+ Provides helper methods for episode management during rollouts.
580
+ """
581
+
582
+ def __init__(self, max_episodes: int = 1000):
583
+ self.max_episodes = max_episodes
584
+ self.episodes: List[EpisodeBuffer] = []
585
+ self._current_episode: Optional[EpisodeBuffer] = None
586
+
587
+ def start_episode(
588
+ self,
589
+ scenario_id: str,
590
+ archetype: str = "trader",
591
+ ) -> EpisodeBuffer:
592
+ """Start a new episode"""
593
+ import uuid
594
+
595
+ episode_id = f"ep-{uuid.uuid4().hex[:8]}"
596
+ episode = EpisodeBuffer(
597
+ episode_id=episode_id,
598
+ scenario_id=scenario_id,
599
+ archetype=archetype,
600
+ )
601
+
602
+ self._current_episode = episode
603
+ return episode
604
+
605
+ def add_turn(self, turn: TurnData) -> None:
606
+ """Add turn to current episode"""
607
+ if self._current_episode is None:
608
+ raise RuntimeError("No active episode. Call start_episode first.")
609
+
610
+ self._current_episode.add_turn(turn)
611
+
612
+ if turn.done:
613
+ self._finalize_current()
614
+
615
+ def _finalize_current(self) -> None:
616
+ """Finalize and store current episode"""
617
+ if self._current_episode is not None:
618
+ self.episodes.append(self._current_episode)
619
+
620
+ # Trim if too many
621
+ if len(self.episodes) > self.max_episodes:
622
+ self.episodes = self.episodes[-self.max_episodes:]
623
+
624
+ self._current_episode = None
625
+
626
+ def get_completed_episodes(self) -> List[EpisodeBuffer]:
627
+ """Get all completed episodes"""
628
+ return [e for e in self.episodes if e.completed]
629
+
630
+ def get_successful_episodes(self) -> List[EpisodeBuffer]:
631
+ """Get successful episodes"""
632
+ return [e for e in self.episodes if e.completed and e.success]
633
+
634
+ def clear(self) -> None:
635
+ """Clear all episodes"""
636
+ self.episodes = []
637
+ self._current_episode = None
638
+
639
+ def get_stats(self) -> Dict:
640
+ """Get collector statistics"""
641
+ completed = self.get_completed_episodes()
642
+ successful = self.get_successful_episodes()
643
+
644
+ return {
645
+ "total_episodes": len(self.episodes),
646
+ "completed_episodes": len(completed),
647
+ "successful_episodes": len(successful),
648
+ "success_rate": len(successful) / max(1, len(completed)),
649
+ "avg_episode_length": (
650
+ sum(e.episode_length for e in completed) / max(1, len(completed))
651
+ ),
652
+ "avg_reward": (
653
+ sum(e.total_reward for e in completed) / max(1, len(completed))
654
+ ),
655
+ }
656
+