@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,581 @@
1
+ """
2
+ Comprehensive tests for simulation, rollout generation, and dataset preparation.
3
+
4
+ Tests:
5
+ 1. FastSimulator - benchmark and data generation modes
6
+ 2. FastRolloutGenerator - rollout creation and quality validation
7
+ 3. MultiPromptDatasetBuilder - dataset creation from trajectories
8
+ 4. RolloutQualityValidator - validation logic
9
+ 5. PromptTypeAnalyzer - correlation analysis
10
+ """
11
+
12
+ import pytest
13
+ from datetime import datetime
14
+
15
+ import sys
16
+ sys.path.insert(0, '.')
17
+
18
+ from src.models import (
19
+ TrainingTrajectory, TrajectoryStep, EnvironmentState,
20
+ Action, LLMCall, AtroposScoredGroup
21
+ )
22
+ from src.training import (
23
+ FastSimulator, SimulatorConfig, GameState,
24
+ RolloutResult, AgentTickData,
25
+ RolloutQualityValidator,
26
+ MultiPromptDatasetBuilder, PromptSample,
27
+ prepare_multi_prompt_training_data, PromptTypeAnalyzer,
28
+ )
29
+
30
+
31
+ # ============================================================
32
+ # Fixtures
33
+ # ============================================================
34
+
35
+ @pytest.fixture
36
+ def sample_env_state():
37
+ """Create sample environment state"""
38
+ return EnvironmentState(
39
+ agent_balance=10000.0,
40
+ agent_pnl=100.0,
41
+ open_positions=2,
42
+ active_markets=5
43
+ )
44
+
45
+
46
+ @pytest.fixture
47
+ def sample_action():
48
+ """Create sample action"""
49
+ return Action(
50
+ action_type='buy',
51
+ parameters={'ticker': 'BTC', 'amount': 0.1},
52
+ success=True
53
+ )
54
+
55
+
56
+ @pytest.fixture
57
+ def sample_llm_call():
58
+ """Create sample LLM call"""
59
+ return LLMCall(
60
+ model='gpt-4',
61
+ system_prompt='You are a trading agent.',
62
+ user_prompt='Market update: BTC is up 5%',
63
+ response='I will buy 0.1 BTC',
64
+ reasoning='Price momentum is positive',
65
+ temperature=0.7,
66
+ max_tokens=100,
67
+ purpose='action'
68
+ )
69
+
70
+
71
+ @pytest.fixture
72
+ def sample_tick_data(sample_env_state, sample_action, sample_llm_call):
73
+ """Create sample tick data"""
74
+ return AgentTickData(
75
+ tick_number=0,
76
+ timestamp=1000000,
77
+ observation={'markets': [{'id': 'm1', 'price': 50000}]},
78
+ environment_state=sample_env_state,
79
+ llm_calls=[sample_llm_call],
80
+ action=sample_action,
81
+ reward=0.1
82
+ )
83
+
84
+
85
+ @pytest.fixture
86
+ def sample_trajectory(sample_env_state, sample_action, sample_llm_call):
87
+ """Create sample trajectory"""
88
+ steps = []
89
+ for i in range(5):
90
+ env = EnvironmentState(
91
+ agent_balance=10000.0 + i * 100,
92
+ agent_pnl=i * 100.0,
93
+ open_positions=i
94
+ )
95
+ steps.append(TrajectoryStep(
96
+ step_number=i,
97
+ timestamp=1000000 + i * 1000,
98
+ environment_state=env,
99
+ provider_accesses=[],
100
+ llm_calls=[LLMCall(
101
+ model='gpt-4',
102
+ system_prompt='You are a trading agent.',
103
+ user_prompt=f'Market update for step {i}: price is moving',
104
+ response=f'I will execute action {i}: buying at current price level',
105
+ reasoning=f'Reasoning for step {i}',
106
+ temperature=0.7,
107
+ max_tokens=100,
108
+ purpose='action'
109
+ )],
110
+ action=Action(
111
+ action_type='trade' if i % 2 == 0 else 'wait',
112
+ parameters={'amount': 100},
113
+ success=True
114
+ ),
115
+ reward=0.1
116
+ ))
117
+
118
+ return TrainingTrajectory(
119
+ id='traj-1',
120
+ trajectory_id='traj-1',
121
+ agent_id='agent-1',
122
+ window_id='2024-01-01T00:00',
123
+ start_time=datetime.now(),
124
+ end_time=datetime.now(),
125
+ duration_ms=5000,
126
+ steps=steps,
127
+ total_reward=0.5,
128
+ final_pnl=400.0,
129
+ episode_length=5,
130
+ final_status='completed'
131
+ )
132
+
133
+
134
+ # ============================================================
135
+ # GameState Tests
136
+ # ============================================================
137
+
138
+ class TestGameState:
139
+ """Tests for GameState dataclass"""
140
+
141
+ def test_default_initialization(self):
142
+ """Test default GameState values"""
143
+ state = GameState()
144
+ assert state.tick == 0
145
+ assert state.time == 0
146
+ assert state.markets == []
147
+ assert state.portfolios == {}
148
+
149
+ def test_to_observation(self):
150
+ """Test observation conversion"""
151
+ state = GameState(
152
+ tick=5,
153
+ time=1000000,
154
+ markets=[{'id': 'm1'}],
155
+ news=[{'headline': 'News 1'}, {'headline': 'News 2'}]
156
+ )
157
+ obs = state.to_observation()
158
+
159
+ assert obs['tick'] == 5
160
+ assert obs['time'] == 1000000
161
+ assert len(obs['markets']) == 1
162
+ assert 'news' in obs
163
+
164
+ def test_get_env_state(self):
165
+ """Test environment state extraction"""
166
+ state = GameState(
167
+ portfolios={
168
+ 'agent-1': {'balance': 15000.0, 'pnl': 500.0, 'positions': 3}
169
+ }
170
+ )
171
+
172
+ env = state.get_env_state('agent-1')
173
+ assert env.agent_balance == 15000.0
174
+ assert env.agent_pnl == 500.0
175
+ assert env.open_positions == 3
176
+
177
+ def test_get_env_state_unknown_agent(self):
178
+ """Test environment state for unknown agent"""
179
+ state = GameState()
180
+ env = state.get_env_state('unknown-agent')
181
+
182
+ # Should return default values
183
+ assert env.agent_balance == 10000.0
184
+ assert env.agent_pnl == 0.0
185
+ assert env.open_positions == 0
186
+
187
+
188
+ # ============================================================
189
+ # SimulatorConfig Tests
190
+ # ============================================================
191
+
192
+ class TestSimulatorConfig:
193
+ """Tests for SimulatorConfig"""
194
+
195
+ def test_default_config(self):
196
+ """Test default configuration"""
197
+ config = SimulatorConfig()
198
+ assert config.mode == 'data_generation'
199
+ assert config.max_concurrent_agents == 8
200
+ assert config.batch_size == 4
201
+ assert config.ticks_per_window == 60
202
+ assert config.min_actions_per_trajectory == 5
203
+
204
+ def test_benchmark_mode_config(self):
205
+ """Test benchmark mode configuration"""
206
+ config = SimulatorConfig(
207
+ mode='benchmark',
208
+ benchmark_snapshot={'ticks': []},
209
+ ground_truth={'marketOutcomes': {}}
210
+ )
211
+ assert config.mode == 'benchmark'
212
+ assert config.benchmark_snapshot is not None
213
+ assert config.ground_truth is not None
214
+
215
+
216
+ # ============================================================
217
+ # FastSimulator Tests
218
+ # ============================================================
219
+
220
+ class TestFastSimulator:
221
+ """Tests for FastSimulator"""
222
+
223
+ def test_for_benchmark(self):
224
+ """Test benchmark mode creation"""
225
+ snapshot = {
226
+ 'ticks': [
227
+ {'state': {'currentTime': 1000}},
228
+ {'state': {'currentTime': 2000}}
229
+ ],
230
+ 'groundTruth': {'marketOutcomes': {}},
231
+ 'initialState': {
232
+ 'predictionMarkets': [{'id': 'm1'}],
233
+ 'currentTime': 1000
234
+ }
235
+ }
236
+
237
+ sim = FastSimulator.for_benchmark(snapshot)
238
+
239
+ assert sim.config.mode == 'benchmark'
240
+ assert len(sim.benchmark_ticks) == 2
241
+ assert sim.game_state.markets == [{'id': 'm1'}]
242
+
243
+ def test_is_complete_benchmark(self):
244
+ """Test completion check in benchmark mode"""
245
+ snapshot = {'ticks': [{}] * 5, 'groundTruth': {}, 'initialState': {}}
246
+ sim = FastSimulator.for_benchmark(snapshot)
247
+
248
+ assert not sim.is_complete()
249
+ sim.current_tick = 5
250
+ assert sim.is_complete()
251
+
252
+ def test_is_complete_data_generation(self):
253
+ """Test completion check in data generation mode"""
254
+ config = SimulatorConfig(max_ticks=100)
255
+ sim = FastSimulator(config)
256
+
257
+ assert not sim.is_complete()
258
+ sim.current_tick = 100
259
+ assert sim.is_complete()
260
+
261
+ def test_advance_tick(self):
262
+ """Test tick advancement"""
263
+ config = SimulatorConfig()
264
+ sim = FastSimulator(config)
265
+
266
+ initial_tick = sim.current_tick
267
+ initial_time = sim.game_state.time
268
+
269
+ sim._advance_tick()
270
+
271
+ assert sim.current_tick == initial_tick + 1
272
+ assert sim.game_state.time == initial_time + 1000
273
+
274
+
275
+ # ============================================================
276
+ # AgentTickData Tests
277
+ # ============================================================
278
+
279
+ class TestAgentTickData:
280
+ """Tests for AgentTickData"""
281
+
282
+ def test_get_full_context(self, sample_tick_data):
283
+ """Test full context generation"""
284
+ context = sample_tick_data.get_full_context()
285
+
286
+ assert '=== OBSERVATION' in context
287
+ assert '=== LLM CALL 1' in context
288
+ assert '=== ACTION ===' in context
289
+ assert 'buy' in context.lower()
290
+
291
+ def test_get_full_context_no_action(self, sample_env_state, sample_llm_call):
292
+ """Test context without action"""
293
+ tick = AgentTickData(
294
+ tick_number=0,
295
+ timestamp=1000000,
296
+ observation={},
297
+ environment_state=sample_env_state,
298
+ llm_calls=[sample_llm_call],
299
+ action=None,
300
+ reward=0.0
301
+ )
302
+
303
+ context = tick.get_full_context()
304
+ assert '=== OBSERVATION' in context
305
+ assert '=== ACTION ===' not in context
306
+
307
+
308
+ # ============================================================
309
+ # RolloutQualityValidator Tests
310
+ # ============================================================
311
+
312
+ class TestRolloutQualityValidator:
313
+ """Tests for RolloutQualityValidator"""
314
+
315
+ def test_validate_valid_rollout(self, sample_trajectory):
316
+ """Test validation of valid rollout"""
317
+ result = RolloutResult(
318
+ agent_id='test',
319
+ trajectory_id='test-traj',
320
+ ticks_completed=10,
321
+ total_duration_ms=5000,
322
+ avg_tick_duration_ms=500.0,
323
+ total_llm_calls=15,
324
+ total_reward=5.0,
325
+ final_pnl=500.0,
326
+ quality_score=0.7,
327
+ trajectory=sample_trajectory
328
+ )
329
+
330
+ is_valid, issues = RolloutQualityValidator.validate_rollout(result)
331
+
332
+ # Should have some issues due to LLM call requirements per step
333
+ assert isinstance(is_valid, bool)
334
+ assert isinstance(issues, list)
335
+
336
+ def test_validate_no_trajectory(self):
337
+ """Test validation with no trajectory"""
338
+ result = RolloutResult(
339
+ agent_id='test',
340
+ trajectory_id='test-traj',
341
+ ticks_completed=10,
342
+ total_duration_ms=5000,
343
+ avg_tick_duration_ms=500.0,
344
+ total_llm_calls=15,
345
+ total_reward=5.0,
346
+ final_pnl=500.0,
347
+ quality_score=0.7,
348
+ trajectory=None
349
+ )
350
+
351
+ is_valid, issues = RolloutQualityValidator.validate_rollout(result)
352
+
353
+ assert not is_valid
354
+ assert 'No trajectory data' in issues
355
+
356
+ def test_validate_too_few_ticks(self, sample_trajectory):
357
+ """Test validation with too few ticks"""
358
+ result = RolloutResult(
359
+ agent_id='test',
360
+ trajectory_id='test-traj',
361
+ ticks_completed=3, # Less than 5
362
+ total_duration_ms=1500,
363
+ avg_tick_duration_ms=500.0,
364
+ total_llm_calls=3,
365
+ total_reward=0.3,
366
+ final_pnl=100.0,
367
+ quality_score=0.3,
368
+ trajectory=sample_trajectory
369
+ )
370
+
371
+ is_valid, issues = RolloutQualityValidator.validate_rollout(result)
372
+
373
+ # Should flag too few ticks
374
+ assert any('Too few ticks' in issue for issue in issues)
375
+
376
+ def test_validate_low_quality_score(self, sample_trajectory):
377
+ """Test validation with low quality score"""
378
+ result = RolloutResult(
379
+ agent_id='test',
380
+ trajectory_id='test-traj',
381
+ ticks_completed=10,
382
+ total_duration_ms=5000,
383
+ avg_tick_duration_ms=500.0,
384
+ total_llm_calls=10,
385
+ total_reward=1.0,
386
+ final_pnl=100.0,
387
+ quality_score=0.3, # Below 0.5 threshold
388
+ trajectory=sample_trajectory
389
+ )
390
+
391
+ is_valid, issues = RolloutQualityValidator.validate_rollout(result)
392
+
393
+ # Should flag low quality
394
+ assert any('Quality score too low' in issue for issue in issues)
395
+
396
+
397
+ # ============================================================
398
+ # MultiPromptDatasetBuilder Tests
399
+ # ============================================================
400
+
401
+ class TestMultiPromptDatasetBuilder:
402
+ """Tests for MultiPromptDatasetBuilder"""
403
+
404
+ def test_initialization(self):
405
+ """Test builder initialization"""
406
+ builder = MultiPromptDatasetBuilder()
407
+
408
+ assert len(builder.datasets) == 4
409
+ assert 'action' in builder.datasets
410
+ assert 'reasoning' in builder.datasets
411
+ assert 'evaluation' in builder.datasets
412
+ assert 'response' in builder.datasets
413
+ assert builder.total_trajectories == 0
414
+
415
+ def test_add_trajectory(self, sample_trajectory):
416
+ """Test adding trajectory"""
417
+ builder = MultiPromptDatasetBuilder()
418
+
419
+ samples_added = builder.add_trajectory(sample_trajectory, trajectory_score=0.8)
420
+
421
+ assert samples_added == 5 # One per step
422
+ assert builder.total_trajectories == 1
423
+ assert builder.total_steps == 5
424
+ assert builder.total_samples == 5
425
+
426
+ def test_get_statistics(self, sample_trajectory):
427
+ """Test statistics calculation"""
428
+ builder = MultiPromptDatasetBuilder()
429
+ builder.add_trajectory(sample_trajectory, trajectory_score=0.8)
430
+
431
+ stats = builder.get_statistics()
432
+
433
+ assert stats['total_trajectories'] == 1
434
+ assert stats['total_samples'] == 5
435
+ assert 'by_purpose' in stats
436
+ assert 'action' in stats['by_purpose']
437
+
438
+ def test_build_training_data(self, sample_trajectory):
439
+ """Test training data building"""
440
+ builder = MultiPromptDatasetBuilder()
441
+
442
+ # Add multiple trajectories
443
+ for i in range(4):
444
+ builder.add_trajectory(sample_trajectory, trajectory_score=0.5 + i * 0.1)
445
+
446
+ groups = builder.build_training_data(purpose='action', group_size=4)
447
+
448
+ # Should create some groups
449
+ assert isinstance(groups, list)
450
+ if groups:
451
+ assert isinstance(groups[0], AtroposScoredGroup)
452
+
453
+
454
+ # ============================================================
455
+ # PromptSample Tests
456
+ # ============================================================
457
+
458
+ class TestPromptSample:
459
+ """Tests for PromptSample"""
460
+
461
+ def test_to_messages(self):
462
+ """Test message conversion"""
463
+ sample = PromptSample(
464
+ trajectory_id='t1',
465
+ step_number=0,
466
+ call_index=0,
467
+ system_prompt='You are a trading agent.',
468
+ user_prompt='What should I do?',
469
+ response='Buy BTC',
470
+ purpose='action',
471
+ action_type='buy',
472
+ model='gpt-4',
473
+ temperature=0.7,
474
+ trajectory_score=0.8,
475
+ step_reward=0.1,
476
+ action_success=True,
477
+ environment_context={'balance': 10000},
478
+ previous_actions=['wait']
479
+ )
480
+
481
+ messages = sample.to_messages()
482
+
483
+ assert len(messages) == 3
484
+ assert messages[0]['role'] == 'system'
485
+ assert messages[1]['role'] == 'user'
486
+ assert messages[2]['role'] == 'assistant'
487
+
488
+ def test_get_weighted_score(self):
489
+ """Test weighted score calculation"""
490
+ sample = PromptSample(
491
+ trajectory_id='t1',
492
+ step_number=0,
493
+ call_index=0,
494
+ system_prompt='sys',
495
+ user_prompt='user',
496
+ response='resp',
497
+ purpose='action',
498
+ action_type='buy',
499
+ model='gpt-4',
500
+ temperature=0.7,
501
+ trajectory_score=0.8,
502
+ step_reward=0.1,
503
+ action_success=True,
504
+ environment_context={},
505
+ previous_actions=[]
506
+ )
507
+
508
+ score = sample.get_weighted_score()
509
+
510
+ # Should be higher than base due to success bonus and step reward
511
+ assert score > 0.8
512
+ assert score <= 1.0
513
+
514
+
515
+ # ============================================================
516
+ # PromptTypeAnalyzer Tests
517
+ # ============================================================
518
+
519
+ class TestPromptTypeAnalyzer:
520
+ """Tests for PromptTypeAnalyzer"""
521
+
522
+ def test_analyze_correlation(self, sample_trajectory):
523
+ """Test correlation analysis"""
524
+ trajs = [sample_trajectory]
525
+ scores = [0.8]
526
+
527
+ analysis = PromptTypeAnalyzer.analyze_correlation(trajs, scores)
528
+
529
+ assert 'prompt_count_by_purpose' in analysis
530
+ assert 'avg_length_by_purpose' in analysis
531
+ assert 'high_score_characteristics' in analysis
532
+ assert 'low_score_characteristics' in analysis
533
+
534
+ def test_analyze_high_low_scores(self, sample_trajectory):
535
+ """Test high/low score classification"""
536
+ # Create trajectories with different scores
537
+ trajs = [sample_trajectory, sample_trajectory]
538
+ scores = [0.9, 0.2] # One high, one low
539
+
540
+ analysis = PromptTypeAnalyzer.analyze_correlation(trajs, scores)
541
+
542
+ # Should have entries in both
543
+ assert len(analysis['high_score_characteristics']) > 0
544
+ assert len(analysis['low_score_characteristics']) > 0
545
+
546
+
547
+ # ============================================================
548
+ # Integration Tests
549
+ # ============================================================
550
+
551
+ class TestIntegration:
552
+ """Integration tests for the full pipeline"""
553
+
554
+ def test_prepare_multi_prompt_training_data(self, sample_trajectory):
555
+ """Test convenience function"""
556
+ trajectories = [sample_trajectory] * 4
557
+ scores = [0.8, 0.6, 0.4, 0.9]
558
+
559
+ result = prepare_multi_prompt_training_data(
560
+ trajectories=trajectories,
561
+ scores=scores,
562
+ group_size=4
563
+ )
564
+
565
+ # Should return dict with purposes as keys
566
+ assert isinstance(result, dict)
567
+ # May or may not have groups depending on variance
568
+
569
+ def test_trajectory_count_score_mismatch(self, sample_trajectory):
570
+ """Test error on mismatched counts"""
571
+ with pytest.raises(ValueError, match='Trajectory count'):
572
+ prepare_multi_prompt_training_data(
573
+ trajectories=[sample_trajectory],
574
+ scores=[0.8, 0.6] # Wrong count
575
+ )
576
+
577
+
578
+ # Run tests
579
+ if __name__ == "__main__":
580
+ pytest.main([__file__, "-v"])
581
+