@elizaos/training 2.0.0-alpha.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. package/Dockerfile +75 -0
  2. package/Makefile +374 -0
  3. package/README.md +346 -0
  4. package/config/rubrics.json +137 -0
  5. package/data/.gitkeep +0 -0
  6. package/data/degen/.gitkeep +2 -0
  7. package/data/trader/.gitkeep +2 -0
  8. package/docker-compose.test.yml +57 -0
  9. package/package.json +58 -0
  10. package/python/config/babylon_atropos.yaml +90 -0
  11. package/python/config/profiles/12gb.json +11 -0
  12. package/python/config/profiles/16gb.json +10 -0
  13. package/python/config/profiles/24gb.json +10 -0
  14. package/python/config/profiles/48gb.json +10 -0
  15. package/python/config/profiles/cpu.json +11 -0
  16. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  17. package/python/config/profiles/l40-2gpu.json +22 -0
  18. package/python/config/profiles/l40-4gpu.json +21 -0
  19. package/python/config/profiles/l40.json +17 -0
  20. package/python/config/tinker_training.yaml +143 -0
  21. package/python/curriculum_state.json +165 -0
  22. package/python/env.template +86 -0
  23. package/python/env.training.template +46 -0
  24. package/python/pyproject.toml +41 -0
  25. package/python/requirements-ci.txt +31 -0
  26. package/python/requirements.txt +87 -0
  27. package/python/scripts/__init__.py +4 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/run_ab_test.py +143 -0
  36. package/python/scripts/run_full_pipeline.py +544 -0
  37. package/python/scripts/run_tinker_training.py +192 -0
  38. package/python/scripts/run_training.py +914 -0
  39. package/python/scripts/test_judge.py +155 -0
  40. package/python/scripts/test_pipeline.py +356 -0
  41. package/python/scripts/test_trained_model.py +380 -0
  42. package/python/scripts/train_local.py +528 -0
  43. package/python/setup.py +20 -0
  44. package/python/src/__init__.py +190 -0
  45. package/python/src/data_bridge/__init__.py +24 -0
  46. package/python/src/data_bridge/converter.py +435 -0
  47. package/python/src/data_bridge/reader.py +393 -0
  48. package/python/src/models.py +283 -0
  49. package/python/src/training/__init__.py +605 -0
  50. package/python/src/training/ab_testing.py +404 -0
  51. package/python/src/training/action_executor.py +621 -0
  52. package/python/src/training/archetype_trainer.py +347 -0
  53. package/python/src/training/atropos_trainer.py +980 -0
  54. package/python/src/training/babylon_env.py +1254 -0
  55. package/python/src/training/error_recovery.py +647 -0
  56. package/python/src/training/evaluation.py +856 -0
  57. package/python/src/training/fast_simulator.py +880 -0
  58. package/python/src/training/format_validator.py +584 -0
  59. package/python/src/training/hybrid_env.py +522 -0
  60. package/python/src/training/kl_controller.py +628 -0
  61. package/python/src/training/multi_prompt_dataset.py +883 -0
  62. package/python/src/training/multi_turn.py +656 -0
  63. package/python/src/training/online_env.py +1084 -0
  64. package/python/src/training/quality_scorer.py +391 -0
  65. package/python/src/training/quality_utils.py +633 -0
  66. package/python/src/training/rewards.py +1344 -0
  67. package/python/src/training/rlaif_env.py +17 -0
  68. package/python/src/training/rollout_generator.py +502 -0
  69. package/python/src/training/rubric_loader.py +198 -0
  70. package/python/src/training/scenario_pool.py +1072 -0
  71. package/python/src/training/schemas.py +481 -0
  72. package/python/src/training/service_manager.py +552 -0
  73. package/python/src/training/simulation_bridge.py +535 -0
  74. package/python/src/training/tick_reward_attribution.py +399 -0
  75. package/python/src/training/tinker_client.py +575 -0
  76. package/python/src/training/tinker_trainer.py +646 -0
  77. package/python/src/training/tokenization_utils.py +402 -0
  78. package/python/tests/e2e/__init__.py +13 -0
  79. package/python/tests/e2e/conftest.py +258 -0
  80. package/python/tests/e2e/test_full_pipeline.py +643 -0
  81. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  82. package/python/tests/integration/__init__.py +12 -0
  83. package/python/tests/integration/conftest.py +383 -0
  84. package/python/tests/integration/test_db_integration.py +649 -0
  85. package/python/tests/integration/test_json_mode_integration.py +554 -0
  86. package/python/tests/test_action_executor.py +594 -0
  87. package/python/tests/test_archetype_scoring.py +1027 -0
  88. package/python/tests/test_atropos_integration.py +360 -0
  89. package/python/tests/test_evaluation.py +727 -0
  90. package/python/tests/test_format_validator.py +486 -0
  91. package/python/tests/test_kl_controller.py +432 -0
  92. package/python/tests/test_lr_scheduler.py +579 -0
  93. package/python/tests/test_multi_turn.py +590 -0
  94. package/python/tests/test_online_env.py +519 -0
  95. package/python/tests/test_quality_scorer.py +474 -0
  96. package/python/tests/test_scenario_pool.py +735 -0
  97. package/python/tests/test_service_manager.py +585 -0
  98. package/python/tests/test_simulation_rollout.py +581 -0
  99. package/python/tests/test_tokenization_utils.py +501 -0
  100. package/python/tests/test_training_orchestrator.py +497 -0
  101. package/python/tests/test_training_output_structure.py +661 -0
  102. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  103. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  104. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  105. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  106. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  107. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  108. package/scripts/assess-training-data.ts +422 -0
  109. package/scripts/e2e-training-test.ts +550 -0
  110. package/scripts/export-rubrics.ts +64 -0
  111. package/scripts/generate-research-report.ts +1523 -0
  112. package/scripts/generate_dataset.sh +173 -0
  113. package/scripts/json-mode-benchmark.ts +399 -0
  114. package/scripts/real-archetype-benchmark.ts +210 -0
  115. package/scripts/run-baseline-comparison.ts +116 -0
  116. package/scripts/run-full-pipeline.ts +272 -0
  117. package/scripts/runpod_setup.sh +137 -0
  118. package/scripts/runpod_validate.sh +147 -0
  119. package/scripts/test-model-in-game.ts +955 -0
  120. package/scripts/test-scoring.ts +73 -0
  121. package/scripts/test-trained-model.ts +209 -0
  122. package/scripts/train-and-test.ts +824 -0
  123. package/scripts/verify-final.ts +118 -0
  124. package/src/adapter.ts +516 -0
  125. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  126. package/src/archetypes/derive-archetype.ts +249 -0
  127. package/src/archetypes/index.ts +22 -0
  128. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  129. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  130. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  131. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  132. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  133. package/src/benchmark/BenchmarkRunner.ts +685 -0
  134. package/src/benchmark/BenchmarkValidator.ts +206 -0
  135. package/src/benchmark/FastEvalRunner.ts +225 -0
  136. package/src/benchmark/MetricsValidator.ts +165 -0
  137. package/src/benchmark/MetricsVisualizer.ts +909 -0
  138. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  139. package/src/benchmark/ModelRegistry.ts +158 -0
  140. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  141. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  142. package/src/benchmark/SimulationEngine.ts +832 -0
  143. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  144. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  145. package/src/benchmark/index.ts +89 -0
  146. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  147. package/src/benchmark/simulation-types.ts +78 -0
  148. package/src/dependencies.ts +439 -0
  149. package/src/generation/TrajectoryGenerator.ts +387 -0
  150. package/src/generation/index.ts +12 -0
  151. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  152. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  153. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  154. package/src/huggingface/index.ts +27 -0
  155. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  156. package/src/index.ts +102 -0
  157. package/src/init-training.ts +53 -0
  158. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  159. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  160. package/src/metrics/index.ts +8 -0
  161. package/src/metrics/types.ts +200 -0
  162. package/src/rubrics/__tests__/index.test.ts +184 -0
  163. package/src/rubrics/ass-kisser.ts +85 -0
  164. package/src/rubrics/degen.ts +80 -0
  165. package/src/rubrics/goody-twoshoes.ts +84 -0
  166. package/src/rubrics/index.ts +236 -0
  167. package/src/rubrics/information-trader.ts +84 -0
  168. package/src/rubrics/infosec.ts +101 -0
  169. package/src/rubrics/liar.ts +104 -0
  170. package/src/rubrics/perps-trader.ts +87 -0
  171. package/src/rubrics/researcher.ts +81 -0
  172. package/src/rubrics/scammer.ts +82 -0
  173. package/src/rubrics/social-butterfly.ts +73 -0
  174. package/src/rubrics/super-predictor.ts +97 -0
  175. package/src/rubrics/trader.ts +67 -0
  176. package/src/scoring/ArchetypeScoringService.ts +486 -0
  177. package/src/scoring/JudgePromptBuilder.ts +556 -0
  178. package/src/scoring/LLMJudgeCache.ts +401 -0
  179. package/src/scoring/index.ts +9 -0
  180. package/src/training/AutomationPipeline.ts +916 -0
  181. package/src/training/BenchmarkService.ts +518 -0
  182. package/src/training/ConfigValidator.ts +220 -0
  183. package/src/training/MarketOutcomesTracker.ts +187 -0
  184. package/src/training/ModelDeployer.ts +186 -0
  185. package/src/training/ModelFetcher.ts +76 -0
  186. package/src/training/ModelSelectionService.ts +341 -0
  187. package/src/training/ModelUsageVerifier.ts +160 -0
  188. package/src/training/MultiModelOrchestrator.ts +580 -0
  189. package/src/training/RLModelConfig.ts +407 -0
  190. package/src/training/RewardBackpropagationService.ts +149 -0
  191. package/src/training/RulerScoringService.ts +666 -0
  192. package/src/training/TrainingMonitor.ts +166 -0
  193. package/src/training/TrajectoryRecorder.ts +399 -0
  194. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  195. package/src/training/index.ts +100 -0
  196. package/src/training/logRLConfig.ts +34 -0
  197. package/src/training/pipeline.ts +129 -0
  198. package/src/training/storage/ModelStorageService.ts +279 -0
  199. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  200. package/src/training/storage/index.ts +17 -0
  201. package/src/training/types.ts +207 -0
  202. package/src/training/window-utils.ts +138 -0
  203. package/src/utils/index.ts +101 -0
  204. package/src/utils/logger.ts +59 -0
  205. package/src/utils/snowflake.ts +17 -0
  206. package/src/utils/synthetic-detector.ts +111 -0
  207. package/tsconfig.json +20 -0
@@ -0,0 +1,735 @@
1
+ """
2
+ Tests for ScenarioPool and CurriculumManager
3
+
4
+ Tests cover:
5
+ - Scenario generation (synthetic and structure)
6
+ - Curriculum learning (tracking, priorities, reset)
7
+ - Pool management (sampling, refresh, persistence)
8
+ """
9
+
10
+ import json
11
+ import tempfile
12
+ from pathlib import Path
13
+ from unittest.mock import AsyncMock, patch
14
+
15
+ import pytest
16
+
17
+ from src.training.scenario_pool import (
18
+ CurriculumManager,
19
+ CurriculumState,
20
+ MarketState,
21
+ NewsItem,
22
+ PerpetualState,
23
+ PortfolioState,
24
+ Scenario,
25
+ ScenarioPool,
26
+ ScenarioPoolConfig,
27
+ SocialPost,
28
+ )
29
+
30
+
31
+ # =============================================================================
32
+ # Data Structure Tests
33
+ # =============================================================================
34
+
35
+
36
+ class TestMarketState:
37
+ """Tests for MarketState dataclass"""
38
+
39
+ def test_creation(self):
40
+ market = MarketState(
41
+ market_id="test-1",
42
+ question="Will BTC exceed $100K?",
43
+ yes_price=0.65,
44
+ no_price=0.35,
45
+ volume_24h=100000.0,
46
+ liquidity=500000.0,
47
+ expires_at=1735689600000,
48
+ category="crypto",
49
+ )
50
+
51
+ assert market.market_id == "test-1"
52
+ assert market.yes_price == 0.65
53
+ assert market.no_price == 0.35
54
+
55
+ def test_to_dict(self):
56
+ market = MarketState(
57
+ market_id="test-1",
58
+ question="Will BTC exceed $100K?",
59
+ yes_price=0.65,
60
+ no_price=0.35,
61
+ volume_24h=100000.0,
62
+ liquidity=500000.0,
63
+ expires_at=1735689600000,
64
+ )
65
+
66
+ result = market.to_dict()
67
+
68
+ assert result["id"] == "test-1"
69
+ assert result["yesPrice"] == 0.65
70
+ assert result["noPrice"] == 0.35
71
+ assert result["volume24h"] == 100000.0
72
+ assert "question" in result
73
+
74
+
75
+ class TestPerpetualState:
76
+ """Tests for PerpetualState dataclass"""
77
+
78
+ def test_creation(self):
79
+ perp = PerpetualState(
80
+ ticker="BTC",
81
+ mark_price=100000.0,
82
+ index_price=99990.0,
83
+ funding_rate=0.0001,
84
+ open_interest=50000000.0,
85
+ volume_24h=100000000.0,
86
+ change_24h=0.02,
87
+ high_24h=102000.0,
88
+ low_24h=98000.0,
89
+ )
90
+
91
+ assert perp.ticker == "BTC"
92
+ assert perp.mark_price == 100000.0
93
+
94
+ def test_to_dict(self):
95
+ perp = PerpetualState(
96
+ ticker="ETH",
97
+ mark_price=3500.0,
98
+ index_price=3495.0,
99
+ funding_rate=-0.0002,
100
+ open_interest=25000000.0,
101
+ volume_24h=50000000.0,
102
+ change_24h=-0.01,
103
+ high_24h=3600.0,
104
+ low_24h=3400.0,
105
+ )
106
+
107
+ result = perp.to_dict()
108
+
109
+ assert result["ticker"] == "ETH"
110
+ assert result["markPrice"] == 3500.0
111
+ assert result["fundingRate"] == -0.0002
112
+
113
+
114
+ class TestScenario:
115
+ """Tests for Scenario dataclass"""
116
+
117
+ def test_creation_minimal(self):
118
+ scenario = Scenario(
119
+ id="test-scenario",
120
+ source="synthetic",
121
+ )
122
+
123
+ assert scenario.id == "test-scenario"
124
+ assert scenario.source == "synthetic"
125
+ assert len(scenario.markets) == 0
126
+ assert scenario.portfolio.balance == 10000.0
127
+
128
+ def test_creation_full(self):
129
+ market = MarketState(
130
+ market_id="m1",
131
+ question="Test?",
132
+ yes_price=0.5,
133
+ no_price=0.5,
134
+ volume_24h=1000.0,
135
+ liquidity=5000.0,
136
+ expires_at=1735689600000,
137
+ )
138
+
139
+ perp = PerpetualState(
140
+ ticker="BTC",
141
+ mark_price=100000.0,
142
+ index_price=100000.0,
143
+ funding_rate=0.0,
144
+ open_interest=1000000.0,
145
+ volume_24h=5000000.0,
146
+ change_24h=0.0,
147
+ high_24h=100000.0,
148
+ low_24h=100000.0,
149
+ )
150
+
151
+ news = NewsItem(
152
+ headline="Test headline",
153
+ sentiment="bullish",
154
+ impact="high",
155
+ source="Test",
156
+ timestamp=1735689600000,
157
+ )
158
+
159
+ post = SocialPost(
160
+ author="test_user",
161
+ content="Test post",
162
+ sentiment="neutral",
163
+ likes=10,
164
+ replies=2,
165
+ timestamp=1735689600000,
166
+ )
167
+
168
+ scenario = Scenario(
169
+ id="full-scenario",
170
+ source="production",
171
+ markets=[market],
172
+ perpetuals=[perp],
173
+ news=[news],
174
+ social_posts=[post],
175
+ portfolio=PortfolioState(balance=15000.0),
176
+ archetype_focus="trader",
177
+ difficulty="hard",
178
+ )
179
+
180
+ assert scenario.id == "full-scenario"
181
+ assert len(scenario.markets) == 1
182
+ assert len(scenario.perpetuals) == 1
183
+ assert len(scenario.news) == 1
184
+ assert len(scenario.social_posts) == 1
185
+ assert scenario.portfolio.balance == 15000.0
186
+ assert scenario.archetype_focus == "trader"
187
+ assert scenario.difficulty == "hard"
188
+
189
+ def test_to_dict(self):
190
+ scenario = Scenario(
191
+ id="dict-test",
192
+ source="synthetic",
193
+ markets=[MarketState(
194
+ market_id="m1",
195
+ question="Test?",
196
+ yes_price=0.6,
197
+ no_price=0.4,
198
+ volume_24h=1000.0,
199
+ liquidity=5000.0,
200
+ expires_at=1735689600000,
201
+ )],
202
+ difficulty="easy",
203
+ )
204
+
205
+ result = scenario.to_dict()
206
+
207
+ assert result["id"] == "dict-test"
208
+ assert result["source"] == "synthetic"
209
+ assert len(result["markets"]) == 1
210
+ assert result["difficulty"] == "easy"
211
+ assert "portfolio" in result
212
+
213
+ def test_to_observation(self):
214
+ scenario = Scenario(
215
+ id="obs-test",
216
+ source="synthetic",
217
+ markets=[MarketState(
218
+ market_id="m1",
219
+ question="Will BTC moon?",
220
+ yes_price=0.7,
221
+ no_price=0.3,
222
+ volume_24h=100000.0,
223
+ liquidity=500000.0,
224
+ expires_at=1735689600000,
225
+ )],
226
+ news=[NewsItem(
227
+ headline="Bullish news",
228
+ sentiment="bullish",
229
+ impact="high",
230
+ source="Test",
231
+ timestamp=1735689600000,
232
+ )],
233
+ )
234
+
235
+ obs = scenario.to_observation()
236
+
237
+ assert "markets" in obs
238
+ assert "perpetuals" in obs
239
+ assert "news" in obs
240
+ assert "portfolio" in obs
241
+ assert "marketSummary" in obs
242
+ assert obs["marketSummary"]["totalMarkets"] == 1
243
+
244
+ def test_sentiment_calculation(self):
245
+ # Mostly bullish
246
+ scenario = Scenario(
247
+ id="bullish-test",
248
+ source="synthetic",
249
+ news=[
250
+ NewsItem(headline="Bull1", sentiment="bullish", impact="high", source="X", timestamp=0),
251
+ NewsItem(headline="Bull2", sentiment="bullish", impact="high", source="X", timestamp=0),
252
+ NewsItem(headline="Neutral", sentiment="neutral", impact="low", source="X", timestamp=0),
253
+ ],
254
+ )
255
+
256
+ obs = scenario.to_observation()
257
+ assert obs["marketSummary"]["avgSentiment"] == "bullish"
258
+
259
+ # Mostly bearish
260
+ scenario2 = Scenario(
261
+ id="bearish-test",
262
+ source="synthetic",
263
+ news=[
264
+ NewsItem(headline="Bear1", sentiment="bearish", impact="high", source="X", timestamp=0),
265
+ NewsItem(headline="Bear2", sentiment="bearish", impact="high", source="X", timestamp=0),
266
+ ],
267
+ )
268
+
269
+ obs2 = scenario2.to_observation()
270
+ assert obs2["marketSummary"]["avgSentiment"] == "bearish"
271
+
272
+
273
+ # =============================================================================
274
+ # CurriculumManager Tests
275
+ # =============================================================================
276
+
277
+
278
+ class TestCurriculumManager:
279
+ """Tests for CurriculumManager"""
280
+
281
+ def test_creation(self):
282
+ manager = CurriculumManager()
283
+
284
+ assert len(manager.attempts) == 0
285
+ assert len(manager.scores) == 0
286
+ assert len(manager.solved) == 0
287
+
288
+ def test_record_attempt(self):
289
+ manager = CurriculumManager()
290
+
291
+ manager.record_attempt("scenario-1", 0.5)
292
+
293
+ assert manager.attempts["scenario-1"] == 1
294
+ assert len(manager.scores["scenario-1"]) == 1
295
+ assert manager.scores["scenario-1"][0] == 0.5
296
+
297
+ def test_multiple_attempts(self):
298
+ manager = CurriculumManager()
299
+
300
+ manager.record_attempt("scenario-1", 0.3)
301
+ manager.record_attempt("scenario-1", 0.5)
302
+ manager.record_attempt("scenario-1", 0.7)
303
+
304
+ assert manager.attempts["scenario-1"] == 3
305
+ assert len(manager.scores["scenario-1"]) == 3
306
+
307
+ def test_solved_detection(self):
308
+ manager = CurriculumManager(
309
+ solve_threshold=0.8,
310
+ min_attempts_for_solved=3,
311
+ )
312
+
313
+ # Not enough attempts
314
+ manager.record_attempt("scenario-1", 0.9)
315
+ manager.record_attempt("scenario-1", 0.9)
316
+ assert "scenario-1" not in manager.solved
317
+
318
+ # Third attempt triggers solved check
319
+ manager.record_attempt("scenario-1", 0.9)
320
+ assert "scenario-1" in manager.solved
321
+
322
+ def test_solved_requires_high_scores(self):
323
+ manager = CurriculumManager(
324
+ solve_threshold=0.8,
325
+ min_attempts_for_solved=3,
326
+ )
327
+
328
+ # Scores below threshold
329
+ manager.record_attempt("scenario-1", 0.5)
330
+ manager.record_attempt("scenario-1", 0.6)
331
+ manager.record_attempt("scenario-1", 0.7)
332
+
333
+ assert "scenario-1" not in manager.solved
334
+
335
+ def test_should_skip(self):
336
+ manager = CurriculumManager(
337
+ max_avg_for_skip=0.85,
338
+ )
339
+
340
+ # Solved scenarios should be skipped
341
+ manager.solved.add("solved-scenario")
342
+ assert manager.should_skip("solved-scenario") is True
343
+
344
+ # High-scoring scenarios should be skipped
345
+ manager.scores["easy-scenario"] = [0.9, 0.9, 0.9]
346
+ assert manager.should_skip("easy-scenario") is True
347
+
348
+ # Low-scoring scenarios should not be skipped
349
+ manager.scores["hard-scenario"] = [0.3, 0.4, 0.5]
350
+ assert manager.should_skip("hard-scenario") is False
351
+
352
+ # New scenarios should not be skipped
353
+ assert manager.should_skip("new-scenario") is False
354
+
355
+ def test_get_priority(self):
356
+ manager = CurriculumManager()
357
+
358
+ # Solved scenarios have zero priority
359
+ manager.solved.add("solved")
360
+ assert manager.get_priority("solved") == 0.0
361
+
362
+ # New scenarios have high priority
363
+ priority_new = manager.get_priority("new-scenario")
364
+ assert priority_new > 0.5
365
+
366
+ # Difficult scenarios have higher priority
367
+ manager.scores["hard"] = [0.2, 0.3, 0.2]
368
+ manager.attempts["hard"] = 3
369
+
370
+ manager.scores["easy"] = [0.8, 0.9, 0.85]
371
+ manager.attempts["easy"] = 3
372
+
373
+ assert manager.get_priority("hard") > manager.get_priority("easy")
374
+
375
+ def test_reset(self):
376
+ manager = CurriculumManager()
377
+
378
+ manager.solved.add("scenario-1")
379
+ manager.solved.add("scenario-2")
380
+
381
+ manager.reset()
382
+
383
+ assert len(manager.solved) == 0
384
+
385
+ def test_get_stats(self):
386
+ manager = CurriculumManager()
387
+
388
+ manager.record_attempt("s1", 0.5)
389
+ manager.record_attempt("s1", 0.6)
390
+ manager.record_attempt("s2", 0.8)
391
+
392
+ stats = manager.get_stats()
393
+
394
+ assert stats["total_scenarios"] == 2
395
+ assert stats["total_attempts"] == 3
396
+ assert stats["avg_score"] == pytest.approx((0.5 + 0.6 + 0.8) / 3, rel=1e-3)
397
+
398
+ def test_checkpoint_save_load(self):
399
+ with tempfile.TemporaryDirectory() as tmpdir:
400
+ checkpoint_path = Path(tmpdir) / "curriculum.json"
401
+
402
+ # Create and populate manager
403
+ manager1 = CurriculumManager(checkpoint_path=str(checkpoint_path))
404
+ manager1.record_attempt("s1", 0.5)
405
+ manager1.record_attempt("s1", 0.6)
406
+ manager1.solved.add("s2")
407
+ manager1._save_checkpoint()
408
+
409
+ # Load in new manager
410
+ manager2 = CurriculumManager(checkpoint_path=str(checkpoint_path))
411
+
412
+ assert manager2.attempts["s1"] == 2
413
+ assert manager2.scores["s1"] == [0.5, 0.6]
414
+ assert "s2" in manager2.solved
415
+
416
+ def test_history_trimming(self):
417
+ manager = CurriculumManager(max_history_per_scenario=5)
418
+
419
+ for i in range(10):
420
+ manager.record_attempt("scenario", float(i) / 10)
421
+
422
+ assert len(manager.scores["scenario"]) == 5
423
+ # Should keep the most recent
424
+ assert manager.scores["scenario"][-1] == 0.9
425
+
426
+
427
+ # =============================================================================
428
+ # ScenarioPool Tests
429
+ # =============================================================================
430
+
431
+
432
+ class TestScenarioPool:
433
+ """Tests for ScenarioPool"""
434
+
435
+ def test_creation(self):
436
+ config = ScenarioPoolConfig()
437
+ pool = ScenarioPool(config)
438
+
439
+ assert len(pool.scenarios) == 0
440
+ assert pool._sample_counter == 0
441
+
442
+ def test_generate_synthetic_batch(self):
443
+ config = ScenarioPoolConfig()
444
+ pool = ScenarioPool(config)
445
+
446
+ scenarios = pool.generate_synthetic_batch(count=10)
447
+
448
+ assert len(scenarios) == 10
449
+ for scenario in scenarios:
450
+ assert scenario.source == "synthetic"
451
+ assert len(scenario.markets) > 0
452
+ assert len(scenario.perpetuals) > 0
453
+ assert scenario.id.startswith("synth-")
454
+
455
+ def test_generate_with_archetype_focus(self):
456
+ config = ScenarioPoolConfig()
457
+ pool = ScenarioPool(config)
458
+
459
+ scenarios = pool.generate_synthetic_batch(count=5, archetype_focus="degen")
460
+
461
+ for scenario in scenarios:
462
+ assert scenario.archetype_focus == "degen"
463
+
464
+ def test_difficulty_distribution(self):
465
+ config = ScenarioPoolConfig(
466
+ synthetic_difficulty_distribution={"easy": 0.5, "medium": 0.3, "hard": 0.2}
467
+ )
468
+ pool = ScenarioPool(config)
469
+
470
+ scenarios = pool.generate_synthetic_batch(count=100)
471
+
472
+ easy_count = sum(1 for s in scenarios if s.difficulty == "easy")
473
+ medium_count = sum(1 for s in scenarios if s.difficulty == "medium")
474
+ hard_count = sum(1 for s in scenarios if s.difficulty == "hard")
475
+
476
+ # Allow some variance
477
+ assert 40 <= easy_count <= 60
478
+ assert 20 <= medium_count <= 40
479
+ assert 10 <= hard_count <= 30
480
+
481
+ def test_sample_without_curriculum(self):
482
+ config = ScenarioPoolConfig(use_curriculum=False)
483
+ pool = ScenarioPool(config)
484
+ pool.scenarios = pool.generate_synthetic_batch(count=20)
485
+
486
+ sampled = pool.sample(count=5)
487
+
488
+ assert len(sampled) == 5
489
+ for s in sampled:
490
+ assert s in pool.scenarios
491
+
492
+ def test_sample_with_curriculum(self):
493
+ config = ScenarioPoolConfig(use_curriculum=True)
494
+ pool = ScenarioPool(config)
495
+ pool.scenarios = pool.generate_synthetic_batch(count=10)
496
+
497
+ # Initially all scenarios should be available
498
+ sampled = pool.sample(count=3)
499
+ assert len(sampled) == 3
500
+
501
+ # Record good scores for some scenarios
502
+ for s in sampled[:2]:
503
+ pool.curriculum.record_attempt(s.id, 0.9)
504
+ pool.curriculum.record_attempt(s.id, 0.9)
505
+ pool.curriculum.record_attempt(s.id, 0.9)
506
+
507
+ # Those should now be skipped
508
+ for _ in range(10):
509
+ new_sampled = pool.sample(count=5)
510
+ # Solved scenarios should have lower probability
511
+ solved_ids = {sampled[0].id, sampled[1].id}
512
+ sampled_ids = {s.id for s in new_sampled}
513
+ # They might still appear due to probability, but should be less frequent
514
+
515
+ def test_record_results(self):
516
+ config = ScenarioPoolConfig(use_curriculum=True)
517
+ pool = ScenarioPool(config)
518
+ pool.scenarios = pool.generate_synthetic_batch(count=5)
519
+
520
+ ids = [s.id for s in pool.scenarios[:3]]
521
+ scores = [0.5, 0.7, 0.9]
522
+
523
+ pool.record_results(ids, scores)
524
+
525
+ assert pool.curriculum.attempts[ids[0]] == 1
526
+ assert pool.curriculum.scores[ids[0]][0] == 0.5
527
+
528
+ def test_get_stats(self):
529
+ config = ScenarioPoolConfig(use_curriculum=True)
530
+ pool = ScenarioPool(config)
531
+ pool.scenarios = pool.generate_synthetic_batch(count=10)
532
+
533
+ stats = pool.get_stats()
534
+
535
+ assert stats["total_scenarios"] == 10
536
+ assert stats["synthetic_scenarios"] == 10
537
+ assert stats["production_scenarios"] == 0
538
+ assert "curriculum" in stats
539
+
540
+ def test_save_and_load_scenarios(self):
541
+ with tempfile.TemporaryDirectory() as tmpdir:
542
+ save_path = Path(tmpdir) / "scenarios.json"
543
+
544
+ config = ScenarioPoolConfig()
545
+ pool1 = ScenarioPool(config)
546
+ pool1.scenarios = pool1.generate_synthetic_batch(count=5)
547
+ original_ids = [s.id for s in pool1.scenarios]
548
+
549
+ pool1.save_scenarios(str(save_path))
550
+
551
+ pool2 = ScenarioPool(config)
552
+ pool2.load_scenarios(str(save_path))
553
+
554
+ loaded_ids = [s.id for s in pool2.scenarios]
555
+
556
+ assert original_ids == loaded_ids
557
+ assert len(pool2.scenarios) == 5
558
+
559
+ def test_refresh_mechanism(self):
560
+ config = ScenarioPoolConfig(refresh_interval=5)
561
+ pool = ScenarioPool(config)
562
+ pool.scenarios = pool.generate_synthetic_batch(count=10)
563
+
564
+ original_ids = {s.id for s in pool.scenarios}
565
+
566
+ # Sample 5 times (reaches refresh interval)
567
+ for _ in range(5):
568
+ pool.sample(count=1)
569
+
570
+ # Synthetic scenarios should be regenerated
571
+ new_ids = {s.id for s in pool.scenarios}
572
+
573
+ # At least some should be different
574
+ assert original_ids != new_ids
575
+
576
+ @pytest.mark.asyncio
577
+ async def test_initialize_without_database(self):
578
+ config = ScenarioPoolConfig(max_scenarios=10)
579
+ pool = ScenarioPool(config)
580
+
581
+ await pool.initialize()
582
+
583
+ # Should fill with synthetic scenarios
584
+ assert len(pool.scenarios) == 10
585
+ for s in pool.scenarios:
586
+ assert s.source == "synthetic"
587
+
588
+
589
+ class TestScenarioGeneration:
590
+ """Tests for scenario content generation"""
591
+
592
+ def test_random_market_generation(self):
593
+ config = ScenarioPoolConfig()
594
+ pool = ScenarioPool(config)
595
+
596
+ market = pool._generate_random_market(0)
597
+
598
+ assert market.market_id == "market-1"
599
+ assert 0.2 <= market.yes_price <= 0.8
600
+ assert market.yes_price + market.no_price == pytest.approx(1.0, rel=1e-3)
601
+ assert market.volume_24h > 0
602
+ assert market.liquidity > 0
603
+
604
+ def test_default_perpetuals_generation(self):
605
+ config = ScenarioPoolConfig()
606
+ pool = ScenarioPool(config)
607
+
608
+ perps = pool._generate_default_perpetuals()
609
+
610
+ assert len(perps) == 5 # BTC, ETH, SOL, DOGE, AVAX
611
+ tickers = {p.ticker for p in perps}
612
+ assert "BTC" in tickers
613
+ assert "ETH" in tickers
614
+
615
+ def test_news_generation(self):
616
+ config = ScenarioPoolConfig()
617
+ pool = ScenarioPool(config)
618
+
619
+ news = pool._generate_random_news(5, "medium")
620
+
621
+ assert len(news) == 5
622
+ for item in news:
623
+ assert item.headline
624
+ assert item.sentiment in ["bullish", "bearish", "neutral"]
625
+ assert item.impact in ["high", "medium", "low"]
626
+ assert item.source
627
+
628
+ def test_posts_generation(self):
629
+ config = ScenarioPoolConfig()
630
+ pool = ScenarioPool(config)
631
+
632
+ posts = pool._generate_random_posts(6)
633
+
634
+ assert len(posts) == 6
635
+ for post in posts:
636
+ assert post.author
637
+ assert post.content
638
+ assert post.sentiment in ["bullish", "bearish", "neutral"]
639
+ assert post.likes >= 0
640
+ assert post.replies >= 0
641
+
642
+ def test_contextual_news_generation(self):
643
+ config = ScenarioPoolConfig()
644
+ pool = ScenarioPool(config)
645
+
646
+ markets = [
647
+ MarketState(
648
+ market_id="m1",
649
+ question="Will BTC exceed $100K?",
650
+ yes_price=0.5,
651
+ no_price=0.5,
652
+ volume_24h=100000.0,
653
+ liquidity=500000.0,
654
+ expires_at=1735689600000,
655
+ )
656
+ ]
657
+
658
+ news = pool._generate_contextual_news(markets)
659
+
660
+ assert len(news) > 0
661
+ # Should have at least one BTC-related news item
662
+ btc_news = [n for n in news if "bitcoin" in n.headline.lower() or "btc" in n.headline.lower()]
663
+ assert len(btc_news) >= 1
664
+
665
+
666
+ # =============================================================================
667
+ # Integration Tests
668
+ # =============================================================================
669
+
670
+
671
+ class TestScenarioPoolIntegration:
672
+ """Integration tests for ScenarioPool"""
673
+
674
+ @pytest.mark.asyncio
675
+ async def test_full_workflow(self):
676
+ """Test complete workflow: init, sample, record, refresh"""
677
+ with tempfile.TemporaryDirectory() as tmpdir:
678
+ config = ScenarioPoolConfig(
679
+ max_scenarios=20,
680
+ refresh_interval=10,
681
+ use_curriculum=True,
682
+ curriculum_checkpoint_path=str(Path(tmpdir) / "curriculum.json"),
683
+ )
684
+ pool = ScenarioPool(config)
685
+
686
+ await pool.initialize()
687
+ assert len(pool.scenarios) == 20
688
+
689
+ # Sample and record results
690
+ for _ in range(5):
691
+ scenarios = pool.sample(count=2)
692
+ ids = [s.id for s in scenarios]
693
+ scores = [0.6, 0.8]
694
+ pool.record_results(ids, scores)
695
+
696
+ stats = pool.get_stats()
697
+ assert stats["curriculum"]["total_attempts"] == 10
698
+
699
+ # Save and reload
700
+ save_path = str(Path(tmpdir) / "scenarios.json")
701
+ pool.save_scenarios(save_path)
702
+
703
+ pool2 = ScenarioPool(config)
704
+ pool2.load_scenarios(save_path)
705
+ assert len(pool2.scenarios) == 20
706
+
707
+ def test_scenario_observation_format(self):
708
+ """Test that observations have correct format for agent consumption"""
709
+ config = ScenarioPoolConfig()
710
+ pool = ScenarioPool(config)
711
+ scenarios = pool.generate_synthetic_batch(count=1)
712
+ scenario = scenarios[0]
713
+
714
+ obs = scenario.to_observation()
715
+
716
+ # Verify required fields
717
+ required_fields = ["timestamp", "markets", "perpetuals", "news", "socialFeed", "portfolio", "marketSummary"]
718
+ for field in required_fields:
719
+ assert field in obs, f"Missing required field: {field}"
720
+
721
+ # Verify nested structure
722
+ if obs["markets"]:
723
+ market = obs["markets"][0]
724
+ assert "id" in market
725
+ assert "yesPrice" in market
726
+ assert "question" in market
727
+
728
+ if obs["perpetuals"]:
729
+ perp = obs["perpetuals"][0]
730
+ assert "ticker" in perp
731
+ assert "markPrice" in perp
732
+
733
+ assert "balance" in obs["portfolio"]
734
+
735
+