@elizaos/training 2.0.0-alpha.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. package/Dockerfile +75 -0
  2. package/Makefile +374 -0
  3. package/README.md +346 -0
  4. package/config/rubrics.json +137 -0
  5. package/data/.gitkeep +0 -0
  6. package/data/degen/.gitkeep +2 -0
  7. package/data/trader/.gitkeep +2 -0
  8. package/docker-compose.test.yml +57 -0
  9. package/package.json +58 -0
  10. package/python/config/babylon_atropos.yaml +90 -0
  11. package/python/config/profiles/12gb.json +11 -0
  12. package/python/config/profiles/16gb.json +10 -0
  13. package/python/config/profiles/24gb.json +10 -0
  14. package/python/config/profiles/48gb.json +10 -0
  15. package/python/config/profiles/cpu.json +11 -0
  16. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  17. package/python/config/profiles/l40-2gpu.json +22 -0
  18. package/python/config/profiles/l40-4gpu.json +21 -0
  19. package/python/config/profiles/l40.json +17 -0
  20. package/python/config/tinker_training.yaml +143 -0
  21. package/python/curriculum_state.json +165 -0
  22. package/python/env.template +86 -0
  23. package/python/env.training.template +46 -0
  24. package/python/pyproject.toml +41 -0
  25. package/python/requirements-ci.txt +31 -0
  26. package/python/requirements.txt +87 -0
  27. package/python/scripts/__init__.py +4 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/run_ab_test.py +143 -0
  36. package/python/scripts/run_full_pipeline.py +544 -0
  37. package/python/scripts/run_tinker_training.py +192 -0
  38. package/python/scripts/run_training.py +914 -0
  39. package/python/scripts/test_judge.py +155 -0
  40. package/python/scripts/test_pipeline.py +356 -0
  41. package/python/scripts/test_trained_model.py +380 -0
  42. package/python/scripts/train_local.py +528 -0
  43. package/python/setup.py +20 -0
  44. package/python/src/__init__.py +190 -0
  45. package/python/src/data_bridge/__init__.py +24 -0
  46. package/python/src/data_bridge/converter.py +435 -0
  47. package/python/src/data_bridge/reader.py +393 -0
  48. package/python/src/models.py +283 -0
  49. package/python/src/training/__init__.py +605 -0
  50. package/python/src/training/ab_testing.py +404 -0
  51. package/python/src/training/action_executor.py +621 -0
  52. package/python/src/training/archetype_trainer.py +347 -0
  53. package/python/src/training/atropos_trainer.py +980 -0
  54. package/python/src/training/babylon_env.py +1254 -0
  55. package/python/src/training/error_recovery.py +647 -0
  56. package/python/src/training/evaluation.py +856 -0
  57. package/python/src/training/fast_simulator.py +880 -0
  58. package/python/src/training/format_validator.py +584 -0
  59. package/python/src/training/hybrid_env.py +522 -0
  60. package/python/src/training/kl_controller.py +628 -0
  61. package/python/src/training/multi_prompt_dataset.py +883 -0
  62. package/python/src/training/multi_turn.py +656 -0
  63. package/python/src/training/online_env.py +1084 -0
  64. package/python/src/training/quality_scorer.py +391 -0
  65. package/python/src/training/quality_utils.py +633 -0
  66. package/python/src/training/rewards.py +1344 -0
  67. package/python/src/training/rlaif_env.py +17 -0
  68. package/python/src/training/rollout_generator.py +502 -0
  69. package/python/src/training/rubric_loader.py +198 -0
  70. package/python/src/training/scenario_pool.py +1072 -0
  71. package/python/src/training/schemas.py +481 -0
  72. package/python/src/training/service_manager.py +552 -0
  73. package/python/src/training/simulation_bridge.py +535 -0
  74. package/python/src/training/tick_reward_attribution.py +399 -0
  75. package/python/src/training/tinker_client.py +575 -0
  76. package/python/src/training/tinker_trainer.py +646 -0
  77. package/python/src/training/tokenization_utils.py +402 -0
  78. package/python/tests/e2e/__init__.py +13 -0
  79. package/python/tests/e2e/conftest.py +258 -0
  80. package/python/tests/e2e/test_full_pipeline.py +643 -0
  81. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  82. package/python/tests/integration/__init__.py +12 -0
  83. package/python/tests/integration/conftest.py +383 -0
  84. package/python/tests/integration/test_db_integration.py +649 -0
  85. package/python/tests/integration/test_json_mode_integration.py +554 -0
  86. package/python/tests/test_action_executor.py +594 -0
  87. package/python/tests/test_archetype_scoring.py +1027 -0
  88. package/python/tests/test_atropos_integration.py +360 -0
  89. package/python/tests/test_evaluation.py +727 -0
  90. package/python/tests/test_format_validator.py +486 -0
  91. package/python/tests/test_kl_controller.py +432 -0
  92. package/python/tests/test_lr_scheduler.py +579 -0
  93. package/python/tests/test_multi_turn.py +590 -0
  94. package/python/tests/test_online_env.py +519 -0
  95. package/python/tests/test_quality_scorer.py +474 -0
  96. package/python/tests/test_scenario_pool.py +735 -0
  97. package/python/tests/test_service_manager.py +585 -0
  98. package/python/tests/test_simulation_rollout.py +581 -0
  99. package/python/tests/test_tokenization_utils.py +501 -0
  100. package/python/tests/test_training_orchestrator.py +497 -0
  101. package/python/tests/test_training_output_structure.py +661 -0
  102. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  103. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  104. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  105. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  106. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  107. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  108. package/scripts/assess-training-data.ts +422 -0
  109. package/scripts/e2e-training-test.ts +550 -0
  110. package/scripts/export-rubrics.ts +64 -0
  111. package/scripts/generate-research-report.ts +1523 -0
  112. package/scripts/generate_dataset.sh +173 -0
  113. package/scripts/json-mode-benchmark.ts +399 -0
  114. package/scripts/real-archetype-benchmark.ts +210 -0
  115. package/scripts/run-baseline-comparison.ts +116 -0
  116. package/scripts/run-full-pipeline.ts +272 -0
  117. package/scripts/runpod_setup.sh +137 -0
  118. package/scripts/runpod_validate.sh +147 -0
  119. package/scripts/test-model-in-game.ts +955 -0
  120. package/scripts/test-scoring.ts +73 -0
  121. package/scripts/test-trained-model.ts +209 -0
  122. package/scripts/train-and-test.ts +824 -0
  123. package/scripts/verify-final.ts +118 -0
  124. package/src/adapter.ts +516 -0
  125. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  126. package/src/archetypes/derive-archetype.ts +249 -0
  127. package/src/archetypes/index.ts +22 -0
  128. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  129. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  130. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  131. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  132. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  133. package/src/benchmark/BenchmarkRunner.ts +685 -0
  134. package/src/benchmark/BenchmarkValidator.ts +206 -0
  135. package/src/benchmark/FastEvalRunner.ts +225 -0
  136. package/src/benchmark/MetricsValidator.ts +165 -0
  137. package/src/benchmark/MetricsVisualizer.ts +909 -0
  138. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  139. package/src/benchmark/ModelRegistry.ts +158 -0
  140. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  141. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  142. package/src/benchmark/SimulationEngine.ts +832 -0
  143. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  144. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  145. package/src/benchmark/index.ts +89 -0
  146. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  147. package/src/benchmark/simulation-types.ts +78 -0
  148. package/src/dependencies.ts +439 -0
  149. package/src/generation/TrajectoryGenerator.ts +387 -0
  150. package/src/generation/index.ts +12 -0
  151. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  152. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  153. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  154. package/src/huggingface/index.ts +27 -0
  155. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  156. package/src/index.ts +102 -0
  157. package/src/init-training.ts +53 -0
  158. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  159. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  160. package/src/metrics/index.ts +8 -0
  161. package/src/metrics/types.ts +200 -0
  162. package/src/rubrics/__tests__/index.test.ts +184 -0
  163. package/src/rubrics/ass-kisser.ts +85 -0
  164. package/src/rubrics/degen.ts +80 -0
  165. package/src/rubrics/goody-twoshoes.ts +84 -0
  166. package/src/rubrics/index.ts +236 -0
  167. package/src/rubrics/information-trader.ts +84 -0
  168. package/src/rubrics/infosec.ts +101 -0
  169. package/src/rubrics/liar.ts +104 -0
  170. package/src/rubrics/perps-trader.ts +87 -0
  171. package/src/rubrics/researcher.ts +81 -0
  172. package/src/rubrics/scammer.ts +82 -0
  173. package/src/rubrics/social-butterfly.ts +73 -0
  174. package/src/rubrics/super-predictor.ts +97 -0
  175. package/src/rubrics/trader.ts +67 -0
  176. package/src/scoring/ArchetypeScoringService.ts +486 -0
  177. package/src/scoring/JudgePromptBuilder.ts +556 -0
  178. package/src/scoring/LLMJudgeCache.ts +401 -0
  179. package/src/scoring/index.ts +9 -0
  180. package/src/training/AutomationPipeline.ts +916 -0
  181. package/src/training/BenchmarkService.ts +518 -0
  182. package/src/training/ConfigValidator.ts +220 -0
  183. package/src/training/MarketOutcomesTracker.ts +187 -0
  184. package/src/training/ModelDeployer.ts +186 -0
  185. package/src/training/ModelFetcher.ts +76 -0
  186. package/src/training/ModelSelectionService.ts +341 -0
  187. package/src/training/ModelUsageVerifier.ts +160 -0
  188. package/src/training/MultiModelOrchestrator.ts +580 -0
  189. package/src/training/RLModelConfig.ts +407 -0
  190. package/src/training/RewardBackpropagationService.ts +149 -0
  191. package/src/training/RulerScoringService.ts +666 -0
  192. package/src/training/TrainingMonitor.ts +166 -0
  193. package/src/training/TrajectoryRecorder.ts +399 -0
  194. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  195. package/src/training/index.ts +100 -0
  196. package/src/training/logRLConfig.ts +34 -0
  197. package/src/training/pipeline.ts +129 -0
  198. package/src/training/storage/ModelStorageService.ts +279 -0
  199. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  200. package/src/training/storage/index.ts +17 -0
  201. package/src/training/types.ts +207 -0
  202. package/src/training/window-utils.ts +138 -0
  203. package/src/utils/index.ts +101 -0
  204. package/src/utils/logger.ts +59 -0
  205. package/src/utils/snowflake.ts +17 -0
  206. package/src/utils/synthetic-detector.ts +111 -0
  207. package/tsconfig.json +20 -0
@@ -0,0 +1,590 @@
1
+ """
2
+ Tests for Multi-Turn Episode Manager
3
+
4
+ Covers:
5
+ - TurnData structure
6
+ - EpisodeBuffer management
7
+ - GAE advantage computation
8
+ - Reward shaping
9
+ - Episode collection
10
+ """
11
+
12
+ from datetime import datetime, timezone
13
+ from typing import List
14
+
15
+ import pytest
16
+
17
+ from src.training.multi_turn import (
18
+ TurnData,
19
+ EpisodeBuffer,
20
+ GAEConfig,
21
+ MultiTurnEpisodeManager,
22
+ EpisodeCollector,
23
+ shape_trading_rewards,
24
+ compute_episode_return,
25
+ normalize_episode_rewards,
26
+ )
27
+
28
+
29
+ # =============================================================================
30
+ # Fixtures
31
+ # =============================================================================
32
+
33
+
34
+ @pytest.fixture
35
+ def sample_turn():
36
+ """Create a sample turn"""
37
+ return TurnData(
38
+ turn_number=0,
39
+ episode_id="ep-001",
40
+ action_type="buy",
41
+ action_text='{"action": "buy", "market": "BTC", "amount": 100}',
42
+ reward=0.5,
43
+ format_score=0.8,
44
+ reasoning_score=0.7,
45
+ done=False,
46
+ )
47
+
48
+
49
+ @pytest.fixture
50
+ def sample_episode():
51
+ """Create a sample episode with multiple turns"""
52
+ turns = [
53
+ TurnData(turn_number=0, reward=0.1, action_type="buy", done=False),
54
+ TurnData(turn_number=1, reward=0.2, action_type="wait", done=False),
55
+ TurnData(turn_number=2, reward=-0.1, action_type="sell", done=False),
56
+ TurnData(turn_number=3, reward=0.5, action_type="close_perp", done=True),
57
+ ]
58
+
59
+ episode = EpisodeBuffer(episode_id="ep-001", scenario_id="scenario-1")
60
+ for turn in turns:
61
+ episode.add_turn(turn)
62
+
63
+ return episode
64
+
65
+
66
+ @pytest.fixture
67
+ def manager():
68
+ """Create a multi-turn manager"""
69
+ return MultiTurnEpisodeManager(
70
+ gamma=0.99,
71
+ gae_lambda=0.95,
72
+ max_turns=20,
73
+ )
74
+
75
+
76
+ # =============================================================================
77
+ # TurnData Tests
78
+ # =============================================================================
79
+
80
+
81
+ class TestTurnData:
82
+ """Tests for TurnData dataclass"""
83
+
84
+ def test_creation(self, sample_turn):
85
+ """Test creating turn data"""
86
+ assert sample_turn.turn_number == 0
87
+ assert sample_turn.episode_id == "ep-001"
88
+ assert sample_turn.action_type == "buy"
89
+ assert sample_turn.reward == 0.5
90
+
91
+ def test_default_values(self):
92
+ """Test default values"""
93
+ turn = TurnData(turn_number=0)
94
+
95
+ assert turn.episode_id == ""
96
+ assert turn.reward == 0.0
97
+ assert turn.value == 0.0
98
+ assert turn.advantage == 0.0
99
+ assert turn.done is False
100
+
101
+ def test_to_dict(self, sample_turn):
102
+ """Test conversion to dictionary"""
103
+ d = sample_turn.to_dict()
104
+
105
+ assert "turn_number" in d
106
+ assert "episode_id" in d
107
+ assert "action_type" in d
108
+ assert "reward" in d
109
+ assert "advantage" in d
110
+ assert d["turn_number"] == 0
111
+ assert d["reward"] == 0.5
112
+
113
+ def test_action_text_truncation(self):
114
+ """Test that long action text is truncated in to_dict"""
115
+ long_text = "x" * 500
116
+ turn = TurnData(turn_number=0, action_text=long_text)
117
+
118
+ d = turn.to_dict()
119
+ assert len(d["action_text"]) <= 200
120
+
121
+
122
+ # =============================================================================
123
+ # EpisodeBuffer Tests
124
+ # =============================================================================
125
+
126
+
127
+ class TestEpisodeBuffer:
128
+ """Tests for EpisodeBuffer"""
129
+
130
+ def test_creation(self):
131
+ """Test creating episode buffer"""
132
+ episode = EpisodeBuffer(
133
+ episode_id="ep-001",
134
+ scenario_id="scenario-1",
135
+ archetype="trader",
136
+ )
137
+
138
+ assert episode.episode_id == "ep-001"
139
+ assert episode.scenario_id == "scenario-1"
140
+ assert len(episode.turns) == 0
141
+ assert not episode.completed
142
+
143
+ def test_add_turn(self):
144
+ """Test adding turns"""
145
+ episode = EpisodeBuffer(episode_id="ep-001")
146
+ turn = TurnData(turn_number=0, reward=0.1)
147
+
148
+ episode.add_turn(turn)
149
+
150
+ assert len(episode.turns) == 1
151
+ assert episode.turns[0].episode_id == "ep-001"
152
+
153
+ def test_finalization_on_done(self):
154
+ """Test that episode finalizes when done turn is added"""
155
+ episode = EpisodeBuffer(episode_id="ep-001")
156
+
157
+ episode.add_turn(TurnData(turn_number=0, reward=0.1, done=False))
158
+ assert not episode.completed
159
+
160
+ episode.add_turn(TurnData(turn_number=1, reward=0.2, done=True))
161
+ assert episode.completed
162
+ assert episode.episode_length == 2
163
+ assert episode.total_reward == pytest.approx(0.3, abs=0.01)
164
+
165
+ def test_success_determination(self):
166
+ """Test success is based on total reward"""
167
+ # Positive total reward = success
168
+ episode1 = EpisodeBuffer(episode_id="ep-001")
169
+ episode1.add_turn(TurnData(turn_number=0, reward=0.5, done=True))
170
+ assert episode1.success
171
+
172
+ # Negative total reward = not success
173
+ episode2 = EpisodeBuffer(episode_id="ep-002")
174
+ episode2.add_turn(TurnData(turn_number=0, reward=-0.5, done=True))
175
+ assert not episode2.success
176
+
177
+ def test_get_messages(self, sample_episode):
178
+ """Test getting messages from episode"""
179
+ # Add messages to last turn
180
+ sample_episode.turns[-1].messages = [
181
+ {"role": "user", "content": "trade"},
182
+ {"role": "assistant", "content": "done"},
183
+ ]
184
+
185
+ messages = sample_episode.get_messages()
186
+ assert len(messages) == 2
187
+
188
+ def test_get_trajectory(self, sample_episode):
189
+ """Test getting trajectory summary"""
190
+ trajectory = sample_episode.get_trajectory()
191
+
192
+ assert len(trajectory) == 4
193
+ assert trajectory[0][0] == "buy" # First action type
194
+ assert trajectory[3][0] == "close_perp" # Last action type
195
+
196
+ def test_to_dict(self, sample_episode):
197
+ """Test conversion to dictionary"""
198
+ d = sample_episode.to_dict()
199
+
200
+ assert "episode_id" in d
201
+ assert "scenario_id" in d
202
+ assert "turns" in d
203
+ assert len(d["turns"]) == 4
204
+
205
+
206
+ # =============================================================================
207
+ # GAEConfig Tests
208
+ # =============================================================================
209
+
210
+
211
+ class TestGAEConfig:
212
+ """Tests for GAEConfig"""
213
+
214
+ def test_default_values(self):
215
+ """Test default configuration"""
216
+ config = GAEConfig()
217
+
218
+ assert config.gamma == 0.99
219
+ assert config.gae_lambda == 0.95
220
+ assert config.normalize_advantages is True
221
+
222
+ def test_custom_values(self):
223
+ """Test custom configuration"""
224
+ config = GAEConfig(gamma=0.95, gae_lambda=0.9)
225
+
226
+ assert config.gamma == 0.95
227
+ assert config.gae_lambda == 0.9
228
+
229
+
230
+ # =============================================================================
231
+ # MultiTurnEpisodeManager Tests
232
+ # =============================================================================
233
+
234
+
235
+ class TestMultiTurnEpisodeManager:
236
+ """Tests for MultiTurnEpisodeManager"""
237
+
238
+ def test_creation(self, manager):
239
+ """Test creating manager"""
240
+ assert manager.config.gamma == 0.99
241
+ assert manager.config.gae_lambda == 0.95
242
+ assert manager.max_turns == 20
243
+
244
+ def test_compute_advantages_single_turn(self, manager):
245
+ """Test advantage computation for single turn"""
246
+ turns = [TurnData(turn_number=0, reward=1.0, done=True)]
247
+
248
+ manager.compute_advantages(turns)
249
+
250
+ assert turns[0].return_to_go == 1.0
251
+ assert turns[0].value == 1.0
252
+ # For single turn, advantage should be related to TD error
253
+ assert turns[0].advantage != 0 or turns[0].reward != 0
254
+
255
+ def test_compute_advantages_multiple_turns(self, manager, sample_episode):
256
+ """Test advantage computation for multiple turns"""
257
+ turns = sample_episode.turns
258
+
259
+ manager.compute_advantages(turns)
260
+
261
+ # All turns should have computed values
262
+ for turn in turns:
263
+ assert turn.return_to_go != 0 or turn.reward == 0
264
+ assert turn.value != 0 or turn.return_to_go == 0
265
+
266
+ # Return-to-go should be higher for earlier turns (cumulative)
267
+ # unless later turns have much higher rewards
268
+ assert turns[0].return_to_go >= turns[-1].return_to_go or \
269
+ sum(t.reward for t in turns[:2]) < sum(t.reward for t in turns[2:])
270
+
271
+ def test_compute_advantages_empty(self, manager):
272
+ """Test with empty turn list"""
273
+ turns = []
274
+ manager.compute_advantages(turns) # Should not raise
275
+
276
+ def test_advantage_clipping(self, manager):
277
+ """Test that extreme advantages are clipped"""
278
+ manager.config.clip_advantages = True
279
+ manager.config.advantage_clip = 5.0
280
+
281
+ # Create turns with extreme rewards
282
+ turns = [
283
+ TurnData(turn_number=0, reward=100.0, done=False),
284
+ TurnData(turn_number=1, reward=-100.0, done=True),
285
+ ]
286
+
287
+ manager.compute_advantages(turns)
288
+
289
+ for turn in turns:
290
+ assert abs(turn.advantage) <= 5.0
291
+
292
+ def test_compute_batch_advantages(self, manager):
293
+ """Test batch advantage computation"""
294
+ episodes = [
295
+ [TurnData(turn_number=0, reward=0.5, done=True)],
296
+ [TurnData(turn_number=0, reward=-0.3, done=True)],
297
+ [TurnData(turn_number=0, reward=0.2, done=True)],
298
+ ]
299
+
300
+ manager.compute_batch_advantages(episodes)
301
+
302
+ # All episodes should have advantages
303
+ for episode in episodes:
304
+ for turn in episode:
305
+ # With normalization, advantages should be centered
306
+ pass # Just verify no errors
307
+
308
+ def test_batch_normalization(self, manager):
309
+ """Test that batch normalization centers advantages"""
310
+ manager.config.normalize_advantages = True
311
+
312
+ episodes = [
313
+ [TurnData(turn_number=0, reward=1.0, done=True)],
314
+ [TurnData(turn_number=0, reward=0.0, done=True)],
315
+ [TurnData(turn_number=0, reward=-1.0, done=True)],
316
+ ]
317
+
318
+ manager.compute_batch_advantages(episodes)
319
+
320
+ # Collect all advantages
321
+ advantages = [t.advantage for ep in episodes for t in ep]
322
+
323
+ # Mean should be approximately 0 after normalization
324
+ mean = sum(advantages) / len(advantages)
325
+ assert abs(mean) < 0.1
326
+
327
+ def test_get_stats(self, manager, sample_episode):
328
+ """Test getting manager statistics"""
329
+ manager.compute_advantages(sample_episode.turns)
330
+
331
+ stats = manager.get_stats()
332
+
333
+ assert stats["episodes_processed"] == 1
334
+ assert stats["total_turns"] == 4
335
+ assert stats["gamma"] == 0.99
336
+
337
+
338
+ # =============================================================================
339
+ # Reward Shaping Tests
340
+ # =============================================================================
341
+
342
+
343
+ class TestRewardShaping:
344
+ """Tests for reward shaping utilities"""
345
+
346
+ def test_shape_trading_rewards(self):
347
+ """Test shaping trading rewards"""
348
+ turns = [
349
+ TurnData(
350
+ turn_number=0,
351
+ reward=0.1,
352
+ action_type="buy",
353
+ format_score=0.8,
354
+ reasoning_score=0.7,
355
+ ),
356
+ TurnData(
357
+ turn_number=1,
358
+ reward=-0.1,
359
+ action_type="wait",
360
+ format_score=0.9,
361
+ reasoning_score=0.6,
362
+ ),
363
+ ]
364
+
365
+ original_rewards = [t.reward for t in turns]
366
+
367
+ shape_trading_rewards(turns)
368
+
369
+ # Rewards should be modified
370
+ for i, turn in enumerate(turns):
371
+ # With bonuses added, rewards should differ from original
372
+ # (unless weights are all 0)
373
+ pass # Just check no errors
374
+
375
+ def test_compute_episode_return(self, sample_episode):
376
+ """Test computing discounted return"""
377
+ returns = compute_episode_return(sample_episode.turns, gamma=0.99)
378
+
379
+ # Return should be weighted sum of rewards
380
+ assert returns != 0
381
+
382
+ def test_compute_episode_return_no_discount(self, sample_episode):
383
+ """Test return with no discounting"""
384
+ returns = compute_episode_return(sample_episode.turns, gamma=1.0)
385
+
386
+ expected = sum(t.reward for t in sample_episode.turns)
387
+ assert returns == pytest.approx(expected, abs=0.01)
388
+
389
+ def test_normalize_episode_rewards(self):
390
+ """Test normalizing rewards across episodes"""
391
+ episodes = [
392
+ [TurnData(turn_number=0, reward=10.0)],
393
+ [TurnData(turn_number=0, reward=0.0)],
394
+ [TurnData(turn_number=0, reward=-10.0)],
395
+ ]
396
+
397
+ normalize_episode_rewards(episodes)
398
+
399
+ # Rewards should be normalized
400
+ rewards = [t.reward for ep in episodes for t in ep]
401
+ mean = sum(rewards) / len(rewards)
402
+
403
+ # Mean should be approximately 0
404
+ assert abs(mean) < 0.1
405
+
406
+
407
+ # =============================================================================
408
+ # EpisodeCollector Tests
409
+ # =============================================================================
410
+
411
+
412
+ class TestEpisodeCollector:
413
+ """Tests for EpisodeCollector"""
414
+
415
+ def test_creation(self):
416
+ """Test creating collector"""
417
+ collector = EpisodeCollector(max_episodes=100)
418
+
419
+ assert collector.max_episodes == 100
420
+ assert len(collector.episodes) == 0
421
+
422
+ def test_start_episode(self):
423
+ """Test starting a new episode"""
424
+ collector = EpisodeCollector()
425
+
426
+ episode = collector.start_episode("scenario-1", "trader")
427
+
428
+ assert episode.scenario_id == "scenario-1"
429
+ assert episode.archetype == "trader"
430
+ assert collector._current_episode is not None
431
+
432
+ def test_add_turn(self):
433
+ """Test adding turns to current episode"""
434
+ collector = EpisodeCollector()
435
+ collector.start_episode("scenario-1")
436
+
437
+ collector.add_turn(TurnData(turn_number=0, reward=0.1))
438
+
439
+ assert len(collector._current_episode.turns) == 1
440
+
441
+ def test_add_turn_without_episode(self):
442
+ """Test that adding turn without starting episode raises"""
443
+ collector = EpisodeCollector()
444
+
445
+ with pytest.raises(RuntimeError):
446
+ collector.add_turn(TurnData(turn_number=0))
447
+
448
+ def test_episode_finalization(self):
449
+ """Test that done turn finalizes episode"""
450
+ collector = EpisodeCollector()
451
+ collector.start_episode("scenario-1")
452
+
453
+ collector.add_turn(TurnData(turn_number=0, reward=0.1, done=False))
454
+ assert len(collector.episodes) == 0
455
+
456
+ collector.add_turn(TurnData(turn_number=1, reward=0.2, done=True))
457
+ assert len(collector.episodes) == 1
458
+ assert collector._current_episode is None
459
+
460
+ def test_max_episodes_limit(self):
461
+ """Test that collector respects max episodes"""
462
+ collector = EpisodeCollector(max_episodes=3)
463
+
464
+ for i in range(5):
465
+ collector.start_episode(f"scenario-{i}")
466
+ collector.add_turn(TurnData(turn_number=0, done=True))
467
+
468
+ assert len(collector.episodes) == 3
469
+
470
+ def test_get_completed_episodes(self):
471
+ """Test getting completed episodes"""
472
+ collector = EpisodeCollector()
473
+
474
+ # Complete episode
475
+ collector.start_episode("scenario-1")
476
+ collector.add_turn(TurnData(turn_number=0, done=True))
477
+
478
+ # Incomplete episode
479
+ collector.start_episode("scenario-2")
480
+ collector.add_turn(TurnData(turn_number=0, done=False))
481
+
482
+ completed = collector.get_completed_episodes()
483
+ assert len(completed) == 1
484
+
485
+ def test_get_successful_episodes(self):
486
+ """Test getting successful episodes"""
487
+ collector = EpisodeCollector()
488
+
489
+ # Successful episode
490
+ collector.start_episode("scenario-1")
491
+ collector.add_turn(TurnData(turn_number=0, reward=1.0, done=True))
492
+
493
+ # Failed episode
494
+ collector.start_episode("scenario-2")
495
+ collector.add_turn(TurnData(turn_number=0, reward=-1.0, done=True))
496
+
497
+ successful = collector.get_successful_episodes()
498
+ assert len(successful) == 1
499
+
500
+ def test_clear(self):
501
+ """Test clearing collector"""
502
+ collector = EpisodeCollector()
503
+
504
+ collector.start_episode("scenario-1")
505
+ collector.add_turn(TurnData(turn_number=0, done=True))
506
+
507
+ collector.clear()
508
+
509
+ assert len(collector.episodes) == 0
510
+ assert collector._current_episode is None
511
+
512
+ def test_get_stats(self):
513
+ """Test getting collector statistics"""
514
+ collector = EpisodeCollector()
515
+
516
+ # Add episodes
517
+ for i in range(3):
518
+ collector.start_episode(f"scenario-{i}")
519
+ collector.add_turn(TurnData(turn_number=0, reward=0.5, done=True))
520
+
521
+ stats = collector.get_stats()
522
+
523
+ assert stats["total_episodes"] == 3
524
+ assert stats["completed_episodes"] == 3
525
+ assert stats["successful_episodes"] == 3
526
+ assert stats["success_rate"] == 1.0
527
+
528
+
529
+ # =============================================================================
530
+ # Integration Tests
531
+ # =============================================================================
532
+
533
+
534
+ class TestMultiTurnIntegration:
535
+ """Integration tests for multi-turn system"""
536
+
537
+ def test_full_episode_workflow(self, manager):
538
+ """Test complete workflow from collection to training items"""
539
+ collector = EpisodeCollector()
540
+
541
+ # Collect episode
542
+ collector.start_episode("scenario-1", "trader")
543
+ for i in range(5):
544
+ reward = 0.1 * (i + 1) if i < 4 else 0.5
545
+ done = i == 4
546
+ collector.add_turn(TurnData(
547
+ turn_number=i,
548
+ reward=reward,
549
+ action_type="buy" if i % 2 == 0 else "wait",
550
+ done=done,
551
+ ))
552
+
553
+ # Get completed episode
554
+ episodes = collector.get_completed_episodes()
555
+ assert len(episodes) == 1
556
+
557
+ # Compute advantages
558
+ turns = episodes[0].turns
559
+ manager.compute_advantages(turns)
560
+
561
+ # All turns should have computed values
562
+ for turn in turns:
563
+ assert turn.value != 0 or turn.reward == 0
564
+
565
+ def test_multiple_episodes_batch(self, manager):
566
+ """Test processing multiple episodes as batch"""
567
+ collector = EpisodeCollector()
568
+
569
+ # Collect multiple episodes
570
+ for ep_idx in range(3):
571
+ collector.start_episode(f"scenario-{ep_idx}")
572
+ for turn_idx in range(4):
573
+ reward = 0.1 * (turn_idx + 1) * (1 if ep_idx == 0 else -1)
574
+ done = turn_idx == 3
575
+ collector.add_turn(TurnData(
576
+ turn_number=turn_idx,
577
+ reward=reward,
578
+ done=done,
579
+ ))
580
+
581
+ # Get all turn lists
582
+ episodes = [ep.turns for ep in collector.get_completed_episodes()]
583
+
584
+ # Batch compute
585
+ manager.compute_batch_advantages(episodes)
586
+
587
+ # Verify statistics updated
588
+ stats = manager.get_stats()
589
+ assert stats["episodes_processed"] == 3
590
+