@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
"""
|
|
2
|
+
End-to-end tests for online training mode (Phase 3).
|
|
3
|
+
|
|
4
|
+
These tests verify the complete online training pipeline:
|
|
5
|
+
1. Simulation bridge client connectivity
|
|
6
|
+
2. Scenario retrieval from bridge
|
|
7
|
+
3. Online environment rollout collection
|
|
8
|
+
4. Full training loop with online rollouts
|
|
9
|
+
|
|
10
|
+
Requirements:
|
|
11
|
+
- Simulation bridge server running (make bridge-server)
|
|
12
|
+
- Or mock server for unit testing
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import json
|
|
17
|
+
import os
|
|
18
|
+
import pytest
|
|
19
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
20
|
+
|
|
21
|
+
import sys
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
# Add src to path
|
|
25
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
26
|
+
|
|
27
|
+
from src.training.simulation_bridge import (
|
|
28
|
+
SimulationBridge,
|
|
29
|
+
Scenario,
|
|
30
|
+
MarketState,
|
|
31
|
+
PerpMarket,
|
|
32
|
+
PredictionMarket,
|
|
33
|
+
Position,
|
|
34
|
+
NewsItem,
|
|
35
|
+
SocialContext,
|
|
36
|
+
ActionOutcome,
|
|
37
|
+
)
|
|
38
|
+
from src.training.scenario_pool import (
|
|
39
|
+
Scenario as PoolScenario,
|
|
40
|
+
MarketState as PoolMarketState,
|
|
41
|
+
PortfolioState,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class TestSimulationBridgeClient:
|
|
46
|
+
"""Tests for the Python simulation bridge client"""
|
|
47
|
+
|
|
48
|
+
@pytest.fixture
|
|
49
|
+
def mock_response_data(self):
|
|
50
|
+
"""Standard mock response data from bridge"""
|
|
51
|
+
return {
|
|
52
|
+
"npcId": "test-npc-1",
|
|
53
|
+
"archetype": "trader",
|
|
54
|
+
"marketState": {
|
|
55
|
+
"perpMarkets": [
|
|
56
|
+
{
|
|
57
|
+
"ticker": "BTC",
|
|
58
|
+
"currentPrice": 45000.0,
|
|
59
|
+
"changePercent24h": 2.5,
|
|
60
|
+
"volume24h": 1000000.0,
|
|
61
|
+
}
|
|
62
|
+
],
|
|
63
|
+
"predictionMarkets": [
|
|
64
|
+
{
|
|
65
|
+
"id": "market-1",
|
|
66
|
+
"title": "Will BTC hit $50K?",
|
|
67
|
+
"yesPrice": 0.65,
|
|
68
|
+
"noPrice": 0.35,
|
|
69
|
+
}
|
|
70
|
+
],
|
|
71
|
+
},
|
|
72
|
+
"positions": [
|
|
73
|
+
{
|
|
74
|
+
"id": "pos-1",
|
|
75
|
+
"marketType": "perp",
|
|
76
|
+
"ticker": "BTC",
|
|
77
|
+
"side": "long",
|
|
78
|
+
"size": 0.5,
|
|
79
|
+
"unrealizedPnL": 250.0,
|
|
80
|
+
}
|
|
81
|
+
],
|
|
82
|
+
"balance": 10000.0,
|
|
83
|
+
"recentNews": [
|
|
84
|
+
{
|
|
85
|
+
"content": "Market update: BTC rising",
|
|
86
|
+
"source": "CryptoNews",
|
|
87
|
+
"timestamp": "2025-01-01T00:00:00Z",
|
|
88
|
+
}
|
|
89
|
+
],
|
|
90
|
+
"socialContext": {
|
|
91
|
+
"relationships": [
|
|
92
|
+
{"actorId": "actor-1", "actorName": "Whale", "sentiment": 0.8}
|
|
93
|
+
],
|
|
94
|
+
"groupChats": ["traders-lounge"],
|
|
95
|
+
"recentMessages": [{"from": "Whale", "content": "Bullish today!"}],
|
|
96
|
+
},
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
def test_scenario_parsing(self, mock_response_data):
|
|
100
|
+
"""Test that bridge response is correctly parsed into Scenario"""
|
|
101
|
+
# This tests the parsing logic without network calls
|
|
102
|
+
data = mock_response_data
|
|
103
|
+
|
|
104
|
+
market_state = MarketState(
|
|
105
|
+
perp_markets=[
|
|
106
|
+
PerpMarket(
|
|
107
|
+
ticker=m["ticker"],
|
|
108
|
+
current_price=m["currentPrice"],
|
|
109
|
+
change_percent_24h=m["changePercent24h"],
|
|
110
|
+
volume_24h=m["volume24h"],
|
|
111
|
+
)
|
|
112
|
+
for m in data.get("marketState", {}).get("perpMarkets", [])
|
|
113
|
+
],
|
|
114
|
+
prediction_markets=[
|
|
115
|
+
PredictionMarket(
|
|
116
|
+
id=m["id"],
|
|
117
|
+
question=m["title"],
|
|
118
|
+
yes_price=m["yesPrice"],
|
|
119
|
+
no_price=m["noPrice"],
|
|
120
|
+
)
|
|
121
|
+
for m in data.get("marketState", {}).get("predictionMarkets", [])
|
|
122
|
+
],
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
positions = [
|
|
126
|
+
Position(
|
|
127
|
+
id=p["id"],
|
|
128
|
+
market_type=p["marketType"],
|
|
129
|
+
ticker=p.get("ticker"),
|
|
130
|
+
side=p["side"],
|
|
131
|
+
size=p["size"],
|
|
132
|
+
unrealized_pnl=p.get("unrealizedPnL", 0),
|
|
133
|
+
)
|
|
134
|
+
for p in data.get("positions", [])
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
scenario = Scenario(
|
|
138
|
+
npc_id=data["npcId"],
|
|
139
|
+
archetype=data["archetype"],
|
|
140
|
+
market_state=market_state,
|
|
141
|
+
positions=positions,
|
|
142
|
+
balance=data["balance"],
|
|
143
|
+
recent_news=[
|
|
144
|
+
NewsItem(
|
|
145
|
+
content=n["content"],
|
|
146
|
+
source=n["source"],
|
|
147
|
+
timestamp=n["timestamp"],
|
|
148
|
+
)
|
|
149
|
+
for n in data.get("recentNews", [])
|
|
150
|
+
],
|
|
151
|
+
social_context=SocialContext(),
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
assert scenario.npc_id == "test-npc-1"
|
|
155
|
+
assert scenario.archetype == "trader"
|
|
156
|
+
assert scenario.balance == 10000.0
|
|
157
|
+
assert len(scenario.market_state.perp_markets) == 1
|
|
158
|
+
assert scenario.market_state.perp_markets[0].ticker == "BTC"
|
|
159
|
+
assert len(scenario.positions) == 1
|
|
160
|
+
assert scenario.positions[0].unrealized_pnl == 250.0
|
|
161
|
+
|
|
162
|
+
def test_scenario_to_prompt_context(self, mock_response_data):
|
|
163
|
+
"""Test that scenario can be converted to prompt context"""
|
|
164
|
+
data = mock_response_data
|
|
165
|
+
|
|
166
|
+
market_state = MarketState(
|
|
167
|
+
perp_markets=[
|
|
168
|
+
PerpMarket(
|
|
169
|
+
ticker="BTC",
|
|
170
|
+
current_price=45000.0,
|
|
171
|
+
change_percent_24h=2.5,
|
|
172
|
+
volume_24h=1000000.0,
|
|
173
|
+
)
|
|
174
|
+
],
|
|
175
|
+
prediction_markets=[
|
|
176
|
+
PredictionMarket(
|
|
177
|
+
id="market-1",
|
|
178
|
+
question="Will BTC hit $50K?",
|
|
179
|
+
yes_price=0.65,
|
|
180
|
+
no_price=0.35,
|
|
181
|
+
)
|
|
182
|
+
],
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
scenario = Scenario(
|
|
186
|
+
npc_id="test-npc-1",
|
|
187
|
+
archetype="trader",
|
|
188
|
+
market_state=market_state,
|
|
189
|
+
positions=[],
|
|
190
|
+
balance=10000.0,
|
|
191
|
+
recent_news=[],
|
|
192
|
+
social_context=SocialContext(),
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
context = scenario.to_prompt_context()
|
|
196
|
+
|
|
197
|
+
assert "Agent ID: test-npc-1" in context
|
|
198
|
+
assert "Archetype: trader" in context
|
|
199
|
+
assert "Balance: $10,000.00" in context
|
|
200
|
+
assert "BTC" in context
|
|
201
|
+
# Price format may vary (with or without comma)
|
|
202
|
+
assert "45000" in context
|
|
203
|
+
assert "+2.50%" in context
|
|
204
|
+
|
|
205
|
+
@pytest.mark.asyncio
|
|
206
|
+
async def test_bridge_client_initialization(self):
|
|
207
|
+
"""Test that bridge client initializes correctly"""
|
|
208
|
+
bridge = SimulationBridge(base_url="http://localhost:3001")
|
|
209
|
+
|
|
210
|
+
assert bridge.base_url == "http://localhost:3001"
|
|
211
|
+
assert not bridge.is_initialized
|
|
212
|
+
assert bridge.npc_ids == []
|
|
213
|
+
assert bridge.archetypes == {}
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class TestOnlineEnvIntegration:
|
|
217
|
+
"""Integration tests for online environment"""
|
|
218
|
+
|
|
219
|
+
def test_pool_scenario_add_methods(self):
|
|
220
|
+
"""Test Scenario.add_market, add_perpetual, add_news methods"""
|
|
221
|
+
from src.training.scenario_pool import Scenario, PortfolioState
|
|
222
|
+
|
|
223
|
+
scenario = PoolScenario(
|
|
224
|
+
id="test-1",
|
|
225
|
+
source="synthetic",
|
|
226
|
+
archetype_focus="trader",
|
|
227
|
+
difficulty="medium",
|
|
228
|
+
portfolio=PortfolioState(balance=10000.0, positions=[]),
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Test add_market
|
|
232
|
+
scenario.add_market({
|
|
233
|
+
"id": "mkt-1",
|
|
234
|
+
"question": "Will BTC hit $50K?",
|
|
235
|
+
"yesPrice": 0.65,
|
|
236
|
+
"noPrice": 0.35,
|
|
237
|
+
})
|
|
238
|
+
|
|
239
|
+
assert len(scenario.markets) == 1
|
|
240
|
+
assert scenario.markets[0].market_id == "mkt-1"
|
|
241
|
+
assert scenario.markets[0].question == "Will BTC hit $50K?"
|
|
242
|
+
|
|
243
|
+
# Test add_perpetual
|
|
244
|
+
scenario.add_perpetual({
|
|
245
|
+
"ticker": "BTC",
|
|
246
|
+
"markPrice": 45000.0,
|
|
247
|
+
"change24h": 2.5,
|
|
248
|
+
})
|
|
249
|
+
|
|
250
|
+
assert len(scenario.perpetuals) == 1
|
|
251
|
+
assert scenario.perpetuals[0].ticker == "BTC"
|
|
252
|
+
assert scenario.perpetuals[0].mark_price == 45000.0
|
|
253
|
+
|
|
254
|
+
# Test add_news
|
|
255
|
+
scenario.add_news({
|
|
256
|
+
"headline": "BTC is rising",
|
|
257
|
+
"sentiment": "bullish",
|
|
258
|
+
"impact": "high",
|
|
259
|
+
"source": "CryptoNews",
|
|
260
|
+
})
|
|
261
|
+
|
|
262
|
+
assert len(scenario.news) == 1
|
|
263
|
+
assert scenario.news[0].headline == "BTC is rising"
|
|
264
|
+
|
|
265
|
+
def test_scenario_metadata(self):
|
|
266
|
+
"""Test Scenario.metadata field for extensibility"""
|
|
267
|
+
from src.training.scenario_pool import Scenario, PortfolioState
|
|
268
|
+
|
|
269
|
+
scenario = PoolScenario(
|
|
270
|
+
id="test-1",
|
|
271
|
+
source="synthetic",
|
|
272
|
+
archetype_focus="trader",
|
|
273
|
+
difficulty="medium",
|
|
274
|
+
portfolio=PortfolioState(balance=10000.0, positions=[]),
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Metadata should be empty by default
|
|
278
|
+
assert scenario.metadata == {}
|
|
279
|
+
|
|
280
|
+
# Can add arbitrary metadata
|
|
281
|
+
scenario.metadata["mode"] = "online"
|
|
282
|
+
scenario.metadata["npc_id"] = "npc-1"
|
|
283
|
+
scenario.metadata["bridge_scenario"] = {"npc_id": "npc-1"}
|
|
284
|
+
|
|
285
|
+
assert scenario.metadata["mode"] == "online"
|
|
286
|
+
assert scenario.metadata["npc_id"] == "npc-1"
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class TestHybridEnv:
|
|
290
|
+
"""Tests for hybrid environment"""
|
|
291
|
+
|
|
292
|
+
def test_hybrid_config_online_ratio(self):
|
|
293
|
+
"""Test that hybrid config accepts online_ratio"""
|
|
294
|
+
from src.training.hybrid_env import BabylonHybridEnvConfig
|
|
295
|
+
|
|
296
|
+
config = BabylonHybridEnvConfig(
|
|
297
|
+
tokenizer_name="test-model",
|
|
298
|
+
online_ratio=0.3,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
assert config.online_ratio == 0.3
|
|
302
|
+
|
|
303
|
+
def test_hybrid_config_defaults(self):
|
|
304
|
+
"""Test hybrid config default values"""
|
|
305
|
+
from src.training.hybrid_env import BabylonHybridEnvConfig
|
|
306
|
+
|
|
307
|
+
config = BabylonHybridEnvConfig(tokenizer_name="test-model")
|
|
308
|
+
|
|
309
|
+
assert config.online_ratio == 0.2
|
|
310
|
+
assert config.use_simulation_bridge is False # Default from parent
|
|
311
|
+
assert config.db_url is None
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class TestModeSelection:
|
|
315
|
+
"""Tests for training mode selection in run_training.py"""
|
|
316
|
+
|
|
317
|
+
def test_mode_argument_parsing(self):
|
|
318
|
+
"""Test that mode arguments are parsed correctly"""
|
|
319
|
+
# This would require importing and testing argument parsing
|
|
320
|
+
# For now, we just verify the modes are valid
|
|
321
|
+
valid_modes = ["offline", "online", "hybrid"]
|
|
322
|
+
assert "offline" in valid_modes
|
|
323
|
+
assert "online" in valid_modes
|
|
324
|
+
assert "hybrid" in valid_modes
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
# Integration test that requires bridge server
|
|
328
|
+
@pytest.mark.skipif(
|
|
329
|
+
os.getenv("SIMULATION_BRIDGE_URL") is None,
|
|
330
|
+
reason="Simulation bridge not configured"
|
|
331
|
+
)
|
|
332
|
+
class TestLiveBridgeIntegration:
|
|
333
|
+
"""Live integration tests with actual bridge server"""
|
|
334
|
+
|
|
335
|
+
@pytest.mark.asyncio
|
|
336
|
+
async def test_live_bridge_health(self):
|
|
337
|
+
"""Test bridge health check with live server"""
|
|
338
|
+
bridge_url = os.getenv("SIMULATION_BRIDGE_URL", "http://localhost:3001")
|
|
339
|
+
|
|
340
|
+
async with SimulationBridge(bridge_url) as bridge:
|
|
341
|
+
health = await bridge.health_check()
|
|
342
|
+
|
|
343
|
+
assert "status" in health
|
|
344
|
+
assert health["status"] == "healthy"
|
|
345
|
+
|
|
346
|
+
@pytest.mark.asyncio
|
|
347
|
+
async def test_live_bridge_init_and_scenario(self):
|
|
348
|
+
"""Test initializing bridge and getting scenario"""
|
|
349
|
+
bridge_url = os.getenv("SIMULATION_BRIDGE_URL", "http://localhost:3001")
|
|
350
|
+
|
|
351
|
+
async with SimulationBridge(bridge_url) as bridge:
|
|
352
|
+
# Initialize
|
|
353
|
+
result = await bridge.initialize(num_npcs=5, archetypes=["trader", "degen"])
|
|
354
|
+
|
|
355
|
+
assert bridge.is_initialized
|
|
356
|
+
assert len(bridge.npc_ids) == 5
|
|
357
|
+
|
|
358
|
+
# Get scenario
|
|
359
|
+
npc_id = bridge.npc_ids[0]
|
|
360
|
+
scenario = await bridge.get_scenario(npc_id)
|
|
361
|
+
|
|
362
|
+
assert scenario.npc_id == npc_id
|
|
363
|
+
assert scenario.archetype in ["trader", "degen"]
|
|
364
|
+
assert scenario.balance > 0
|
|
365
|
+
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# Training Pipeline Integration Tests
|
|
2
|
+
#
|
|
3
|
+
# This package contains integration tests that require running infrastructure.
|
|
4
|
+
#
|
|
5
|
+
# Test Tiers:
|
|
6
|
+
# - test_json_mode_integration.py: Tests JSON-only trajectory processing (no DB)
|
|
7
|
+
# - test_db_integration.py: Tests database trajectory processing (requires PostgreSQL)
|
|
8
|
+
#
|
|
9
|
+
# Setup:
|
|
10
|
+
# docker compose -f docker-compose.test.yml up -d
|
|
11
|
+
# DATABASE_URL=postgresql://babylon_test:test_password@localhost:5434/babylon_test pytest python/tests/integration/
|
|
12
|
+
|