@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,554 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JSON Mode Integration Tests
|
|
3
|
+
|
|
4
|
+
Tests the complete JSON-based training data pipeline:
|
|
5
|
+
1. Trajectory loading from JSON files
|
|
6
|
+
2. Archetype extraction from step parameters
|
|
7
|
+
3. Scoring with archetype-aware rewards
|
|
8
|
+
4. GRPO group formation
|
|
9
|
+
5. End-to-end scoring pipeline
|
|
10
|
+
|
|
11
|
+
These tests run WITHOUT a database, using only local JSON files.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
import sys
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Dict, List
|
|
18
|
+
|
|
19
|
+
import pytest
|
|
20
|
+
|
|
21
|
+
# Add src to path
|
|
22
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
23
|
+
|
|
24
|
+
from src.data_bridge.reader import JsonTrajectoryReader, validate_llm_calls
|
|
25
|
+
from src.training.rewards import (
|
|
26
|
+
archetype_composite_reward,
|
|
27
|
+
calculate_archetype_behavior_bonus,
|
|
28
|
+
BehaviorMetrics,
|
|
29
|
+
TrajectoryRewardInputs,
|
|
30
|
+
)
|
|
31
|
+
from src.training.rubric_loader import (
|
|
32
|
+
normalize_archetype,
|
|
33
|
+
has_custom_rubric,
|
|
34
|
+
get_rubric,
|
|
35
|
+
get_available_archetypes,
|
|
36
|
+
)
|
|
37
|
+
from tests.integration.conftest import (
|
|
38
|
+
TrajectoryFixture,
|
|
39
|
+
create_trading_step,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TestJsonTrajectoryLoading:
|
|
44
|
+
"""Test loading trajectories from JSON files."""
|
|
45
|
+
|
|
46
|
+
def test_load_single_trajectory_file(
|
|
47
|
+
self,
|
|
48
|
+
temp_trajectory_dir: Path,
|
|
49
|
+
sample_trader_trajectory: TrajectoryFixture,
|
|
50
|
+
):
|
|
51
|
+
"""Test loading a single trajectory from JSON file."""
|
|
52
|
+
# Write trajectory to file
|
|
53
|
+
file_path = temp_trajectory_dir / f"{sample_trader_trajectory.trajectory_id}.json"
|
|
54
|
+
file_path.write_text(json.dumps(sample_trader_trajectory.to_json_file_format()))
|
|
55
|
+
|
|
56
|
+
# Load using reader
|
|
57
|
+
reader = JsonTrajectoryReader(str(temp_trajectory_dir))
|
|
58
|
+
window_ids = reader.get_window_ids()
|
|
59
|
+
|
|
60
|
+
assert len(window_ids) >= 1
|
|
61
|
+
assert sample_trader_trajectory.window_id in window_ids
|
|
62
|
+
|
|
63
|
+
trajectories = reader.get_trajectories_by_window(sample_trader_trajectory.window_id)
|
|
64
|
+
assert len(trajectories) == 1
|
|
65
|
+
|
|
66
|
+
# JsonTrajectoryReader returns trajectory data directly (unwrapped)
|
|
67
|
+
traj = trajectories[0]
|
|
68
|
+
assert traj.get("archetype") == "trader"
|
|
69
|
+
|
|
70
|
+
def test_load_multiple_trajectories_same_window(
|
|
71
|
+
self,
|
|
72
|
+
temp_trajectory_dir: Path,
|
|
73
|
+
trajectory_group: List[TrajectoryFixture],
|
|
74
|
+
):
|
|
75
|
+
"""Test loading multiple trajectories from same window."""
|
|
76
|
+
# Write all trajectories
|
|
77
|
+
for traj in trajectory_group:
|
|
78
|
+
file_path = temp_trajectory_dir / f"{traj.trajectory_id}.json"
|
|
79
|
+
file_path.write_text(json.dumps(traj.to_json_file_format()))
|
|
80
|
+
|
|
81
|
+
reader = JsonTrajectoryReader(str(temp_trajectory_dir))
|
|
82
|
+
trajectories = reader.get_trajectories_by_window("window-test-1")
|
|
83
|
+
|
|
84
|
+
assert len(trajectories) == 3
|
|
85
|
+
# JsonTrajectoryReader returns trajectory data directly
|
|
86
|
+
archetypes = {t.get("archetype") for t in trajectories}
|
|
87
|
+
assert archetypes == {"trader", "degen", "scammer"}
|
|
88
|
+
|
|
89
|
+
def test_validate_llm_calls_in_loaded_trajectory(
|
|
90
|
+
self,
|
|
91
|
+
temp_trajectory_dir: Path,
|
|
92
|
+
sample_trader_trajectory: TrajectoryFixture,
|
|
93
|
+
):
|
|
94
|
+
"""Test that loaded trajectories pass LLM call validation."""
|
|
95
|
+
file_path = temp_trajectory_dir / f"{sample_trader_trajectory.trajectory_id}.json"
|
|
96
|
+
file_path.write_text(json.dumps(sample_trader_trajectory.to_json_file_format()))
|
|
97
|
+
|
|
98
|
+
reader = JsonTrajectoryReader(str(temp_trajectory_dir))
|
|
99
|
+
trajectories = reader.get_trajectories_by_window(sample_trader_trajectory.window_id)
|
|
100
|
+
|
|
101
|
+
# JsonTrajectoryReader returns trajectory data directly
|
|
102
|
+
traj_data = trajectories[0]
|
|
103
|
+
steps_json = traj_data.get("stepsJson", "[]")
|
|
104
|
+
steps = json.loads(steps_json) if isinstance(steps_json, str) else steps_json
|
|
105
|
+
|
|
106
|
+
is_valid, issues = validate_llm_calls(steps)
|
|
107
|
+
assert is_valid, f"LLM calls should be valid: {issues}"
|
|
108
|
+
|
|
109
|
+
def test_load_all_archetypes(
|
|
110
|
+
self,
|
|
111
|
+
temp_trajectory_dir: Path,
|
|
112
|
+
all_archetype_trajectories: Dict[str, TrajectoryFixture],
|
|
113
|
+
):
|
|
114
|
+
"""Test loading trajectories for all valid archetypes."""
|
|
115
|
+
# Write all trajectories
|
|
116
|
+
for archetype, traj in all_archetype_trajectories.items():
|
|
117
|
+
file_path = temp_trajectory_dir / f"{traj.trajectory_id}.json"
|
|
118
|
+
file_path.write_text(json.dumps(traj.to_json_file_format()))
|
|
119
|
+
|
|
120
|
+
reader = JsonTrajectoryReader(str(temp_trajectory_dir))
|
|
121
|
+
trajectories = reader.get_trajectories_by_window("window-all-archetypes")
|
|
122
|
+
|
|
123
|
+
# JsonTrajectoryReader returns trajectory data directly
|
|
124
|
+
loaded_archetypes = {t.get("archetype") for t in trajectories}
|
|
125
|
+
expected_archetypes = set(get_available_archetypes())
|
|
126
|
+
|
|
127
|
+
assert loaded_archetypes == expected_archetypes
|
|
128
|
+
|
|
129
|
+
def test_empty_directory_returns_empty_list(self, temp_trajectory_dir: Path):
|
|
130
|
+
"""Test that empty directory returns empty results."""
|
|
131
|
+
reader = JsonTrajectoryReader(str(temp_trajectory_dir))
|
|
132
|
+
window_ids = reader.get_window_ids()
|
|
133
|
+
assert window_ids == []
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class TestArchetypeExtractionFromSteps:
|
|
137
|
+
"""Test extracting archetype from step action parameters."""
|
|
138
|
+
|
|
139
|
+
def test_extract_archetype_from_action_parameters(self):
|
|
140
|
+
"""Test extracting archetype from step's action parameters."""
|
|
141
|
+
step = create_trading_step(0, "buy_prediction", "trader")
|
|
142
|
+
|
|
143
|
+
# Simulate extraction logic from babylon_env.py
|
|
144
|
+
action = step.get("action", {})
|
|
145
|
+
params = action.get("parameters", {})
|
|
146
|
+
archetype = params.get("archetype")
|
|
147
|
+
|
|
148
|
+
assert archetype == "trader"
|
|
149
|
+
|
|
150
|
+
def test_extract_archetype_from_first_step(self):
|
|
151
|
+
"""Test extracting archetype from first step when trajectory-level is missing."""
|
|
152
|
+
steps = [
|
|
153
|
+
create_trading_step(0, "buy_prediction", "degen"),
|
|
154
|
+
create_trading_step(1, "hold", "degen"),
|
|
155
|
+
create_trading_step(2, "sell_prediction", "degen"),
|
|
156
|
+
]
|
|
157
|
+
|
|
158
|
+
# Simulate extraction logic
|
|
159
|
+
archetype = None
|
|
160
|
+
for step in steps:
|
|
161
|
+
action = step.get("action", {})
|
|
162
|
+
params = action.get("parameters", {})
|
|
163
|
+
if params.get("archetype"):
|
|
164
|
+
archetype = params.get("archetype")
|
|
165
|
+
break
|
|
166
|
+
|
|
167
|
+
assert archetype == "degen"
|
|
168
|
+
|
|
169
|
+
def test_normalize_extracted_archetype(self):
|
|
170
|
+
"""Test that extracted archetypes are normalized correctly."""
|
|
171
|
+
test_cases = [
|
|
172
|
+
("TRADER", "trader"),
|
|
173
|
+
("Social_Butterfly", "social-butterfly"),
|
|
174
|
+
("goody_twoshoes", "goody-twoshoes"),
|
|
175
|
+
("DEGEN", "degen"),
|
|
176
|
+
("", "default"),
|
|
177
|
+
(None, "default"),
|
|
178
|
+
]
|
|
179
|
+
|
|
180
|
+
for input_val, expected in test_cases:
|
|
181
|
+
result = normalize_archetype(input_val)
|
|
182
|
+
assert result == expected, f"normalize_archetype({input_val}) = {result}, expected {expected}"
|
|
183
|
+
|
|
184
|
+
def test_validate_extracted_archetype(self):
|
|
185
|
+
"""Test that extracted archetypes are validated."""
|
|
186
|
+
valid_archetypes = get_available_archetypes()
|
|
187
|
+
|
|
188
|
+
for arch in valid_archetypes:
|
|
189
|
+
normalized = normalize_archetype(arch)
|
|
190
|
+
assert has_custom_rubric(normalized), f"{arch} should have custom rubric"
|
|
191
|
+
|
|
192
|
+
def test_fallback_for_invalid_archetype(self):
|
|
193
|
+
"""Test fallback when archetype is invalid."""
|
|
194
|
+
invalid_archetype = "invalid-fake-archetype"
|
|
195
|
+
normalized = normalize_archetype(invalid_archetype)
|
|
196
|
+
|
|
197
|
+
# Should get default rubric for invalid archetypes
|
|
198
|
+
rubric = get_rubric(normalized)
|
|
199
|
+
assert rubric is not None
|
|
200
|
+
assert len(rubric) > 0
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class TestArchetypeAwareScoring:
|
|
204
|
+
"""Test scoring with archetype-specific weights and bonuses."""
|
|
205
|
+
|
|
206
|
+
def test_trader_scores_high_on_pnl(self, sample_reward_inputs: TrajectoryRewardInputs):
|
|
207
|
+
"""Test that traders are scored primarily on PnL."""
|
|
208
|
+
behavior = BehaviorMetrics(
|
|
209
|
+
trades_executed=5,
|
|
210
|
+
total_pnl=150.0,
|
|
211
|
+
win_rate=0.6,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# High PnL trader
|
|
215
|
+
high_pnl_inputs = TrajectoryRewardInputs(
|
|
216
|
+
final_pnl=500.0,
|
|
217
|
+
starting_balance=10000.0,
|
|
218
|
+
end_balance=10500.0,
|
|
219
|
+
format_score=0.8,
|
|
220
|
+
reasoning_score=0.75,
|
|
221
|
+
)
|
|
222
|
+
high_score = archetype_composite_reward(high_pnl_inputs, "trader", behavior)
|
|
223
|
+
|
|
224
|
+
# Low PnL trader
|
|
225
|
+
low_pnl_inputs = TrajectoryRewardInputs(
|
|
226
|
+
final_pnl=-200.0,
|
|
227
|
+
starting_balance=10000.0,
|
|
228
|
+
end_balance=9800.0,
|
|
229
|
+
format_score=0.8,
|
|
230
|
+
reasoning_score=0.75,
|
|
231
|
+
)
|
|
232
|
+
low_score = archetype_composite_reward(low_pnl_inputs, "trader", behavior)
|
|
233
|
+
|
|
234
|
+
assert high_score > low_score, "Trader with high PnL should score higher"
|
|
235
|
+
|
|
236
|
+
def test_degen_tolerates_losses_for_activity(self):
|
|
237
|
+
"""Test that degens can score well despite losses if active."""
|
|
238
|
+
high_activity = BehaviorMetrics(
|
|
239
|
+
trades_executed=30,
|
|
240
|
+
pnl_variance=500,
|
|
241
|
+
avg_position_size=300,
|
|
242
|
+
total_pnl=-500.0, # Loss
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
low_activity = BehaviorMetrics(
|
|
246
|
+
trades_executed=2,
|
|
247
|
+
pnl_variance=10,
|
|
248
|
+
avg_position_size=50,
|
|
249
|
+
total_pnl=50.0, # Small profit
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
degen_loss_inputs = TrajectoryRewardInputs(
|
|
253
|
+
final_pnl=-500.0,
|
|
254
|
+
starting_balance=10000.0,
|
|
255
|
+
end_balance=9500.0,
|
|
256
|
+
format_score=0.7,
|
|
257
|
+
reasoning_score=0.6,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
degen_profit_inputs = TrajectoryRewardInputs(
|
|
261
|
+
final_pnl=50.0,
|
|
262
|
+
starting_balance=10000.0,
|
|
263
|
+
end_balance=10050.0,
|
|
264
|
+
format_score=0.7,
|
|
265
|
+
reasoning_score=0.6,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
active_degen_score = archetype_composite_reward(degen_loss_inputs, "degen", high_activity)
|
|
269
|
+
passive_degen_score = archetype_composite_reward(degen_profit_inputs, "degen", low_activity)
|
|
270
|
+
|
|
271
|
+
# Active degen with loss should not score much lower than passive degen with profit
|
|
272
|
+
# The behavior bonus should compensate for the PnL loss
|
|
273
|
+
assert active_degen_score > 0.1, "Active degen should score reasonably despite loss"
|
|
274
|
+
|
|
275
|
+
def test_social_butterfly_scores_on_social_metrics(self):
|
|
276
|
+
"""Test that social butterflies are scored on social activity."""
|
|
277
|
+
high_social = BehaviorMetrics(
|
|
278
|
+
posts_created=10,
|
|
279
|
+
comments_made=25,
|
|
280
|
+
dms_initiated=5,
|
|
281
|
+
unique_users_interacted=30,
|
|
282
|
+
trades_executed=2,
|
|
283
|
+
total_pnl=10.0,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
low_social = BehaviorMetrics(
|
|
287
|
+
posts_created=0,
|
|
288
|
+
comments_made=1,
|
|
289
|
+
dms_initiated=0,
|
|
290
|
+
unique_users_interacted=2,
|
|
291
|
+
trades_executed=10,
|
|
292
|
+
total_pnl=200.0, # Better PnL
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
inputs = TrajectoryRewardInputs(
|
|
296
|
+
final_pnl=10.0,
|
|
297
|
+
starting_balance=10000.0,
|
|
298
|
+
end_balance=10010.0,
|
|
299
|
+
format_score=0.7,
|
|
300
|
+
reasoning_score=0.7,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
high_social_score = archetype_composite_reward(inputs, "social-butterfly", high_social)
|
|
304
|
+
|
|
305
|
+
# Social butterfly behavior bonus
|
|
306
|
+
high_social_bonus = calculate_archetype_behavior_bonus("social-butterfly", high_social)
|
|
307
|
+
low_social_bonus = calculate_archetype_behavior_bonus("social-butterfly", low_social)
|
|
308
|
+
|
|
309
|
+
assert high_social_bonus > low_social_bonus, "High social activity should get higher bonus"
|
|
310
|
+
|
|
311
|
+
def test_scammer_scores_on_profit_from_manipulation(self):
|
|
312
|
+
"""Test that scammers score on PnL from deceptive actions."""
|
|
313
|
+
scammer_behavior = BehaviorMetrics(
|
|
314
|
+
posts_created=5,
|
|
315
|
+
trades_executed=3,
|
|
316
|
+
total_pnl=500.0,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
inputs = TrajectoryRewardInputs(
|
|
320
|
+
final_pnl=500.0,
|
|
321
|
+
starting_balance=10000.0,
|
|
322
|
+
end_balance=10500.0,
|
|
323
|
+
format_score=0.75,
|
|
324
|
+
reasoning_score=0.7,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
score = archetype_composite_reward(inputs, "scammer", scammer_behavior)
|
|
328
|
+
assert score > 0.3, "Profitable scammer should score well"
|
|
329
|
+
|
|
330
|
+
def test_all_archetypes_produce_valid_scores(
|
|
331
|
+
self,
|
|
332
|
+
all_archetype_trajectories: Dict[str, TrajectoryFixture],
|
|
333
|
+
):
|
|
334
|
+
"""Test that all archetypes produce valid scores in [0, 1] range."""
|
|
335
|
+
for archetype, traj in all_archetype_trajectories.items():
|
|
336
|
+
behavior = BehaviorMetrics(
|
|
337
|
+
trades_executed=3,
|
|
338
|
+
total_pnl=traj.final_pnl,
|
|
339
|
+
episode_length=traj.episode_length,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
inputs = TrajectoryRewardInputs(
|
|
343
|
+
final_pnl=traj.final_pnl,
|
|
344
|
+
starting_balance=10000.0,
|
|
345
|
+
end_balance=10000.0 + traj.final_pnl,
|
|
346
|
+
format_score=0.7,
|
|
347
|
+
reasoning_score=0.7,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
score = archetype_composite_reward(inputs, archetype, behavior)
|
|
351
|
+
|
|
352
|
+
assert 0.0 <= score <= 1.0, f"{archetype} score {score} out of [0,1] range"
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class TestGRPOGroupFormation:
|
|
356
|
+
"""Test GRPO group formation from trajectories."""
|
|
357
|
+
|
|
358
|
+
def test_group_trajectories_by_window(
|
|
359
|
+
self,
|
|
360
|
+
temp_trajectory_dir: Path,
|
|
361
|
+
trajectory_group: List[TrajectoryFixture],
|
|
362
|
+
):
|
|
363
|
+
"""Test grouping trajectories by window ID."""
|
|
364
|
+
# Write trajectories to same window
|
|
365
|
+
for traj in trajectory_group:
|
|
366
|
+
file_path = temp_trajectory_dir / f"{traj.trajectory_id}.json"
|
|
367
|
+
file_path.write_text(json.dumps(traj.to_json_file_format()))
|
|
368
|
+
|
|
369
|
+
reader = JsonTrajectoryReader(str(temp_trajectory_dir))
|
|
370
|
+
trajectories = reader.get_trajectories_by_window("window-test-1")
|
|
371
|
+
|
|
372
|
+
assert len(trajectories) >= 2, "GRPO requires at least 2 trajectories per group"
|
|
373
|
+
|
|
374
|
+
def test_score_centering_for_grpo(
|
|
375
|
+
self,
|
|
376
|
+
trajectory_group: List[TrajectoryFixture],
|
|
377
|
+
):
|
|
378
|
+
"""Test that scores are centered around mean for GRPO stability."""
|
|
379
|
+
scores = []
|
|
380
|
+
|
|
381
|
+
for traj in trajectory_group:
|
|
382
|
+
behavior = BehaviorMetrics(
|
|
383
|
+
trades_executed=traj.episode_length,
|
|
384
|
+
total_pnl=traj.final_pnl,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
inputs = TrajectoryRewardInputs(
|
|
388
|
+
final_pnl=traj.final_pnl,
|
|
389
|
+
starting_balance=10000.0,
|
|
390
|
+
end_balance=10000.0 + traj.final_pnl,
|
|
391
|
+
format_score=0.7,
|
|
392
|
+
reasoning_score=0.7,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
score = archetype_composite_reward(inputs, traj.archetype, behavior)
|
|
396
|
+
scores.append(score)
|
|
397
|
+
|
|
398
|
+
# Center scores
|
|
399
|
+
mean_score = sum(scores) / len(scores)
|
|
400
|
+
centered_scores = [s - mean_score for s in scores]
|
|
401
|
+
|
|
402
|
+
# Check that centering works
|
|
403
|
+
centered_mean = sum(centered_scores) / len(centered_scores)
|
|
404
|
+
assert abs(centered_mean) < 0.01, "Centered scores should have mean ~0"
|
|
405
|
+
|
|
406
|
+
def test_relative_ordering_preserved(
|
|
407
|
+
self,
|
|
408
|
+
trajectory_group: List[TrajectoryFixture],
|
|
409
|
+
):
|
|
410
|
+
"""Test that relative ordering is preserved after centering."""
|
|
411
|
+
scores = []
|
|
412
|
+
|
|
413
|
+
for traj in trajectory_group:
|
|
414
|
+
behavior = BehaviorMetrics(
|
|
415
|
+
trades_executed=traj.episode_length,
|
|
416
|
+
total_pnl=traj.final_pnl,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
inputs = TrajectoryRewardInputs(
|
|
420
|
+
final_pnl=traj.final_pnl,
|
|
421
|
+
starting_balance=10000.0,
|
|
422
|
+
end_balance=10000.0 + traj.final_pnl,
|
|
423
|
+
format_score=0.7,
|
|
424
|
+
reasoning_score=0.7,
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
score = archetype_composite_reward(inputs, traj.archetype, behavior)
|
|
428
|
+
scores.append((traj.archetype, score))
|
|
429
|
+
|
|
430
|
+
# Sort by score
|
|
431
|
+
sorted_by_score = sorted(scores, key=lambda x: x[1], reverse=True)
|
|
432
|
+
|
|
433
|
+
# Verify trader (high PnL) scores higher than degen (negative PnL)
|
|
434
|
+
# Note: This depends on archetype weights - trader prioritizes PnL
|
|
435
|
+
trader_score = next(s for a, s in scores if a == "trader")
|
|
436
|
+
degen_score = next(s for a, s in scores if a == "degen")
|
|
437
|
+
|
|
438
|
+
# Trader should score higher due to positive PnL
|
|
439
|
+
assert trader_score > degen_score, "Trader with profit should beat degen with loss"
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
class TestEndToEndJsonPipeline:
|
|
443
|
+
"""Test complete end-to-end JSON mode pipeline."""
|
|
444
|
+
|
|
445
|
+
def test_full_pipeline_single_window(
|
|
446
|
+
self,
|
|
447
|
+
temp_trajectory_dir: Path,
|
|
448
|
+
trajectory_group: List[TrajectoryFixture],
|
|
449
|
+
):
|
|
450
|
+
"""Test full pipeline from JSON files to scores."""
|
|
451
|
+
# Step 1: Write trajectories to files
|
|
452
|
+
for traj in trajectory_group:
|
|
453
|
+
file_path = temp_trajectory_dir / f"{traj.trajectory_id}.json"
|
|
454
|
+
file_path.write_text(json.dumps(traj.to_json_file_format()))
|
|
455
|
+
|
|
456
|
+
# Step 2: Load trajectories
|
|
457
|
+
reader = JsonTrajectoryReader(str(temp_trajectory_dir))
|
|
458
|
+
window_ids = reader.get_window_ids()
|
|
459
|
+
assert "window-test-1" in window_ids
|
|
460
|
+
|
|
461
|
+
trajectories = reader.get_trajectories_by_window("window-test-1")
|
|
462
|
+
assert len(trajectories) == 3
|
|
463
|
+
|
|
464
|
+
# Step 3: Validate LLM calls
|
|
465
|
+
for traj_data in trajectories:
|
|
466
|
+
# JsonTrajectoryReader returns trajectory data directly
|
|
467
|
+
steps_json = traj_data.get("stepsJson", "[]")
|
|
468
|
+
steps = json.loads(steps_json) if isinstance(steps_json, str) else steps_json
|
|
469
|
+
is_valid, issues = validate_llm_calls(steps)
|
|
470
|
+
assert is_valid, f"Invalid LLM calls: {issues}"
|
|
471
|
+
|
|
472
|
+
# Step 4: Extract archetypes and score
|
|
473
|
+
scores = []
|
|
474
|
+
for traj_data in trajectories:
|
|
475
|
+
# JsonTrajectoryReader returns trajectory data directly
|
|
476
|
+
traj = traj_data
|
|
477
|
+
archetype = traj.get("archetype", "default")
|
|
478
|
+
archetype_norm = normalize_archetype(archetype)
|
|
479
|
+
|
|
480
|
+
steps = json.loads(traj.get("stepsJson", "[]"))
|
|
481
|
+
|
|
482
|
+
behavior = BehaviorMetrics(
|
|
483
|
+
trades_executed=len([s for s in steps if s.get("action", {}).get("actionType", "") != "hold"]),
|
|
484
|
+
total_pnl=traj.get("finalPnL", 0.0),
|
|
485
|
+
episode_length=traj.get("episodeLength", len(steps)),
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
inputs = TrajectoryRewardInputs(
|
|
489
|
+
final_pnl=traj.get("finalPnL", 0.0),
|
|
490
|
+
starting_balance=10000.0,
|
|
491
|
+
end_balance=10000.0 + traj.get("finalPnL", 0.0),
|
|
492
|
+
format_score=0.7,
|
|
493
|
+
reasoning_score=0.7,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
score = archetype_composite_reward(inputs, archetype_norm, behavior)
|
|
497
|
+
scores.append((archetype_norm, score))
|
|
498
|
+
|
|
499
|
+
# Step 5: Center scores for GRPO
|
|
500
|
+
mean_score = sum(s for _, s in scores) / len(scores)
|
|
501
|
+
centered_scores = [(a, s - mean_score) for a, s in scores]
|
|
502
|
+
|
|
503
|
+
# Verify results
|
|
504
|
+
assert len(centered_scores) == 3
|
|
505
|
+
centered_mean = sum(s for _, s in centered_scores) / len(centered_scores)
|
|
506
|
+
assert abs(centered_mean) < 0.01
|
|
507
|
+
|
|
508
|
+
def test_pipeline_with_all_archetypes(
|
|
509
|
+
self,
|
|
510
|
+
temp_trajectory_dir: Path,
|
|
511
|
+
all_archetype_trajectories: Dict[str, TrajectoryFixture],
|
|
512
|
+
):
|
|
513
|
+
"""Test pipeline with all archetype types."""
|
|
514
|
+
# Write all trajectories
|
|
515
|
+
for archetype, traj in all_archetype_trajectories.items():
|
|
516
|
+
file_path = temp_trajectory_dir / f"{traj.trajectory_id}.json"
|
|
517
|
+
file_path.write_text(json.dumps(traj.to_json_file_format()))
|
|
518
|
+
|
|
519
|
+
# Load and score all
|
|
520
|
+
reader = JsonTrajectoryReader(str(temp_trajectory_dir))
|
|
521
|
+
trajectories = reader.get_trajectories_by_window("window-all-archetypes")
|
|
522
|
+
|
|
523
|
+
assert len(trajectories) == len(get_available_archetypes())
|
|
524
|
+
|
|
525
|
+
# Score each
|
|
526
|
+
archetype_scores: Dict[str, float] = {}
|
|
527
|
+
for traj_data in trajectories:
|
|
528
|
+
# JsonTrajectoryReader returns trajectory data directly
|
|
529
|
+
traj = traj_data
|
|
530
|
+
archetype = normalize_archetype(traj.get("archetype", "default"))
|
|
531
|
+
|
|
532
|
+
behavior = BehaviorMetrics(
|
|
533
|
+
trades_executed=3,
|
|
534
|
+
total_pnl=traj.get("finalPnL", 0.0),
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
inputs = TrajectoryRewardInputs(
|
|
538
|
+
final_pnl=traj.get("finalPnL", 0.0),
|
|
539
|
+
starting_balance=10000.0,
|
|
540
|
+
end_balance=10000.0 + traj.get("finalPnL", 0.0),
|
|
541
|
+
format_score=0.7,
|
|
542
|
+
reasoning_score=0.7,
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
score = archetype_composite_reward(inputs, archetype, behavior)
|
|
546
|
+
archetype_scores[archetype] = score
|
|
547
|
+
|
|
548
|
+
# Verify all archetypes got scored
|
|
549
|
+
assert len(archetype_scores) == len(get_available_archetypes())
|
|
550
|
+
|
|
551
|
+
# All scores should be valid
|
|
552
|
+
for arch, score in archetype_scores.items():
|
|
553
|
+
assert 0.0 <= score <= 1.0, f"{arch} has invalid score: {score}"
|
|
554
|
+
|