@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,1072 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Scenario Pool for Online GRPO Training
|
|
3
|
+
|
|
4
|
+
Manages scenario sampling for online rollout generation.
|
|
5
|
+
|
|
6
|
+
Sources:
|
|
7
|
+
1. Production snapshots - Real market states from database
|
|
8
|
+
2. Synthetic scenarios - Generated for curriculum learning
|
|
9
|
+
3. Edge cases - Hand-crafted for robustness testing
|
|
10
|
+
|
|
11
|
+
Features:
|
|
12
|
+
- Curriculum learning with difficulty tracking
|
|
13
|
+
- Archetype-specific scenario generation
|
|
14
|
+
- Periodic refresh from production data
|
|
15
|
+
- Serializable state for checkpointing
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import json
|
|
19
|
+
import logging
|
|
20
|
+
import random
|
|
21
|
+
from dataclasses import dataclass, field
|
|
22
|
+
from datetime import datetime, timezone
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from typing import Dict, List, Literal, Optional, Set
|
|
25
|
+
from uuid import uuid4
|
|
26
|
+
|
|
27
|
+
import numpy as np
|
|
28
|
+
from pydantic import BaseModel, Field
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# =============================================================================
|
|
34
|
+
# Scenario Data Structures
|
|
35
|
+
# =============================================================================
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class MarketState:
|
|
40
|
+
"""State of a single market"""
|
|
41
|
+
market_id: str
|
|
42
|
+
question: str
|
|
43
|
+
yes_price: float
|
|
44
|
+
no_price: float
|
|
45
|
+
volume_24h: float
|
|
46
|
+
liquidity: float
|
|
47
|
+
expires_at: int # Unix timestamp ms
|
|
48
|
+
category: str = "general"
|
|
49
|
+
status: str = "active"
|
|
50
|
+
|
|
51
|
+
def to_dict(self) -> Dict:
|
|
52
|
+
return {
|
|
53
|
+
"id": self.market_id,
|
|
54
|
+
"question": self.question,
|
|
55
|
+
"yesPrice": self.yes_price,
|
|
56
|
+
"noPrice": self.no_price,
|
|
57
|
+
"volume24h": self.volume_24h,
|
|
58
|
+
"liquidity": self.liquidity,
|
|
59
|
+
"expiresAt": self.expires_at,
|
|
60
|
+
"category": self.category,
|
|
61
|
+
"status": self.status,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class PerpetualState:
|
|
67
|
+
"""State of a perpetual market"""
|
|
68
|
+
ticker: str
|
|
69
|
+
mark_price: float
|
|
70
|
+
index_price: float
|
|
71
|
+
funding_rate: float
|
|
72
|
+
open_interest: float
|
|
73
|
+
volume_24h: float
|
|
74
|
+
change_24h: float
|
|
75
|
+
high_24h: float
|
|
76
|
+
low_24h: float
|
|
77
|
+
|
|
78
|
+
def to_dict(self) -> Dict:
|
|
79
|
+
return {
|
|
80
|
+
"ticker": self.ticker,
|
|
81
|
+
"markPrice": self.mark_price,
|
|
82
|
+
"indexPrice": self.index_price,
|
|
83
|
+
"fundingRate": self.funding_rate,
|
|
84
|
+
"openInterest": self.open_interest,
|
|
85
|
+
"volume24h": self.volume_24h,
|
|
86
|
+
"change24h": self.change_24h,
|
|
87
|
+
"high24h": self.high_24h,
|
|
88
|
+
"low24h": self.low_24h,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class NewsItem:
|
|
94
|
+
"""A news item in the scenario"""
|
|
95
|
+
headline: str
|
|
96
|
+
sentiment: Literal["bullish", "bearish", "neutral"]
|
|
97
|
+
impact: Literal["high", "medium", "low"]
|
|
98
|
+
source: str
|
|
99
|
+
timestamp: int
|
|
100
|
+
relevance_score: float = 1.0
|
|
101
|
+
|
|
102
|
+
def to_dict(self) -> Dict:
|
|
103
|
+
return {
|
|
104
|
+
"headline": self.headline,
|
|
105
|
+
"sentiment": self.sentiment,
|
|
106
|
+
"impact": self.impact,
|
|
107
|
+
"source": self.source,
|
|
108
|
+
"timestamp": self.timestamp,
|
|
109
|
+
"relevanceScore": self.relevance_score,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass
|
|
114
|
+
class SocialPost:
|
|
115
|
+
"""A social post in the scenario"""
|
|
116
|
+
author: str
|
|
117
|
+
content: str
|
|
118
|
+
sentiment: Literal["bullish", "bearish", "neutral"]
|
|
119
|
+
likes: int
|
|
120
|
+
replies: int
|
|
121
|
+
timestamp: int
|
|
122
|
+
verified: bool = False
|
|
123
|
+
|
|
124
|
+
def to_dict(self) -> Dict:
|
|
125
|
+
return {
|
|
126
|
+
"author": self.author,
|
|
127
|
+
"content": self.content,
|
|
128
|
+
"sentiment": self.sentiment,
|
|
129
|
+
"likes": self.likes,
|
|
130
|
+
"replies": self.replies,
|
|
131
|
+
"timestamp": self.timestamp,
|
|
132
|
+
"verified": self.verified,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@dataclass
|
|
137
|
+
class PortfolioState:
|
|
138
|
+
"""Agent's starting portfolio"""
|
|
139
|
+
balance: float
|
|
140
|
+
positions: List[Dict] = field(default_factory=list)
|
|
141
|
+
total_pnl: float = 0.0
|
|
142
|
+
|
|
143
|
+
def to_dict(self) -> Dict:
|
|
144
|
+
return {
|
|
145
|
+
"balance": self.balance,
|
|
146
|
+
"positions": self.positions,
|
|
147
|
+
"totalPnL": self.total_pnl,
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@dataclass
|
|
152
|
+
class Scenario:
|
|
153
|
+
"""
|
|
154
|
+
Complete scenario for agent rollout.
|
|
155
|
+
|
|
156
|
+
Contains all information an agent needs to make decisions:
|
|
157
|
+
- Market state (prediction markets, perpetuals)
|
|
158
|
+
- Information sources (news, social)
|
|
159
|
+
- Agent's portfolio
|
|
160
|
+
- Metadata for curriculum
|
|
161
|
+
"""
|
|
162
|
+
id: str
|
|
163
|
+
source: Literal["production", "synthetic", "edge_case"]
|
|
164
|
+
|
|
165
|
+
# Market data
|
|
166
|
+
markets: List[MarketState] = field(default_factory=list)
|
|
167
|
+
perpetuals: List[PerpetualState] = field(default_factory=list)
|
|
168
|
+
|
|
169
|
+
# Information sources
|
|
170
|
+
news: List[NewsItem] = field(default_factory=list)
|
|
171
|
+
social_posts: List[SocialPost] = field(default_factory=list)
|
|
172
|
+
|
|
173
|
+
# Agent state
|
|
174
|
+
portfolio: PortfolioState = field(default_factory=lambda: PortfolioState(balance=10000.0))
|
|
175
|
+
|
|
176
|
+
# Metadata
|
|
177
|
+
archetype_focus: Optional[str] = None
|
|
178
|
+
difficulty: Literal["easy", "medium", "hard"] = "medium"
|
|
179
|
+
timestamp: int = field(default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000))
|
|
180
|
+
|
|
181
|
+
# Ground truth for evaluation (optional)
|
|
182
|
+
ground_truth: Optional[Dict] = None
|
|
183
|
+
|
|
184
|
+
# Extensible metadata for runtime data (e.g., bridge scenario reference)
|
|
185
|
+
metadata: Dict = field(default_factory=dict)
|
|
186
|
+
|
|
187
|
+
def add_market(self, market_dict: Dict) -> None:
|
|
188
|
+
"""Add a prediction market from dict data"""
|
|
189
|
+
self.markets.append(MarketState(
|
|
190
|
+
market_id=market_dict.get("id", f"market-{len(self.markets)}"),
|
|
191
|
+
question=market_dict.get("question", "Unknown"),
|
|
192
|
+
yes_price=market_dict.get("yesPrice", 0.5),
|
|
193
|
+
no_price=market_dict.get("noPrice", 0.5),
|
|
194
|
+
volume_24h=market_dict.get("volume24h", 0),
|
|
195
|
+
liquidity=market_dict.get("liquidity", 0),
|
|
196
|
+
expires_at=market_dict.get("expiresAt", 0),
|
|
197
|
+
category=market_dict.get("category", "general"),
|
|
198
|
+
))
|
|
199
|
+
|
|
200
|
+
def add_perpetual(self, perp_dict: Dict) -> None:
|
|
201
|
+
"""Add a perpetual market from dict data"""
|
|
202
|
+
self.perpetuals.append(PerpetualState(
|
|
203
|
+
ticker=perp_dict.get("ticker", "UNKNOWN"),
|
|
204
|
+
mark_price=perp_dict.get("markPrice", 0),
|
|
205
|
+
index_price=perp_dict.get("indexPrice", perp_dict.get("markPrice", 0)),
|
|
206
|
+
funding_rate=perp_dict.get("fundingRate", 0),
|
|
207
|
+
open_interest=perp_dict.get("openInterest", 0),
|
|
208
|
+
volume_24h=perp_dict.get("volume24h", 0),
|
|
209
|
+
change_24h=perp_dict.get("change24h", 0),
|
|
210
|
+
high_24h=perp_dict.get("high24h", 0),
|
|
211
|
+
low_24h=perp_dict.get("low24h", 0),
|
|
212
|
+
))
|
|
213
|
+
|
|
214
|
+
def add_news(self, news_dict: Dict) -> None:
|
|
215
|
+
"""Add a news item from dict data"""
|
|
216
|
+
# Map sentiment value to allowed literals
|
|
217
|
+
sentiment_raw = news_dict.get("sentiment", "neutral")
|
|
218
|
+
if isinstance(sentiment_raw, (int, float)):
|
|
219
|
+
sentiment = "bullish" if sentiment_raw > 0 else "bearish" if sentiment_raw < 0 else "neutral"
|
|
220
|
+
else:
|
|
221
|
+
sentiment = sentiment_raw if sentiment_raw in ("bullish", "bearish", "neutral") else "neutral"
|
|
222
|
+
|
|
223
|
+
self.news.append(NewsItem(
|
|
224
|
+
headline=news_dict.get("headline", news_dict.get("content", "")[:100]),
|
|
225
|
+
sentiment=sentiment,
|
|
226
|
+
impact=news_dict.get("impact", "medium"),
|
|
227
|
+
source=news_dict.get("source", "Unknown"),
|
|
228
|
+
timestamp=news_dict.get("timestamp", int(datetime.now(timezone.utc).timestamp() * 1000)),
|
|
229
|
+
))
|
|
230
|
+
|
|
231
|
+
def to_dict(self) -> Dict:
|
|
232
|
+
return {
|
|
233
|
+
"id": self.id,
|
|
234
|
+
"source": self.source,
|
|
235
|
+
"markets": [m.to_dict() for m in self.markets],
|
|
236
|
+
"perpetuals": [p.to_dict() for p in self.perpetuals],
|
|
237
|
+
"news": [n.to_dict() for n in self.news],
|
|
238
|
+
"socialPosts": [s.to_dict() for s in self.social_posts],
|
|
239
|
+
"portfolio": self.portfolio.to_dict(),
|
|
240
|
+
"archetypeFocus": self.archetype_focus,
|
|
241
|
+
"difficulty": self.difficulty,
|
|
242
|
+
"timestamp": self.timestamp,
|
|
243
|
+
"groundTruth": self.ground_truth,
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
def to_observation(self) -> Dict:
|
|
247
|
+
"""
|
|
248
|
+
Convert to agent observation format.
|
|
249
|
+
|
|
250
|
+
This is what the agent sees as context.
|
|
251
|
+
"""
|
|
252
|
+
return {
|
|
253
|
+
"timestamp": self.timestamp,
|
|
254
|
+
"markets": [m.to_dict() for m in self.markets],
|
|
255
|
+
"perpetuals": [p.to_dict() for p in self.perpetuals],
|
|
256
|
+
"news": [n.to_dict() for n in self.news],
|
|
257
|
+
"socialFeed": [s.to_dict() for s in self.social_posts],
|
|
258
|
+
"portfolio": self.portfolio.to_dict(),
|
|
259
|
+
"marketSummary": {
|
|
260
|
+
"totalMarkets": len(self.markets),
|
|
261
|
+
"totalPerpetuals": len(self.perpetuals),
|
|
262
|
+
"avgSentiment": self._calculate_avg_sentiment(),
|
|
263
|
+
},
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
def _calculate_avg_sentiment(self) -> str:
|
|
267
|
+
"""Calculate average sentiment from news and social"""
|
|
268
|
+
sentiment_scores = {"bullish": 1, "neutral": 0, "bearish": -1}
|
|
269
|
+
scores = []
|
|
270
|
+
|
|
271
|
+
for item in self.news:
|
|
272
|
+
scores.append(sentiment_scores.get(item.sentiment, 0))
|
|
273
|
+
for post in self.social_posts:
|
|
274
|
+
scores.append(sentiment_scores.get(post.sentiment, 0))
|
|
275
|
+
|
|
276
|
+
if not scores:
|
|
277
|
+
return "neutral"
|
|
278
|
+
|
|
279
|
+
avg = sum(scores) / len(scores)
|
|
280
|
+
if avg > 0.3:
|
|
281
|
+
return "bullish"
|
|
282
|
+
elif avg < -0.3:
|
|
283
|
+
return "bearish"
|
|
284
|
+
return "neutral"
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
# =============================================================================
|
|
288
|
+
# Curriculum Manager
|
|
289
|
+
# =============================================================================
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class CurriculumState(BaseModel):
|
|
293
|
+
"""Serializable curriculum state"""
|
|
294
|
+
attempts: Dict[str, int] = Field(default_factory=dict)
|
|
295
|
+
scores: Dict[str, List[float]] = Field(default_factory=dict)
|
|
296
|
+
solved: List[str] = Field(default_factory=list)
|
|
297
|
+
last_updated: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class CurriculumManager:
|
|
301
|
+
"""
|
|
302
|
+
Adaptive curriculum for scenario selection.
|
|
303
|
+
|
|
304
|
+
Tracks:
|
|
305
|
+
- Per-scenario attempt counts
|
|
306
|
+
- Per-scenario score history
|
|
307
|
+
- Solved/unsolved status
|
|
308
|
+
|
|
309
|
+
Prioritizes:
|
|
310
|
+
- Unsolved scenarios
|
|
311
|
+
- Difficult scenarios (low avg score)
|
|
312
|
+
- Underexplored scenarios
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
def __init__(
|
|
316
|
+
self,
|
|
317
|
+
checkpoint_path: Optional[str] = None,
|
|
318
|
+
solve_threshold: float = 0.8,
|
|
319
|
+
min_attempts_for_solved: int = 3,
|
|
320
|
+
max_avg_for_skip: float = 0.85,
|
|
321
|
+
max_history_per_scenario: int = 10,
|
|
322
|
+
):
|
|
323
|
+
self.checkpoint_path = Path(checkpoint_path) if checkpoint_path else None
|
|
324
|
+
self.solve_threshold = solve_threshold
|
|
325
|
+
self.min_attempts_for_solved = min_attempts_for_solved
|
|
326
|
+
self.max_avg_for_skip = max_avg_for_skip
|
|
327
|
+
self.max_history_per_scenario = max_history_per_scenario
|
|
328
|
+
|
|
329
|
+
# State
|
|
330
|
+
self.attempts: Dict[str, int] = {}
|
|
331
|
+
self.scores: Dict[str, List[float]] = {}
|
|
332
|
+
self.solved: Set[str] = set()
|
|
333
|
+
|
|
334
|
+
self._load_checkpoint()
|
|
335
|
+
|
|
336
|
+
def _load_checkpoint(self) -> None:
|
|
337
|
+
"""Load curriculum state from checkpoint"""
|
|
338
|
+
if not self.checkpoint_path or not self.checkpoint_path.exists():
|
|
339
|
+
return
|
|
340
|
+
|
|
341
|
+
try:
|
|
342
|
+
with open(self.checkpoint_path) as f:
|
|
343
|
+
state = CurriculumState.model_validate_json(f.read())
|
|
344
|
+
self.attempts = state.attempts
|
|
345
|
+
self.scores = state.scores
|
|
346
|
+
self.solved = set(state.solved)
|
|
347
|
+
logger.info(f"Loaded curriculum state: {len(self.solved)} solved scenarios")
|
|
348
|
+
except Exception as e:
|
|
349
|
+
logger.warning(f"Failed to load curriculum checkpoint: {e}")
|
|
350
|
+
|
|
351
|
+
def _save_checkpoint(self) -> None:
|
|
352
|
+
"""Save curriculum state to checkpoint"""
|
|
353
|
+
if not self.checkpoint_path:
|
|
354
|
+
return
|
|
355
|
+
|
|
356
|
+
try:
|
|
357
|
+
self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
|
358
|
+
state = CurriculumState(
|
|
359
|
+
attempts=self.attempts,
|
|
360
|
+
scores=self.scores,
|
|
361
|
+
solved=list(self.solved),
|
|
362
|
+
)
|
|
363
|
+
with open(self.checkpoint_path, "w") as f:
|
|
364
|
+
f.write(state.model_dump_json(indent=2))
|
|
365
|
+
except Exception as e:
|
|
366
|
+
logger.warning(f"Failed to save curriculum checkpoint: {e}")
|
|
367
|
+
|
|
368
|
+
def record_attempt(self, scenario_id: str, score: float) -> None:
|
|
369
|
+
"""Record an attempt on a scenario"""
|
|
370
|
+
self.attempts[scenario_id] = self.attempts.get(scenario_id, 0) + 1
|
|
371
|
+
|
|
372
|
+
if scenario_id not in self.scores:
|
|
373
|
+
self.scores[scenario_id] = []
|
|
374
|
+
|
|
375
|
+
self.scores[scenario_id].append(score)
|
|
376
|
+
|
|
377
|
+
# Trim history
|
|
378
|
+
if len(self.scores[scenario_id]) > self.max_history_per_scenario:
|
|
379
|
+
self.scores[scenario_id] = self.scores[scenario_id][-self.max_history_per_scenario:]
|
|
380
|
+
|
|
381
|
+
# Check if solved
|
|
382
|
+
recent = self.scores[scenario_id][-self.min_attempts_for_solved:]
|
|
383
|
+
if len(recent) >= self.min_attempts_for_solved:
|
|
384
|
+
avg = sum(recent) / len(recent)
|
|
385
|
+
if avg >= self.solve_threshold:
|
|
386
|
+
self.solved.add(scenario_id)
|
|
387
|
+
logger.debug(f"Scenario {scenario_id} marked as solved (avg: {avg:.2f})")
|
|
388
|
+
|
|
389
|
+
self._save_checkpoint()
|
|
390
|
+
|
|
391
|
+
def should_skip(self, scenario_id: str) -> bool:
|
|
392
|
+
"""Check if scenario should be skipped (too easy)"""
|
|
393
|
+
if scenario_id in self.solved:
|
|
394
|
+
return True
|
|
395
|
+
|
|
396
|
+
scores = self.scores.get(scenario_id, [])
|
|
397
|
+
if len(scores) < 2:
|
|
398
|
+
return False
|
|
399
|
+
|
|
400
|
+
recent = scores[-3:]
|
|
401
|
+
avg = sum(recent) / len(recent)
|
|
402
|
+
return avg > self.max_avg_for_skip
|
|
403
|
+
|
|
404
|
+
def get_priority(self, scenario_id: str) -> float:
|
|
405
|
+
"""
|
|
406
|
+
Get priority score for scenario (higher = more priority).
|
|
407
|
+
|
|
408
|
+
Combines:
|
|
409
|
+
- Difficulty (low scores = high priority)
|
|
410
|
+
- Exploration bonus (few attempts = high priority)
|
|
411
|
+
"""
|
|
412
|
+
if scenario_id in self.solved:
|
|
413
|
+
return 0.0
|
|
414
|
+
|
|
415
|
+
scores = self.scores.get(scenario_id, [])
|
|
416
|
+
attempts = self.attempts.get(scenario_id, 0)
|
|
417
|
+
|
|
418
|
+
# Difficulty priority: lower scores = higher priority
|
|
419
|
+
if scores:
|
|
420
|
+
avg = sum(scores) / len(scores)
|
|
421
|
+
difficulty_priority = 1.0 - avg
|
|
422
|
+
else:
|
|
423
|
+
difficulty_priority = 1.0 # Unexplored = high priority
|
|
424
|
+
|
|
425
|
+
# Exploration bonus: fewer attempts = higher priority
|
|
426
|
+
exploration_bonus = 1.0 / (1.0 + attempts)
|
|
427
|
+
|
|
428
|
+
return difficulty_priority + exploration_bonus * 0.5
|
|
429
|
+
|
|
430
|
+
def reset(self) -> None:
|
|
431
|
+
"""Reset curriculum (all scenarios unsolved)"""
|
|
432
|
+
self.solved.clear()
|
|
433
|
+
logger.info("Curriculum reset: all scenarios marked unsolved")
|
|
434
|
+
self._save_checkpoint()
|
|
435
|
+
|
|
436
|
+
def get_stats(self) -> Dict:
|
|
437
|
+
"""Get curriculum statistics"""
|
|
438
|
+
total_scenarios = len(self.attempts)
|
|
439
|
+
total_attempts = sum(self.attempts.values())
|
|
440
|
+
|
|
441
|
+
all_scores = [s for scores in self.scores.values() for s in scores]
|
|
442
|
+
avg_score = sum(all_scores) / len(all_scores) if all_scores else 0.0
|
|
443
|
+
|
|
444
|
+
return {
|
|
445
|
+
"total_scenarios": total_scenarios,
|
|
446
|
+
"solved_scenarios": len(self.solved),
|
|
447
|
+
"total_attempts": total_attempts,
|
|
448
|
+
"avg_score": avg_score,
|
|
449
|
+
"solve_rate": len(self.solved) / total_scenarios if total_scenarios > 0 else 0.0,
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
# =============================================================================
|
|
454
|
+
# Scenario Pool
|
|
455
|
+
# =============================================================================
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
class ScenarioPoolConfig(BaseModel):
|
|
459
|
+
"""Configuration for scenario pool"""
|
|
460
|
+
|
|
461
|
+
# Pool size
|
|
462
|
+
max_scenarios: int = Field(default=500, description="Maximum scenarios to keep in pool")
|
|
463
|
+
min_scenarios: int = Field(default=50, description="Minimum scenarios before refresh")
|
|
464
|
+
|
|
465
|
+
# Refresh settings
|
|
466
|
+
refresh_interval: int = Field(default=1000, description="Refresh every N samples")
|
|
467
|
+
production_ratio: float = Field(default=0.6, description="Ratio of production vs synthetic")
|
|
468
|
+
|
|
469
|
+
# Curriculum settings
|
|
470
|
+
use_curriculum: bool = Field(default=True, description="Enable curriculum learning")
|
|
471
|
+
curriculum_checkpoint_path: str = Field(
|
|
472
|
+
default="./curriculum_state.json",
|
|
473
|
+
description="Path to curriculum state checkpoint",
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
# Generation settings
|
|
477
|
+
synthetic_difficulty_distribution: Dict[str, float] = Field(
|
|
478
|
+
default_factory=lambda: {"easy": 0.3, "medium": 0.5, "hard": 0.2},
|
|
479
|
+
description="Distribution of synthetic scenario difficulties",
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
class ScenarioPool:
|
|
484
|
+
"""
|
|
485
|
+
Manages scenario sampling for online rollouts.
|
|
486
|
+
|
|
487
|
+
Features:
|
|
488
|
+
- Load production snapshots from database
|
|
489
|
+
- Generate synthetic scenarios
|
|
490
|
+
- Curriculum-aware sampling
|
|
491
|
+
- Automatic refresh
|
|
492
|
+
"""
|
|
493
|
+
|
|
494
|
+
def __init__(
|
|
495
|
+
self,
|
|
496
|
+
config: ScenarioPoolConfig,
|
|
497
|
+
database_url: Optional[str] = None,
|
|
498
|
+
):
|
|
499
|
+
self.config = config
|
|
500
|
+
self.database_url = database_url
|
|
501
|
+
|
|
502
|
+
self.scenarios: List[Scenario] = []
|
|
503
|
+
self._sample_counter = 0
|
|
504
|
+
|
|
505
|
+
# Curriculum manager
|
|
506
|
+
self.curriculum = CurriculumManager(
|
|
507
|
+
checkpoint_path=config.curriculum_checkpoint_path if config.use_curriculum else None,
|
|
508
|
+
) if config.use_curriculum else None
|
|
509
|
+
|
|
510
|
+
async def initialize(self) -> None:
|
|
511
|
+
"""Initialize scenario pool"""
|
|
512
|
+
logger.info("Initializing scenario pool...")
|
|
513
|
+
|
|
514
|
+
# Load production scenarios if database available
|
|
515
|
+
if self.database_url:
|
|
516
|
+
production_count = int(self.config.max_scenarios * self.config.production_ratio)
|
|
517
|
+
await self.load_production_snapshots(limit=production_count)
|
|
518
|
+
|
|
519
|
+
# Fill remaining with synthetic
|
|
520
|
+
remaining = self.config.max_scenarios - len(self.scenarios)
|
|
521
|
+
if remaining > 0:
|
|
522
|
+
synthetic = self.generate_synthetic_batch(count=remaining)
|
|
523
|
+
self.scenarios.extend(synthetic)
|
|
524
|
+
|
|
525
|
+
logger.info(f"Scenario pool initialized with {len(self.scenarios)} scenarios")
|
|
526
|
+
|
|
527
|
+
async def load_production_snapshots(
|
|
528
|
+
self,
|
|
529
|
+
limit: int = 200,
|
|
530
|
+
min_quality: float = 0.5,
|
|
531
|
+
) -> None:
|
|
532
|
+
"""
|
|
533
|
+
Load high-quality scenarios from production games.
|
|
534
|
+
|
|
535
|
+
Extracts market states from recent game windows.
|
|
536
|
+
"""
|
|
537
|
+
if not self.database_url:
|
|
538
|
+
logger.warning("No database URL configured, skipping production snapshots")
|
|
539
|
+
return
|
|
540
|
+
|
|
541
|
+
try:
|
|
542
|
+
import asyncpg
|
|
543
|
+
except ImportError:
|
|
544
|
+
logger.warning("asyncpg not installed, skipping production snapshots")
|
|
545
|
+
return
|
|
546
|
+
|
|
547
|
+
try:
|
|
548
|
+
pool = await asyncpg.create_pool(
|
|
549
|
+
self.database_url,
|
|
550
|
+
min_size=1,
|
|
551
|
+
max_size=5,
|
|
552
|
+
command_timeout=30,
|
|
553
|
+
)
|
|
554
|
+
except Exception as e:
|
|
555
|
+
logger.warning(f"Failed to connect to database: {e}")
|
|
556
|
+
return
|
|
557
|
+
|
|
558
|
+
try:
|
|
559
|
+
async with pool.acquire() as conn:
|
|
560
|
+
# Query recent game states with market data
|
|
561
|
+
rows = await conn.fetch("""
|
|
562
|
+
SELECT
|
|
563
|
+
w.id as window_id,
|
|
564
|
+
w."startTime" as start_time,
|
|
565
|
+
w."endTime" as end_time,
|
|
566
|
+
m.id as market_id,
|
|
567
|
+
m.question,
|
|
568
|
+
m."yesPrice",
|
|
569
|
+
m."noPrice",
|
|
570
|
+
m."totalVolume" as volume,
|
|
571
|
+
m.category
|
|
572
|
+
FROM "GameWindow" w
|
|
573
|
+
JOIN "Question" m ON m."gameWindowId" = w.id
|
|
574
|
+
WHERE w."createdAt" > NOW() - INTERVAL '7 days'
|
|
575
|
+
AND m.status = 'active'
|
|
576
|
+
ORDER BY w."createdAt" DESC
|
|
577
|
+
LIMIT $1
|
|
578
|
+
""", limit * 5) # Get more rows to group into scenarios
|
|
579
|
+
|
|
580
|
+
# Group by window
|
|
581
|
+
windows: Dict[str, List[Dict]] = {}
|
|
582
|
+
for row in rows:
|
|
583
|
+
window_id = str(row["window_id"])
|
|
584
|
+
if window_id not in windows:
|
|
585
|
+
windows[window_id] = []
|
|
586
|
+
windows[window_id].append(dict(row))
|
|
587
|
+
|
|
588
|
+
# Create scenarios from windows
|
|
589
|
+
for window_id, market_rows in list(windows.items())[:limit]:
|
|
590
|
+
markets = []
|
|
591
|
+
for row in market_rows[:10]: # Max 10 markets per scenario
|
|
592
|
+
markets.append(MarketState(
|
|
593
|
+
market_id=str(row.get("market_id", uuid4())),
|
|
594
|
+
question=row.get("question", "Unknown question"),
|
|
595
|
+
yes_price=float(row.get("yesPrice", 0.5)),
|
|
596
|
+
no_price=1.0 - float(row.get("yesPrice", 0.5)),
|
|
597
|
+
volume_24h=float(row.get("volume", 0)),
|
|
598
|
+
liquidity=float(row.get("volume", 0)) * 10,
|
|
599
|
+
expires_at=int(datetime.now(timezone.utc).timestamp() * 1000) + 86400000,
|
|
600
|
+
category=row.get("category", "general"),
|
|
601
|
+
))
|
|
602
|
+
|
|
603
|
+
if markets:
|
|
604
|
+
scenario = Scenario(
|
|
605
|
+
id=f"prod-{window_id}",
|
|
606
|
+
source="production",
|
|
607
|
+
markets=markets,
|
|
608
|
+
perpetuals=self._generate_default_perpetuals(),
|
|
609
|
+
news=self._generate_contextual_news(markets),
|
|
610
|
+
social_posts=self._generate_contextual_posts(markets),
|
|
611
|
+
difficulty="medium",
|
|
612
|
+
)
|
|
613
|
+
self.scenarios.append(scenario)
|
|
614
|
+
|
|
615
|
+
logger.info(f"Loaded {len(windows)} production scenarios")
|
|
616
|
+
|
|
617
|
+
except Exception as e:
|
|
618
|
+
logger.warning(f"Error loading production snapshots: {e}")
|
|
619
|
+
finally:
|
|
620
|
+
await pool.close()
|
|
621
|
+
|
|
622
|
+
def generate_synthetic_batch(
|
|
623
|
+
self,
|
|
624
|
+
count: int,
|
|
625
|
+
archetype_focus: Optional[str] = None,
|
|
626
|
+
) -> List[Scenario]:
|
|
627
|
+
"""Generate batch of synthetic scenarios"""
|
|
628
|
+
scenarios = []
|
|
629
|
+
|
|
630
|
+
# Distribute by difficulty
|
|
631
|
+
difficulties = []
|
|
632
|
+
for diff, ratio in self.config.synthetic_difficulty_distribution.items():
|
|
633
|
+
difficulties.extend([diff] * int(count * ratio))
|
|
634
|
+
|
|
635
|
+
# Fill remaining
|
|
636
|
+
while len(difficulties) < count:
|
|
637
|
+
difficulties.append("medium")
|
|
638
|
+
|
|
639
|
+
random.shuffle(difficulties)
|
|
640
|
+
|
|
641
|
+
for i, difficulty in enumerate(difficulties[:count]):
|
|
642
|
+
scenario = self._generate_synthetic_scenario(
|
|
643
|
+
difficulty=difficulty,
|
|
644
|
+
archetype_focus=archetype_focus,
|
|
645
|
+
)
|
|
646
|
+
scenarios.append(scenario)
|
|
647
|
+
|
|
648
|
+
return scenarios
|
|
649
|
+
|
|
650
|
+
def _generate_synthetic_scenario(
|
|
651
|
+
self,
|
|
652
|
+
difficulty: Literal["easy", "medium", "hard"] = "medium",
|
|
653
|
+
archetype_focus: Optional[str] = None,
|
|
654
|
+
) -> Scenario:
|
|
655
|
+
"""Generate a single synthetic scenario"""
|
|
656
|
+
scenario_id = f"synth-{uuid4().hex[:8]}"
|
|
657
|
+
|
|
658
|
+
# Generate markets based on difficulty
|
|
659
|
+
num_markets = {"easy": 3, "medium": 5, "hard": 8}[difficulty]
|
|
660
|
+
markets = [self._generate_random_market(i) for i in range(num_markets)]
|
|
661
|
+
|
|
662
|
+
# Generate perpetuals
|
|
663
|
+
perpetuals = self._generate_default_perpetuals()
|
|
664
|
+
|
|
665
|
+
# Generate news and posts
|
|
666
|
+
num_news = {"easy": 2, "medium": 5, "hard": 8}[difficulty]
|
|
667
|
+
news = self._generate_random_news(num_news, difficulty)
|
|
668
|
+
|
|
669
|
+
num_posts = {"easy": 3, "medium": 6, "hard": 10}[difficulty]
|
|
670
|
+
posts = self._generate_random_posts(num_posts)
|
|
671
|
+
|
|
672
|
+
# Starting balance based on difficulty
|
|
673
|
+
balance = {"easy": 15000, "medium": 10000, "hard": 5000}[difficulty]
|
|
674
|
+
|
|
675
|
+
return Scenario(
|
|
676
|
+
id=scenario_id,
|
|
677
|
+
source="synthetic",
|
|
678
|
+
markets=markets,
|
|
679
|
+
perpetuals=perpetuals,
|
|
680
|
+
news=news,
|
|
681
|
+
social_posts=posts,
|
|
682
|
+
portfolio=PortfolioState(balance=float(balance)),
|
|
683
|
+
archetype_focus=archetype_focus,
|
|
684
|
+
difficulty=difficulty,
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
def _generate_random_market(self, index: int) -> MarketState:
|
|
688
|
+
"""Generate a random prediction market"""
|
|
689
|
+
templates = [
|
|
690
|
+
("Will BTC exceed ${price}K by end of {period}?", "crypto"),
|
|
691
|
+
("Will ETH outperform BTC this {period}?", "crypto"),
|
|
692
|
+
("Will the Fed announce rate {action}?", "macro"),
|
|
693
|
+
("Will {company} stock reach new ATH?", "stocks"),
|
|
694
|
+
("Will total crypto market cap exceed ${cap}T?", "crypto"),
|
|
695
|
+
("Will {coin} flip {coin2} in market cap?", "crypto"),
|
|
696
|
+
("Will inflation be above {rate}% next month?", "macro"),
|
|
697
|
+
]
|
|
698
|
+
|
|
699
|
+
template, category = random.choice(templates)
|
|
700
|
+
question = template.format(
|
|
701
|
+
price=random.choice([100, 120, 150, 200]),
|
|
702
|
+
period=random.choice(["week", "month", "quarter"]),
|
|
703
|
+
action=random.choice(["cuts", "hikes", "pause"]),
|
|
704
|
+
company=random.choice(["NVIDIA", "Apple", "Microsoft", "Tesla"]),
|
|
705
|
+
cap=random.choice([3, 4, 5]),
|
|
706
|
+
coin=random.choice(["SOL", "DOGE", "AVAX"]),
|
|
707
|
+
coin2=random.choice(["ETH", "BNB"]),
|
|
708
|
+
rate=random.choice([2, 3, 4]),
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
yes_price = random.uniform(0.2, 0.8)
|
|
712
|
+
|
|
713
|
+
return MarketState(
|
|
714
|
+
market_id=f"market-{index + 1}",
|
|
715
|
+
question=question,
|
|
716
|
+
yes_price=round(yes_price, 2),
|
|
717
|
+
no_price=round(1 - yes_price, 2),
|
|
718
|
+
volume_24h=float(random.randint(10000, 500000)),
|
|
719
|
+
liquidity=float(random.randint(50000, 1000000)),
|
|
720
|
+
expires_at=int(datetime.now(timezone.utc).timestamp() * 1000) + random.randint(86400000, 604800000),
|
|
721
|
+
category=category,
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
def _generate_default_perpetuals(self) -> List[PerpetualState]:
|
|
725
|
+
"""Generate default perpetual markets"""
|
|
726
|
+
tickers = ["BTC", "ETH", "SOL", "DOGE", "AVAX"]
|
|
727
|
+
base_prices = {"BTC": 100000, "ETH": 3500, "SOL": 180, "DOGE": 0.35, "AVAX": 40}
|
|
728
|
+
|
|
729
|
+
perpetuals = []
|
|
730
|
+
for ticker in tickers:
|
|
731
|
+
base = base_prices.get(ticker, 100)
|
|
732
|
+
price = base * (1 + random.uniform(-0.05, 0.05))
|
|
733
|
+
|
|
734
|
+
perpetuals.append(PerpetualState(
|
|
735
|
+
ticker=ticker,
|
|
736
|
+
mark_price=round(price, 2),
|
|
737
|
+
index_price=round(price * (1 + random.uniform(-0.001, 0.001)), 2),
|
|
738
|
+
funding_rate=round(random.uniform(-0.001, 0.001), 6),
|
|
739
|
+
open_interest=float(random.randint(1000000, 50000000)),
|
|
740
|
+
volume_24h=float(random.randint(5000000, 100000000)),
|
|
741
|
+
change_24h=round(random.uniform(-0.1, 0.1), 4),
|
|
742
|
+
high_24h=round(price * 1.05, 2),
|
|
743
|
+
low_24h=round(price * 0.95, 2),
|
|
744
|
+
))
|
|
745
|
+
|
|
746
|
+
return perpetuals
|
|
747
|
+
|
|
748
|
+
def _generate_random_news(
|
|
749
|
+
self,
|
|
750
|
+
count: int,
|
|
751
|
+
difficulty: str,
|
|
752
|
+
) -> List[NewsItem]:
|
|
753
|
+
"""Generate random news items"""
|
|
754
|
+
templates = [
|
|
755
|
+
("Bitcoin Approaches Key Resistance Level at ${price}K", "bullish", "high"),
|
|
756
|
+
("Federal Reserve Hints at {action} Shift in Policy", "neutral", "high"),
|
|
757
|
+
("Major Exchange Reports Record {metric} Volume", "bullish", "medium"),
|
|
758
|
+
("Regulatory Clarity Expected Next {period}", "neutral", "medium"),
|
|
759
|
+
("Whale Alert: Large {direction} Transfer Detected", "bearish", "low"),
|
|
760
|
+
("New DeFi Protocol Launches with ${tvl}M TVL", "bullish", "low"),
|
|
761
|
+
("Mining Difficulty Reaches New {direction}", "neutral", "low"),
|
|
762
|
+
("Institutional Investors {action} Crypto Holdings", "bullish", "high"),
|
|
763
|
+
("Market Analysis: Technical Indicators Show {signal}", "neutral", "medium"),
|
|
764
|
+
("Breaking: {entity} Announces Crypto {action}", "bullish", "high"),
|
|
765
|
+
]
|
|
766
|
+
|
|
767
|
+
news = []
|
|
768
|
+
sources = ["CoinDesk", "Bloomberg Crypto", "Reuters", "CryptoNews", "The Block"]
|
|
769
|
+
|
|
770
|
+
selected = random.sample(templates, min(count, len(templates)))
|
|
771
|
+
for headline_template, sentiment, impact in selected:
|
|
772
|
+
headline = headline_template.format(
|
|
773
|
+
price=random.choice([100, 120, 150]),
|
|
774
|
+
action=random.choice(["Bullish", "Dovish", "Cautious"]),
|
|
775
|
+
metric=random.choice(["Trading", "Spot", "Derivatives"]),
|
|
776
|
+
period=random.choice(["Month", "Quarter"]),
|
|
777
|
+
direction=random.choice(["Buy", "Sell", "High", "Low"]),
|
|
778
|
+
tvl=random.randint(10, 100),
|
|
779
|
+
signal=random.choice(["Bullish Breakout", "Consolidation", "Bearish Divergence"]),
|
|
780
|
+
entity=random.choice(["BlackRock", "Fidelity", "Goldman Sachs"]),
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
# Harder scenarios have more conflicting signals
|
|
784
|
+
if difficulty == "hard" and random.random() > 0.5:
|
|
785
|
+
sentiment = random.choice(["bullish", "bearish", "neutral"])
|
|
786
|
+
|
|
787
|
+
news.append(NewsItem(
|
|
788
|
+
headline=headline,
|
|
789
|
+
sentiment=sentiment,
|
|
790
|
+
impact=impact,
|
|
791
|
+
source=random.choice(sources),
|
|
792
|
+
timestamp=int(datetime.now(timezone.utc).timestamp() * 1000) - random.randint(0, 3600000),
|
|
793
|
+
relevance_score=random.uniform(0.5, 1.0),
|
|
794
|
+
))
|
|
795
|
+
|
|
796
|
+
return news
|
|
797
|
+
|
|
798
|
+
def _generate_random_posts(self, count: int) -> List[SocialPost]:
|
|
799
|
+
"""Generate random social posts"""
|
|
800
|
+
templates = [
|
|
801
|
+
("Just went long on {ticker}, looking {outlook} 🚀", "bullish"),
|
|
802
|
+
("Taking profits here, market looks overextended", "bearish"),
|
|
803
|
+
("Anyone else seeing this pattern on the {period} chart?", "neutral"),
|
|
804
|
+
("New ATH incoming, calling it now 💎🙌", "bullish"),
|
|
805
|
+
("Be careful, volume is declining significantly", "bearish"),
|
|
806
|
+
("Great entry opportunity if you missed the dip", "bullish"),
|
|
807
|
+
("Liquidation cascade might be coming, stay safe", "bearish"),
|
|
808
|
+
("{ticker} breaking out of the descending wedge!", "bullish"),
|
|
809
|
+
("Funding rates getting extreme, reversal soon?", "neutral"),
|
|
810
|
+
("This is the dip you've been waiting for", "bullish"),
|
|
811
|
+
]
|
|
812
|
+
|
|
813
|
+
posts = []
|
|
814
|
+
for i in range(count):
|
|
815
|
+
template, sentiment = random.choice(templates)
|
|
816
|
+
content = template.format(
|
|
817
|
+
ticker=random.choice(["BTC", "ETH", "SOL"]),
|
|
818
|
+
outlook=random.choice(["bullish", "strong", "good"]),
|
|
819
|
+
period=random.choice(["4H", "1D", "Weekly"]),
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
posts.append(SocialPost(
|
|
823
|
+
author=f"trader_{random.randint(100, 999)}",
|
|
824
|
+
content=content,
|
|
825
|
+
sentiment=sentiment,
|
|
826
|
+
likes=random.randint(0, 500),
|
|
827
|
+
replies=random.randint(0, 50),
|
|
828
|
+
timestamp=int(datetime.now(timezone.utc).timestamp() * 1000) - random.randint(0, 1800000),
|
|
829
|
+
verified=random.random() > 0.7,
|
|
830
|
+
))
|
|
831
|
+
|
|
832
|
+
return posts
|
|
833
|
+
|
|
834
|
+
def _generate_contextual_news(self, markets: List[MarketState]) -> List[NewsItem]:
|
|
835
|
+
"""Generate news relevant to the markets"""
|
|
836
|
+
news = []
|
|
837
|
+
|
|
838
|
+
for market in markets[:3]:
|
|
839
|
+
# Extract key terms from question
|
|
840
|
+
question_lower = market.question.lower()
|
|
841
|
+
|
|
842
|
+
if "btc" in question_lower or "bitcoin" in question_lower:
|
|
843
|
+
news.append(NewsItem(
|
|
844
|
+
headline=f"Bitcoin Technical Analysis: Key Levels to Watch",
|
|
845
|
+
sentiment=random.choice(["bullish", "neutral"]),
|
|
846
|
+
impact="medium",
|
|
847
|
+
source="CryptoNews",
|
|
848
|
+
timestamp=int(datetime.now(timezone.utc).timestamp() * 1000) - random.randint(0, 3600000),
|
|
849
|
+
))
|
|
850
|
+
elif "eth" in question_lower or "ethereum" in question_lower:
|
|
851
|
+
news.append(NewsItem(
|
|
852
|
+
headline="Ethereum Network Activity Surges to New Highs",
|
|
853
|
+
sentiment="bullish",
|
|
854
|
+
impact="medium",
|
|
855
|
+
source="The Block",
|
|
856
|
+
timestamp=int(datetime.now(timezone.utc).timestamp() * 1000) - random.randint(0, 3600000),
|
|
857
|
+
))
|
|
858
|
+
elif "fed" in question_lower or "rate" in question_lower:
|
|
859
|
+
news.append(NewsItem(
|
|
860
|
+
headline="Fed Officials Signal Patience on Rate Decisions",
|
|
861
|
+
sentiment="neutral",
|
|
862
|
+
impact="high",
|
|
863
|
+
source="Bloomberg",
|
|
864
|
+
timestamp=int(datetime.now(timezone.utc).timestamp() * 1000) - random.randint(0, 3600000),
|
|
865
|
+
))
|
|
866
|
+
|
|
867
|
+
# Add some generic news
|
|
868
|
+
generic_news = self._generate_random_news(3, "medium")
|
|
869
|
+
news.extend(generic_news)
|
|
870
|
+
|
|
871
|
+
return news
|
|
872
|
+
|
|
873
|
+
def _generate_contextual_posts(self, markets: List[MarketState]) -> List[SocialPost]:
|
|
874
|
+
"""Generate social posts relevant to the markets"""
|
|
875
|
+
posts = []
|
|
876
|
+
|
|
877
|
+
for market in markets[:2]:
|
|
878
|
+
question_lower = market.question.lower()
|
|
879
|
+
|
|
880
|
+
if market.yes_price > 0.6:
|
|
881
|
+
sentiment = "bullish"
|
|
882
|
+
content = f"Market is pricing in high probability - {market.question[:50]}..."
|
|
883
|
+
elif market.yes_price < 0.4:
|
|
884
|
+
sentiment = "bearish"
|
|
885
|
+
content = f"Looks unlikely based on current odds - {market.question[:50]}..."
|
|
886
|
+
else:
|
|
887
|
+
sentiment = "neutral"
|
|
888
|
+
content = f"This one could go either way - {market.question[:50]}..."
|
|
889
|
+
|
|
890
|
+
posts.append(SocialPost(
|
|
891
|
+
author=f"analyst_{random.randint(1, 100)}",
|
|
892
|
+
content=content,
|
|
893
|
+
sentiment=sentiment,
|
|
894
|
+
likes=random.randint(10, 100),
|
|
895
|
+
replies=random.randint(1, 20),
|
|
896
|
+
timestamp=int(datetime.now(timezone.utc).timestamp() * 1000) - random.randint(0, 1800000),
|
|
897
|
+
verified=True,
|
|
898
|
+
))
|
|
899
|
+
|
|
900
|
+
# Add generic posts
|
|
901
|
+
generic_posts = self._generate_random_posts(4)
|
|
902
|
+
posts.extend(generic_posts)
|
|
903
|
+
|
|
904
|
+
return posts
|
|
905
|
+
|
|
906
|
+
def sample(self, count: int = 1) -> List[Scenario]:
|
|
907
|
+
"""
|
|
908
|
+
Sample scenarios respecting curriculum.
|
|
909
|
+
|
|
910
|
+
Uses priority-weighted sampling when curriculum is enabled.
|
|
911
|
+
"""
|
|
912
|
+
self._sample_counter += count
|
|
913
|
+
|
|
914
|
+
# Check if refresh needed
|
|
915
|
+
if self._sample_counter >= self.config.refresh_interval:
|
|
916
|
+
self._sample_counter = 0
|
|
917
|
+
logger.info("Refresh interval reached, regenerating synthetic scenarios")
|
|
918
|
+
# Keep production, regenerate synthetic
|
|
919
|
+
production = [s for s in self.scenarios if s.source == "production"]
|
|
920
|
+
synthetic_count = self.config.max_scenarios - len(production)
|
|
921
|
+
synthetic = self.generate_synthetic_batch(synthetic_count)
|
|
922
|
+
self.scenarios = production + synthetic
|
|
923
|
+
|
|
924
|
+
if not self.scenarios:
|
|
925
|
+
return []
|
|
926
|
+
|
|
927
|
+
if self.curriculum:
|
|
928
|
+
# Filter out scenarios that should be skipped
|
|
929
|
+
available = [s for s in self.scenarios if not self.curriculum.should_skip(s.id)]
|
|
930
|
+
|
|
931
|
+
if not available:
|
|
932
|
+
# All solved, reset curriculum
|
|
933
|
+
logger.info("All scenarios solved, resetting curriculum")
|
|
934
|
+
self.curriculum.reset()
|
|
935
|
+
available = self.scenarios
|
|
936
|
+
|
|
937
|
+
# Calculate priorities
|
|
938
|
+
priorities = [self.curriculum.get_priority(s.id) for s in available]
|
|
939
|
+
|
|
940
|
+
# Normalize to probabilities
|
|
941
|
+
total = sum(priorities)
|
|
942
|
+
if total == 0:
|
|
943
|
+
probs = [1.0 / len(available)] * len(available)
|
|
944
|
+
else:
|
|
945
|
+
probs = [p / total for p in priorities]
|
|
946
|
+
|
|
947
|
+
# Sample with replacement if count > available
|
|
948
|
+
indices = np.random.choice(
|
|
949
|
+
len(available),
|
|
950
|
+
size=min(count, len(available)),
|
|
951
|
+
replace=False,
|
|
952
|
+
p=probs,
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
return [available[i] for i in indices]
|
|
956
|
+
else:
|
|
957
|
+
# Simple random sampling
|
|
958
|
+
return random.sample(self.scenarios, min(count, len(self.scenarios)))
|
|
959
|
+
|
|
960
|
+
def record_results(
|
|
961
|
+
self,
|
|
962
|
+
scenario_ids: List[str],
|
|
963
|
+
scores: List[float],
|
|
964
|
+
) -> None:
|
|
965
|
+
"""Record training results for curriculum updates"""
|
|
966
|
+
if not self.curriculum:
|
|
967
|
+
return
|
|
968
|
+
|
|
969
|
+
for scenario_id, score in zip(scenario_ids, scores):
|
|
970
|
+
self.curriculum.record_attempt(scenario_id, score)
|
|
971
|
+
|
|
972
|
+
def get_stats(self) -> Dict:
|
|
973
|
+
"""Get pool statistics"""
|
|
974
|
+
stats = {
|
|
975
|
+
"total_scenarios": len(self.scenarios),
|
|
976
|
+
"production_scenarios": len([s for s in self.scenarios if s.source == "production"]),
|
|
977
|
+
"synthetic_scenarios": len([s for s in self.scenarios if s.source == "synthetic"]),
|
|
978
|
+
"samples_since_refresh": self._sample_counter,
|
|
979
|
+
}
|
|
980
|
+
|
|
981
|
+
if self.curriculum:
|
|
982
|
+
stats["curriculum"] = self.curriculum.get_stats()
|
|
983
|
+
|
|
984
|
+
return stats
|
|
985
|
+
|
|
986
|
+
def save_scenarios(self, path: str) -> None:
|
|
987
|
+
"""Save scenarios to JSON file"""
|
|
988
|
+
output_path = Path(path)
|
|
989
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
990
|
+
|
|
991
|
+
data = [s.to_dict() for s in self.scenarios]
|
|
992
|
+
|
|
993
|
+
with open(output_path, "w") as f:
|
|
994
|
+
json.dump(data, f, indent=2)
|
|
995
|
+
|
|
996
|
+
logger.info(f"Saved {len(data)} scenarios to {path}")
|
|
997
|
+
|
|
998
|
+
def load_scenarios(self, path: str) -> None:
|
|
999
|
+
"""Load scenarios from JSON file"""
|
|
1000
|
+
with open(path) as f:
|
|
1001
|
+
data = json.load(f)
|
|
1002
|
+
|
|
1003
|
+
# Clear existing
|
|
1004
|
+
self.scenarios.clear()
|
|
1005
|
+
|
|
1006
|
+
for item in data:
|
|
1007
|
+
# Reconstruct scenario from dict
|
|
1008
|
+
markets = [MarketState(
|
|
1009
|
+
market_id=m["id"],
|
|
1010
|
+
question=m["question"],
|
|
1011
|
+
yes_price=m["yesPrice"],
|
|
1012
|
+
no_price=m["noPrice"],
|
|
1013
|
+
volume_24h=m["volume24h"],
|
|
1014
|
+
liquidity=m["liquidity"],
|
|
1015
|
+
expires_at=m["expiresAt"],
|
|
1016
|
+
category=m.get("category", "general"),
|
|
1017
|
+
) for m in item.get("markets", [])]
|
|
1018
|
+
|
|
1019
|
+
perpetuals = [PerpetualState(
|
|
1020
|
+
ticker=p["ticker"],
|
|
1021
|
+
mark_price=p["markPrice"],
|
|
1022
|
+
index_price=p["indexPrice"],
|
|
1023
|
+
funding_rate=p["fundingRate"],
|
|
1024
|
+
open_interest=p["openInterest"],
|
|
1025
|
+
volume_24h=p["volume24h"],
|
|
1026
|
+
change_24h=p["change24h"],
|
|
1027
|
+
high_24h=p["high24h"],
|
|
1028
|
+
low_24h=p["low24h"],
|
|
1029
|
+
) for p in item.get("perpetuals", [])]
|
|
1030
|
+
|
|
1031
|
+
news = [NewsItem(
|
|
1032
|
+
headline=n["headline"],
|
|
1033
|
+
sentiment=n["sentiment"],
|
|
1034
|
+
impact=n["impact"],
|
|
1035
|
+
source=n["source"],
|
|
1036
|
+
timestamp=n["timestamp"],
|
|
1037
|
+
) for n in item.get("news", [])]
|
|
1038
|
+
|
|
1039
|
+
posts = [SocialPost(
|
|
1040
|
+
author=p["author"],
|
|
1041
|
+
content=p["content"],
|
|
1042
|
+
sentiment=p["sentiment"],
|
|
1043
|
+
likes=p["likes"],
|
|
1044
|
+
replies=p["replies"],
|
|
1045
|
+
timestamp=p["timestamp"],
|
|
1046
|
+
verified=p.get("verified", False),
|
|
1047
|
+
) for p in item.get("socialPosts", [])]
|
|
1048
|
+
|
|
1049
|
+
portfolio_data = item.get("portfolio", {})
|
|
1050
|
+
portfolio = PortfolioState(
|
|
1051
|
+
balance=portfolio_data.get("balance", 10000.0),
|
|
1052
|
+
positions=portfolio_data.get("positions", []),
|
|
1053
|
+
total_pnl=portfolio_data.get("totalPnL", 0.0),
|
|
1054
|
+
)
|
|
1055
|
+
|
|
1056
|
+
scenario = Scenario(
|
|
1057
|
+
id=item["id"],
|
|
1058
|
+
source=item["source"],
|
|
1059
|
+
markets=markets,
|
|
1060
|
+
perpetuals=perpetuals,
|
|
1061
|
+
news=news,
|
|
1062
|
+
social_posts=posts,
|
|
1063
|
+
portfolio=portfolio,
|
|
1064
|
+
archetype_focus=item.get("archetypeFocus"),
|
|
1065
|
+
difficulty=item.get("difficulty", "medium"),
|
|
1066
|
+
timestamp=item.get("timestamp", int(datetime.now(timezone.utc).timestamp() * 1000)),
|
|
1067
|
+
ground_truth=item.get("groundTruth"),
|
|
1068
|
+
)
|
|
1069
|
+
self.scenarios.append(scenario)
|
|
1070
|
+
|
|
1071
|
+
logger.info(f"Loaded {len(self.scenarios)} scenarios from {path}")
|
|
1072
|
+
|