@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Integration tests for Atropos RLAIF implementation
|
|
3
|
+
|
|
4
|
+
Tests:
|
|
5
|
+
1. Module imports work correctly
|
|
6
|
+
2. Data conversion functions work
|
|
7
|
+
3. Reward functions produce valid outputs
|
|
8
|
+
4. Environment can be instantiated (mock mode)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import pytest
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
from typing import Dict
|
|
14
|
+
|
|
15
|
+
# Check for optional dependencies
|
|
16
|
+
try:
|
|
17
|
+
import torch # noqa: F401
|
|
18
|
+
HAS_TORCH = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
HAS_TORCH = False
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import wandb # noqa: F401
|
|
24
|
+
HAS_WANDB = True
|
|
25
|
+
except ImportError:
|
|
26
|
+
HAS_WANDB = False
|
|
27
|
+
|
|
28
|
+
requires_torch = pytest.mark.skipif(not HAS_TORCH, reason="torch not installed")
|
|
29
|
+
requires_wandb = pytest.mark.skipif(not HAS_WANDB, reason="wandb not installed")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# Test imports work
|
|
33
|
+
class TestImports:
|
|
34
|
+
"""Verify all modules can be imported"""
|
|
35
|
+
|
|
36
|
+
def test_import_models(self):
|
|
37
|
+
from src.models import (
|
|
38
|
+
TrainingTrajectory,
|
|
39
|
+
AtroposScoredGroup,
|
|
40
|
+
)
|
|
41
|
+
assert TrainingTrajectory is not None
|
|
42
|
+
assert AtroposScoredGroup is not None
|
|
43
|
+
|
|
44
|
+
def test_import_converter(self):
|
|
45
|
+
from src.data_bridge import (
|
|
46
|
+
TrajectoryToAtroposConverter,
|
|
47
|
+
ScoredGroupResult,
|
|
48
|
+
)
|
|
49
|
+
assert TrajectoryToAtroposConverter is not None
|
|
50
|
+
assert ScoredGroupResult is not None
|
|
51
|
+
|
|
52
|
+
def test_import_rewards(self):
|
|
53
|
+
from src.training.rewards import (
|
|
54
|
+
pnl_reward,
|
|
55
|
+
RewardNormalizer,
|
|
56
|
+
)
|
|
57
|
+
assert pnl_reward is not None
|
|
58
|
+
assert RewardNormalizer is not None
|
|
59
|
+
|
|
60
|
+
@requires_torch
|
|
61
|
+
def test_import_trainer(self):
|
|
62
|
+
from src.training import (
|
|
63
|
+
AtroposTrainer,
|
|
64
|
+
)
|
|
65
|
+
assert AtroposTrainer is not None
|
|
66
|
+
|
|
67
|
+
@requires_wandb
|
|
68
|
+
def test_import_environment(self):
|
|
69
|
+
from src.training import (
|
|
70
|
+
RLAIFEnv,
|
|
71
|
+
)
|
|
72
|
+
assert RLAIFEnv is not None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class TestRewardFunctions:
|
|
76
|
+
"""Test reward calculation functions using archetype-aware API"""
|
|
77
|
+
|
|
78
|
+
def test_pnl_reward_positive(self):
|
|
79
|
+
from src.training.rewards import pnl_reward, TrajectoryRewardInputs
|
|
80
|
+
|
|
81
|
+
# Positive P&L should give positive reward
|
|
82
|
+
inputs = TrajectoryRewardInputs(
|
|
83
|
+
final_pnl=500.0,
|
|
84
|
+
starting_balance=10000.0,
|
|
85
|
+
end_balance=10500.0,
|
|
86
|
+
)
|
|
87
|
+
reward = pnl_reward(inputs)
|
|
88
|
+
assert reward > 0.0
|
|
89
|
+
|
|
90
|
+
def test_pnl_reward_negative(self):
|
|
91
|
+
from src.training.rewards import pnl_reward, TrajectoryRewardInputs
|
|
92
|
+
|
|
93
|
+
# Negative P&L should give negative reward
|
|
94
|
+
inputs = TrajectoryRewardInputs(
|
|
95
|
+
final_pnl=-500.0,
|
|
96
|
+
starting_balance=10000.0,
|
|
97
|
+
end_balance=9500.0,
|
|
98
|
+
)
|
|
99
|
+
reward = pnl_reward(inputs)
|
|
100
|
+
assert reward < 0.0
|
|
101
|
+
|
|
102
|
+
def test_pnl_reward_zero(self):
|
|
103
|
+
from src.training.rewards import pnl_reward, TrajectoryRewardInputs
|
|
104
|
+
|
|
105
|
+
# Zero P&L should give ~0
|
|
106
|
+
inputs = TrajectoryRewardInputs(
|
|
107
|
+
final_pnl=0.0,
|
|
108
|
+
starting_balance=10000.0,
|
|
109
|
+
end_balance=10000.0,
|
|
110
|
+
)
|
|
111
|
+
reward = pnl_reward(inputs)
|
|
112
|
+
assert -0.1 <= reward <= 0.1
|
|
113
|
+
|
|
114
|
+
def test_archetype_composite_reward(self):
|
|
115
|
+
from src.training.rewards import (
|
|
116
|
+
archetype_composite_reward,
|
|
117
|
+
TrajectoryRewardInputs,
|
|
118
|
+
BehaviorMetrics,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
inputs = TrajectoryRewardInputs(
|
|
122
|
+
final_pnl=500.0,
|
|
123
|
+
starting_balance=10000.0,
|
|
124
|
+
end_balance=10500.0,
|
|
125
|
+
format_score=0.8,
|
|
126
|
+
reasoning_score=0.75,
|
|
127
|
+
)
|
|
128
|
+
behavior = BehaviorMetrics(
|
|
129
|
+
trades_executed=5,
|
|
130
|
+
total_pnl=500.0,
|
|
131
|
+
episode_length=10,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
reward = archetype_composite_reward(inputs, "trader", behavior)
|
|
135
|
+
assert 0.0 <= reward <= 1.0
|
|
136
|
+
|
|
137
|
+
def test_composite_reward_with_inputs(self):
|
|
138
|
+
from src.training.rewards import composite_reward, TrajectoryRewardInputs
|
|
139
|
+
|
|
140
|
+
inputs = TrajectoryRewardInputs(
|
|
141
|
+
final_pnl=500.0,
|
|
142
|
+
starting_balance=10000.0,
|
|
143
|
+
end_balance=10500.0,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
reward = composite_reward(inputs)
|
|
147
|
+
assert 0.0 <= reward <= 1.0
|
|
148
|
+
|
|
149
|
+
def test_relative_scores(self):
|
|
150
|
+
from src.training.rewards import relative_scores
|
|
151
|
+
|
|
152
|
+
# relative_scores expects a list of raw reward floats
|
|
153
|
+
rewards = [0.8, 0.5, 0.2] # High, medium, low rewards
|
|
154
|
+
|
|
155
|
+
scores = relative_scores(rewards)
|
|
156
|
+
|
|
157
|
+
# Should return normalized scores in [0, 1]
|
|
158
|
+
assert all(0.0 <= s <= 1.0 for s in scores)
|
|
159
|
+
# Best reward should have highest relative score
|
|
160
|
+
assert scores[0] > scores[1] > scores[2]
|
|
161
|
+
|
|
162
|
+
def test_reward_normalizer(self):
|
|
163
|
+
from src.training.rewards import RewardNormalizer
|
|
164
|
+
|
|
165
|
+
normalizer = RewardNormalizer()
|
|
166
|
+
|
|
167
|
+
# Update with some rewards
|
|
168
|
+
for r in [0.5, 0.6, 0.7, 0.8, 0.55, 0.65, 0.75, 0.85]:
|
|
169
|
+
normalizer.update(r)
|
|
170
|
+
|
|
171
|
+
# Normalize should work
|
|
172
|
+
normalized = normalizer.normalize(0.65)
|
|
173
|
+
assert isinstance(normalized, float)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class TestConverter:
|
|
177
|
+
"""Test Trajectory to Atropos conversion"""
|
|
178
|
+
|
|
179
|
+
def create_sample_trajectory(self) -> Dict:
|
|
180
|
+
"""Create a sample trajectory for testing"""
|
|
181
|
+
from src.models import (
|
|
182
|
+
TrainingTrajectory,
|
|
183
|
+
TrajectoryStep,
|
|
184
|
+
EnvironmentState,
|
|
185
|
+
Action,
|
|
186
|
+
LLMCall,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
steps = []
|
|
190
|
+
for i in range(5):
|
|
191
|
+
step = TrajectoryStep(
|
|
192
|
+
step_number=i,
|
|
193
|
+
timestamp=1000000 + i * 1000,
|
|
194
|
+
environment_state=EnvironmentState(
|
|
195
|
+
agent_balance=10000.0 + i * 100,
|
|
196
|
+
agent_pnl=i * 100.0,
|
|
197
|
+
open_positions=i,
|
|
198
|
+
),
|
|
199
|
+
provider_accesses=[],
|
|
200
|
+
llm_calls=[
|
|
201
|
+
LLMCall(
|
|
202
|
+
model="gpt-4",
|
|
203
|
+
system_prompt="You are a trading agent",
|
|
204
|
+
user_prompt=f"Market update {i}",
|
|
205
|
+
response=f"Action {i}",
|
|
206
|
+
temperature=0.7,
|
|
207
|
+
max_tokens=100,
|
|
208
|
+
purpose="action",
|
|
209
|
+
)
|
|
210
|
+
],
|
|
211
|
+
action=Action(
|
|
212
|
+
action_type="trade",
|
|
213
|
+
parameters={"amount": 100},
|
|
214
|
+
success=True,
|
|
215
|
+
),
|
|
216
|
+
reward=0.1,
|
|
217
|
+
)
|
|
218
|
+
steps.append(step)
|
|
219
|
+
|
|
220
|
+
return TrainingTrajectory(
|
|
221
|
+
id="test-1",
|
|
222
|
+
trajectory_id="traj-1",
|
|
223
|
+
agent_id="agent-1",
|
|
224
|
+
window_id="2024-01-01T00:00",
|
|
225
|
+
start_time=datetime.now(),
|
|
226
|
+
end_time=datetime.now(),
|
|
227
|
+
duration_ms=5000,
|
|
228
|
+
steps=steps,
|
|
229
|
+
total_reward=0.5,
|
|
230
|
+
final_pnl=400.0,
|
|
231
|
+
episode_length=5,
|
|
232
|
+
final_status="completed",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def test_convert_trajectory(self):
|
|
236
|
+
from src.data_bridge import TrajectoryToAtroposConverter
|
|
237
|
+
|
|
238
|
+
converter = TrajectoryToAtroposConverter()
|
|
239
|
+
traj = self.create_sample_trajectory()
|
|
240
|
+
|
|
241
|
+
result = converter.convert_trajectory(traj)
|
|
242
|
+
|
|
243
|
+
assert result is not None
|
|
244
|
+
assert len(result.messages) >= 3 # system + at least one exchange
|
|
245
|
+
assert result.metadata["trajectory_id"] == "traj-1"
|
|
246
|
+
assert result.metadata["final_pnl"] == 400.0
|
|
247
|
+
|
|
248
|
+
def test_convert_window_group(self):
|
|
249
|
+
from src.data_bridge import TrajectoryToAtroposConverter
|
|
250
|
+
|
|
251
|
+
converter = TrajectoryToAtroposConverter()
|
|
252
|
+
trajs = [self.create_sample_trajectory() for _ in range(4)]
|
|
253
|
+
|
|
254
|
+
# Modify trajectory IDs
|
|
255
|
+
for i, t in enumerate(trajs):
|
|
256
|
+
t.trajectory_id = f"traj-{i}"
|
|
257
|
+
|
|
258
|
+
result = converter.convert_window_group(trajs, None)
|
|
259
|
+
|
|
260
|
+
assert result.group_size == 4
|
|
261
|
+
assert len(result.scores) == 4
|
|
262
|
+
assert len(result.messages) == 4
|
|
263
|
+
|
|
264
|
+
def test_dropout(self):
|
|
265
|
+
from src.data_bridge import TrajectoryToAtroposConverter
|
|
266
|
+
|
|
267
|
+
# High dropout should skip some trajectories
|
|
268
|
+
converter = TrajectoryToAtroposConverter(dropout_rate=0.5)
|
|
269
|
+
|
|
270
|
+
dropped_count = 0
|
|
271
|
+
for _ in range(100):
|
|
272
|
+
traj = self.create_sample_trajectory()
|
|
273
|
+
result = converter.convert_trajectory(traj)
|
|
274
|
+
if result is None:
|
|
275
|
+
dropped_count += 1
|
|
276
|
+
|
|
277
|
+
# Should drop roughly 50%
|
|
278
|
+
assert 30 < dropped_count < 70
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@requires_torch
|
|
282
|
+
class TestTrainerConfig:
|
|
283
|
+
"""Test trainer configuration (requires torch)"""
|
|
284
|
+
|
|
285
|
+
def test_default_config(self):
|
|
286
|
+
from src.training import AtroposTrainingConfig
|
|
287
|
+
|
|
288
|
+
config = AtroposTrainingConfig()
|
|
289
|
+
|
|
290
|
+
assert config.model_name == "Qwen/Qwen2.5-3B-Instruct"
|
|
291
|
+
assert config.learning_rate == 1e-5
|
|
292
|
+
assert config.training_steps == 100
|
|
293
|
+
|
|
294
|
+
def test_custom_config(self):
|
|
295
|
+
from src.training import AtroposTrainingConfig
|
|
296
|
+
|
|
297
|
+
config = AtroposTrainingConfig(
|
|
298
|
+
model_name="Qwen/Qwen2.5-7B-Instruct",
|
|
299
|
+
training_steps=50,
|
|
300
|
+
learning_rate=5e-6,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
assert config.model_name == "Qwen/Qwen2.5-7B-Instruct"
|
|
304
|
+
assert config.training_steps == 50
|
|
305
|
+
assert config.learning_rate == 5e-6
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
@requires_wandb
|
|
309
|
+
class TestEnvironmentConfig:
|
|
310
|
+
"""Test environment configuration (requires wandb)"""
|
|
311
|
+
|
|
312
|
+
def test_default_config(self):
|
|
313
|
+
from src.training import RLAIFEnvConfig
|
|
314
|
+
|
|
315
|
+
config = RLAIFEnvConfig()
|
|
316
|
+
|
|
317
|
+
assert config.group_size == 4
|
|
318
|
+
assert config.lookback_hours == 72
|
|
319
|
+
assert config.min_agents_per_window == 2
|
|
320
|
+
|
|
321
|
+
def test_custom_config(self):
|
|
322
|
+
from src.training import RLAIFEnvConfig
|
|
323
|
+
|
|
324
|
+
config = RLAIFEnvConfig(
|
|
325
|
+
group_size=8,
|
|
326
|
+
lookback_hours=48,
|
|
327
|
+
judge_model="gpt-4",
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
assert config.group_size == 8
|
|
331
|
+
assert config.lookback_hours == 48
|
|
332
|
+
assert config.judge_model == "gpt-4"
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class TestCalculateDropoutRate:
|
|
336
|
+
"""Test dropout rate calculation"""
|
|
337
|
+
|
|
338
|
+
def test_no_dropout_needed(self):
|
|
339
|
+
from src.data_bridge import calculate_dropout_rate
|
|
340
|
+
|
|
341
|
+
rate = calculate_dropout_rate(500, target_trajectories=1000)
|
|
342
|
+
assert rate == 0.0
|
|
343
|
+
|
|
344
|
+
def test_dropout_needed(self):
|
|
345
|
+
from src.data_bridge import calculate_dropout_rate
|
|
346
|
+
|
|
347
|
+
rate = calculate_dropout_rate(2000, target_trajectories=1000)
|
|
348
|
+
assert 0.0 < rate <= 0.3
|
|
349
|
+
|
|
350
|
+
def test_max_dropout_cap(self):
|
|
351
|
+
from src.data_bridge import calculate_dropout_rate
|
|
352
|
+
|
|
353
|
+
rate = calculate_dropout_rate(10000, target_trajectories=1000, max_dropout=0.2)
|
|
354
|
+
assert rate == 0.2
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
# Run tests with: pytest tests/test_atropos_integration.py -v
|
|
358
|
+
if __name__ == "__main__":
|
|
359
|
+
pytest.main([__file__, "-v"])
|
|
360
|
+
|