@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,522 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Babylon Hybrid Environment for GRPO Training
|
|
3
|
+
|
|
4
|
+
Combines offline (database) and online (simulation bridge) rollouts.
|
|
5
|
+
This provides the best of both worlds:
|
|
6
|
+
- Offline: Large, diverse dataset from historical trajectories
|
|
7
|
+
- Online: Fresh rollouts from current policy interacting with simulation
|
|
8
|
+
|
|
9
|
+
Usage:
|
|
10
|
+
make train-hybrid # 80% offline, 20% online by default
|
|
11
|
+
|
|
12
|
+
# Or with custom ratio
|
|
13
|
+
python scripts/run_training.py --mode hybrid --hybrid-online-ratio 0.3
|
|
14
|
+
|
|
15
|
+
The online ratio determines what fraction of rollouts come from the
|
|
16
|
+
simulation bridge vs the database.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import asyncio
|
|
20
|
+
import copy
|
|
21
|
+
import logging
|
|
22
|
+
import os
|
|
23
|
+
import random
|
|
24
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
25
|
+
|
|
26
|
+
from pydantic import Field
|
|
27
|
+
|
|
28
|
+
from atroposlib.envs.base import APIServerConfig, BaseEnv, ScoredDataGroup
|
|
29
|
+
|
|
30
|
+
from .babylon_env import BabylonEnvConfig, BabylonRLAIFEnv
|
|
31
|
+
from .online_env import BabylonOnlineEnv, BabylonOnlineEnvConfig, Scenario
|
|
32
|
+
from .simulation_bridge import SimulationBridge
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BabylonHybridEnvConfig(BabylonOnlineEnvConfig):
|
|
38
|
+
"""
|
|
39
|
+
Configuration for hybrid environment.
|
|
40
|
+
|
|
41
|
+
Inherits from BabylonOnlineEnvConfig and adds offline ratio control.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
online_ratio: float = Field(
|
|
45
|
+
default=0.2,
|
|
46
|
+
description="Ratio of rollouts from online simulation (0.0 = all offline, 1.0 = all online)"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Database settings for offline mode (same as BabylonEnvConfig)
|
|
50
|
+
db_url: Optional[str] = Field(
|
|
51
|
+
default=None,
|
|
52
|
+
description="PostgreSQL connection URL for offline trajectories"
|
|
53
|
+
)
|
|
54
|
+
trajectory_window_size: int = Field(
|
|
55
|
+
default=1000,
|
|
56
|
+
description="Number of trajectories to cache in memory"
|
|
57
|
+
)
|
|
58
|
+
min_trajectories: int = Field(
|
|
59
|
+
default=10,
|
|
60
|
+
description="Minimum trajectories required to start offline training"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class BabylonHybridEnv(BaseEnv):
|
|
65
|
+
"""
|
|
66
|
+
Hybrid environment that mixes offline and online rollouts.
|
|
67
|
+
|
|
68
|
+
Architecture:
|
|
69
|
+
- Maintains both an offline trajectory cache and online bridge connection
|
|
70
|
+
- For each get_next_item() call, randomly selects offline vs online
|
|
71
|
+
- Collects trajectories using the appropriate mode
|
|
72
|
+
- Scores and returns consistent ScoredDataGroup format
|
|
73
|
+
|
|
74
|
+
Benefits:
|
|
75
|
+
- Stability from large offline dataset
|
|
76
|
+
- Adaptability from on-policy online rollouts
|
|
77
|
+
- Smooth transition from offline to online training
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
name = "babylon_hybrid_env"
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
config: BabylonHybridEnvConfig,
|
|
85
|
+
server_configs: List[APIServerConfig],
|
|
86
|
+
slurm: bool = False,
|
|
87
|
+
testing: bool = False,
|
|
88
|
+
):
|
|
89
|
+
super().__init__(config, server_configs, slurm, testing)
|
|
90
|
+
self.config: BabylonHybridEnvConfig = config
|
|
91
|
+
self._server_configs = server_configs
|
|
92
|
+
|
|
93
|
+
# Offline components (from BabylonRLAIFEnv)
|
|
94
|
+
self.db_pool = None
|
|
95
|
+
self.trajectory_cache: List[Dict] = []
|
|
96
|
+
self.current_cache_idx: int = 0
|
|
97
|
+
|
|
98
|
+
# Online components (from BabylonOnlineEnv)
|
|
99
|
+
self.simulation_bridge: Optional[SimulationBridge] = None
|
|
100
|
+
self.scenario_pool = None
|
|
101
|
+
self._bridge_npc_index: int = 0
|
|
102
|
+
|
|
103
|
+
# Hybrid control
|
|
104
|
+
self.online_ratio = config.online_ratio
|
|
105
|
+
self.iter = 0
|
|
106
|
+
self.online_count = 0
|
|
107
|
+
self.offline_count = 0
|
|
108
|
+
|
|
109
|
+
# Tokenizer (set in setup)
|
|
110
|
+
self.tokenizer = None
|
|
111
|
+
|
|
112
|
+
logger.info(f"HybridEnv initialized with online_ratio={self.online_ratio:.0%}")
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def config_init(cls) -> Tuple[BabylonHybridEnvConfig, List[APIServerConfig]]:
|
|
116
|
+
"""Create default config"""
|
|
117
|
+
env_config = BabylonHybridEnvConfig(
|
|
118
|
+
tokenizer_name="Qwen/Qwen2.5-3B-Instruct",
|
|
119
|
+
rollout_server_url="http://localhost:8000",
|
|
120
|
+
total_steps=1000,
|
|
121
|
+
batch_size=16,
|
|
122
|
+
online_ratio=float(os.getenv("HYBRID_ONLINE_RATIO", "0.2")),
|
|
123
|
+
use_simulation_bridge=True,
|
|
124
|
+
simulation_bridge_url=os.getenv("SIMULATION_BRIDGE_URL", "http://localhost:3001"),
|
|
125
|
+
db_url=os.getenv("DATABASE_URL"),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
server_configs = [
|
|
129
|
+
APIServerConfig(
|
|
130
|
+
model_name="Qwen/Qwen2.5-3B-Instruct",
|
|
131
|
+
base_url="http://localhost:9001/v1",
|
|
132
|
+
)
|
|
133
|
+
]
|
|
134
|
+
|
|
135
|
+
return env_config, server_configs
|
|
136
|
+
|
|
137
|
+
async def setup(self):
|
|
138
|
+
"""Initialize both offline and online components"""
|
|
139
|
+
from transformers import AutoTokenizer
|
|
140
|
+
|
|
141
|
+
logger.info("Setting up hybrid environment...")
|
|
142
|
+
|
|
143
|
+
# Load tokenizer
|
|
144
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name)
|
|
145
|
+
|
|
146
|
+
# Setup offline component (database)
|
|
147
|
+
if self.config.db_url:
|
|
148
|
+
await self._setup_offline()
|
|
149
|
+
else:
|
|
150
|
+
logger.warning("No DATABASE_URL set, hybrid will only use online rollouts")
|
|
151
|
+
self.online_ratio = 1.0
|
|
152
|
+
|
|
153
|
+
# Setup online component (simulation bridge)
|
|
154
|
+
if self.config.use_simulation_bridge:
|
|
155
|
+
await self._setup_online()
|
|
156
|
+
else:
|
|
157
|
+
logger.warning("Simulation bridge disabled, hybrid will only use offline rollouts")
|
|
158
|
+
self.online_ratio = 0.0
|
|
159
|
+
|
|
160
|
+
logger.info(f"Hybrid setup complete: online_ratio={self.online_ratio:.0%}, "
|
|
161
|
+
f"offline_trajectories={len(self.trajectory_cache)}, "
|
|
162
|
+
f"bridge_npcs={len(self.simulation_bridge.npc_ids) if self.simulation_bridge else 0}")
|
|
163
|
+
|
|
164
|
+
async def _setup_offline(self):
|
|
165
|
+
"""Setup database connection and load trajectories"""
|
|
166
|
+
import asyncpg
|
|
167
|
+
|
|
168
|
+
logger.info("Connecting to database for offline trajectories...")
|
|
169
|
+
|
|
170
|
+
self.db_pool = await asyncpg.create_pool(
|
|
171
|
+
self.config.db_url,
|
|
172
|
+
min_size=2,
|
|
173
|
+
max_size=10,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Load initial trajectory window
|
|
177
|
+
await self._load_trajectory_window()
|
|
178
|
+
|
|
179
|
+
if len(self.trajectory_cache) < self.config.min_trajectories:
|
|
180
|
+
logger.warning(f"Only {len(self.trajectory_cache)} trajectories in DB, "
|
|
181
|
+
f"need {self.config.min_trajectories}")
|
|
182
|
+
|
|
183
|
+
async def _load_trajectory_window(self):
|
|
184
|
+
"""Load a window of trajectories from database"""
|
|
185
|
+
if not self.db_pool:
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
async with self.db_pool.acquire() as conn:
|
|
189
|
+
# Load trajectories with reasoning
|
|
190
|
+
rows = await conn.fetch("""
|
|
191
|
+
SELECT
|
|
192
|
+
id, archetype, scenario_context, model_response,
|
|
193
|
+
reasoning, metrics, created_at
|
|
194
|
+
FROM trajectories
|
|
195
|
+
WHERE model_response IS NOT NULL
|
|
196
|
+
ORDER BY created_at DESC
|
|
197
|
+
LIMIT $1
|
|
198
|
+
""", self.config.trajectory_window_size)
|
|
199
|
+
|
|
200
|
+
self.trajectory_cache = [dict(row) for row in rows]
|
|
201
|
+
self.current_cache_idx = 0
|
|
202
|
+
|
|
203
|
+
logger.info(f"Loaded {len(self.trajectory_cache)} trajectories from database")
|
|
204
|
+
|
|
205
|
+
async def _setup_online(self):
|
|
206
|
+
"""Setup simulation bridge connection"""
|
|
207
|
+
logger.info(f"Connecting to simulation bridge at {self.config.simulation_bridge_url}...")
|
|
208
|
+
|
|
209
|
+
self.simulation_bridge = SimulationBridge(
|
|
210
|
+
base_url=self.config.simulation_bridge_url,
|
|
211
|
+
)
|
|
212
|
+
await self.simulation_bridge.__aenter__()
|
|
213
|
+
|
|
214
|
+
# Initialize with archetypes
|
|
215
|
+
archetypes = list(self.config.archetype_distribution.keys())
|
|
216
|
+
await self.simulation_bridge.initialize(
|
|
217
|
+
num_npcs=self.config.bridge_num_npcs,
|
|
218
|
+
archetypes=archetypes,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
logger.info(f"Simulation bridge connected with {len(self.simulation_bridge.npc_ids)} NPCs")
|
|
222
|
+
|
|
223
|
+
async def get_next_item(self) -> Tuple[Any, str]:
|
|
224
|
+
"""
|
|
225
|
+
Get next item for training.
|
|
226
|
+
|
|
227
|
+
Randomly decides between offline and online based on online_ratio.
|
|
228
|
+
"""
|
|
229
|
+
self.iter += 1
|
|
230
|
+
|
|
231
|
+
# Decide online vs offline based on ratio
|
|
232
|
+
use_online = random.random() < self.online_ratio
|
|
233
|
+
|
|
234
|
+
# If online selected but not available, fall back to offline
|
|
235
|
+
if use_online and (not self.simulation_bridge or not self.simulation_bridge.is_initialized):
|
|
236
|
+
use_online = False
|
|
237
|
+
|
|
238
|
+
# If offline selected but no trajectories, use online
|
|
239
|
+
if not use_online and len(self.trajectory_cache) == 0:
|
|
240
|
+
use_online = True
|
|
241
|
+
|
|
242
|
+
if use_online:
|
|
243
|
+
self.online_count += 1
|
|
244
|
+
return await self._get_online_item()
|
|
245
|
+
else:
|
|
246
|
+
self.offline_count += 1
|
|
247
|
+
return self._get_offline_item()
|
|
248
|
+
|
|
249
|
+
async def _get_online_item(self) -> Tuple["PoolScenario", str]:
|
|
250
|
+
"""Get a scenario from simulation bridge"""
|
|
251
|
+
from .scenario_pool import Scenario as PoolScenario, PortfolioState
|
|
252
|
+
|
|
253
|
+
npc_ids = self.simulation_bridge.npc_ids
|
|
254
|
+
npc_id = npc_ids[self._bridge_npc_index % len(npc_ids)]
|
|
255
|
+
self._bridge_npc_index += 1
|
|
256
|
+
|
|
257
|
+
bridge_scenario = await self.simulation_bridge.get_scenario(npc_id)
|
|
258
|
+
archetype = bridge_scenario.archetype
|
|
259
|
+
|
|
260
|
+
# Convert to Scenario format used by scoring
|
|
261
|
+
scenario = PoolScenario(
|
|
262
|
+
id=f"bridge-{npc_id}-{self.iter}",
|
|
263
|
+
source="production",
|
|
264
|
+
archetype_focus=archetype,
|
|
265
|
+
difficulty="medium",
|
|
266
|
+
portfolio=PortfolioState(
|
|
267
|
+
balance=bridge_scenario.balance,
|
|
268
|
+
positions=[],
|
|
269
|
+
),
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Add market data from bridge
|
|
273
|
+
for m in bridge_scenario.market_state.prediction_markets:
|
|
274
|
+
scenario.add_market({
|
|
275
|
+
"id": m.id,
|
|
276
|
+
"question": m.question,
|
|
277
|
+
"yesPrice": m.yes_price,
|
|
278
|
+
"noPrice": m.no_price,
|
|
279
|
+
})
|
|
280
|
+
|
|
281
|
+
for m in bridge_scenario.market_state.perp_markets:
|
|
282
|
+
scenario.add_perpetual({
|
|
283
|
+
"ticker": m.ticker,
|
|
284
|
+
"markPrice": m.current_price,
|
|
285
|
+
"change24h": m.change_percent_24h,
|
|
286
|
+
})
|
|
287
|
+
|
|
288
|
+
# Store bridge scenario for action execution
|
|
289
|
+
scenario.metadata["bridge_scenario"] = bridge_scenario
|
|
290
|
+
scenario.metadata["npc_id"] = npc_id
|
|
291
|
+
scenario.metadata["mode"] = "online"
|
|
292
|
+
|
|
293
|
+
return (scenario, archetype)
|
|
294
|
+
|
|
295
|
+
def _get_offline_item(self) -> Tuple[Dict, str]:
|
|
296
|
+
"""Get a trajectory from cached database trajectories"""
|
|
297
|
+
if not self.trajectory_cache:
|
|
298
|
+
raise RuntimeError("No trajectories in cache")
|
|
299
|
+
|
|
300
|
+
# Round-robin through cache
|
|
301
|
+
traj = self.trajectory_cache[self.current_cache_idx]
|
|
302
|
+
self.current_cache_idx = (self.current_cache_idx + 1) % len(self.trajectory_cache)
|
|
303
|
+
|
|
304
|
+
archetype = traj.get("archetype", "trader")
|
|
305
|
+
|
|
306
|
+
# Add source metadata
|
|
307
|
+
traj_copy = copy.deepcopy(traj)
|
|
308
|
+
traj_copy["source"] = "offline"
|
|
309
|
+
|
|
310
|
+
return (traj_copy, archetype)
|
|
311
|
+
|
|
312
|
+
async def collect_trajectories(self, item: Tuple[Any, str]) -> Tuple[Optional[ScoredDataGroup], List]:
|
|
313
|
+
"""
|
|
314
|
+
Collect and score trajectories.
|
|
315
|
+
|
|
316
|
+
Delegates to appropriate handler based on item source.
|
|
317
|
+
"""
|
|
318
|
+
data, archetype = item
|
|
319
|
+
|
|
320
|
+
# Check if it's a Scenario (online) or Dict (offline)
|
|
321
|
+
if hasattr(data, "metadata") and data.metadata.get("source") == "online":
|
|
322
|
+
return await self._collect_online(data, archetype)
|
|
323
|
+
else:
|
|
324
|
+
return await self._collect_offline(data, archetype)
|
|
325
|
+
|
|
326
|
+
async def _collect_online(self, scenario: "Scenario", archetype: str) -> Tuple[Optional[ScoredDataGroup], List]:
|
|
327
|
+
"""Collect online rollouts via simulation bridge"""
|
|
328
|
+
from .online_env import build_trading_system_prompt, build_observation_prompt
|
|
329
|
+
from .quality_scorer import score_response
|
|
330
|
+
from .format_validator import validate_response_format
|
|
331
|
+
|
|
332
|
+
# Build messages
|
|
333
|
+
system_prompt = build_trading_system_prompt(archetype)
|
|
334
|
+
user_prompt = build_observation_prompt(scenario)
|
|
335
|
+
|
|
336
|
+
messages = [
|
|
337
|
+
{"role": "system", "content": system_prompt},
|
|
338
|
+
{"role": "user", "content": user_prompt},
|
|
339
|
+
]
|
|
340
|
+
|
|
341
|
+
# Generate completions using managed_server
|
|
342
|
+
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
|
343
|
+
chat_completions = await managed.chat_completion(
|
|
344
|
+
messages=messages,
|
|
345
|
+
n=self.config.group_size,
|
|
346
|
+
max_tokens=self.config.max_response_tokens,
|
|
347
|
+
temperature=self.config.temperature,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
state = managed.get_state()
|
|
351
|
+
nodes = state["nodes"]
|
|
352
|
+
|
|
353
|
+
if not nodes or len(nodes) < 2:
|
|
354
|
+
logger.warning("Insufficient nodes from managed_server")
|
|
355
|
+
return None, []
|
|
356
|
+
|
|
357
|
+
# Process and score completions
|
|
358
|
+
rollout_data = []
|
|
359
|
+
for i, choice in enumerate(chat_completions.choices):
|
|
360
|
+
if i >= len(nodes):
|
|
361
|
+
break
|
|
362
|
+
|
|
363
|
+
node = nodes[i]
|
|
364
|
+
response_content = choice.message.content or ""
|
|
365
|
+
|
|
366
|
+
# Score the response
|
|
367
|
+
quality = score_response(
|
|
368
|
+
response=response_content,
|
|
369
|
+
archetype=archetype,
|
|
370
|
+
execute_action=False,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
format_result = validate_response_format(response_content)
|
|
374
|
+
|
|
375
|
+
# Calculate final score
|
|
376
|
+
base_score = quality.combined_format_score * 0.4 + quality.reasoning_score * 0.3
|
|
377
|
+
action_bonus = 0.3 if format_result.is_valid else 0.0
|
|
378
|
+
final_score = base_score + action_bonus
|
|
379
|
+
|
|
380
|
+
rollout_data.append({
|
|
381
|
+
"tokens": node.tokens,
|
|
382
|
+
"masks": node.masked_tokens,
|
|
383
|
+
"score": final_score,
|
|
384
|
+
})
|
|
385
|
+
|
|
386
|
+
# Center scores
|
|
387
|
+
scores = [r["score"] for r in rollout_data]
|
|
388
|
+
mean_score = sum(scores) / len(scores)
|
|
389
|
+
|
|
390
|
+
# Build ScoredDataGroup
|
|
391
|
+
scored_group = ScoredDataGroup(
|
|
392
|
+
tokens=[r["tokens"] for r in rollout_data],
|
|
393
|
+
masks=[r["masks"] for r in rollout_data],
|
|
394
|
+
scores=[s - mean_score for s in scores],
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
return scored_group, []
|
|
398
|
+
|
|
399
|
+
async def _collect_offline(self, traj: Dict, archetype: str) -> Tuple[Optional[ScoredDataGroup], List]:
|
|
400
|
+
"""Collect offline rollouts from database trajectory"""
|
|
401
|
+
from .rewards import archetype_composite_reward, BehaviorMetrics
|
|
402
|
+
from .quality_scorer import score_response
|
|
403
|
+
from .format_validator import validate_response_format
|
|
404
|
+
from .tokenization_utils import tokenize_for_trainer
|
|
405
|
+
|
|
406
|
+
# Build messages from trajectory
|
|
407
|
+
scenario_context = traj.get("scenario_context", {})
|
|
408
|
+
model_response = traj.get("model_response", "")
|
|
409
|
+
|
|
410
|
+
if not model_response:
|
|
411
|
+
return None, []
|
|
412
|
+
|
|
413
|
+
# Build chat messages
|
|
414
|
+
messages = [
|
|
415
|
+
{"role": "system", "content": f"You are a {archetype} trading agent."},
|
|
416
|
+
{"role": "user", "content": str(scenario_context)},
|
|
417
|
+
{"role": "assistant", "content": model_response},
|
|
418
|
+
]
|
|
419
|
+
|
|
420
|
+
# Get vLLM URL for generation
|
|
421
|
+
vllm_base_url = self._server_configs[0].base_url if self._server_configs else "http://localhost:9001/v1"
|
|
422
|
+
model_name = self.config.tokenizer_name
|
|
423
|
+
|
|
424
|
+
# Generate N completions for the same prompt
|
|
425
|
+
import aiohttp
|
|
426
|
+
|
|
427
|
+
prompt_messages = messages[:-1] # Exclude assistant response
|
|
428
|
+
|
|
429
|
+
async with aiohttp.ClientSession() as session:
|
|
430
|
+
async with session.post(
|
|
431
|
+
f"{vllm_base_url}/chat/completions",
|
|
432
|
+
json={
|
|
433
|
+
"model": model_name,
|
|
434
|
+
"messages": prompt_messages,
|
|
435
|
+
"max_tokens": 512,
|
|
436
|
+
"n": self.config.group_size,
|
|
437
|
+
"temperature": 0.7,
|
|
438
|
+
},
|
|
439
|
+
) as resp:
|
|
440
|
+
if resp.status != 200:
|
|
441
|
+
logger.warning(f"vLLM request failed: {resp.status}")
|
|
442
|
+
return None, []
|
|
443
|
+
result = await resp.json()
|
|
444
|
+
|
|
445
|
+
choices = result.get("choices", [])
|
|
446
|
+
if len(choices) < 2:
|
|
447
|
+
return None, []
|
|
448
|
+
|
|
449
|
+
# Score each completion
|
|
450
|
+
rollout_data = []
|
|
451
|
+
for choice in choices:
|
|
452
|
+
response_content = choice.get("message", {}).get("content", "")
|
|
453
|
+
|
|
454
|
+
# Build full messages
|
|
455
|
+
full_messages = copy.deepcopy(prompt_messages)
|
|
456
|
+
full_messages.append({"role": "assistant", "content": response_content})
|
|
457
|
+
|
|
458
|
+
# Tokenize with proper masking
|
|
459
|
+
token_result = tokenize_for_trainer(
|
|
460
|
+
self.tokenizer,
|
|
461
|
+
full_messages,
|
|
462
|
+
train_on_all_assistant_turns=True,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
# Score
|
|
466
|
+
quality = score_response(
|
|
467
|
+
response=response_content,
|
|
468
|
+
archetype=archetype,
|
|
469
|
+
execute_action=False,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
format_result = validate_response_format(response_content)
|
|
473
|
+
|
|
474
|
+
base_score = quality.combined_format_score * 0.4 + quality.reasoning_score * 0.3
|
|
475
|
+
action_bonus = 0.3 if format_result.is_valid else 0.0
|
|
476
|
+
final_score = base_score + action_bonus
|
|
477
|
+
|
|
478
|
+
rollout_data.append({
|
|
479
|
+
"tokens": token_result["input_ids"],
|
|
480
|
+
"masks": token_result["masks"],
|
|
481
|
+
"score": final_score,
|
|
482
|
+
})
|
|
483
|
+
|
|
484
|
+
# Center scores and add small noise to prevent identical scores
|
|
485
|
+
scores = [r["score"] + random.uniform(-0.01, 0.01) for r in rollout_data]
|
|
486
|
+
mean_score = sum(scores) / len(scores)
|
|
487
|
+
|
|
488
|
+
scored_group = ScoredDataGroup(
|
|
489
|
+
tokens=[r["tokens"] for r in rollout_data],
|
|
490
|
+
masks=[r["masks"] for r in rollout_data],
|
|
491
|
+
scores=[s - mean_score for s in scores],
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
return scored_group, []
|
|
495
|
+
|
|
496
|
+
async def cleanup(self):
|
|
497
|
+
"""Clean up resources"""
|
|
498
|
+
if self.simulation_bridge:
|
|
499
|
+
logger.info("Cleaning up simulation bridge...")
|
|
500
|
+
await self.simulation_bridge.reset()
|
|
501
|
+
await self.simulation_bridge.__aexit__(None, None, None)
|
|
502
|
+
self.simulation_bridge = None
|
|
503
|
+
|
|
504
|
+
if self.db_pool:
|
|
505
|
+
logger.info("Closing database pool...")
|
|
506
|
+
await self.db_pool.close()
|
|
507
|
+
self.db_pool = None
|
|
508
|
+
|
|
509
|
+
logger.info(f"Hybrid stats: online={self.online_count}, offline={self.offline_count}")
|
|
510
|
+
|
|
511
|
+
async def evaluate(self):
|
|
512
|
+
"""Periodic evaluation logging"""
|
|
513
|
+
total = self.online_count + self.offline_count
|
|
514
|
+
if total > 0:
|
|
515
|
+
actual_online_ratio = self.online_count / total
|
|
516
|
+
logger.info(f"Hybrid stats: total={total}, online={self.online_count} ({actual_online_ratio:.1%}), "
|
|
517
|
+
f"offline={self.offline_count} ({1-actual_online_ratio:.1%})")
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
if __name__ == "__main__":
|
|
521
|
+
BabylonHybridEnv.cli()
|
|
522
|
+
|