@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,581 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Comprehensive tests for simulation, rollout generation, and dataset preparation.
|
|
3
|
+
|
|
4
|
+
Tests:
|
|
5
|
+
1. FastSimulator - benchmark and data generation modes
|
|
6
|
+
2. FastRolloutGenerator - rollout creation and quality validation
|
|
7
|
+
3. MultiPromptDatasetBuilder - dataset creation from trajectories
|
|
8
|
+
4. RolloutQualityValidator - validation logic
|
|
9
|
+
5. PromptTypeAnalyzer - correlation analysis
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import pytest
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
|
|
15
|
+
import sys
|
|
16
|
+
sys.path.insert(0, '.')
|
|
17
|
+
|
|
18
|
+
from src.models import (
|
|
19
|
+
TrainingTrajectory, TrajectoryStep, EnvironmentState,
|
|
20
|
+
Action, LLMCall, AtroposScoredGroup
|
|
21
|
+
)
|
|
22
|
+
from src.training import (
|
|
23
|
+
FastSimulator, SimulatorConfig, GameState,
|
|
24
|
+
RolloutResult, AgentTickData,
|
|
25
|
+
RolloutQualityValidator,
|
|
26
|
+
MultiPromptDatasetBuilder, PromptSample,
|
|
27
|
+
prepare_multi_prompt_training_data, PromptTypeAnalyzer,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# ============================================================
|
|
32
|
+
# Fixtures
|
|
33
|
+
# ============================================================
|
|
34
|
+
|
|
35
|
+
@pytest.fixture
|
|
36
|
+
def sample_env_state():
|
|
37
|
+
"""Create sample environment state"""
|
|
38
|
+
return EnvironmentState(
|
|
39
|
+
agent_balance=10000.0,
|
|
40
|
+
agent_pnl=100.0,
|
|
41
|
+
open_positions=2,
|
|
42
|
+
active_markets=5
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@pytest.fixture
|
|
47
|
+
def sample_action():
|
|
48
|
+
"""Create sample action"""
|
|
49
|
+
return Action(
|
|
50
|
+
action_type='buy',
|
|
51
|
+
parameters={'ticker': 'BTC', 'amount': 0.1},
|
|
52
|
+
success=True
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@pytest.fixture
|
|
57
|
+
def sample_llm_call():
|
|
58
|
+
"""Create sample LLM call"""
|
|
59
|
+
return LLMCall(
|
|
60
|
+
model='gpt-4',
|
|
61
|
+
system_prompt='You are a trading agent.',
|
|
62
|
+
user_prompt='Market update: BTC is up 5%',
|
|
63
|
+
response='I will buy 0.1 BTC',
|
|
64
|
+
reasoning='Price momentum is positive',
|
|
65
|
+
temperature=0.7,
|
|
66
|
+
max_tokens=100,
|
|
67
|
+
purpose='action'
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@pytest.fixture
|
|
72
|
+
def sample_tick_data(sample_env_state, sample_action, sample_llm_call):
|
|
73
|
+
"""Create sample tick data"""
|
|
74
|
+
return AgentTickData(
|
|
75
|
+
tick_number=0,
|
|
76
|
+
timestamp=1000000,
|
|
77
|
+
observation={'markets': [{'id': 'm1', 'price': 50000}]},
|
|
78
|
+
environment_state=sample_env_state,
|
|
79
|
+
llm_calls=[sample_llm_call],
|
|
80
|
+
action=sample_action,
|
|
81
|
+
reward=0.1
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@pytest.fixture
|
|
86
|
+
def sample_trajectory(sample_env_state, sample_action, sample_llm_call):
|
|
87
|
+
"""Create sample trajectory"""
|
|
88
|
+
steps = []
|
|
89
|
+
for i in range(5):
|
|
90
|
+
env = EnvironmentState(
|
|
91
|
+
agent_balance=10000.0 + i * 100,
|
|
92
|
+
agent_pnl=i * 100.0,
|
|
93
|
+
open_positions=i
|
|
94
|
+
)
|
|
95
|
+
steps.append(TrajectoryStep(
|
|
96
|
+
step_number=i,
|
|
97
|
+
timestamp=1000000 + i * 1000,
|
|
98
|
+
environment_state=env,
|
|
99
|
+
provider_accesses=[],
|
|
100
|
+
llm_calls=[LLMCall(
|
|
101
|
+
model='gpt-4',
|
|
102
|
+
system_prompt='You are a trading agent.',
|
|
103
|
+
user_prompt=f'Market update for step {i}: price is moving',
|
|
104
|
+
response=f'I will execute action {i}: buying at current price level',
|
|
105
|
+
reasoning=f'Reasoning for step {i}',
|
|
106
|
+
temperature=0.7,
|
|
107
|
+
max_tokens=100,
|
|
108
|
+
purpose='action'
|
|
109
|
+
)],
|
|
110
|
+
action=Action(
|
|
111
|
+
action_type='trade' if i % 2 == 0 else 'wait',
|
|
112
|
+
parameters={'amount': 100},
|
|
113
|
+
success=True
|
|
114
|
+
),
|
|
115
|
+
reward=0.1
|
|
116
|
+
))
|
|
117
|
+
|
|
118
|
+
return TrainingTrajectory(
|
|
119
|
+
id='traj-1',
|
|
120
|
+
trajectory_id='traj-1',
|
|
121
|
+
agent_id='agent-1',
|
|
122
|
+
window_id='2024-01-01T00:00',
|
|
123
|
+
start_time=datetime.now(),
|
|
124
|
+
end_time=datetime.now(),
|
|
125
|
+
duration_ms=5000,
|
|
126
|
+
steps=steps,
|
|
127
|
+
total_reward=0.5,
|
|
128
|
+
final_pnl=400.0,
|
|
129
|
+
episode_length=5,
|
|
130
|
+
final_status='completed'
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# ============================================================
|
|
135
|
+
# GameState Tests
|
|
136
|
+
# ============================================================
|
|
137
|
+
|
|
138
|
+
class TestGameState:
|
|
139
|
+
"""Tests for GameState dataclass"""
|
|
140
|
+
|
|
141
|
+
def test_default_initialization(self):
|
|
142
|
+
"""Test default GameState values"""
|
|
143
|
+
state = GameState()
|
|
144
|
+
assert state.tick == 0
|
|
145
|
+
assert state.time == 0
|
|
146
|
+
assert state.markets == []
|
|
147
|
+
assert state.portfolios == {}
|
|
148
|
+
|
|
149
|
+
def test_to_observation(self):
|
|
150
|
+
"""Test observation conversion"""
|
|
151
|
+
state = GameState(
|
|
152
|
+
tick=5,
|
|
153
|
+
time=1000000,
|
|
154
|
+
markets=[{'id': 'm1'}],
|
|
155
|
+
news=[{'headline': 'News 1'}, {'headline': 'News 2'}]
|
|
156
|
+
)
|
|
157
|
+
obs = state.to_observation()
|
|
158
|
+
|
|
159
|
+
assert obs['tick'] == 5
|
|
160
|
+
assert obs['time'] == 1000000
|
|
161
|
+
assert len(obs['markets']) == 1
|
|
162
|
+
assert 'news' in obs
|
|
163
|
+
|
|
164
|
+
def test_get_env_state(self):
|
|
165
|
+
"""Test environment state extraction"""
|
|
166
|
+
state = GameState(
|
|
167
|
+
portfolios={
|
|
168
|
+
'agent-1': {'balance': 15000.0, 'pnl': 500.0, 'positions': 3}
|
|
169
|
+
}
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
env = state.get_env_state('agent-1')
|
|
173
|
+
assert env.agent_balance == 15000.0
|
|
174
|
+
assert env.agent_pnl == 500.0
|
|
175
|
+
assert env.open_positions == 3
|
|
176
|
+
|
|
177
|
+
def test_get_env_state_unknown_agent(self):
|
|
178
|
+
"""Test environment state for unknown agent"""
|
|
179
|
+
state = GameState()
|
|
180
|
+
env = state.get_env_state('unknown-agent')
|
|
181
|
+
|
|
182
|
+
# Should return default values
|
|
183
|
+
assert env.agent_balance == 10000.0
|
|
184
|
+
assert env.agent_pnl == 0.0
|
|
185
|
+
assert env.open_positions == 0
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
# ============================================================
|
|
189
|
+
# SimulatorConfig Tests
|
|
190
|
+
# ============================================================
|
|
191
|
+
|
|
192
|
+
class TestSimulatorConfig:
|
|
193
|
+
"""Tests for SimulatorConfig"""
|
|
194
|
+
|
|
195
|
+
def test_default_config(self):
|
|
196
|
+
"""Test default configuration"""
|
|
197
|
+
config = SimulatorConfig()
|
|
198
|
+
assert config.mode == 'data_generation'
|
|
199
|
+
assert config.max_concurrent_agents == 8
|
|
200
|
+
assert config.batch_size == 4
|
|
201
|
+
assert config.ticks_per_window == 60
|
|
202
|
+
assert config.min_actions_per_trajectory == 5
|
|
203
|
+
|
|
204
|
+
def test_benchmark_mode_config(self):
|
|
205
|
+
"""Test benchmark mode configuration"""
|
|
206
|
+
config = SimulatorConfig(
|
|
207
|
+
mode='benchmark',
|
|
208
|
+
benchmark_snapshot={'ticks': []},
|
|
209
|
+
ground_truth={'marketOutcomes': {}}
|
|
210
|
+
)
|
|
211
|
+
assert config.mode == 'benchmark'
|
|
212
|
+
assert config.benchmark_snapshot is not None
|
|
213
|
+
assert config.ground_truth is not None
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# ============================================================
|
|
217
|
+
# FastSimulator Tests
|
|
218
|
+
# ============================================================
|
|
219
|
+
|
|
220
|
+
class TestFastSimulator:
|
|
221
|
+
"""Tests for FastSimulator"""
|
|
222
|
+
|
|
223
|
+
def test_for_benchmark(self):
|
|
224
|
+
"""Test benchmark mode creation"""
|
|
225
|
+
snapshot = {
|
|
226
|
+
'ticks': [
|
|
227
|
+
{'state': {'currentTime': 1000}},
|
|
228
|
+
{'state': {'currentTime': 2000}}
|
|
229
|
+
],
|
|
230
|
+
'groundTruth': {'marketOutcomes': {}},
|
|
231
|
+
'initialState': {
|
|
232
|
+
'predictionMarkets': [{'id': 'm1'}],
|
|
233
|
+
'currentTime': 1000
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
sim = FastSimulator.for_benchmark(snapshot)
|
|
238
|
+
|
|
239
|
+
assert sim.config.mode == 'benchmark'
|
|
240
|
+
assert len(sim.benchmark_ticks) == 2
|
|
241
|
+
assert sim.game_state.markets == [{'id': 'm1'}]
|
|
242
|
+
|
|
243
|
+
def test_is_complete_benchmark(self):
|
|
244
|
+
"""Test completion check in benchmark mode"""
|
|
245
|
+
snapshot = {'ticks': [{}] * 5, 'groundTruth': {}, 'initialState': {}}
|
|
246
|
+
sim = FastSimulator.for_benchmark(snapshot)
|
|
247
|
+
|
|
248
|
+
assert not sim.is_complete()
|
|
249
|
+
sim.current_tick = 5
|
|
250
|
+
assert sim.is_complete()
|
|
251
|
+
|
|
252
|
+
def test_is_complete_data_generation(self):
|
|
253
|
+
"""Test completion check in data generation mode"""
|
|
254
|
+
config = SimulatorConfig(max_ticks=100)
|
|
255
|
+
sim = FastSimulator(config)
|
|
256
|
+
|
|
257
|
+
assert not sim.is_complete()
|
|
258
|
+
sim.current_tick = 100
|
|
259
|
+
assert sim.is_complete()
|
|
260
|
+
|
|
261
|
+
def test_advance_tick(self):
|
|
262
|
+
"""Test tick advancement"""
|
|
263
|
+
config = SimulatorConfig()
|
|
264
|
+
sim = FastSimulator(config)
|
|
265
|
+
|
|
266
|
+
initial_tick = sim.current_tick
|
|
267
|
+
initial_time = sim.game_state.time
|
|
268
|
+
|
|
269
|
+
sim._advance_tick()
|
|
270
|
+
|
|
271
|
+
assert sim.current_tick == initial_tick + 1
|
|
272
|
+
assert sim.game_state.time == initial_time + 1000
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# ============================================================
|
|
276
|
+
# AgentTickData Tests
|
|
277
|
+
# ============================================================
|
|
278
|
+
|
|
279
|
+
class TestAgentTickData:
|
|
280
|
+
"""Tests for AgentTickData"""
|
|
281
|
+
|
|
282
|
+
def test_get_full_context(self, sample_tick_data):
|
|
283
|
+
"""Test full context generation"""
|
|
284
|
+
context = sample_tick_data.get_full_context()
|
|
285
|
+
|
|
286
|
+
assert '=== OBSERVATION' in context
|
|
287
|
+
assert '=== LLM CALL 1' in context
|
|
288
|
+
assert '=== ACTION ===' in context
|
|
289
|
+
assert 'buy' in context.lower()
|
|
290
|
+
|
|
291
|
+
def test_get_full_context_no_action(self, sample_env_state, sample_llm_call):
|
|
292
|
+
"""Test context without action"""
|
|
293
|
+
tick = AgentTickData(
|
|
294
|
+
tick_number=0,
|
|
295
|
+
timestamp=1000000,
|
|
296
|
+
observation={},
|
|
297
|
+
environment_state=sample_env_state,
|
|
298
|
+
llm_calls=[sample_llm_call],
|
|
299
|
+
action=None,
|
|
300
|
+
reward=0.0
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
context = tick.get_full_context()
|
|
304
|
+
assert '=== OBSERVATION' in context
|
|
305
|
+
assert '=== ACTION ===' not in context
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
# ============================================================
|
|
309
|
+
# RolloutQualityValidator Tests
|
|
310
|
+
# ============================================================
|
|
311
|
+
|
|
312
|
+
class TestRolloutQualityValidator:
|
|
313
|
+
"""Tests for RolloutQualityValidator"""
|
|
314
|
+
|
|
315
|
+
def test_validate_valid_rollout(self, sample_trajectory):
|
|
316
|
+
"""Test validation of valid rollout"""
|
|
317
|
+
result = RolloutResult(
|
|
318
|
+
agent_id='test',
|
|
319
|
+
trajectory_id='test-traj',
|
|
320
|
+
ticks_completed=10,
|
|
321
|
+
total_duration_ms=5000,
|
|
322
|
+
avg_tick_duration_ms=500.0,
|
|
323
|
+
total_llm_calls=15,
|
|
324
|
+
total_reward=5.0,
|
|
325
|
+
final_pnl=500.0,
|
|
326
|
+
quality_score=0.7,
|
|
327
|
+
trajectory=sample_trajectory
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
is_valid, issues = RolloutQualityValidator.validate_rollout(result)
|
|
331
|
+
|
|
332
|
+
# Should have some issues due to LLM call requirements per step
|
|
333
|
+
assert isinstance(is_valid, bool)
|
|
334
|
+
assert isinstance(issues, list)
|
|
335
|
+
|
|
336
|
+
def test_validate_no_trajectory(self):
|
|
337
|
+
"""Test validation with no trajectory"""
|
|
338
|
+
result = RolloutResult(
|
|
339
|
+
agent_id='test',
|
|
340
|
+
trajectory_id='test-traj',
|
|
341
|
+
ticks_completed=10,
|
|
342
|
+
total_duration_ms=5000,
|
|
343
|
+
avg_tick_duration_ms=500.0,
|
|
344
|
+
total_llm_calls=15,
|
|
345
|
+
total_reward=5.0,
|
|
346
|
+
final_pnl=500.0,
|
|
347
|
+
quality_score=0.7,
|
|
348
|
+
trajectory=None
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
is_valid, issues = RolloutQualityValidator.validate_rollout(result)
|
|
352
|
+
|
|
353
|
+
assert not is_valid
|
|
354
|
+
assert 'No trajectory data' in issues
|
|
355
|
+
|
|
356
|
+
def test_validate_too_few_ticks(self, sample_trajectory):
|
|
357
|
+
"""Test validation with too few ticks"""
|
|
358
|
+
result = RolloutResult(
|
|
359
|
+
agent_id='test',
|
|
360
|
+
trajectory_id='test-traj',
|
|
361
|
+
ticks_completed=3, # Less than 5
|
|
362
|
+
total_duration_ms=1500,
|
|
363
|
+
avg_tick_duration_ms=500.0,
|
|
364
|
+
total_llm_calls=3,
|
|
365
|
+
total_reward=0.3,
|
|
366
|
+
final_pnl=100.0,
|
|
367
|
+
quality_score=0.3,
|
|
368
|
+
trajectory=sample_trajectory
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
is_valid, issues = RolloutQualityValidator.validate_rollout(result)
|
|
372
|
+
|
|
373
|
+
# Should flag too few ticks
|
|
374
|
+
assert any('Too few ticks' in issue for issue in issues)
|
|
375
|
+
|
|
376
|
+
def test_validate_low_quality_score(self, sample_trajectory):
|
|
377
|
+
"""Test validation with low quality score"""
|
|
378
|
+
result = RolloutResult(
|
|
379
|
+
agent_id='test',
|
|
380
|
+
trajectory_id='test-traj',
|
|
381
|
+
ticks_completed=10,
|
|
382
|
+
total_duration_ms=5000,
|
|
383
|
+
avg_tick_duration_ms=500.0,
|
|
384
|
+
total_llm_calls=10,
|
|
385
|
+
total_reward=1.0,
|
|
386
|
+
final_pnl=100.0,
|
|
387
|
+
quality_score=0.3, # Below 0.5 threshold
|
|
388
|
+
trajectory=sample_trajectory
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
is_valid, issues = RolloutQualityValidator.validate_rollout(result)
|
|
392
|
+
|
|
393
|
+
# Should flag low quality
|
|
394
|
+
assert any('Quality score too low' in issue for issue in issues)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
# ============================================================
|
|
398
|
+
# MultiPromptDatasetBuilder Tests
|
|
399
|
+
# ============================================================
|
|
400
|
+
|
|
401
|
+
class TestMultiPromptDatasetBuilder:
|
|
402
|
+
"""Tests for MultiPromptDatasetBuilder"""
|
|
403
|
+
|
|
404
|
+
def test_initialization(self):
|
|
405
|
+
"""Test builder initialization"""
|
|
406
|
+
builder = MultiPromptDatasetBuilder()
|
|
407
|
+
|
|
408
|
+
assert len(builder.datasets) == 4
|
|
409
|
+
assert 'action' in builder.datasets
|
|
410
|
+
assert 'reasoning' in builder.datasets
|
|
411
|
+
assert 'evaluation' in builder.datasets
|
|
412
|
+
assert 'response' in builder.datasets
|
|
413
|
+
assert builder.total_trajectories == 0
|
|
414
|
+
|
|
415
|
+
def test_add_trajectory(self, sample_trajectory):
|
|
416
|
+
"""Test adding trajectory"""
|
|
417
|
+
builder = MultiPromptDatasetBuilder()
|
|
418
|
+
|
|
419
|
+
samples_added = builder.add_trajectory(sample_trajectory, trajectory_score=0.8)
|
|
420
|
+
|
|
421
|
+
assert samples_added == 5 # One per step
|
|
422
|
+
assert builder.total_trajectories == 1
|
|
423
|
+
assert builder.total_steps == 5
|
|
424
|
+
assert builder.total_samples == 5
|
|
425
|
+
|
|
426
|
+
def test_get_statistics(self, sample_trajectory):
|
|
427
|
+
"""Test statistics calculation"""
|
|
428
|
+
builder = MultiPromptDatasetBuilder()
|
|
429
|
+
builder.add_trajectory(sample_trajectory, trajectory_score=0.8)
|
|
430
|
+
|
|
431
|
+
stats = builder.get_statistics()
|
|
432
|
+
|
|
433
|
+
assert stats['total_trajectories'] == 1
|
|
434
|
+
assert stats['total_samples'] == 5
|
|
435
|
+
assert 'by_purpose' in stats
|
|
436
|
+
assert 'action' in stats['by_purpose']
|
|
437
|
+
|
|
438
|
+
def test_build_training_data(self, sample_trajectory):
|
|
439
|
+
"""Test training data building"""
|
|
440
|
+
builder = MultiPromptDatasetBuilder()
|
|
441
|
+
|
|
442
|
+
# Add multiple trajectories
|
|
443
|
+
for i in range(4):
|
|
444
|
+
builder.add_trajectory(sample_trajectory, trajectory_score=0.5 + i * 0.1)
|
|
445
|
+
|
|
446
|
+
groups = builder.build_training_data(purpose='action', group_size=4)
|
|
447
|
+
|
|
448
|
+
# Should create some groups
|
|
449
|
+
assert isinstance(groups, list)
|
|
450
|
+
if groups:
|
|
451
|
+
assert isinstance(groups[0], AtroposScoredGroup)
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
# ============================================================
|
|
455
|
+
# PromptSample Tests
|
|
456
|
+
# ============================================================
|
|
457
|
+
|
|
458
|
+
class TestPromptSample:
|
|
459
|
+
"""Tests for PromptSample"""
|
|
460
|
+
|
|
461
|
+
def test_to_messages(self):
|
|
462
|
+
"""Test message conversion"""
|
|
463
|
+
sample = PromptSample(
|
|
464
|
+
trajectory_id='t1',
|
|
465
|
+
step_number=0,
|
|
466
|
+
call_index=0,
|
|
467
|
+
system_prompt='You are a trading agent.',
|
|
468
|
+
user_prompt='What should I do?',
|
|
469
|
+
response='Buy BTC',
|
|
470
|
+
purpose='action',
|
|
471
|
+
action_type='buy',
|
|
472
|
+
model='gpt-4',
|
|
473
|
+
temperature=0.7,
|
|
474
|
+
trajectory_score=0.8,
|
|
475
|
+
step_reward=0.1,
|
|
476
|
+
action_success=True,
|
|
477
|
+
environment_context={'balance': 10000},
|
|
478
|
+
previous_actions=['wait']
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
messages = sample.to_messages()
|
|
482
|
+
|
|
483
|
+
assert len(messages) == 3
|
|
484
|
+
assert messages[0]['role'] == 'system'
|
|
485
|
+
assert messages[1]['role'] == 'user'
|
|
486
|
+
assert messages[2]['role'] == 'assistant'
|
|
487
|
+
|
|
488
|
+
def test_get_weighted_score(self):
|
|
489
|
+
"""Test weighted score calculation"""
|
|
490
|
+
sample = PromptSample(
|
|
491
|
+
trajectory_id='t1',
|
|
492
|
+
step_number=0,
|
|
493
|
+
call_index=0,
|
|
494
|
+
system_prompt='sys',
|
|
495
|
+
user_prompt='user',
|
|
496
|
+
response='resp',
|
|
497
|
+
purpose='action',
|
|
498
|
+
action_type='buy',
|
|
499
|
+
model='gpt-4',
|
|
500
|
+
temperature=0.7,
|
|
501
|
+
trajectory_score=0.8,
|
|
502
|
+
step_reward=0.1,
|
|
503
|
+
action_success=True,
|
|
504
|
+
environment_context={},
|
|
505
|
+
previous_actions=[]
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
score = sample.get_weighted_score()
|
|
509
|
+
|
|
510
|
+
# Should be higher than base due to success bonus and step reward
|
|
511
|
+
assert score > 0.8
|
|
512
|
+
assert score <= 1.0
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
# ============================================================
|
|
516
|
+
# PromptTypeAnalyzer Tests
|
|
517
|
+
# ============================================================
|
|
518
|
+
|
|
519
|
+
class TestPromptTypeAnalyzer:
|
|
520
|
+
"""Tests for PromptTypeAnalyzer"""
|
|
521
|
+
|
|
522
|
+
def test_analyze_correlation(self, sample_trajectory):
|
|
523
|
+
"""Test correlation analysis"""
|
|
524
|
+
trajs = [sample_trajectory]
|
|
525
|
+
scores = [0.8]
|
|
526
|
+
|
|
527
|
+
analysis = PromptTypeAnalyzer.analyze_correlation(trajs, scores)
|
|
528
|
+
|
|
529
|
+
assert 'prompt_count_by_purpose' in analysis
|
|
530
|
+
assert 'avg_length_by_purpose' in analysis
|
|
531
|
+
assert 'high_score_characteristics' in analysis
|
|
532
|
+
assert 'low_score_characteristics' in analysis
|
|
533
|
+
|
|
534
|
+
def test_analyze_high_low_scores(self, sample_trajectory):
|
|
535
|
+
"""Test high/low score classification"""
|
|
536
|
+
# Create trajectories with different scores
|
|
537
|
+
trajs = [sample_trajectory, sample_trajectory]
|
|
538
|
+
scores = [0.9, 0.2] # One high, one low
|
|
539
|
+
|
|
540
|
+
analysis = PromptTypeAnalyzer.analyze_correlation(trajs, scores)
|
|
541
|
+
|
|
542
|
+
# Should have entries in both
|
|
543
|
+
assert len(analysis['high_score_characteristics']) > 0
|
|
544
|
+
assert len(analysis['low_score_characteristics']) > 0
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
# ============================================================
|
|
548
|
+
# Integration Tests
|
|
549
|
+
# ============================================================
|
|
550
|
+
|
|
551
|
+
class TestIntegration:
|
|
552
|
+
"""Integration tests for the full pipeline"""
|
|
553
|
+
|
|
554
|
+
def test_prepare_multi_prompt_training_data(self, sample_trajectory):
|
|
555
|
+
"""Test convenience function"""
|
|
556
|
+
trajectories = [sample_trajectory] * 4
|
|
557
|
+
scores = [0.8, 0.6, 0.4, 0.9]
|
|
558
|
+
|
|
559
|
+
result = prepare_multi_prompt_training_data(
|
|
560
|
+
trajectories=trajectories,
|
|
561
|
+
scores=scores,
|
|
562
|
+
group_size=4
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
# Should return dict with purposes as keys
|
|
566
|
+
assert isinstance(result, dict)
|
|
567
|
+
# May or may not have groups depending on variance
|
|
568
|
+
|
|
569
|
+
def test_trajectory_count_score_mismatch(self, sample_trajectory):
|
|
570
|
+
"""Test error on mismatched counts"""
|
|
571
|
+
with pytest.raises(ValueError, match='Trajectory count'):
|
|
572
|
+
prepare_multi_prompt_training_data(
|
|
573
|
+
trajectories=[sample_trajectory],
|
|
574
|
+
scores=[0.8, 0.6] # Wrong count
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
# Run tests
|
|
579
|
+
if __name__ == "__main__":
|
|
580
|
+
pytest.main([__file__, "-v"])
|
|
581
|
+
|