@elizaos/training 2.0.0-alpha.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. package/Dockerfile +75 -0
  2. package/Makefile +374 -0
  3. package/README.md +346 -0
  4. package/config/rubrics.json +137 -0
  5. package/data/.gitkeep +0 -0
  6. package/data/degen/.gitkeep +2 -0
  7. package/data/trader/.gitkeep +2 -0
  8. package/docker-compose.test.yml +57 -0
  9. package/package.json +58 -0
  10. package/python/config/babylon_atropos.yaml +90 -0
  11. package/python/config/profiles/12gb.json +11 -0
  12. package/python/config/profiles/16gb.json +10 -0
  13. package/python/config/profiles/24gb.json +10 -0
  14. package/python/config/profiles/48gb.json +10 -0
  15. package/python/config/profiles/cpu.json +11 -0
  16. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  17. package/python/config/profiles/l40-2gpu.json +22 -0
  18. package/python/config/profiles/l40-4gpu.json +21 -0
  19. package/python/config/profiles/l40.json +17 -0
  20. package/python/config/tinker_training.yaml +143 -0
  21. package/python/curriculum_state.json +165 -0
  22. package/python/env.template +86 -0
  23. package/python/env.training.template +46 -0
  24. package/python/pyproject.toml +41 -0
  25. package/python/requirements-ci.txt +31 -0
  26. package/python/requirements.txt +87 -0
  27. package/python/scripts/__init__.py +4 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/run_ab_test.py +143 -0
  36. package/python/scripts/run_full_pipeline.py +544 -0
  37. package/python/scripts/run_tinker_training.py +192 -0
  38. package/python/scripts/run_training.py +914 -0
  39. package/python/scripts/test_judge.py +155 -0
  40. package/python/scripts/test_pipeline.py +356 -0
  41. package/python/scripts/test_trained_model.py +380 -0
  42. package/python/scripts/train_local.py +528 -0
  43. package/python/setup.py +20 -0
  44. package/python/src/__init__.py +190 -0
  45. package/python/src/data_bridge/__init__.py +24 -0
  46. package/python/src/data_bridge/converter.py +435 -0
  47. package/python/src/data_bridge/reader.py +393 -0
  48. package/python/src/models.py +283 -0
  49. package/python/src/training/__init__.py +605 -0
  50. package/python/src/training/ab_testing.py +404 -0
  51. package/python/src/training/action_executor.py +621 -0
  52. package/python/src/training/archetype_trainer.py +347 -0
  53. package/python/src/training/atropos_trainer.py +980 -0
  54. package/python/src/training/babylon_env.py +1254 -0
  55. package/python/src/training/error_recovery.py +647 -0
  56. package/python/src/training/evaluation.py +856 -0
  57. package/python/src/training/fast_simulator.py +880 -0
  58. package/python/src/training/format_validator.py +584 -0
  59. package/python/src/training/hybrid_env.py +522 -0
  60. package/python/src/training/kl_controller.py +628 -0
  61. package/python/src/training/multi_prompt_dataset.py +883 -0
  62. package/python/src/training/multi_turn.py +656 -0
  63. package/python/src/training/online_env.py +1084 -0
  64. package/python/src/training/quality_scorer.py +391 -0
  65. package/python/src/training/quality_utils.py +633 -0
  66. package/python/src/training/rewards.py +1344 -0
  67. package/python/src/training/rlaif_env.py +17 -0
  68. package/python/src/training/rollout_generator.py +502 -0
  69. package/python/src/training/rubric_loader.py +198 -0
  70. package/python/src/training/scenario_pool.py +1072 -0
  71. package/python/src/training/schemas.py +481 -0
  72. package/python/src/training/service_manager.py +552 -0
  73. package/python/src/training/simulation_bridge.py +535 -0
  74. package/python/src/training/tick_reward_attribution.py +399 -0
  75. package/python/src/training/tinker_client.py +575 -0
  76. package/python/src/training/tinker_trainer.py +646 -0
  77. package/python/src/training/tokenization_utils.py +402 -0
  78. package/python/tests/e2e/__init__.py +13 -0
  79. package/python/tests/e2e/conftest.py +258 -0
  80. package/python/tests/e2e/test_full_pipeline.py +643 -0
  81. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  82. package/python/tests/integration/__init__.py +12 -0
  83. package/python/tests/integration/conftest.py +383 -0
  84. package/python/tests/integration/test_db_integration.py +649 -0
  85. package/python/tests/integration/test_json_mode_integration.py +554 -0
  86. package/python/tests/test_action_executor.py +594 -0
  87. package/python/tests/test_archetype_scoring.py +1027 -0
  88. package/python/tests/test_atropos_integration.py +360 -0
  89. package/python/tests/test_evaluation.py +727 -0
  90. package/python/tests/test_format_validator.py +486 -0
  91. package/python/tests/test_kl_controller.py +432 -0
  92. package/python/tests/test_lr_scheduler.py +579 -0
  93. package/python/tests/test_multi_turn.py +590 -0
  94. package/python/tests/test_online_env.py +519 -0
  95. package/python/tests/test_quality_scorer.py +474 -0
  96. package/python/tests/test_scenario_pool.py +735 -0
  97. package/python/tests/test_service_manager.py +585 -0
  98. package/python/tests/test_simulation_rollout.py +581 -0
  99. package/python/tests/test_tokenization_utils.py +501 -0
  100. package/python/tests/test_training_orchestrator.py +497 -0
  101. package/python/tests/test_training_output_structure.py +661 -0
  102. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  103. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  104. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  105. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  106. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  107. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  108. package/scripts/assess-training-data.ts +422 -0
  109. package/scripts/e2e-training-test.ts +550 -0
  110. package/scripts/export-rubrics.ts +64 -0
  111. package/scripts/generate-research-report.ts +1523 -0
  112. package/scripts/generate_dataset.sh +173 -0
  113. package/scripts/json-mode-benchmark.ts +399 -0
  114. package/scripts/real-archetype-benchmark.ts +210 -0
  115. package/scripts/run-baseline-comparison.ts +116 -0
  116. package/scripts/run-full-pipeline.ts +272 -0
  117. package/scripts/runpod_setup.sh +137 -0
  118. package/scripts/runpod_validate.sh +147 -0
  119. package/scripts/test-model-in-game.ts +955 -0
  120. package/scripts/test-scoring.ts +73 -0
  121. package/scripts/test-trained-model.ts +209 -0
  122. package/scripts/train-and-test.ts +824 -0
  123. package/scripts/verify-final.ts +118 -0
  124. package/src/adapter.ts +516 -0
  125. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  126. package/src/archetypes/derive-archetype.ts +249 -0
  127. package/src/archetypes/index.ts +22 -0
  128. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  129. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  130. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  131. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  132. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  133. package/src/benchmark/BenchmarkRunner.ts +685 -0
  134. package/src/benchmark/BenchmarkValidator.ts +206 -0
  135. package/src/benchmark/FastEvalRunner.ts +225 -0
  136. package/src/benchmark/MetricsValidator.ts +165 -0
  137. package/src/benchmark/MetricsVisualizer.ts +909 -0
  138. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  139. package/src/benchmark/ModelRegistry.ts +158 -0
  140. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  141. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  142. package/src/benchmark/SimulationEngine.ts +832 -0
  143. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  144. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  145. package/src/benchmark/index.ts +89 -0
  146. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  147. package/src/benchmark/simulation-types.ts +78 -0
  148. package/src/dependencies.ts +439 -0
  149. package/src/generation/TrajectoryGenerator.ts +387 -0
  150. package/src/generation/index.ts +12 -0
  151. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  152. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  153. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  154. package/src/huggingface/index.ts +27 -0
  155. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  156. package/src/index.ts +102 -0
  157. package/src/init-training.ts +53 -0
  158. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  159. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  160. package/src/metrics/index.ts +8 -0
  161. package/src/metrics/types.ts +200 -0
  162. package/src/rubrics/__tests__/index.test.ts +184 -0
  163. package/src/rubrics/ass-kisser.ts +85 -0
  164. package/src/rubrics/degen.ts +80 -0
  165. package/src/rubrics/goody-twoshoes.ts +84 -0
  166. package/src/rubrics/index.ts +236 -0
  167. package/src/rubrics/information-trader.ts +84 -0
  168. package/src/rubrics/infosec.ts +101 -0
  169. package/src/rubrics/liar.ts +104 -0
  170. package/src/rubrics/perps-trader.ts +87 -0
  171. package/src/rubrics/researcher.ts +81 -0
  172. package/src/rubrics/scammer.ts +82 -0
  173. package/src/rubrics/social-butterfly.ts +73 -0
  174. package/src/rubrics/super-predictor.ts +97 -0
  175. package/src/rubrics/trader.ts +67 -0
  176. package/src/scoring/ArchetypeScoringService.ts +486 -0
  177. package/src/scoring/JudgePromptBuilder.ts +556 -0
  178. package/src/scoring/LLMJudgeCache.ts +401 -0
  179. package/src/scoring/index.ts +9 -0
  180. package/src/training/AutomationPipeline.ts +916 -0
  181. package/src/training/BenchmarkService.ts +518 -0
  182. package/src/training/ConfigValidator.ts +220 -0
  183. package/src/training/MarketOutcomesTracker.ts +187 -0
  184. package/src/training/ModelDeployer.ts +186 -0
  185. package/src/training/ModelFetcher.ts +76 -0
  186. package/src/training/ModelSelectionService.ts +341 -0
  187. package/src/training/ModelUsageVerifier.ts +160 -0
  188. package/src/training/MultiModelOrchestrator.ts +580 -0
  189. package/src/training/RLModelConfig.ts +407 -0
  190. package/src/training/RewardBackpropagationService.ts +149 -0
  191. package/src/training/RulerScoringService.ts +666 -0
  192. package/src/training/TrainingMonitor.ts +166 -0
  193. package/src/training/TrajectoryRecorder.ts +399 -0
  194. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  195. package/src/training/index.ts +100 -0
  196. package/src/training/logRLConfig.ts +34 -0
  197. package/src/training/pipeline.ts +129 -0
  198. package/src/training/storage/ModelStorageService.ts +279 -0
  199. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  200. package/src/training/storage/index.ts +17 -0
  201. package/src/training/types.ts +207 -0
  202. package/src/training/window-utils.ts +138 -0
  203. package/src/utils/index.ts +101 -0
  204. package/src/utils/logger.ts +59 -0
  205. package/src/utils/snowflake.ts +17 -0
  206. package/src/utils/synthetic-detector.ts +111 -0
  207. package/tsconfig.json +20 -0
@@ -0,0 +1,1072 @@
1
+ """
2
+ Scenario Pool for Online GRPO Training
3
+
4
+ Manages scenario sampling for online rollout generation.
5
+
6
+ Sources:
7
+ 1. Production snapshots - Real market states from database
8
+ 2. Synthetic scenarios - Generated for curriculum learning
9
+ 3. Edge cases - Hand-crafted for robustness testing
10
+
11
+ Features:
12
+ - Curriculum learning with difficulty tracking
13
+ - Archetype-specific scenario generation
14
+ - Periodic refresh from production data
15
+ - Serializable state for checkpointing
16
+ """
17
+
18
+ import json
19
+ import logging
20
+ import random
21
+ from dataclasses import dataclass, field
22
+ from datetime import datetime, timezone
23
+ from pathlib import Path
24
+ from typing import Dict, List, Literal, Optional, Set
25
+ from uuid import uuid4
26
+
27
+ import numpy as np
28
+ from pydantic import BaseModel, Field
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ # =============================================================================
34
+ # Scenario Data Structures
35
+ # =============================================================================
36
+
37
+
38
+ @dataclass
39
+ class MarketState:
40
+ """State of a single market"""
41
+ market_id: str
42
+ question: str
43
+ yes_price: float
44
+ no_price: float
45
+ volume_24h: float
46
+ liquidity: float
47
+ expires_at: int # Unix timestamp ms
48
+ category: str = "general"
49
+ status: str = "active"
50
+
51
+ def to_dict(self) -> Dict:
52
+ return {
53
+ "id": self.market_id,
54
+ "question": self.question,
55
+ "yesPrice": self.yes_price,
56
+ "noPrice": self.no_price,
57
+ "volume24h": self.volume_24h,
58
+ "liquidity": self.liquidity,
59
+ "expiresAt": self.expires_at,
60
+ "category": self.category,
61
+ "status": self.status,
62
+ }
63
+
64
+
65
+ @dataclass
66
+ class PerpetualState:
67
+ """State of a perpetual market"""
68
+ ticker: str
69
+ mark_price: float
70
+ index_price: float
71
+ funding_rate: float
72
+ open_interest: float
73
+ volume_24h: float
74
+ change_24h: float
75
+ high_24h: float
76
+ low_24h: float
77
+
78
+ def to_dict(self) -> Dict:
79
+ return {
80
+ "ticker": self.ticker,
81
+ "markPrice": self.mark_price,
82
+ "indexPrice": self.index_price,
83
+ "fundingRate": self.funding_rate,
84
+ "openInterest": self.open_interest,
85
+ "volume24h": self.volume_24h,
86
+ "change24h": self.change_24h,
87
+ "high24h": self.high_24h,
88
+ "low24h": self.low_24h,
89
+ }
90
+
91
+
92
+ @dataclass
93
+ class NewsItem:
94
+ """A news item in the scenario"""
95
+ headline: str
96
+ sentiment: Literal["bullish", "bearish", "neutral"]
97
+ impact: Literal["high", "medium", "low"]
98
+ source: str
99
+ timestamp: int
100
+ relevance_score: float = 1.0
101
+
102
+ def to_dict(self) -> Dict:
103
+ return {
104
+ "headline": self.headline,
105
+ "sentiment": self.sentiment,
106
+ "impact": self.impact,
107
+ "source": self.source,
108
+ "timestamp": self.timestamp,
109
+ "relevanceScore": self.relevance_score,
110
+ }
111
+
112
+
113
+ @dataclass
114
+ class SocialPost:
115
+ """A social post in the scenario"""
116
+ author: str
117
+ content: str
118
+ sentiment: Literal["bullish", "bearish", "neutral"]
119
+ likes: int
120
+ replies: int
121
+ timestamp: int
122
+ verified: bool = False
123
+
124
+ def to_dict(self) -> Dict:
125
+ return {
126
+ "author": self.author,
127
+ "content": self.content,
128
+ "sentiment": self.sentiment,
129
+ "likes": self.likes,
130
+ "replies": self.replies,
131
+ "timestamp": self.timestamp,
132
+ "verified": self.verified,
133
+ }
134
+
135
+
136
+ @dataclass
137
+ class PortfolioState:
138
+ """Agent's starting portfolio"""
139
+ balance: float
140
+ positions: List[Dict] = field(default_factory=list)
141
+ total_pnl: float = 0.0
142
+
143
+ def to_dict(self) -> Dict:
144
+ return {
145
+ "balance": self.balance,
146
+ "positions": self.positions,
147
+ "totalPnL": self.total_pnl,
148
+ }
149
+
150
+
151
+ @dataclass
152
+ class Scenario:
153
+ """
154
+ Complete scenario for agent rollout.
155
+
156
+ Contains all information an agent needs to make decisions:
157
+ - Market state (prediction markets, perpetuals)
158
+ - Information sources (news, social)
159
+ - Agent's portfolio
160
+ - Metadata for curriculum
161
+ """
162
+ id: str
163
+ source: Literal["production", "synthetic", "edge_case"]
164
+
165
+ # Market data
166
+ markets: List[MarketState] = field(default_factory=list)
167
+ perpetuals: List[PerpetualState] = field(default_factory=list)
168
+
169
+ # Information sources
170
+ news: List[NewsItem] = field(default_factory=list)
171
+ social_posts: List[SocialPost] = field(default_factory=list)
172
+
173
+ # Agent state
174
+ portfolio: PortfolioState = field(default_factory=lambda: PortfolioState(balance=10000.0))
175
+
176
+ # Metadata
177
+ archetype_focus: Optional[str] = None
178
+ difficulty: Literal["easy", "medium", "hard"] = "medium"
179
+ timestamp: int = field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
180
+
181
+ # Ground truth for evaluation (optional)
182
+ ground_truth: Optional[Dict] = None
183
+
184
+ # Extensible metadata for runtime data (e.g., bridge scenario reference)
185
+ metadata: Dict = field(default_factory=dict)
186
+
187
+ def add_market(self, market_dict: Dict) -> None:
188
+ """Add a prediction market from dict data"""
189
+ self.markets.append(MarketState(
190
+ market_id=market_dict.get("id", f"market-{len(self.markets)}"),
191
+ question=market_dict.get("question", "Unknown"),
192
+ yes_price=market_dict.get("yesPrice", 0.5),
193
+ no_price=market_dict.get("noPrice", 0.5),
194
+ volume_24h=market_dict.get("volume24h", 0),
195
+ liquidity=market_dict.get("liquidity", 0),
196
+ expires_at=market_dict.get("expiresAt", 0),
197
+ category=market_dict.get("category", "general"),
198
+ ))
199
+
200
+ def add_perpetual(self, perp_dict: Dict) -> None:
201
+ """Add a perpetual market from dict data"""
202
+ self.perpetuals.append(PerpetualState(
203
+ ticker=perp_dict.get("ticker", "UNKNOWN"),
204
+ mark_price=perp_dict.get("markPrice", 0),
205
+ index_price=perp_dict.get("indexPrice", perp_dict.get("markPrice", 0)),
206
+ funding_rate=perp_dict.get("fundingRate", 0),
207
+ open_interest=perp_dict.get("openInterest", 0),
208
+ volume_24h=perp_dict.get("volume24h", 0),
209
+ change_24h=perp_dict.get("change24h", 0),
210
+ high_24h=perp_dict.get("high24h", 0),
211
+ low_24h=perp_dict.get("low24h", 0),
212
+ ))
213
+
214
+ def add_news(self, news_dict: Dict) -> None:
215
+ """Add a news item from dict data"""
216
+ # Map sentiment value to allowed literals
217
+ sentiment_raw = news_dict.get("sentiment", "neutral")
218
+ if isinstance(sentiment_raw, (int, float)):
219
+ sentiment = "bullish" if sentiment_raw > 0 else "bearish" if sentiment_raw < 0 else "neutral"
220
+ else:
221
+ sentiment = sentiment_raw if sentiment_raw in ("bullish", "bearish", "neutral") else "neutral"
222
+
223
+ self.news.append(NewsItem(
224
+ headline=news_dict.get("headline", news_dict.get("content", "")[:100]),
225
+ sentiment=sentiment,
226
+ impact=news_dict.get("impact", "medium"),
227
+ source=news_dict.get("source", "Unknown"),
228
+ timestamp=news_dict.get("timestamp", int(datetime.now(timezone.utc).timestamp() * 1000)),
229
+ ))
230
+
231
+ def to_dict(self) -> Dict:
232
+ return {
233
+ "id": self.id,
234
+ "source": self.source,
235
+ "markets": [m.to_dict() for m in self.markets],
236
+ "perpetuals": [p.to_dict() for p in self.perpetuals],
237
+ "news": [n.to_dict() for n in self.news],
238
+ "socialPosts": [s.to_dict() for s in self.social_posts],
239
+ "portfolio": self.portfolio.to_dict(),
240
+ "archetypeFocus": self.archetype_focus,
241
+ "difficulty": self.difficulty,
242
+ "timestamp": self.timestamp,
243
+ "groundTruth": self.ground_truth,
244
+ }
245
+
246
+ def to_observation(self) -> Dict:
247
+ """
248
+ Convert to agent observation format.
249
+
250
+ This is what the agent sees as context.
251
+ """
252
+ return {
253
+ "timestamp": self.timestamp,
254
+ "markets": [m.to_dict() for m in self.markets],
255
+ "perpetuals": [p.to_dict() for p in self.perpetuals],
256
+ "news": [n.to_dict() for n in self.news],
257
+ "socialFeed": [s.to_dict() for s in self.social_posts],
258
+ "portfolio": self.portfolio.to_dict(),
259
+ "marketSummary": {
260
+ "totalMarkets": len(self.markets),
261
+ "totalPerpetuals": len(self.perpetuals),
262
+ "avgSentiment": self._calculate_avg_sentiment(),
263
+ },
264
+ }
265
+
266
+ def _calculate_avg_sentiment(self) -> str:
267
+ """Calculate average sentiment from news and social"""
268
+ sentiment_scores = {"bullish": 1, "neutral": 0, "bearish": -1}
269
+ scores = []
270
+
271
+ for item in self.news:
272
+ scores.append(sentiment_scores.get(item.sentiment, 0))
273
+ for post in self.social_posts:
274
+ scores.append(sentiment_scores.get(post.sentiment, 0))
275
+
276
+ if not scores:
277
+ return "neutral"
278
+
279
+ avg = sum(scores) / len(scores)
280
+ if avg > 0.3:
281
+ return "bullish"
282
+ elif avg < -0.3:
283
+ return "bearish"
284
+ return "neutral"
285
+
286
+
287
+ # =============================================================================
288
+ # Curriculum Manager
289
+ # =============================================================================
290
+
291
+
292
+ class CurriculumState(BaseModel):
293
+ """Serializable curriculum state"""
294
+ attempts: Dict[str, int] = Field(default_factory=dict)
295
+ scores: Dict[str, List[float]] = Field(default_factory=dict)
296
+ solved: List[str] = Field(default_factory=list)
297
+ last_updated: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
298
+
299
+
300
+ class CurriculumManager:
301
+ """
302
+ Adaptive curriculum for scenario selection.
303
+
304
+ Tracks:
305
+ - Per-scenario attempt counts
306
+ - Per-scenario score history
307
+ - Solved/unsolved status
308
+
309
+ Prioritizes:
310
+ - Unsolved scenarios
311
+ - Difficult scenarios (low avg score)
312
+ - Underexplored scenarios
313
+ """
314
+
315
+ def __init__(
316
+ self,
317
+ checkpoint_path: Optional[str] = None,
318
+ solve_threshold: float = 0.8,
319
+ min_attempts_for_solved: int = 3,
320
+ max_avg_for_skip: float = 0.85,
321
+ max_history_per_scenario: int = 10,
322
+ ):
323
+ self.checkpoint_path = Path(checkpoint_path) if checkpoint_path else None
324
+ self.solve_threshold = solve_threshold
325
+ self.min_attempts_for_solved = min_attempts_for_solved
326
+ self.max_avg_for_skip = max_avg_for_skip
327
+ self.max_history_per_scenario = max_history_per_scenario
328
+
329
+ # State
330
+ self.attempts: Dict[str, int] = {}
331
+ self.scores: Dict[str, List[float]] = {}
332
+ self.solved: Set[str] = set()
333
+
334
+ self._load_checkpoint()
335
+
336
+ def _load_checkpoint(self) -> None:
337
+ """Load curriculum state from checkpoint"""
338
+ if not self.checkpoint_path or not self.checkpoint_path.exists():
339
+ return
340
+
341
+ try:
342
+ with open(self.checkpoint_path) as f:
343
+ state = CurriculumState.model_validate_json(f.read())
344
+ self.attempts = state.attempts
345
+ self.scores = state.scores
346
+ self.solved = set(state.solved)
347
+ logger.info(f"Loaded curriculum state: {len(self.solved)} solved scenarios")
348
+ except Exception as e:
349
+ logger.warning(f"Failed to load curriculum checkpoint: {e}")
350
+
351
+ def _save_checkpoint(self) -> None:
352
+ """Save curriculum state to checkpoint"""
353
+ if not self.checkpoint_path:
354
+ return
355
+
356
+ try:
357
+ self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
358
+ state = CurriculumState(
359
+ attempts=self.attempts,
360
+ scores=self.scores,
361
+ solved=list(self.solved),
362
+ )
363
+ with open(self.checkpoint_path, "w") as f:
364
+ f.write(state.model_dump_json(indent=2))
365
+ except Exception as e:
366
+ logger.warning(f"Failed to save curriculum checkpoint: {e}")
367
+
368
+ def record_attempt(self, scenario_id: str, score: float) -> None:
369
+ """Record an attempt on a scenario"""
370
+ self.attempts[scenario_id] = self.attempts.get(scenario_id, 0) + 1
371
+
372
+ if scenario_id not in self.scores:
373
+ self.scores[scenario_id] = []
374
+
375
+ self.scores[scenario_id].append(score)
376
+
377
+ # Trim history
378
+ if len(self.scores[scenario_id]) > self.max_history_per_scenario:
379
+ self.scores[scenario_id] = self.scores[scenario_id][-self.max_history_per_scenario:]
380
+
381
+ # Check if solved
382
+ recent = self.scores[scenario_id][-self.min_attempts_for_solved:]
383
+ if len(recent) >= self.min_attempts_for_solved:
384
+ avg = sum(recent) / len(recent)
385
+ if avg >= self.solve_threshold:
386
+ self.solved.add(scenario_id)
387
+ logger.debug(f"Scenario {scenario_id} marked as solved (avg: {avg:.2f})")
388
+
389
+ self._save_checkpoint()
390
+
391
+ def should_skip(self, scenario_id: str) -> bool:
392
+ """Check if scenario should be skipped (too easy)"""
393
+ if scenario_id in self.solved:
394
+ return True
395
+
396
+ scores = self.scores.get(scenario_id, [])
397
+ if len(scores) < 2:
398
+ return False
399
+
400
+ recent = scores[-3:]
401
+ avg = sum(recent) / len(recent)
402
+ return avg > self.max_avg_for_skip
403
+
404
+ def get_priority(self, scenario_id: str) -> float:
405
+ """
406
+ Get priority score for scenario (higher = more priority).
407
+
408
+ Combines:
409
+ - Difficulty (low scores = high priority)
410
+ - Exploration bonus (few attempts = high priority)
411
+ """
412
+ if scenario_id in self.solved:
413
+ return 0.0
414
+
415
+ scores = self.scores.get(scenario_id, [])
416
+ attempts = self.attempts.get(scenario_id, 0)
417
+
418
+ # Difficulty priority: lower scores = higher priority
419
+ if scores:
420
+ avg = sum(scores) / len(scores)
421
+ difficulty_priority = 1.0 - avg
422
+ else:
423
+ difficulty_priority = 1.0 # Unexplored = high priority
424
+
425
+ # Exploration bonus: fewer attempts = higher priority
426
+ exploration_bonus = 1.0 / (1.0 + attempts)
427
+
428
+ return difficulty_priority + exploration_bonus * 0.5
429
+
430
+ def reset(self) -> None:
431
+ """Reset curriculum (all scenarios unsolved)"""
432
+ self.solved.clear()
433
+ logger.info("Curriculum reset: all scenarios marked unsolved")
434
+ self._save_checkpoint()
435
+
436
+ def get_stats(self) -> Dict:
437
+ """Get curriculum statistics"""
438
+ total_scenarios = len(self.attempts)
439
+ total_attempts = sum(self.attempts.values())
440
+
441
+ all_scores = [s for scores in self.scores.values() for s in scores]
442
+ avg_score = sum(all_scores) / len(all_scores) if all_scores else 0.0
443
+
444
+ return {
445
+ "total_scenarios": total_scenarios,
446
+ "solved_scenarios": len(self.solved),
447
+ "total_attempts": total_attempts,
448
+ "avg_score": avg_score,
449
+ "solve_rate": len(self.solved) / total_scenarios if total_scenarios > 0 else 0.0,
450
+ }
451
+
452
+
453
+ # =============================================================================
454
+ # Scenario Pool
455
+ # =============================================================================
456
+
457
+
458
+ class ScenarioPoolConfig(BaseModel):
459
+ """Configuration for scenario pool"""
460
+
461
+ # Pool size
462
+ max_scenarios: int = Field(default=500, description="Maximum scenarios to keep in pool")
463
+ min_scenarios: int = Field(default=50, description="Minimum scenarios before refresh")
464
+
465
+ # Refresh settings
466
+ refresh_interval: int = Field(default=1000, description="Refresh every N samples")
467
+ production_ratio: float = Field(default=0.6, description="Ratio of production vs synthetic")
468
+
469
+ # Curriculum settings
470
+ use_curriculum: bool = Field(default=True, description="Enable curriculum learning")
471
+ curriculum_checkpoint_path: str = Field(
472
+ default="./curriculum_state.json",
473
+ description="Path to curriculum state checkpoint",
474
+ )
475
+
476
+ # Generation settings
477
+ synthetic_difficulty_distribution: Dict[str, float] = Field(
478
+ default_factory=lambda: {"easy": 0.3, "medium": 0.5, "hard": 0.2},
479
+ description="Distribution of synthetic scenario difficulties",
480
+ )
481
+
482
+
483
+ class ScenarioPool:
484
+ """
485
+ Manages scenario sampling for online rollouts.
486
+
487
+ Features:
488
+ - Load production snapshots from database
489
+ - Generate synthetic scenarios
490
+ - Curriculum-aware sampling
491
+ - Automatic refresh
492
+ """
493
+
494
+ def __init__(
495
+ self,
496
+ config: ScenarioPoolConfig,
497
+ database_url: Optional[str] = None,
498
+ ):
499
+ self.config = config
500
+ self.database_url = database_url
501
+
502
+ self.scenarios: List[Scenario] = []
503
+ self._sample_counter = 0
504
+
505
+ # Curriculum manager
506
+ self.curriculum = CurriculumManager(
507
+ checkpoint_path=config.curriculum_checkpoint_path if config.use_curriculum else None,
508
+ ) if config.use_curriculum else None
509
+
510
+ async def initialize(self) -> None:
511
+ """Initialize scenario pool"""
512
+ logger.info("Initializing scenario pool...")
513
+
514
+ # Load production scenarios if database available
515
+ if self.database_url:
516
+ production_count = int(self.config.max_scenarios * self.config.production_ratio)
517
+ await self.load_production_snapshots(limit=production_count)
518
+
519
+ # Fill remaining with synthetic
520
+ remaining = self.config.max_scenarios - len(self.scenarios)
521
+ if remaining > 0:
522
+ synthetic = self.generate_synthetic_batch(count=remaining)
523
+ self.scenarios.extend(synthetic)
524
+
525
+ logger.info(f"Scenario pool initialized with {len(self.scenarios)} scenarios")
526
+
527
+ async def load_production_snapshots(
528
+ self,
529
+ limit: int = 200,
530
+ min_quality: float = 0.5,
531
+ ) -> None:
532
+ """
533
+ Load high-quality scenarios from production games.
534
+
535
+ Extracts market states from recent game windows.
536
+ """
537
+ if not self.database_url:
538
+ logger.warning("No database URL configured, skipping production snapshots")
539
+ return
540
+
541
+ try:
542
+ import asyncpg
543
+ except ImportError:
544
+ logger.warning("asyncpg not installed, skipping production snapshots")
545
+ return
546
+
547
+ try:
548
+ pool = await asyncpg.create_pool(
549
+ self.database_url,
550
+ min_size=1,
551
+ max_size=5,
552
+ command_timeout=30,
553
+ )
554
+ except Exception as e:
555
+ logger.warning(f"Failed to connect to database: {e}")
556
+ return
557
+
558
+ try:
559
+ async with pool.acquire() as conn:
560
+ # Query recent game states with market data
561
+ rows = await conn.fetch("""
562
+ SELECT
563
+ w.id as window_id,
564
+ w."startTime" as start_time,
565
+ w."endTime" as end_time,
566
+ m.id as market_id,
567
+ m.question,
568
+ m."yesPrice",
569
+ m."noPrice",
570
+ m."totalVolume" as volume,
571
+ m.category
572
+ FROM "GameWindow" w
573
+ JOIN "Question" m ON m."gameWindowId" = w.id
574
+ WHERE w."createdAt" > NOW() - INTERVAL '7 days'
575
+ AND m.status = 'active'
576
+ ORDER BY w."createdAt" DESC
577
+ LIMIT $1
578
+ """, limit * 5) # Get more rows to group into scenarios
579
+
580
+ # Group by window
581
+ windows: Dict[str, List[Dict]] = {}
582
+ for row in rows:
583
+ window_id = str(row["window_id"])
584
+ if window_id not in windows:
585
+ windows[window_id] = []
586
+ windows[window_id].append(dict(row))
587
+
588
+ # Create scenarios from windows
589
+ for window_id, market_rows in list(windows.items())[:limit]:
590
+ markets = []
591
+ for row in market_rows[:10]: # Max 10 markets per scenario
592
+ markets.append(MarketState(
593
+ market_id=str(row.get("market_id", uuid4())),
594
+ question=row.get("question", "Unknown question"),
595
+ yes_price=float(row.get("yesPrice", 0.5)),
596
+ no_price=1.0 - float(row.get("yesPrice", 0.5)),
597
+ volume_24h=float(row.get("volume", 0)),
598
+ liquidity=float(row.get("volume", 0)) * 10,
599
+ expires_at=int(datetime.now(timezone.utc).timestamp() * 1000) + 86400000,
600
+ category=row.get("category", "general"),
601
+ ))
602
+
603
+ if markets:
604
+ scenario = Scenario(
605
+ id=f"prod-{window_id}",
606
+ source="production",
607
+ markets=markets,
608
+ perpetuals=self._generate_default_perpetuals(),
609
+ news=self._generate_contextual_news(markets),
610
+ social_posts=self._generate_contextual_posts(markets),
611
+ difficulty="medium",
612
+ )
613
+ self.scenarios.append(scenario)
614
+
615
+ logger.info(f"Loaded {len(windows)} production scenarios")
616
+
617
+ except Exception as e:
618
+ logger.warning(f"Error loading production snapshots: {e}")
619
+ finally:
620
+ await pool.close()
621
+
622
+ def generate_synthetic_batch(
623
+ self,
624
+ count: int,
625
+ archetype_focus: Optional[str] = None,
626
+ ) -> List[Scenario]:
627
+ """Generate batch of synthetic scenarios"""
628
+ scenarios = []
629
+
630
+ # Distribute by difficulty
631
+ difficulties = []
632
+ for diff, ratio in self.config.synthetic_difficulty_distribution.items():
633
+ difficulties.extend([diff] * int(count * ratio))
634
+
635
+ # Fill remaining
636
+ while len(difficulties) < count:
637
+ difficulties.append("medium")
638
+
639
+ random.shuffle(difficulties)
640
+
641
+ for i, difficulty in enumerate(difficulties[:count]):
642
+ scenario = self._generate_synthetic_scenario(
643
+ difficulty=difficulty,
644
+ archetype_focus=archetype_focus,
645
+ )
646
+ scenarios.append(scenario)
647
+
648
+ return scenarios
649
+
650
+ def _generate_synthetic_scenario(
651
+ self,
652
+ difficulty: Literal["easy", "medium", "hard"] = "medium",
653
+ archetype_focus: Optional[str] = None,
654
+ ) -> Scenario:
655
+ """Generate a single synthetic scenario"""
656
+ scenario_id = f"synth-{uuid4().hex[:8]}"
657
+
658
+ # Generate markets based on difficulty
659
+ num_markets = {"easy": 3, "medium": 5, "hard": 8}[difficulty]
660
+ markets = [self._generate_random_market(i) for i in range(num_markets)]
661
+
662
+ # Generate perpetuals
663
+ perpetuals = self._generate_default_perpetuals()
664
+
665
+ # Generate news and posts
666
+ num_news = {"easy": 2, "medium": 5, "hard": 8}[difficulty]
667
+ news = self._generate_random_news(num_news, difficulty)
668
+
669
+ num_posts = {"easy": 3, "medium": 6, "hard": 10}[difficulty]
670
+ posts = self._generate_random_posts(num_posts)
671
+
672
+ # Starting balance based on difficulty
673
+ balance = {"easy": 15000, "medium": 10000, "hard": 5000}[difficulty]
674
+
675
+ return Scenario(
676
+ id=scenario_id,
677
+ source="synthetic",
678
+ markets=markets,
679
+ perpetuals=perpetuals,
680
+ news=news,
681
+ social_posts=posts,
682
+ portfolio=PortfolioState(balance=float(balance)),
683
+ archetype_focus=archetype_focus,
684
+ difficulty=difficulty,
685
+ )
686
+
687
+ def _generate_random_market(self, index: int) -> MarketState:
688
+ """Generate a random prediction market"""
689
+ templates = [
690
+ ("Will BTC exceed ${price}K by end of {period}?", "crypto"),
691
+ ("Will ETH outperform BTC this {period}?", "crypto"),
692
+ ("Will the Fed announce rate {action}?", "macro"),
693
+ ("Will {company} stock reach new ATH?", "stocks"),
694
+ ("Will total crypto market cap exceed ${cap}T?", "crypto"),
695
+ ("Will {coin} flip {coin2} in market cap?", "crypto"),
696
+ ("Will inflation be above {rate}% next month?", "macro"),
697
+ ]
698
+
699
+ template, category = random.choice(templates)
700
+ question = template.format(
701
+ price=random.choice([100, 120, 150, 200]),
702
+ period=random.choice(["week", "month", "quarter"]),
703
+ action=random.choice(["cuts", "hikes", "pause"]),
704
+ company=random.choice(["NVIDIA", "Apple", "Microsoft", "Tesla"]),
705
+ cap=random.choice([3, 4, 5]),
706
+ coin=random.choice(["SOL", "DOGE", "AVAX"]),
707
+ coin2=random.choice(["ETH", "BNB"]),
708
+ rate=random.choice([2, 3, 4]),
709
+ )
710
+
711
+ yes_price = random.uniform(0.2, 0.8)
712
+
713
+ return MarketState(
714
+ market_id=f"market-{index + 1}",
715
+ question=question,
716
+ yes_price=round(yes_price, 2),
717
+ no_price=round(1 - yes_price, 2),
718
+ volume_24h=float(random.randint(10000, 500000)),
719
+ liquidity=float(random.randint(50000, 1000000)),
720
+ expires_at=int(datetime.now(timezone.utc).timestamp() * 1000) + random.randint(86400000, 604800000),
721
+ category=category,
722
+ )
723
+
724
+ def _generate_default_perpetuals(self) -> List[PerpetualState]:
725
+ """Generate default perpetual markets"""
726
+ tickers = ["BTC", "ETH", "SOL", "DOGE", "AVAX"]
727
+ base_prices = {"BTC": 100000, "ETH": 3500, "SOL": 180, "DOGE": 0.35, "AVAX": 40}
728
+
729
+ perpetuals = []
730
+ for ticker in tickers:
731
+ base = base_prices.get(ticker, 100)
732
+ price = base * (1 + random.uniform(-0.05, 0.05))
733
+
734
+ perpetuals.append(PerpetualState(
735
+ ticker=ticker,
736
+ mark_price=round(price, 2),
737
+ index_price=round(price * (1 + random.uniform(-0.001, 0.001)), 2),
738
+ funding_rate=round(random.uniform(-0.001, 0.001), 6),
739
+ open_interest=float(random.randint(1000000, 50000000)),
740
+ volume_24h=float(random.randint(5000000, 100000000)),
741
+ change_24h=round(random.uniform(-0.1, 0.1), 4),
742
+ high_24h=round(price * 1.05, 2),
743
+ low_24h=round(price * 0.95, 2),
744
+ ))
745
+
746
+ return perpetuals
747
+
748
+ def _generate_random_news(
749
+ self,
750
+ count: int,
751
+ difficulty: str,
752
+ ) -> List[NewsItem]:
753
+ """Generate random news items"""
754
+ templates = [
755
+ ("Bitcoin Approaches Key Resistance Level at ${price}K", "bullish", "high"),
756
+ ("Federal Reserve Hints at {action} Shift in Policy", "neutral", "high"),
757
+ ("Major Exchange Reports Record {metric} Volume", "bullish", "medium"),
758
+ ("Regulatory Clarity Expected Next {period}", "neutral", "medium"),
759
+ ("Whale Alert: Large {direction} Transfer Detected", "bearish", "low"),
760
+ ("New DeFi Protocol Launches with ${tvl}M TVL", "bullish", "low"),
761
+ ("Mining Difficulty Reaches New {direction}", "neutral", "low"),
762
+ ("Institutional Investors {action} Crypto Holdings", "bullish", "high"),
763
+ ("Market Analysis: Technical Indicators Show {signal}", "neutral", "medium"),
764
+ ("Breaking: {entity} Announces Crypto {action}", "bullish", "high"),
765
+ ]
766
+
767
+ news = []
768
+ sources = ["CoinDesk", "Bloomberg Crypto", "Reuters", "CryptoNews", "The Block"]
769
+
770
+ selected = random.sample(templates, min(count, len(templates)))
771
+ for headline_template, sentiment, impact in selected:
772
+ headline = headline_template.format(
773
+ price=random.choice([100, 120, 150]),
774
+ action=random.choice(["Bullish", "Dovish", "Cautious"]),
775
+ metric=random.choice(["Trading", "Spot", "Derivatives"]),
776
+ period=random.choice(["Month", "Quarter"]),
777
+ direction=random.choice(["Buy", "Sell", "High", "Low"]),
778
+ tvl=random.randint(10, 100),
779
+ signal=random.choice(["Bullish Breakout", "Consolidation", "Bearish Divergence"]),
780
+ entity=random.choice(["BlackRock", "Fidelity", "Goldman Sachs"]),
781
+ )
782
+
783
+ # Harder scenarios have more conflicting signals
784
+ if difficulty == "hard" and random.random() > 0.5:
785
+ sentiment = random.choice(["bullish", "bearish", "neutral"])
786
+
787
+ news.append(NewsItem(
788
+ headline=headline,
789
+ sentiment=sentiment,
790
+ impact=impact,
791
+ source=random.choice(sources),
792
+ timestamp=int(datetime.now(timezone.utc).timestamp() * 1000) - random.randint(0, 3600000),
793
+ relevance_score=random.uniform(0.5, 1.0),
794
+ ))
795
+
796
+ return news
797
+
798
+ def _generate_random_posts(self, count: int) -> List[SocialPost]:
799
+ """Generate random social posts"""
800
+ templates = [
801
+ ("Just went long on {ticker}, looking {outlook} 🚀", "bullish"),
802
+ ("Taking profits here, market looks overextended", "bearish"),
803
+ ("Anyone else seeing this pattern on the {period} chart?", "neutral"),
804
+ ("New ATH incoming, calling it now 💎🙌", "bullish"),
805
+ ("Be careful, volume is declining significantly", "bearish"),
806
+ ("Great entry opportunity if you missed the dip", "bullish"),
807
+ ("Liquidation cascade might be coming, stay safe", "bearish"),
808
+ ("{ticker} breaking out of the descending wedge!", "bullish"),
809
+ ("Funding rates getting extreme, reversal soon?", "neutral"),
810
+ ("This is the dip you've been waiting for", "bullish"),
811
+ ]
812
+
813
+ posts = []
814
+ for i in range(count):
815
+ template, sentiment = random.choice(templates)
816
+ content = template.format(
817
+ ticker=random.choice(["BTC", "ETH", "SOL"]),
818
+ outlook=random.choice(["bullish", "strong", "good"]),
819
+ period=random.choice(["4H", "1D", "Weekly"]),
820
+ )
821
+
822
+ posts.append(SocialPost(
823
+ author=f"trader_{random.randint(100, 999)}",
824
+ content=content,
825
+ sentiment=sentiment,
826
+ likes=random.randint(0, 500),
827
+ replies=random.randint(0, 50),
828
+ timestamp=int(datetime.now(timezone.utc).timestamp() * 1000) - random.randint(0, 1800000),
829
+ verified=random.random() > 0.7,
830
+ ))
831
+
832
+ return posts
833
+
834
+ def _generate_contextual_news(self, markets: List[MarketState]) -> List[NewsItem]:
835
+ """Generate news relevant to the markets"""
836
+ news = []
837
+
838
+ for market in markets[:3]:
839
+ # Extract key terms from question
840
+ question_lower = market.question.lower()
841
+
842
+ if "btc" in question_lower or "bitcoin" in question_lower:
843
+ news.append(NewsItem(
844
+ headline=f"Bitcoin Technical Analysis: Key Levels to Watch",
845
+ sentiment=random.choice(["bullish", "neutral"]),
846
+ impact="medium",
847
+ source="CryptoNews",
848
+ timestamp=int(datetime.now(timezone.utc).timestamp() * 1000) - random.randint(0, 3600000),
849
+ ))
850
+ elif "eth" in question_lower or "ethereum" in question_lower:
851
+ news.append(NewsItem(
852
+ headline="Ethereum Network Activity Surges to New Highs",
853
+ sentiment="bullish",
854
+ impact="medium",
855
+ source="The Block",
856
+ timestamp=int(datetime.now(timezone.utc).timestamp() * 1000) - random.randint(0, 3600000),
857
+ ))
858
+ elif "fed" in question_lower or "rate" in question_lower:
859
+ news.append(NewsItem(
860
+ headline="Fed Officials Signal Patience on Rate Decisions",
861
+ sentiment="neutral",
862
+ impact="high",
863
+ source="Bloomberg",
864
+ timestamp=int(datetime.now(timezone.utc).timestamp() * 1000) - random.randint(0, 3600000),
865
+ ))
866
+
867
+ # Add some generic news
868
+ generic_news = self._generate_random_news(3, "medium")
869
+ news.extend(generic_news)
870
+
871
+ return news
872
+
873
+ def _generate_contextual_posts(self, markets: List[MarketState]) -> List[SocialPost]:
874
+ """Generate social posts relevant to the markets"""
875
+ posts = []
876
+
877
+ for market in markets[:2]:
878
+ question_lower = market.question.lower()
879
+
880
+ if market.yes_price > 0.6:
881
+ sentiment = "bullish"
882
+ content = f"Market is pricing in high probability - {market.question[:50]}..."
883
+ elif market.yes_price < 0.4:
884
+ sentiment = "bearish"
885
+ content = f"Looks unlikely based on current odds - {market.question[:50]}..."
886
+ else:
887
+ sentiment = "neutral"
888
+ content = f"This one could go either way - {market.question[:50]}..."
889
+
890
+ posts.append(SocialPost(
891
+ author=f"analyst_{random.randint(1, 100)}",
892
+ content=content,
893
+ sentiment=sentiment,
894
+ likes=random.randint(10, 100),
895
+ replies=random.randint(1, 20),
896
+ timestamp=int(datetime.now(timezone.utc).timestamp() * 1000) - random.randint(0, 1800000),
897
+ verified=True,
898
+ ))
899
+
900
+ # Add generic posts
901
+ generic_posts = self._generate_random_posts(4)
902
+ posts.extend(generic_posts)
903
+
904
+ return posts
905
+
906
+ def sample(self, count: int = 1) -> List[Scenario]:
907
+ """
908
+ Sample scenarios respecting curriculum.
909
+
910
+ Uses priority-weighted sampling when curriculum is enabled.
911
+ """
912
+ self._sample_counter += count
913
+
914
+ # Check if refresh needed
915
+ if self._sample_counter >= self.config.refresh_interval:
916
+ self._sample_counter = 0
917
+ logger.info("Refresh interval reached, regenerating synthetic scenarios")
918
+ # Keep production, regenerate synthetic
919
+ production = [s for s in self.scenarios if s.source == "production"]
920
+ synthetic_count = self.config.max_scenarios - len(production)
921
+ synthetic = self.generate_synthetic_batch(synthetic_count)
922
+ self.scenarios = production + synthetic
923
+
924
+ if not self.scenarios:
925
+ return []
926
+
927
+ if self.curriculum:
928
+ # Filter out scenarios that should be skipped
929
+ available = [s for s in self.scenarios if not self.curriculum.should_skip(s.id)]
930
+
931
+ if not available:
932
+ # All solved, reset curriculum
933
+ logger.info("All scenarios solved, resetting curriculum")
934
+ self.curriculum.reset()
935
+ available = self.scenarios
936
+
937
+ # Calculate priorities
938
+ priorities = [self.curriculum.get_priority(s.id) for s in available]
939
+
940
+ # Normalize to probabilities
941
+ total = sum(priorities)
942
+ if total == 0:
943
+ probs = [1.0 / len(available)] * len(available)
944
+ else:
945
+ probs = [p / total for p in priorities]
946
+
947
+ # Sample with replacement if count > available
948
+ indices = np.random.choice(
949
+ len(available),
950
+ size=min(count, len(available)),
951
+ replace=False,
952
+ p=probs,
953
+ )
954
+
955
+ return [available[i] for i in indices]
956
+ else:
957
+ # Simple random sampling
958
+ return random.sample(self.scenarios, min(count, len(self.scenarios)))
959
+
960
+ def record_results(
961
+ self,
962
+ scenario_ids: List[str],
963
+ scores: List[float],
964
+ ) -> None:
965
+ """Record training results for curriculum updates"""
966
+ if not self.curriculum:
967
+ return
968
+
969
+ for scenario_id, score in zip(scenario_ids, scores):
970
+ self.curriculum.record_attempt(scenario_id, score)
971
+
972
+ def get_stats(self) -> Dict:
973
+ """Get pool statistics"""
974
+ stats = {
975
+ "total_scenarios": len(self.scenarios),
976
+ "production_scenarios": len([s for s in self.scenarios if s.source == "production"]),
977
+ "synthetic_scenarios": len([s for s in self.scenarios if s.source == "synthetic"]),
978
+ "samples_since_refresh": self._sample_counter,
979
+ }
980
+
981
+ if self.curriculum:
982
+ stats["curriculum"] = self.curriculum.get_stats()
983
+
984
+ return stats
985
+
986
+ def save_scenarios(self, path: str) -> None:
987
+ """Save scenarios to JSON file"""
988
+ output_path = Path(path)
989
+ output_path.parent.mkdir(parents=True, exist_ok=True)
990
+
991
+ data = [s.to_dict() for s in self.scenarios]
992
+
993
+ with open(output_path, "w") as f:
994
+ json.dump(data, f, indent=2)
995
+
996
+ logger.info(f"Saved {len(data)} scenarios to {path}")
997
+
998
+ def load_scenarios(self, path: str) -> None:
999
+ """Load scenarios from JSON file"""
1000
+ with open(path) as f:
1001
+ data = json.load(f)
1002
+
1003
+ # Clear existing
1004
+ self.scenarios.clear()
1005
+
1006
+ for item in data:
1007
+ # Reconstruct scenario from dict
1008
+ markets = [MarketState(
1009
+ market_id=m["id"],
1010
+ question=m["question"],
1011
+ yes_price=m["yesPrice"],
1012
+ no_price=m["noPrice"],
1013
+ volume_24h=m["volume24h"],
1014
+ liquidity=m["liquidity"],
1015
+ expires_at=m["expiresAt"],
1016
+ category=m.get("category", "general"),
1017
+ ) for m in item.get("markets", [])]
1018
+
1019
+ perpetuals = [PerpetualState(
1020
+ ticker=p["ticker"],
1021
+ mark_price=p["markPrice"],
1022
+ index_price=p["indexPrice"],
1023
+ funding_rate=p["fundingRate"],
1024
+ open_interest=p["openInterest"],
1025
+ volume_24h=p["volume24h"],
1026
+ change_24h=p["change24h"],
1027
+ high_24h=p["high24h"],
1028
+ low_24h=p["low24h"],
1029
+ ) for p in item.get("perpetuals", [])]
1030
+
1031
+ news = [NewsItem(
1032
+ headline=n["headline"],
1033
+ sentiment=n["sentiment"],
1034
+ impact=n["impact"],
1035
+ source=n["source"],
1036
+ timestamp=n["timestamp"],
1037
+ ) for n in item.get("news", [])]
1038
+
1039
+ posts = [SocialPost(
1040
+ author=p["author"],
1041
+ content=p["content"],
1042
+ sentiment=p["sentiment"],
1043
+ likes=p["likes"],
1044
+ replies=p["replies"],
1045
+ timestamp=p["timestamp"],
1046
+ verified=p.get("verified", False),
1047
+ ) for p in item.get("socialPosts", [])]
1048
+
1049
+ portfolio_data = item.get("portfolio", {})
1050
+ portfolio = PortfolioState(
1051
+ balance=portfolio_data.get("balance", 10000.0),
1052
+ positions=portfolio_data.get("positions", []),
1053
+ total_pnl=portfolio_data.get("totalPnL", 0.0),
1054
+ )
1055
+
1056
+ scenario = Scenario(
1057
+ id=item["id"],
1058
+ source=item["source"],
1059
+ markets=markets,
1060
+ perpetuals=perpetuals,
1061
+ news=news,
1062
+ social_posts=posts,
1063
+ portfolio=portfolio,
1064
+ archetype_focus=item.get("archetypeFocus"),
1065
+ difficulty=item.get("difficulty", "medium"),
1066
+ timestamp=item.get("timestamp", int(datetime.now(timezone.utc).timestamp() * 1000)),
1067
+ ground_truth=item.get("groundTruth"),
1068
+ )
1069
+ self.scenarios.append(scenario)
1070
+
1071
+ logger.info(f"Loaded {len(self.scenarios)} scenarios from {path}")
1072
+