@elizaos/training 2.0.0-alpha.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. package/Dockerfile +75 -0
  2. package/Makefile +374 -0
  3. package/README.md +346 -0
  4. package/config/rubrics.json +137 -0
  5. package/data/.gitkeep +0 -0
  6. package/data/degen/.gitkeep +2 -0
  7. package/data/trader/.gitkeep +2 -0
  8. package/docker-compose.test.yml +57 -0
  9. package/package.json +58 -0
  10. package/python/config/babylon_atropos.yaml +90 -0
  11. package/python/config/profiles/12gb.json +11 -0
  12. package/python/config/profiles/16gb.json +10 -0
  13. package/python/config/profiles/24gb.json +10 -0
  14. package/python/config/profiles/48gb.json +10 -0
  15. package/python/config/profiles/cpu.json +11 -0
  16. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  17. package/python/config/profiles/l40-2gpu.json +22 -0
  18. package/python/config/profiles/l40-4gpu.json +21 -0
  19. package/python/config/profiles/l40.json +17 -0
  20. package/python/config/tinker_training.yaml +143 -0
  21. package/python/curriculum_state.json +165 -0
  22. package/python/env.template +86 -0
  23. package/python/env.training.template +46 -0
  24. package/python/pyproject.toml +41 -0
  25. package/python/requirements-ci.txt +31 -0
  26. package/python/requirements.txt +87 -0
  27. package/python/scripts/__init__.py +4 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/run_ab_test.py +143 -0
  36. package/python/scripts/run_full_pipeline.py +544 -0
  37. package/python/scripts/run_tinker_training.py +192 -0
  38. package/python/scripts/run_training.py +914 -0
  39. package/python/scripts/test_judge.py +155 -0
  40. package/python/scripts/test_pipeline.py +356 -0
  41. package/python/scripts/test_trained_model.py +380 -0
  42. package/python/scripts/train_local.py +528 -0
  43. package/python/setup.py +20 -0
  44. package/python/src/__init__.py +190 -0
  45. package/python/src/data_bridge/__init__.py +24 -0
  46. package/python/src/data_bridge/converter.py +435 -0
  47. package/python/src/data_bridge/reader.py +393 -0
  48. package/python/src/models.py +283 -0
  49. package/python/src/training/__init__.py +605 -0
  50. package/python/src/training/ab_testing.py +404 -0
  51. package/python/src/training/action_executor.py +621 -0
  52. package/python/src/training/archetype_trainer.py +347 -0
  53. package/python/src/training/atropos_trainer.py +980 -0
  54. package/python/src/training/babylon_env.py +1254 -0
  55. package/python/src/training/error_recovery.py +647 -0
  56. package/python/src/training/evaluation.py +856 -0
  57. package/python/src/training/fast_simulator.py +880 -0
  58. package/python/src/training/format_validator.py +584 -0
  59. package/python/src/training/hybrid_env.py +522 -0
  60. package/python/src/training/kl_controller.py +628 -0
  61. package/python/src/training/multi_prompt_dataset.py +883 -0
  62. package/python/src/training/multi_turn.py +656 -0
  63. package/python/src/training/online_env.py +1084 -0
  64. package/python/src/training/quality_scorer.py +391 -0
  65. package/python/src/training/quality_utils.py +633 -0
  66. package/python/src/training/rewards.py +1344 -0
  67. package/python/src/training/rlaif_env.py +17 -0
  68. package/python/src/training/rollout_generator.py +502 -0
  69. package/python/src/training/rubric_loader.py +198 -0
  70. package/python/src/training/scenario_pool.py +1072 -0
  71. package/python/src/training/schemas.py +481 -0
  72. package/python/src/training/service_manager.py +552 -0
  73. package/python/src/training/simulation_bridge.py +535 -0
  74. package/python/src/training/tick_reward_attribution.py +399 -0
  75. package/python/src/training/tinker_client.py +575 -0
  76. package/python/src/training/tinker_trainer.py +646 -0
  77. package/python/src/training/tokenization_utils.py +402 -0
  78. package/python/tests/e2e/__init__.py +13 -0
  79. package/python/tests/e2e/conftest.py +258 -0
  80. package/python/tests/e2e/test_full_pipeline.py +643 -0
  81. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  82. package/python/tests/integration/__init__.py +12 -0
  83. package/python/tests/integration/conftest.py +383 -0
  84. package/python/tests/integration/test_db_integration.py +649 -0
  85. package/python/tests/integration/test_json_mode_integration.py +554 -0
  86. package/python/tests/test_action_executor.py +594 -0
  87. package/python/tests/test_archetype_scoring.py +1027 -0
  88. package/python/tests/test_atropos_integration.py +360 -0
  89. package/python/tests/test_evaluation.py +727 -0
  90. package/python/tests/test_format_validator.py +486 -0
  91. package/python/tests/test_kl_controller.py +432 -0
  92. package/python/tests/test_lr_scheduler.py +579 -0
  93. package/python/tests/test_multi_turn.py +590 -0
  94. package/python/tests/test_online_env.py +519 -0
  95. package/python/tests/test_quality_scorer.py +474 -0
  96. package/python/tests/test_scenario_pool.py +735 -0
  97. package/python/tests/test_service_manager.py +585 -0
  98. package/python/tests/test_simulation_rollout.py +581 -0
  99. package/python/tests/test_tokenization_utils.py +501 -0
  100. package/python/tests/test_training_orchestrator.py +497 -0
  101. package/python/tests/test_training_output_structure.py +661 -0
  102. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  103. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  104. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  105. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  106. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  107. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  108. package/scripts/assess-training-data.ts +422 -0
  109. package/scripts/e2e-training-test.ts +550 -0
  110. package/scripts/export-rubrics.ts +64 -0
  111. package/scripts/generate-research-report.ts +1523 -0
  112. package/scripts/generate_dataset.sh +173 -0
  113. package/scripts/json-mode-benchmark.ts +399 -0
  114. package/scripts/real-archetype-benchmark.ts +210 -0
  115. package/scripts/run-baseline-comparison.ts +116 -0
  116. package/scripts/run-full-pipeline.ts +272 -0
  117. package/scripts/runpod_setup.sh +137 -0
  118. package/scripts/runpod_validate.sh +147 -0
  119. package/scripts/test-model-in-game.ts +955 -0
  120. package/scripts/test-scoring.ts +73 -0
  121. package/scripts/test-trained-model.ts +209 -0
  122. package/scripts/train-and-test.ts +824 -0
  123. package/scripts/verify-final.ts +118 -0
  124. package/src/adapter.ts +516 -0
  125. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  126. package/src/archetypes/derive-archetype.ts +249 -0
  127. package/src/archetypes/index.ts +22 -0
  128. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  129. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  130. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  131. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  132. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  133. package/src/benchmark/BenchmarkRunner.ts +685 -0
  134. package/src/benchmark/BenchmarkValidator.ts +206 -0
  135. package/src/benchmark/FastEvalRunner.ts +225 -0
  136. package/src/benchmark/MetricsValidator.ts +165 -0
  137. package/src/benchmark/MetricsVisualizer.ts +909 -0
  138. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  139. package/src/benchmark/ModelRegistry.ts +158 -0
  140. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  141. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  142. package/src/benchmark/SimulationEngine.ts +832 -0
  143. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  144. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  145. package/src/benchmark/index.ts +89 -0
  146. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  147. package/src/benchmark/simulation-types.ts +78 -0
  148. package/src/dependencies.ts +439 -0
  149. package/src/generation/TrajectoryGenerator.ts +387 -0
  150. package/src/generation/index.ts +12 -0
  151. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  152. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  153. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  154. package/src/huggingface/index.ts +27 -0
  155. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  156. package/src/index.ts +102 -0
  157. package/src/init-training.ts +53 -0
  158. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  159. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  160. package/src/metrics/index.ts +8 -0
  161. package/src/metrics/types.ts +200 -0
  162. package/src/rubrics/__tests__/index.test.ts +184 -0
  163. package/src/rubrics/ass-kisser.ts +85 -0
  164. package/src/rubrics/degen.ts +80 -0
  165. package/src/rubrics/goody-twoshoes.ts +84 -0
  166. package/src/rubrics/index.ts +236 -0
  167. package/src/rubrics/information-trader.ts +84 -0
  168. package/src/rubrics/infosec.ts +101 -0
  169. package/src/rubrics/liar.ts +104 -0
  170. package/src/rubrics/perps-trader.ts +87 -0
  171. package/src/rubrics/researcher.ts +81 -0
  172. package/src/rubrics/scammer.ts +82 -0
  173. package/src/rubrics/social-butterfly.ts +73 -0
  174. package/src/rubrics/super-predictor.ts +97 -0
  175. package/src/rubrics/trader.ts +67 -0
  176. package/src/scoring/ArchetypeScoringService.ts +486 -0
  177. package/src/scoring/JudgePromptBuilder.ts +556 -0
  178. package/src/scoring/LLMJudgeCache.ts +401 -0
  179. package/src/scoring/index.ts +9 -0
  180. package/src/training/AutomationPipeline.ts +916 -0
  181. package/src/training/BenchmarkService.ts +518 -0
  182. package/src/training/ConfigValidator.ts +220 -0
  183. package/src/training/MarketOutcomesTracker.ts +187 -0
  184. package/src/training/ModelDeployer.ts +186 -0
  185. package/src/training/ModelFetcher.ts +76 -0
  186. package/src/training/ModelSelectionService.ts +341 -0
  187. package/src/training/ModelUsageVerifier.ts +160 -0
  188. package/src/training/MultiModelOrchestrator.ts +580 -0
  189. package/src/training/RLModelConfig.ts +407 -0
  190. package/src/training/RewardBackpropagationService.ts +149 -0
  191. package/src/training/RulerScoringService.ts +666 -0
  192. package/src/training/TrainingMonitor.ts +166 -0
  193. package/src/training/TrajectoryRecorder.ts +399 -0
  194. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  195. package/src/training/index.ts +100 -0
  196. package/src/training/logRLConfig.ts +34 -0
  197. package/src/training/pipeline.ts +129 -0
  198. package/src/training/storage/ModelStorageService.ts +279 -0
  199. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  200. package/src/training/storage/index.ts +17 -0
  201. package/src/training/types.ts +207 -0
  202. package/src/training/window-utils.ts +138 -0
  203. package/src/utils/index.ts +101 -0
  204. package/src/utils/logger.ts +59 -0
  205. package/src/utils/snowflake.ts +17 -0
  206. package/src/utils/synthetic-detector.ts +111 -0
  207. package/tsconfig.json +20 -0
@@ -0,0 +1,554 @@
1
+ """
2
+ JSON Mode Integration Tests
3
+
4
+ Tests the complete JSON-based training data pipeline:
5
+ 1. Trajectory loading from JSON files
6
+ 2. Archetype extraction from step parameters
7
+ 3. Scoring with archetype-aware rewards
8
+ 4. GRPO group formation
9
+ 5. End-to-end scoring pipeline
10
+
11
+ These tests run WITHOUT a database, using only local JSON files.
12
+ """
13
+
14
+ import json
15
+ import sys
16
+ from pathlib import Path
17
+ from typing import Dict, List
18
+
19
+ import pytest
20
+
21
+ # Add src to path
22
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
23
+
24
+ from src.data_bridge.reader import JsonTrajectoryReader, validate_llm_calls
25
+ from src.training.rewards import (
26
+ archetype_composite_reward,
27
+ calculate_archetype_behavior_bonus,
28
+ BehaviorMetrics,
29
+ TrajectoryRewardInputs,
30
+ )
31
+ from src.training.rubric_loader import (
32
+ normalize_archetype,
33
+ has_custom_rubric,
34
+ get_rubric,
35
+ get_available_archetypes,
36
+ )
37
+ from tests.integration.conftest import (
38
+ TrajectoryFixture,
39
+ create_trading_step,
40
+ )
41
+
42
+
43
+ class TestJsonTrajectoryLoading:
44
+ """Test loading trajectories from JSON files."""
45
+
46
+ def test_load_single_trajectory_file(
47
+ self,
48
+ temp_trajectory_dir: Path,
49
+ sample_trader_trajectory: TrajectoryFixture,
50
+ ):
51
+ """Test loading a single trajectory from JSON file."""
52
+ # Write trajectory to file
53
+ file_path = temp_trajectory_dir / f"{sample_trader_trajectory.trajectory_id}.json"
54
+ file_path.write_text(json.dumps(sample_trader_trajectory.to_json_file_format()))
55
+
56
+ # Load using reader
57
+ reader = JsonTrajectoryReader(str(temp_trajectory_dir))
58
+ window_ids = reader.get_window_ids()
59
+
60
+ assert len(window_ids) >= 1
61
+ assert sample_trader_trajectory.window_id in window_ids
62
+
63
+ trajectories = reader.get_trajectories_by_window(sample_trader_trajectory.window_id)
64
+ assert len(trajectories) == 1
65
+
66
+ # JsonTrajectoryReader returns trajectory data directly (unwrapped)
67
+ traj = trajectories[0]
68
+ assert traj.get("archetype") == "trader"
69
+
70
+ def test_load_multiple_trajectories_same_window(
71
+ self,
72
+ temp_trajectory_dir: Path,
73
+ trajectory_group: List[TrajectoryFixture],
74
+ ):
75
+ """Test loading multiple trajectories from same window."""
76
+ # Write all trajectories
77
+ for traj in trajectory_group:
78
+ file_path = temp_trajectory_dir / f"{traj.trajectory_id}.json"
79
+ file_path.write_text(json.dumps(traj.to_json_file_format()))
80
+
81
+ reader = JsonTrajectoryReader(str(temp_trajectory_dir))
82
+ trajectories = reader.get_trajectories_by_window("window-test-1")
83
+
84
+ assert len(trajectories) == 3
85
+ # JsonTrajectoryReader returns trajectory data directly
86
+ archetypes = {t.get("archetype") for t in trajectories}
87
+ assert archetypes == {"trader", "degen", "scammer"}
88
+
89
+ def test_validate_llm_calls_in_loaded_trajectory(
90
+ self,
91
+ temp_trajectory_dir: Path,
92
+ sample_trader_trajectory: TrajectoryFixture,
93
+ ):
94
+ """Test that loaded trajectories pass LLM call validation."""
95
+ file_path = temp_trajectory_dir / f"{sample_trader_trajectory.trajectory_id}.json"
96
+ file_path.write_text(json.dumps(sample_trader_trajectory.to_json_file_format()))
97
+
98
+ reader = JsonTrajectoryReader(str(temp_trajectory_dir))
99
+ trajectories = reader.get_trajectories_by_window(sample_trader_trajectory.window_id)
100
+
101
+ # JsonTrajectoryReader returns trajectory data directly
102
+ traj_data = trajectories[0]
103
+ steps_json = traj_data.get("stepsJson", "[]")
104
+ steps = json.loads(steps_json) if isinstance(steps_json, str) else steps_json
105
+
106
+ is_valid, issues = validate_llm_calls(steps)
107
+ assert is_valid, f"LLM calls should be valid: {issues}"
108
+
109
+ def test_load_all_archetypes(
110
+ self,
111
+ temp_trajectory_dir: Path,
112
+ all_archetype_trajectories: Dict[str, TrajectoryFixture],
113
+ ):
114
+ """Test loading trajectories for all valid archetypes."""
115
+ # Write all trajectories
116
+ for archetype, traj in all_archetype_trajectories.items():
117
+ file_path = temp_trajectory_dir / f"{traj.trajectory_id}.json"
118
+ file_path.write_text(json.dumps(traj.to_json_file_format()))
119
+
120
+ reader = JsonTrajectoryReader(str(temp_trajectory_dir))
121
+ trajectories = reader.get_trajectories_by_window("window-all-archetypes")
122
+
123
+ # JsonTrajectoryReader returns trajectory data directly
124
+ loaded_archetypes = {t.get("archetype") for t in trajectories}
125
+ expected_archetypes = set(get_available_archetypes())
126
+
127
+ assert loaded_archetypes == expected_archetypes
128
+
129
+ def test_empty_directory_returns_empty_list(self, temp_trajectory_dir: Path):
130
+ """Test that empty directory returns empty results."""
131
+ reader = JsonTrajectoryReader(str(temp_trajectory_dir))
132
+ window_ids = reader.get_window_ids()
133
+ assert window_ids == []
134
+
135
+
136
+ class TestArchetypeExtractionFromSteps:
137
+ """Test extracting archetype from step action parameters."""
138
+
139
+ def test_extract_archetype_from_action_parameters(self):
140
+ """Test extracting archetype from step's action parameters."""
141
+ step = create_trading_step(0, "buy_prediction", "trader")
142
+
143
+ # Simulate extraction logic from babylon_env.py
144
+ action = step.get("action", {})
145
+ params = action.get("parameters", {})
146
+ archetype = params.get("archetype")
147
+
148
+ assert archetype == "trader"
149
+
150
+ def test_extract_archetype_from_first_step(self):
151
+ """Test extracting archetype from first step when trajectory-level is missing."""
152
+ steps = [
153
+ create_trading_step(0, "buy_prediction", "degen"),
154
+ create_trading_step(1, "hold", "degen"),
155
+ create_trading_step(2, "sell_prediction", "degen"),
156
+ ]
157
+
158
+ # Simulate extraction logic
159
+ archetype = None
160
+ for step in steps:
161
+ action = step.get("action", {})
162
+ params = action.get("parameters", {})
163
+ if params.get("archetype"):
164
+ archetype = params.get("archetype")
165
+ break
166
+
167
+ assert archetype == "degen"
168
+
169
+ def test_normalize_extracted_archetype(self):
170
+ """Test that extracted archetypes are normalized correctly."""
171
+ test_cases = [
172
+ ("TRADER", "trader"),
173
+ ("Social_Butterfly", "social-butterfly"),
174
+ ("goody_twoshoes", "goody-twoshoes"),
175
+ ("DEGEN", "degen"),
176
+ ("", "default"),
177
+ (None, "default"),
178
+ ]
179
+
180
+ for input_val, expected in test_cases:
181
+ result = normalize_archetype(input_val)
182
+ assert result == expected, f"normalize_archetype({input_val}) = {result}, expected {expected}"
183
+
184
+ def test_validate_extracted_archetype(self):
185
+ """Test that extracted archetypes are validated."""
186
+ valid_archetypes = get_available_archetypes()
187
+
188
+ for arch in valid_archetypes:
189
+ normalized = normalize_archetype(arch)
190
+ assert has_custom_rubric(normalized), f"{arch} should have custom rubric"
191
+
192
+ def test_fallback_for_invalid_archetype(self):
193
+ """Test fallback when archetype is invalid."""
194
+ invalid_archetype = "invalid-fake-archetype"
195
+ normalized = normalize_archetype(invalid_archetype)
196
+
197
+ # Should get default rubric for invalid archetypes
198
+ rubric = get_rubric(normalized)
199
+ assert rubric is not None
200
+ assert len(rubric) > 0
201
+
202
+
203
+ class TestArchetypeAwareScoring:
204
+ """Test scoring with archetype-specific weights and bonuses."""
205
+
206
+ def test_trader_scores_high_on_pnl(self, sample_reward_inputs: TrajectoryRewardInputs):
207
+ """Test that traders are scored primarily on PnL."""
208
+ behavior = BehaviorMetrics(
209
+ trades_executed=5,
210
+ total_pnl=150.0,
211
+ win_rate=0.6,
212
+ )
213
+
214
+ # High PnL trader
215
+ high_pnl_inputs = TrajectoryRewardInputs(
216
+ final_pnl=500.0,
217
+ starting_balance=10000.0,
218
+ end_balance=10500.0,
219
+ format_score=0.8,
220
+ reasoning_score=0.75,
221
+ )
222
+ high_score = archetype_composite_reward(high_pnl_inputs, "trader", behavior)
223
+
224
+ # Low PnL trader
225
+ low_pnl_inputs = TrajectoryRewardInputs(
226
+ final_pnl=-200.0,
227
+ starting_balance=10000.0,
228
+ end_balance=9800.0,
229
+ format_score=0.8,
230
+ reasoning_score=0.75,
231
+ )
232
+ low_score = archetype_composite_reward(low_pnl_inputs, "trader", behavior)
233
+
234
+ assert high_score > low_score, "Trader with high PnL should score higher"
235
+
236
+ def test_degen_tolerates_losses_for_activity(self):
237
+ """Test that degens can score well despite losses if active."""
238
+ high_activity = BehaviorMetrics(
239
+ trades_executed=30,
240
+ pnl_variance=500,
241
+ avg_position_size=300,
242
+ total_pnl=-500.0, # Loss
243
+ )
244
+
245
+ low_activity = BehaviorMetrics(
246
+ trades_executed=2,
247
+ pnl_variance=10,
248
+ avg_position_size=50,
249
+ total_pnl=50.0, # Small profit
250
+ )
251
+
252
+ degen_loss_inputs = TrajectoryRewardInputs(
253
+ final_pnl=-500.0,
254
+ starting_balance=10000.0,
255
+ end_balance=9500.0,
256
+ format_score=0.7,
257
+ reasoning_score=0.6,
258
+ )
259
+
260
+ degen_profit_inputs = TrajectoryRewardInputs(
261
+ final_pnl=50.0,
262
+ starting_balance=10000.0,
263
+ end_balance=10050.0,
264
+ format_score=0.7,
265
+ reasoning_score=0.6,
266
+ )
267
+
268
+ active_degen_score = archetype_composite_reward(degen_loss_inputs, "degen", high_activity)
269
+ passive_degen_score = archetype_composite_reward(degen_profit_inputs, "degen", low_activity)
270
+
271
+ # Active degen with loss should not score much lower than passive degen with profit
272
+ # The behavior bonus should compensate for the PnL loss
273
+ assert active_degen_score > 0.1, "Active degen should score reasonably despite loss"
274
+
275
+ def test_social_butterfly_scores_on_social_metrics(self):
276
+ """Test that social butterflies are scored on social activity."""
277
+ high_social = BehaviorMetrics(
278
+ posts_created=10,
279
+ comments_made=25,
280
+ dms_initiated=5,
281
+ unique_users_interacted=30,
282
+ trades_executed=2,
283
+ total_pnl=10.0,
284
+ )
285
+
286
+ low_social = BehaviorMetrics(
287
+ posts_created=0,
288
+ comments_made=1,
289
+ dms_initiated=0,
290
+ unique_users_interacted=2,
291
+ trades_executed=10,
292
+ total_pnl=200.0, # Better PnL
293
+ )
294
+
295
+ inputs = TrajectoryRewardInputs(
296
+ final_pnl=10.0,
297
+ starting_balance=10000.0,
298
+ end_balance=10010.0,
299
+ format_score=0.7,
300
+ reasoning_score=0.7,
301
+ )
302
+
303
+ high_social_score = archetype_composite_reward(inputs, "social-butterfly", high_social)
304
+
305
+ # Social butterfly behavior bonus
306
+ high_social_bonus = calculate_archetype_behavior_bonus("social-butterfly", high_social)
307
+ low_social_bonus = calculate_archetype_behavior_bonus("social-butterfly", low_social)
308
+
309
+ assert high_social_bonus > low_social_bonus, "High social activity should get higher bonus"
310
+
311
+ def test_scammer_scores_on_profit_from_manipulation(self):
312
+ """Test that scammers score on PnL from deceptive actions."""
313
+ scammer_behavior = BehaviorMetrics(
314
+ posts_created=5,
315
+ trades_executed=3,
316
+ total_pnl=500.0,
317
+ )
318
+
319
+ inputs = TrajectoryRewardInputs(
320
+ final_pnl=500.0,
321
+ starting_balance=10000.0,
322
+ end_balance=10500.0,
323
+ format_score=0.75,
324
+ reasoning_score=0.7,
325
+ )
326
+
327
+ score = archetype_composite_reward(inputs, "scammer", scammer_behavior)
328
+ assert score > 0.3, "Profitable scammer should score well"
329
+
330
+ def test_all_archetypes_produce_valid_scores(
331
+ self,
332
+ all_archetype_trajectories: Dict[str, TrajectoryFixture],
333
+ ):
334
+ """Test that all archetypes produce valid scores in [0, 1] range."""
335
+ for archetype, traj in all_archetype_trajectories.items():
336
+ behavior = BehaviorMetrics(
337
+ trades_executed=3,
338
+ total_pnl=traj.final_pnl,
339
+ episode_length=traj.episode_length,
340
+ )
341
+
342
+ inputs = TrajectoryRewardInputs(
343
+ final_pnl=traj.final_pnl,
344
+ starting_balance=10000.0,
345
+ end_balance=10000.0 + traj.final_pnl,
346
+ format_score=0.7,
347
+ reasoning_score=0.7,
348
+ )
349
+
350
+ score = archetype_composite_reward(inputs, archetype, behavior)
351
+
352
+ assert 0.0 <= score <= 1.0, f"{archetype} score {score} out of [0,1] range"
353
+
354
+
355
+ class TestGRPOGroupFormation:
356
+ """Test GRPO group formation from trajectories."""
357
+
358
+ def test_group_trajectories_by_window(
359
+ self,
360
+ temp_trajectory_dir: Path,
361
+ trajectory_group: List[TrajectoryFixture],
362
+ ):
363
+ """Test grouping trajectories by window ID."""
364
+ # Write trajectories to same window
365
+ for traj in trajectory_group:
366
+ file_path = temp_trajectory_dir / f"{traj.trajectory_id}.json"
367
+ file_path.write_text(json.dumps(traj.to_json_file_format()))
368
+
369
+ reader = JsonTrajectoryReader(str(temp_trajectory_dir))
370
+ trajectories = reader.get_trajectories_by_window("window-test-1")
371
+
372
+ assert len(trajectories) >= 2, "GRPO requires at least 2 trajectories per group"
373
+
374
+ def test_score_centering_for_grpo(
375
+ self,
376
+ trajectory_group: List[TrajectoryFixture],
377
+ ):
378
+ """Test that scores are centered around mean for GRPO stability."""
379
+ scores = []
380
+
381
+ for traj in trajectory_group:
382
+ behavior = BehaviorMetrics(
383
+ trades_executed=traj.episode_length,
384
+ total_pnl=traj.final_pnl,
385
+ )
386
+
387
+ inputs = TrajectoryRewardInputs(
388
+ final_pnl=traj.final_pnl,
389
+ starting_balance=10000.0,
390
+ end_balance=10000.0 + traj.final_pnl,
391
+ format_score=0.7,
392
+ reasoning_score=0.7,
393
+ )
394
+
395
+ score = archetype_composite_reward(inputs, traj.archetype, behavior)
396
+ scores.append(score)
397
+
398
+ # Center scores
399
+ mean_score = sum(scores) / len(scores)
400
+ centered_scores = [s - mean_score for s in scores]
401
+
402
+ # Check that centering works
403
+ centered_mean = sum(centered_scores) / len(centered_scores)
404
+ assert abs(centered_mean) < 0.01, "Centered scores should have mean ~0"
405
+
406
+ def test_relative_ordering_preserved(
407
+ self,
408
+ trajectory_group: List[TrajectoryFixture],
409
+ ):
410
+ """Test that relative ordering is preserved after centering."""
411
+ scores = []
412
+
413
+ for traj in trajectory_group:
414
+ behavior = BehaviorMetrics(
415
+ trades_executed=traj.episode_length,
416
+ total_pnl=traj.final_pnl,
417
+ )
418
+
419
+ inputs = TrajectoryRewardInputs(
420
+ final_pnl=traj.final_pnl,
421
+ starting_balance=10000.0,
422
+ end_balance=10000.0 + traj.final_pnl,
423
+ format_score=0.7,
424
+ reasoning_score=0.7,
425
+ )
426
+
427
+ score = archetype_composite_reward(inputs, traj.archetype, behavior)
428
+ scores.append((traj.archetype, score))
429
+
430
+ # Sort by score
431
+ sorted_by_score = sorted(scores, key=lambda x: x[1], reverse=True)
432
+
433
+ # Verify trader (high PnL) scores higher than degen (negative PnL)
434
+ # Note: This depends on archetype weights - trader prioritizes PnL
435
+ trader_score = next(s for a, s in scores if a == "trader")
436
+ degen_score = next(s for a, s in scores if a == "degen")
437
+
438
+ # Trader should score higher due to positive PnL
439
+ assert trader_score > degen_score, "Trader with profit should beat degen with loss"
440
+
441
+
442
+ class TestEndToEndJsonPipeline:
443
+ """Test complete end-to-end JSON mode pipeline."""
444
+
445
+ def test_full_pipeline_single_window(
446
+ self,
447
+ temp_trajectory_dir: Path,
448
+ trajectory_group: List[TrajectoryFixture],
449
+ ):
450
+ """Test full pipeline from JSON files to scores."""
451
+ # Step 1: Write trajectories to files
452
+ for traj in trajectory_group:
453
+ file_path = temp_trajectory_dir / f"{traj.trajectory_id}.json"
454
+ file_path.write_text(json.dumps(traj.to_json_file_format()))
455
+
456
+ # Step 2: Load trajectories
457
+ reader = JsonTrajectoryReader(str(temp_trajectory_dir))
458
+ window_ids = reader.get_window_ids()
459
+ assert "window-test-1" in window_ids
460
+
461
+ trajectories = reader.get_trajectories_by_window("window-test-1")
462
+ assert len(trajectories) == 3
463
+
464
+ # Step 3: Validate LLM calls
465
+ for traj_data in trajectories:
466
+ # JsonTrajectoryReader returns trajectory data directly
467
+ steps_json = traj_data.get("stepsJson", "[]")
468
+ steps = json.loads(steps_json) if isinstance(steps_json, str) else steps_json
469
+ is_valid, issues = validate_llm_calls(steps)
470
+ assert is_valid, f"Invalid LLM calls: {issues}"
471
+
472
+ # Step 4: Extract archetypes and score
473
+ scores = []
474
+ for traj_data in trajectories:
475
+ # JsonTrajectoryReader returns trajectory data directly
476
+ traj = traj_data
477
+ archetype = traj.get("archetype", "default")
478
+ archetype_norm = normalize_archetype(archetype)
479
+
480
+ steps = json.loads(traj.get("stepsJson", "[]"))
481
+
482
+ behavior = BehaviorMetrics(
483
+ trades_executed=len([s for s in steps if s.get("action", {}).get("actionType", "") != "hold"]),
484
+ total_pnl=traj.get("finalPnL", 0.0),
485
+ episode_length=traj.get("episodeLength", len(steps)),
486
+ )
487
+
488
+ inputs = TrajectoryRewardInputs(
489
+ final_pnl=traj.get("finalPnL", 0.0),
490
+ starting_balance=10000.0,
491
+ end_balance=10000.0 + traj.get("finalPnL", 0.0),
492
+ format_score=0.7,
493
+ reasoning_score=0.7,
494
+ )
495
+
496
+ score = archetype_composite_reward(inputs, archetype_norm, behavior)
497
+ scores.append((archetype_norm, score))
498
+
499
+ # Step 5: Center scores for GRPO
500
+ mean_score = sum(s for _, s in scores) / len(scores)
501
+ centered_scores = [(a, s - mean_score) for a, s in scores]
502
+
503
+ # Verify results
504
+ assert len(centered_scores) == 3
505
+ centered_mean = sum(s for _, s in centered_scores) / len(centered_scores)
506
+ assert abs(centered_mean) < 0.01
507
+
508
+ def test_pipeline_with_all_archetypes(
509
+ self,
510
+ temp_trajectory_dir: Path,
511
+ all_archetype_trajectories: Dict[str, TrajectoryFixture],
512
+ ):
513
+ """Test pipeline with all archetype types."""
514
+ # Write all trajectories
515
+ for archetype, traj in all_archetype_trajectories.items():
516
+ file_path = temp_trajectory_dir / f"{traj.trajectory_id}.json"
517
+ file_path.write_text(json.dumps(traj.to_json_file_format()))
518
+
519
+ # Load and score all
520
+ reader = JsonTrajectoryReader(str(temp_trajectory_dir))
521
+ trajectories = reader.get_trajectories_by_window("window-all-archetypes")
522
+
523
+ assert len(trajectories) == len(get_available_archetypes())
524
+
525
+ # Score each
526
+ archetype_scores: Dict[str, float] = {}
527
+ for traj_data in trajectories:
528
+ # JsonTrajectoryReader returns trajectory data directly
529
+ traj = traj_data
530
+ archetype = normalize_archetype(traj.get("archetype", "default"))
531
+
532
+ behavior = BehaviorMetrics(
533
+ trades_executed=3,
534
+ total_pnl=traj.get("finalPnL", 0.0),
535
+ )
536
+
537
+ inputs = TrajectoryRewardInputs(
538
+ final_pnl=traj.get("finalPnL", 0.0),
539
+ starting_balance=10000.0,
540
+ end_balance=10000.0 + traj.get("finalPnL", 0.0),
541
+ format_score=0.7,
542
+ reasoning_score=0.7,
543
+ )
544
+
545
+ score = archetype_composite_reward(inputs, archetype, behavior)
546
+ archetype_scores[archetype] = score
547
+
548
+ # Verify all archetypes got scored
549
+ assert len(archetype_scores) == len(get_available_archetypes())
550
+
551
+ # All scores should be valid
552
+ for arch, score in archetype_scores.items():
553
+ assert 0.0 <= score <= 1.0, f"{arch} has invalid score: {score}"
554
+