@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,590 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for Multi-Turn Episode Manager
|
|
3
|
+
|
|
4
|
+
Covers:
|
|
5
|
+
- TurnData structure
|
|
6
|
+
- EpisodeBuffer management
|
|
7
|
+
- GAE advantage computation
|
|
8
|
+
- Reward shaping
|
|
9
|
+
- Episode collection
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from datetime import datetime, timezone
|
|
13
|
+
from typing import List
|
|
14
|
+
|
|
15
|
+
import pytest
|
|
16
|
+
|
|
17
|
+
from src.training.multi_turn import (
|
|
18
|
+
TurnData,
|
|
19
|
+
EpisodeBuffer,
|
|
20
|
+
GAEConfig,
|
|
21
|
+
MultiTurnEpisodeManager,
|
|
22
|
+
EpisodeCollector,
|
|
23
|
+
shape_trading_rewards,
|
|
24
|
+
compute_episode_return,
|
|
25
|
+
normalize_episode_rewards,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# =============================================================================
|
|
30
|
+
# Fixtures
|
|
31
|
+
# =============================================================================
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@pytest.fixture
|
|
35
|
+
def sample_turn():
|
|
36
|
+
"""Create a sample turn"""
|
|
37
|
+
return TurnData(
|
|
38
|
+
turn_number=0,
|
|
39
|
+
episode_id="ep-001",
|
|
40
|
+
action_type="buy",
|
|
41
|
+
action_text='{"action": "buy", "market": "BTC", "amount": 100}',
|
|
42
|
+
reward=0.5,
|
|
43
|
+
format_score=0.8,
|
|
44
|
+
reasoning_score=0.7,
|
|
45
|
+
done=False,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.fixture
|
|
50
|
+
def sample_episode():
|
|
51
|
+
"""Create a sample episode with multiple turns"""
|
|
52
|
+
turns = [
|
|
53
|
+
TurnData(turn_number=0, reward=0.1, action_type="buy", done=False),
|
|
54
|
+
TurnData(turn_number=1, reward=0.2, action_type="wait", done=False),
|
|
55
|
+
TurnData(turn_number=2, reward=-0.1, action_type="sell", done=False),
|
|
56
|
+
TurnData(turn_number=3, reward=0.5, action_type="close_perp", done=True),
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
episode = EpisodeBuffer(episode_id="ep-001", scenario_id="scenario-1")
|
|
60
|
+
for turn in turns:
|
|
61
|
+
episode.add_turn(turn)
|
|
62
|
+
|
|
63
|
+
return episode
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@pytest.fixture
|
|
67
|
+
def manager():
|
|
68
|
+
"""Create a multi-turn manager"""
|
|
69
|
+
return MultiTurnEpisodeManager(
|
|
70
|
+
gamma=0.99,
|
|
71
|
+
gae_lambda=0.95,
|
|
72
|
+
max_turns=20,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# =============================================================================
|
|
77
|
+
# TurnData Tests
|
|
78
|
+
# =============================================================================
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class TestTurnData:
|
|
82
|
+
"""Tests for TurnData dataclass"""
|
|
83
|
+
|
|
84
|
+
def test_creation(self, sample_turn):
|
|
85
|
+
"""Test creating turn data"""
|
|
86
|
+
assert sample_turn.turn_number == 0
|
|
87
|
+
assert sample_turn.episode_id == "ep-001"
|
|
88
|
+
assert sample_turn.action_type == "buy"
|
|
89
|
+
assert sample_turn.reward == 0.5
|
|
90
|
+
|
|
91
|
+
def test_default_values(self):
|
|
92
|
+
"""Test default values"""
|
|
93
|
+
turn = TurnData(turn_number=0)
|
|
94
|
+
|
|
95
|
+
assert turn.episode_id == ""
|
|
96
|
+
assert turn.reward == 0.0
|
|
97
|
+
assert turn.value == 0.0
|
|
98
|
+
assert turn.advantage == 0.0
|
|
99
|
+
assert turn.done is False
|
|
100
|
+
|
|
101
|
+
def test_to_dict(self, sample_turn):
|
|
102
|
+
"""Test conversion to dictionary"""
|
|
103
|
+
d = sample_turn.to_dict()
|
|
104
|
+
|
|
105
|
+
assert "turn_number" in d
|
|
106
|
+
assert "episode_id" in d
|
|
107
|
+
assert "action_type" in d
|
|
108
|
+
assert "reward" in d
|
|
109
|
+
assert "advantage" in d
|
|
110
|
+
assert d["turn_number"] == 0
|
|
111
|
+
assert d["reward"] == 0.5
|
|
112
|
+
|
|
113
|
+
def test_action_text_truncation(self):
|
|
114
|
+
"""Test that long action text is truncated in to_dict"""
|
|
115
|
+
long_text = "x" * 500
|
|
116
|
+
turn = TurnData(turn_number=0, action_text=long_text)
|
|
117
|
+
|
|
118
|
+
d = turn.to_dict()
|
|
119
|
+
assert len(d["action_text"]) <= 200
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# =============================================================================
|
|
123
|
+
# EpisodeBuffer Tests
|
|
124
|
+
# =============================================================================
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class TestEpisodeBuffer:
|
|
128
|
+
"""Tests for EpisodeBuffer"""
|
|
129
|
+
|
|
130
|
+
def test_creation(self):
|
|
131
|
+
"""Test creating episode buffer"""
|
|
132
|
+
episode = EpisodeBuffer(
|
|
133
|
+
episode_id="ep-001",
|
|
134
|
+
scenario_id="scenario-1",
|
|
135
|
+
archetype="trader",
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
assert episode.episode_id == "ep-001"
|
|
139
|
+
assert episode.scenario_id == "scenario-1"
|
|
140
|
+
assert len(episode.turns) == 0
|
|
141
|
+
assert not episode.completed
|
|
142
|
+
|
|
143
|
+
def test_add_turn(self):
|
|
144
|
+
"""Test adding turns"""
|
|
145
|
+
episode = EpisodeBuffer(episode_id="ep-001")
|
|
146
|
+
turn = TurnData(turn_number=0, reward=0.1)
|
|
147
|
+
|
|
148
|
+
episode.add_turn(turn)
|
|
149
|
+
|
|
150
|
+
assert len(episode.turns) == 1
|
|
151
|
+
assert episode.turns[0].episode_id == "ep-001"
|
|
152
|
+
|
|
153
|
+
def test_finalization_on_done(self):
|
|
154
|
+
"""Test that episode finalizes when done turn is added"""
|
|
155
|
+
episode = EpisodeBuffer(episode_id="ep-001")
|
|
156
|
+
|
|
157
|
+
episode.add_turn(TurnData(turn_number=0, reward=0.1, done=False))
|
|
158
|
+
assert not episode.completed
|
|
159
|
+
|
|
160
|
+
episode.add_turn(TurnData(turn_number=1, reward=0.2, done=True))
|
|
161
|
+
assert episode.completed
|
|
162
|
+
assert episode.episode_length == 2
|
|
163
|
+
assert episode.total_reward == pytest.approx(0.3, abs=0.01)
|
|
164
|
+
|
|
165
|
+
def test_success_determination(self):
|
|
166
|
+
"""Test success is based on total reward"""
|
|
167
|
+
# Positive total reward = success
|
|
168
|
+
episode1 = EpisodeBuffer(episode_id="ep-001")
|
|
169
|
+
episode1.add_turn(TurnData(turn_number=0, reward=0.5, done=True))
|
|
170
|
+
assert episode1.success
|
|
171
|
+
|
|
172
|
+
# Negative total reward = not success
|
|
173
|
+
episode2 = EpisodeBuffer(episode_id="ep-002")
|
|
174
|
+
episode2.add_turn(TurnData(turn_number=0, reward=-0.5, done=True))
|
|
175
|
+
assert not episode2.success
|
|
176
|
+
|
|
177
|
+
def test_get_messages(self, sample_episode):
|
|
178
|
+
"""Test getting messages from episode"""
|
|
179
|
+
# Add messages to last turn
|
|
180
|
+
sample_episode.turns[-1].messages = [
|
|
181
|
+
{"role": "user", "content": "trade"},
|
|
182
|
+
{"role": "assistant", "content": "done"},
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
messages = sample_episode.get_messages()
|
|
186
|
+
assert len(messages) == 2
|
|
187
|
+
|
|
188
|
+
def test_get_trajectory(self, sample_episode):
|
|
189
|
+
"""Test getting trajectory summary"""
|
|
190
|
+
trajectory = sample_episode.get_trajectory()
|
|
191
|
+
|
|
192
|
+
assert len(trajectory) == 4
|
|
193
|
+
assert trajectory[0][0] == "buy" # First action type
|
|
194
|
+
assert trajectory[3][0] == "close_perp" # Last action type
|
|
195
|
+
|
|
196
|
+
def test_to_dict(self, sample_episode):
|
|
197
|
+
"""Test conversion to dictionary"""
|
|
198
|
+
d = sample_episode.to_dict()
|
|
199
|
+
|
|
200
|
+
assert "episode_id" in d
|
|
201
|
+
assert "scenario_id" in d
|
|
202
|
+
assert "turns" in d
|
|
203
|
+
assert len(d["turns"]) == 4
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# =============================================================================
|
|
207
|
+
# GAEConfig Tests
|
|
208
|
+
# =============================================================================
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class TestGAEConfig:
|
|
212
|
+
"""Tests for GAEConfig"""
|
|
213
|
+
|
|
214
|
+
def test_default_values(self):
|
|
215
|
+
"""Test default configuration"""
|
|
216
|
+
config = GAEConfig()
|
|
217
|
+
|
|
218
|
+
assert config.gamma == 0.99
|
|
219
|
+
assert config.gae_lambda == 0.95
|
|
220
|
+
assert config.normalize_advantages is True
|
|
221
|
+
|
|
222
|
+
def test_custom_values(self):
|
|
223
|
+
"""Test custom configuration"""
|
|
224
|
+
config = GAEConfig(gamma=0.95, gae_lambda=0.9)
|
|
225
|
+
|
|
226
|
+
assert config.gamma == 0.95
|
|
227
|
+
assert config.gae_lambda == 0.9
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
# =============================================================================
|
|
231
|
+
# MultiTurnEpisodeManager Tests
|
|
232
|
+
# =============================================================================
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class TestMultiTurnEpisodeManager:
|
|
236
|
+
"""Tests for MultiTurnEpisodeManager"""
|
|
237
|
+
|
|
238
|
+
def test_creation(self, manager):
|
|
239
|
+
"""Test creating manager"""
|
|
240
|
+
assert manager.config.gamma == 0.99
|
|
241
|
+
assert manager.config.gae_lambda == 0.95
|
|
242
|
+
assert manager.max_turns == 20
|
|
243
|
+
|
|
244
|
+
def test_compute_advantages_single_turn(self, manager):
|
|
245
|
+
"""Test advantage computation for single turn"""
|
|
246
|
+
turns = [TurnData(turn_number=0, reward=1.0, done=True)]
|
|
247
|
+
|
|
248
|
+
manager.compute_advantages(turns)
|
|
249
|
+
|
|
250
|
+
assert turns[0].return_to_go == 1.0
|
|
251
|
+
assert turns[0].value == 1.0
|
|
252
|
+
# For single turn, advantage should be related to TD error
|
|
253
|
+
assert turns[0].advantage != 0 or turns[0].reward != 0
|
|
254
|
+
|
|
255
|
+
def test_compute_advantages_multiple_turns(self, manager, sample_episode):
|
|
256
|
+
"""Test advantage computation for multiple turns"""
|
|
257
|
+
turns = sample_episode.turns
|
|
258
|
+
|
|
259
|
+
manager.compute_advantages(turns)
|
|
260
|
+
|
|
261
|
+
# All turns should have computed values
|
|
262
|
+
for turn in turns:
|
|
263
|
+
assert turn.return_to_go != 0 or turn.reward == 0
|
|
264
|
+
assert turn.value != 0 or turn.return_to_go == 0
|
|
265
|
+
|
|
266
|
+
# Return-to-go should be higher for earlier turns (cumulative)
|
|
267
|
+
# unless later turns have much higher rewards
|
|
268
|
+
assert turns[0].return_to_go >= turns[-1].return_to_go or \
|
|
269
|
+
sum(t.reward for t in turns[:2]) < sum(t.reward for t in turns[2:])
|
|
270
|
+
|
|
271
|
+
def test_compute_advantages_empty(self, manager):
|
|
272
|
+
"""Test with empty turn list"""
|
|
273
|
+
turns = []
|
|
274
|
+
manager.compute_advantages(turns) # Should not raise
|
|
275
|
+
|
|
276
|
+
def test_advantage_clipping(self, manager):
|
|
277
|
+
"""Test that extreme advantages are clipped"""
|
|
278
|
+
manager.config.clip_advantages = True
|
|
279
|
+
manager.config.advantage_clip = 5.0
|
|
280
|
+
|
|
281
|
+
# Create turns with extreme rewards
|
|
282
|
+
turns = [
|
|
283
|
+
TurnData(turn_number=0, reward=100.0, done=False),
|
|
284
|
+
TurnData(turn_number=1, reward=-100.0, done=True),
|
|
285
|
+
]
|
|
286
|
+
|
|
287
|
+
manager.compute_advantages(turns)
|
|
288
|
+
|
|
289
|
+
for turn in turns:
|
|
290
|
+
assert abs(turn.advantage) <= 5.0
|
|
291
|
+
|
|
292
|
+
def test_compute_batch_advantages(self, manager):
|
|
293
|
+
"""Test batch advantage computation"""
|
|
294
|
+
episodes = [
|
|
295
|
+
[TurnData(turn_number=0, reward=0.5, done=True)],
|
|
296
|
+
[TurnData(turn_number=0, reward=-0.3, done=True)],
|
|
297
|
+
[TurnData(turn_number=0, reward=0.2, done=True)],
|
|
298
|
+
]
|
|
299
|
+
|
|
300
|
+
manager.compute_batch_advantages(episodes)
|
|
301
|
+
|
|
302
|
+
# All episodes should have advantages
|
|
303
|
+
for episode in episodes:
|
|
304
|
+
for turn in episode:
|
|
305
|
+
# With normalization, advantages should be centered
|
|
306
|
+
pass # Just verify no errors
|
|
307
|
+
|
|
308
|
+
def test_batch_normalization(self, manager):
|
|
309
|
+
"""Test that batch normalization centers advantages"""
|
|
310
|
+
manager.config.normalize_advantages = True
|
|
311
|
+
|
|
312
|
+
episodes = [
|
|
313
|
+
[TurnData(turn_number=0, reward=1.0, done=True)],
|
|
314
|
+
[TurnData(turn_number=0, reward=0.0, done=True)],
|
|
315
|
+
[TurnData(turn_number=0, reward=-1.0, done=True)],
|
|
316
|
+
]
|
|
317
|
+
|
|
318
|
+
manager.compute_batch_advantages(episodes)
|
|
319
|
+
|
|
320
|
+
# Collect all advantages
|
|
321
|
+
advantages = [t.advantage for ep in episodes for t in ep]
|
|
322
|
+
|
|
323
|
+
# Mean should be approximately 0 after normalization
|
|
324
|
+
mean = sum(advantages) / len(advantages)
|
|
325
|
+
assert abs(mean) < 0.1
|
|
326
|
+
|
|
327
|
+
def test_get_stats(self, manager, sample_episode):
|
|
328
|
+
"""Test getting manager statistics"""
|
|
329
|
+
manager.compute_advantages(sample_episode.turns)
|
|
330
|
+
|
|
331
|
+
stats = manager.get_stats()
|
|
332
|
+
|
|
333
|
+
assert stats["episodes_processed"] == 1
|
|
334
|
+
assert stats["total_turns"] == 4
|
|
335
|
+
assert stats["gamma"] == 0.99
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
# =============================================================================
|
|
339
|
+
# Reward Shaping Tests
|
|
340
|
+
# =============================================================================
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
class TestRewardShaping:
|
|
344
|
+
"""Tests for reward shaping utilities"""
|
|
345
|
+
|
|
346
|
+
def test_shape_trading_rewards(self):
|
|
347
|
+
"""Test shaping trading rewards"""
|
|
348
|
+
turns = [
|
|
349
|
+
TurnData(
|
|
350
|
+
turn_number=0,
|
|
351
|
+
reward=0.1,
|
|
352
|
+
action_type="buy",
|
|
353
|
+
format_score=0.8,
|
|
354
|
+
reasoning_score=0.7,
|
|
355
|
+
),
|
|
356
|
+
TurnData(
|
|
357
|
+
turn_number=1,
|
|
358
|
+
reward=-0.1,
|
|
359
|
+
action_type="wait",
|
|
360
|
+
format_score=0.9,
|
|
361
|
+
reasoning_score=0.6,
|
|
362
|
+
),
|
|
363
|
+
]
|
|
364
|
+
|
|
365
|
+
original_rewards = [t.reward for t in turns]
|
|
366
|
+
|
|
367
|
+
shape_trading_rewards(turns)
|
|
368
|
+
|
|
369
|
+
# Rewards should be modified
|
|
370
|
+
for i, turn in enumerate(turns):
|
|
371
|
+
# With bonuses added, rewards should differ from original
|
|
372
|
+
# (unless weights are all 0)
|
|
373
|
+
pass # Just check no errors
|
|
374
|
+
|
|
375
|
+
def test_compute_episode_return(self, sample_episode):
|
|
376
|
+
"""Test computing discounted return"""
|
|
377
|
+
returns = compute_episode_return(sample_episode.turns, gamma=0.99)
|
|
378
|
+
|
|
379
|
+
# Return should be weighted sum of rewards
|
|
380
|
+
assert returns != 0
|
|
381
|
+
|
|
382
|
+
def test_compute_episode_return_no_discount(self, sample_episode):
|
|
383
|
+
"""Test return with no discounting"""
|
|
384
|
+
returns = compute_episode_return(sample_episode.turns, gamma=1.0)
|
|
385
|
+
|
|
386
|
+
expected = sum(t.reward for t in sample_episode.turns)
|
|
387
|
+
assert returns == pytest.approx(expected, abs=0.01)
|
|
388
|
+
|
|
389
|
+
def test_normalize_episode_rewards(self):
|
|
390
|
+
"""Test normalizing rewards across episodes"""
|
|
391
|
+
episodes = [
|
|
392
|
+
[TurnData(turn_number=0, reward=10.0)],
|
|
393
|
+
[TurnData(turn_number=0, reward=0.0)],
|
|
394
|
+
[TurnData(turn_number=0, reward=-10.0)],
|
|
395
|
+
]
|
|
396
|
+
|
|
397
|
+
normalize_episode_rewards(episodes)
|
|
398
|
+
|
|
399
|
+
# Rewards should be normalized
|
|
400
|
+
rewards = [t.reward for ep in episodes for t in ep]
|
|
401
|
+
mean = sum(rewards) / len(rewards)
|
|
402
|
+
|
|
403
|
+
# Mean should be approximately 0
|
|
404
|
+
assert abs(mean) < 0.1
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
# =============================================================================
|
|
408
|
+
# EpisodeCollector Tests
|
|
409
|
+
# =============================================================================
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
class TestEpisodeCollector:
|
|
413
|
+
"""Tests for EpisodeCollector"""
|
|
414
|
+
|
|
415
|
+
def test_creation(self):
|
|
416
|
+
"""Test creating collector"""
|
|
417
|
+
collector = EpisodeCollector(max_episodes=100)
|
|
418
|
+
|
|
419
|
+
assert collector.max_episodes == 100
|
|
420
|
+
assert len(collector.episodes) == 0
|
|
421
|
+
|
|
422
|
+
def test_start_episode(self):
|
|
423
|
+
"""Test starting a new episode"""
|
|
424
|
+
collector = EpisodeCollector()
|
|
425
|
+
|
|
426
|
+
episode = collector.start_episode("scenario-1", "trader")
|
|
427
|
+
|
|
428
|
+
assert episode.scenario_id == "scenario-1"
|
|
429
|
+
assert episode.archetype == "trader"
|
|
430
|
+
assert collector._current_episode is not None
|
|
431
|
+
|
|
432
|
+
def test_add_turn(self):
|
|
433
|
+
"""Test adding turns to current episode"""
|
|
434
|
+
collector = EpisodeCollector()
|
|
435
|
+
collector.start_episode("scenario-1")
|
|
436
|
+
|
|
437
|
+
collector.add_turn(TurnData(turn_number=0, reward=0.1))
|
|
438
|
+
|
|
439
|
+
assert len(collector._current_episode.turns) == 1
|
|
440
|
+
|
|
441
|
+
def test_add_turn_without_episode(self):
|
|
442
|
+
"""Test that adding turn without starting episode raises"""
|
|
443
|
+
collector = EpisodeCollector()
|
|
444
|
+
|
|
445
|
+
with pytest.raises(RuntimeError):
|
|
446
|
+
collector.add_turn(TurnData(turn_number=0))
|
|
447
|
+
|
|
448
|
+
def test_episode_finalization(self):
|
|
449
|
+
"""Test that done turn finalizes episode"""
|
|
450
|
+
collector = EpisodeCollector()
|
|
451
|
+
collector.start_episode("scenario-1")
|
|
452
|
+
|
|
453
|
+
collector.add_turn(TurnData(turn_number=0, reward=0.1, done=False))
|
|
454
|
+
assert len(collector.episodes) == 0
|
|
455
|
+
|
|
456
|
+
collector.add_turn(TurnData(turn_number=1, reward=0.2, done=True))
|
|
457
|
+
assert len(collector.episodes) == 1
|
|
458
|
+
assert collector._current_episode is None
|
|
459
|
+
|
|
460
|
+
def test_max_episodes_limit(self):
|
|
461
|
+
"""Test that collector respects max episodes"""
|
|
462
|
+
collector = EpisodeCollector(max_episodes=3)
|
|
463
|
+
|
|
464
|
+
for i in range(5):
|
|
465
|
+
collector.start_episode(f"scenario-{i}")
|
|
466
|
+
collector.add_turn(TurnData(turn_number=0, done=True))
|
|
467
|
+
|
|
468
|
+
assert len(collector.episodes) == 3
|
|
469
|
+
|
|
470
|
+
def test_get_completed_episodes(self):
|
|
471
|
+
"""Test getting completed episodes"""
|
|
472
|
+
collector = EpisodeCollector()
|
|
473
|
+
|
|
474
|
+
# Complete episode
|
|
475
|
+
collector.start_episode("scenario-1")
|
|
476
|
+
collector.add_turn(TurnData(turn_number=0, done=True))
|
|
477
|
+
|
|
478
|
+
# Incomplete episode
|
|
479
|
+
collector.start_episode("scenario-2")
|
|
480
|
+
collector.add_turn(TurnData(turn_number=0, done=False))
|
|
481
|
+
|
|
482
|
+
completed = collector.get_completed_episodes()
|
|
483
|
+
assert len(completed) == 1
|
|
484
|
+
|
|
485
|
+
def test_get_successful_episodes(self):
|
|
486
|
+
"""Test getting successful episodes"""
|
|
487
|
+
collector = EpisodeCollector()
|
|
488
|
+
|
|
489
|
+
# Successful episode
|
|
490
|
+
collector.start_episode("scenario-1")
|
|
491
|
+
collector.add_turn(TurnData(turn_number=0, reward=1.0, done=True))
|
|
492
|
+
|
|
493
|
+
# Failed episode
|
|
494
|
+
collector.start_episode("scenario-2")
|
|
495
|
+
collector.add_turn(TurnData(turn_number=0, reward=-1.0, done=True))
|
|
496
|
+
|
|
497
|
+
successful = collector.get_successful_episodes()
|
|
498
|
+
assert len(successful) == 1
|
|
499
|
+
|
|
500
|
+
def test_clear(self):
|
|
501
|
+
"""Test clearing collector"""
|
|
502
|
+
collector = EpisodeCollector()
|
|
503
|
+
|
|
504
|
+
collector.start_episode("scenario-1")
|
|
505
|
+
collector.add_turn(TurnData(turn_number=0, done=True))
|
|
506
|
+
|
|
507
|
+
collector.clear()
|
|
508
|
+
|
|
509
|
+
assert len(collector.episodes) == 0
|
|
510
|
+
assert collector._current_episode is None
|
|
511
|
+
|
|
512
|
+
def test_get_stats(self):
|
|
513
|
+
"""Test getting collector statistics"""
|
|
514
|
+
collector = EpisodeCollector()
|
|
515
|
+
|
|
516
|
+
# Add episodes
|
|
517
|
+
for i in range(3):
|
|
518
|
+
collector.start_episode(f"scenario-{i}")
|
|
519
|
+
collector.add_turn(TurnData(turn_number=0, reward=0.5, done=True))
|
|
520
|
+
|
|
521
|
+
stats = collector.get_stats()
|
|
522
|
+
|
|
523
|
+
assert stats["total_episodes"] == 3
|
|
524
|
+
assert stats["completed_episodes"] == 3
|
|
525
|
+
assert stats["successful_episodes"] == 3
|
|
526
|
+
assert stats["success_rate"] == 1.0
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
# =============================================================================
|
|
530
|
+
# Integration Tests
|
|
531
|
+
# =============================================================================
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
class TestMultiTurnIntegration:
|
|
535
|
+
"""Integration tests for multi-turn system"""
|
|
536
|
+
|
|
537
|
+
def test_full_episode_workflow(self, manager):
|
|
538
|
+
"""Test complete workflow from collection to training items"""
|
|
539
|
+
collector = EpisodeCollector()
|
|
540
|
+
|
|
541
|
+
# Collect episode
|
|
542
|
+
collector.start_episode("scenario-1", "trader")
|
|
543
|
+
for i in range(5):
|
|
544
|
+
reward = 0.1 * (i + 1) if i < 4 else 0.5
|
|
545
|
+
done = i == 4
|
|
546
|
+
collector.add_turn(TurnData(
|
|
547
|
+
turn_number=i,
|
|
548
|
+
reward=reward,
|
|
549
|
+
action_type="buy" if i % 2 == 0 else "wait",
|
|
550
|
+
done=done,
|
|
551
|
+
))
|
|
552
|
+
|
|
553
|
+
# Get completed episode
|
|
554
|
+
episodes = collector.get_completed_episodes()
|
|
555
|
+
assert len(episodes) == 1
|
|
556
|
+
|
|
557
|
+
# Compute advantages
|
|
558
|
+
turns = episodes[0].turns
|
|
559
|
+
manager.compute_advantages(turns)
|
|
560
|
+
|
|
561
|
+
# All turns should have computed values
|
|
562
|
+
for turn in turns:
|
|
563
|
+
assert turn.value != 0 or turn.reward == 0
|
|
564
|
+
|
|
565
|
+
def test_multiple_episodes_batch(self, manager):
|
|
566
|
+
"""Test processing multiple episodes as batch"""
|
|
567
|
+
collector = EpisodeCollector()
|
|
568
|
+
|
|
569
|
+
# Collect multiple episodes
|
|
570
|
+
for ep_idx in range(3):
|
|
571
|
+
collector.start_episode(f"scenario-{ep_idx}")
|
|
572
|
+
for turn_idx in range(4):
|
|
573
|
+
reward = 0.1 * (turn_idx + 1) * (1 if ep_idx == 0 else -1)
|
|
574
|
+
done = turn_idx == 3
|
|
575
|
+
collector.add_turn(TurnData(
|
|
576
|
+
turn_number=turn_idx,
|
|
577
|
+
reward=reward,
|
|
578
|
+
done=done,
|
|
579
|
+
))
|
|
580
|
+
|
|
581
|
+
# Get all turn lists
|
|
582
|
+
episodes = [ep.turns for ep in collector.get_completed_episodes()]
|
|
583
|
+
|
|
584
|
+
# Batch compute
|
|
585
|
+
manager.compute_batch_advantages(episodes)
|
|
586
|
+
|
|
587
|
+
# Verify statistics updated
|
|
588
|
+
stats = manager.get_stats()
|
|
589
|
+
assert stats["episodes_processed"] == 3
|
|
590
|
+
|