@elizaos/training 2.0.0-alpha.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. package/Dockerfile +75 -0
  2. package/Makefile +374 -0
  3. package/README.md +346 -0
  4. package/config/rubrics.json +137 -0
  5. package/data/.gitkeep +0 -0
  6. package/data/degen/.gitkeep +2 -0
  7. package/data/trader/.gitkeep +2 -0
  8. package/docker-compose.test.yml +57 -0
  9. package/package.json +58 -0
  10. package/python/config/babylon_atropos.yaml +90 -0
  11. package/python/config/profiles/12gb.json +11 -0
  12. package/python/config/profiles/16gb.json +10 -0
  13. package/python/config/profiles/24gb.json +10 -0
  14. package/python/config/profiles/48gb.json +10 -0
  15. package/python/config/profiles/cpu.json +11 -0
  16. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  17. package/python/config/profiles/l40-2gpu.json +22 -0
  18. package/python/config/profiles/l40-4gpu.json +21 -0
  19. package/python/config/profiles/l40.json +17 -0
  20. package/python/config/tinker_training.yaml +143 -0
  21. package/python/curriculum_state.json +165 -0
  22. package/python/env.template +86 -0
  23. package/python/env.training.template +46 -0
  24. package/python/pyproject.toml +41 -0
  25. package/python/requirements-ci.txt +31 -0
  26. package/python/requirements.txt +87 -0
  27. package/python/scripts/__init__.py +4 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/run_ab_test.py +143 -0
  36. package/python/scripts/run_full_pipeline.py +544 -0
  37. package/python/scripts/run_tinker_training.py +192 -0
  38. package/python/scripts/run_training.py +914 -0
  39. package/python/scripts/test_judge.py +155 -0
  40. package/python/scripts/test_pipeline.py +356 -0
  41. package/python/scripts/test_trained_model.py +380 -0
  42. package/python/scripts/train_local.py +528 -0
  43. package/python/setup.py +20 -0
  44. package/python/src/__init__.py +190 -0
  45. package/python/src/data_bridge/__init__.py +24 -0
  46. package/python/src/data_bridge/converter.py +435 -0
  47. package/python/src/data_bridge/reader.py +393 -0
  48. package/python/src/models.py +283 -0
  49. package/python/src/training/__init__.py +605 -0
  50. package/python/src/training/ab_testing.py +404 -0
  51. package/python/src/training/action_executor.py +621 -0
  52. package/python/src/training/archetype_trainer.py +347 -0
  53. package/python/src/training/atropos_trainer.py +980 -0
  54. package/python/src/training/babylon_env.py +1254 -0
  55. package/python/src/training/error_recovery.py +647 -0
  56. package/python/src/training/evaluation.py +856 -0
  57. package/python/src/training/fast_simulator.py +880 -0
  58. package/python/src/training/format_validator.py +584 -0
  59. package/python/src/training/hybrid_env.py +522 -0
  60. package/python/src/training/kl_controller.py +628 -0
  61. package/python/src/training/multi_prompt_dataset.py +883 -0
  62. package/python/src/training/multi_turn.py +656 -0
  63. package/python/src/training/online_env.py +1084 -0
  64. package/python/src/training/quality_scorer.py +391 -0
  65. package/python/src/training/quality_utils.py +633 -0
  66. package/python/src/training/rewards.py +1344 -0
  67. package/python/src/training/rlaif_env.py +17 -0
  68. package/python/src/training/rollout_generator.py +502 -0
  69. package/python/src/training/rubric_loader.py +198 -0
  70. package/python/src/training/scenario_pool.py +1072 -0
  71. package/python/src/training/schemas.py +481 -0
  72. package/python/src/training/service_manager.py +552 -0
  73. package/python/src/training/simulation_bridge.py +535 -0
  74. package/python/src/training/tick_reward_attribution.py +399 -0
  75. package/python/src/training/tinker_client.py +575 -0
  76. package/python/src/training/tinker_trainer.py +646 -0
  77. package/python/src/training/tokenization_utils.py +402 -0
  78. package/python/tests/e2e/__init__.py +13 -0
  79. package/python/tests/e2e/conftest.py +258 -0
  80. package/python/tests/e2e/test_full_pipeline.py +643 -0
  81. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  82. package/python/tests/integration/__init__.py +12 -0
  83. package/python/tests/integration/conftest.py +383 -0
  84. package/python/tests/integration/test_db_integration.py +649 -0
  85. package/python/tests/integration/test_json_mode_integration.py +554 -0
  86. package/python/tests/test_action_executor.py +594 -0
  87. package/python/tests/test_archetype_scoring.py +1027 -0
  88. package/python/tests/test_atropos_integration.py +360 -0
  89. package/python/tests/test_evaluation.py +727 -0
  90. package/python/tests/test_format_validator.py +486 -0
  91. package/python/tests/test_kl_controller.py +432 -0
  92. package/python/tests/test_lr_scheduler.py +579 -0
  93. package/python/tests/test_multi_turn.py +590 -0
  94. package/python/tests/test_online_env.py +519 -0
  95. package/python/tests/test_quality_scorer.py +474 -0
  96. package/python/tests/test_scenario_pool.py +735 -0
  97. package/python/tests/test_service_manager.py +585 -0
  98. package/python/tests/test_simulation_rollout.py +581 -0
  99. package/python/tests/test_tokenization_utils.py +501 -0
  100. package/python/tests/test_training_orchestrator.py +497 -0
  101. package/python/tests/test_training_output_structure.py +661 -0
  102. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  103. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  104. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  105. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  106. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  107. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  108. package/scripts/assess-training-data.ts +422 -0
  109. package/scripts/e2e-training-test.ts +550 -0
  110. package/scripts/export-rubrics.ts +64 -0
  111. package/scripts/generate-research-report.ts +1523 -0
  112. package/scripts/generate_dataset.sh +173 -0
  113. package/scripts/json-mode-benchmark.ts +399 -0
  114. package/scripts/real-archetype-benchmark.ts +210 -0
  115. package/scripts/run-baseline-comparison.ts +116 -0
  116. package/scripts/run-full-pipeline.ts +272 -0
  117. package/scripts/runpod_setup.sh +137 -0
  118. package/scripts/runpod_validate.sh +147 -0
  119. package/scripts/test-model-in-game.ts +955 -0
  120. package/scripts/test-scoring.ts +73 -0
  121. package/scripts/test-trained-model.ts +209 -0
  122. package/scripts/train-and-test.ts +824 -0
  123. package/scripts/verify-final.ts +118 -0
  124. package/src/adapter.ts +516 -0
  125. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  126. package/src/archetypes/derive-archetype.ts +249 -0
  127. package/src/archetypes/index.ts +22 -0
  128. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  129. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  130. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  131. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  132. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  133. package/src/benchmark/BenchmarkRunner.ts +685 -0
  134. package/src/benchmark/BenchmarkValidator.ts +206 -0
  135. package/src/benchmark/FastEvalRunner.ts +225 -0
  136. package/src/benchmark/MetricsValidator.ts +165 -0
  137. package/src/benchmark/MetricsVisualizer.ts +909 -0
  138. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  139. package/src/benchmark/ModelRegistry.ts +158 -0
  140. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  141. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  142. package/src/benchmark/SimulationEngine.ts +832 -0
  143. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  144. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  145. package/src/benchmark/index.ts +89 -0
  146. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  147. package/src/benchmark/simulation-types.ts +78 -0
  148. package/src/dependencies.ts +439 -0
  149. package/src/generation/TrajectoryGenerator.ts +387 -0
  150. package/src/generation/index.ts +12 -0
  151. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  152. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  153. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  154. package/src/huggingface/index.ts +27 -0
  155. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  156. package/src/index.ts +102 -0
  157. package/src/init-training.ts +53 -0
  158. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  159. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  160. package/src/metrics/index.ts +8 -0
  161. package/src/metrics/types.ts +200 -0
  162. package/src/rubrics/__tests__/index.test.ts +184 -0
  163. package/src/rubrics/ass-kisser.ts +85 -0
  164. package/src/rubrics/degen.ts +80 -0
  165. package/src/rubrics/goody-twoshoes.ts +84 -0
  166. package/src/rubrics/index.ts +236 -0
  167. package/src/rubrics/information-trader.ts +84 -0
  168. package/src/rubrics/infosec.ts +101 -0
  169. package/src/rubrics/liar.ts +104 -0
  170. package/src/rubrics/perps-trader.ts +87 -0
  171. package/src/rubrics/researcher.ts +81 -0
  172. package/src/rubrics/scammer.ts +82 -0
  173. package/src/rubrics/social-butterfly.ts +73 -0
  174. package/src/rubrics/super-predictor.ts +97 -0
  175. package/src/rubrics/trader.ts +67 -0
  176. package/src/scoring/ArchetypeScoringService.ts +486 -0
  177. package/src/scoring/JudgePromptBuilder.ts +556 -0
  178. package/src/scoring/LLMJudgeCache.ts +401 -0
  179. package/src/scoring/index.ts +9 -0
  180. package/src/training/AutomationPipeline.ts +916 -0
  181. package/src/training/BenchmarkService.ts +518 -0
  182. package/src/training/ConfigValidator.ts +220 -0
  183. package/src/training/MarketOutcomesTracker.ts +187 -0
  184. package/src/training/ModelDeployer.ts +186 -0
  185. package/src/training/ModelFetcher.ts +76 -0
  186. package/src/training/ModelSelectionService.ts +341 -0
  187. package/src/training/ModelUsageVerifier.ts +160 -0
  188. package/src/training/MultiModelOrchestrator.ts +580 -0
  189. package/src/training/RLModelConfig.ts +407 -0
  190. package/src/training/RewardBackpropagationService.ts +149 -0
  191. package/src/training/RulerScoringService.ts +666 -0
  192. package/src/training/TrainingMonitor.ts +166 -0
  193. package/src/training/TrajectoryRecorder.ts +399 -0
  194. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  195. package/src/training/index.ts +100 -0
  196. package/src/training/logRLConfig.ts +34 -0
  197. package/src/training/pipeline.ts +129 -0
  198. package/src/training/storage/ModelStorageService.ts +279 -0
  199. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  200. package/src/training/storage/index.ts +17 -0
  201. package/src/training/types.ts +207 -0
  202. package/src/training/window-utils.ts +138 -0
  203. package/src/utils/index.ts +101 -0
  204. package/src/utils/logger.ts +59 -0
  205. package/src/utils/snowflake.ts +17 -0
  206. package/src/utils/synthetic-detector.ts +111 -0
  207. package/tsconfig.json +20 -0
@@ -0,0 +1,649 @@
1
+ """
2
+ Database Integration Tests
3
+
4
+ Tests the complete database-based training data pipeline:
5
+ 1. Trajectory insertion with archetype
6
+ 2. Querying trajectories by archetype
7
+ 3. Scoring database trajectories
8
+ 4. GRPO group formation from database
9
+ 5. End-to-end database pipeline
10
+
11
+ These tests REQUIRE a running PostgreSQL database.
12
+ Use docker compose -f docker-compose.test.yml up -d before running.
13
+ """
14
+
15
+ import json
16
+ import os
17
+ import sys
18
+ from pathlib import Path
19
+ from typing import Dict, List
20
+ from datetime import datetime
21
+ import uuid
22
+
23
+ import pytest
24
+
25
+ # Add src to path
26
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
27
+
28
+ from src.training.rewards import (
29
+ archetype_composite_reward,
30
+ BehaviorMetrics,
31
+ TrajectoryRewardInputs,
32
+ )
33
+ from src.training.rubric_loader import (
34
+ normalize_archetype,
35
+ has_custom_rubric,
36
+ get_available_archetypes,
37
+ )
38
+ from tests.integration.conftest import (
39
+ TrajectoryFixture,
40
+ skip_if_no_database,
41
+ is_database_available,
42
+ )
43
+
44
+
45
+ # Skip all tests in this module if database is not available
46
+ pytestmark = skip_if_no_database()
47
+
48
+
49
+ def generate_test_id() -> str:
50
+ """Generate a unique test ID to avoid conflicts."""
51
+ return f"test-{uuid.uuid4().hex[:8]}"
52
+
53
+
54
+ class TestDatabaseTrajectoryOperations:
55
+ """Test trajectory CRUD operations in database."""
56
+
57
+ @pytest.fixture(autouse=True)
58
+ def setup_db(self, database_url: str):
59
+ """Setup database connection and cleanup."""
60
+ import psycopg2
61
+ self.conn = psycopg2.connect(database_url)
62
+ self.test_prefix = f"test-{uuid.uuid4().hex[:8]}"
63
+ yield
64
+ # Cleanup test data
65
+ cur = self.conn.cursor()
66
+ cur.execute(
67
+ 'DELETE FROM trajectories WHERE "trajectoryId" LIKE %s',
68
+ (f"{self.test_prefix}%",)
69
+ )
70
+ self.conn.commit()
71
+ cur.close()
72
+ self.conn.close()
73
+
74
+ def _insert_trajectory(
75
+ self,
76
+ trajectory_id: str,
77
+ agent_id: str,
78
+ archetype: str,
79
+ window_id: str,
80
+ steps: List[Dict],
81
+ final_pnl: float,
82
+ episode_length: int,
83
+ ):
84
+ """Insert a test trajectory into the database."""
85
+ cur = self.conn.cursor()
86
+ cur.execute(
87
+ '''
88
+ INSERT INTO trajectories (
89
+ "id", "trajectoryId", "agentId", "archetype", "windowId",
90
+ "stepsJson", "rewardComponentsJson", "metricsJson", "metadataJson",
91
+ "finalPnL", "episodeLength", "totalReward",
92
+ "finalStatus", "isTrainingData", "startTime", "endTime",
93
+ "durationMs", "createdAt", "updatedAt"
94
+ ) VALUES (
95
+ %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
96
+ )
97
+ ''',
98
+ (
99
+ trajectory_id,
100
+ trajectory_id,
101
+ agent_id,
102
+ archetype,
103
+ window_id,
104
+ json.dumps(steps),
105
+ "{}", # rewardComponentsJson
106
+ "{}", # metricsJson
107
+ "{}", # metadataJson
108
+ final_pnl,
109
+ episode_length,
110
+ 0.5,
111
+ "completed",
112
+ True,
113
+ datetime.now(),
114
+ datetime.now(),
115
+ 5000,
116
+ datetime.now(),
117
+ datetime.now(),
118
+ )
119
+ )
120
+ self.conn.commit()
121
+ cur.close()
122
+
123
+ def test_insert_trajectory_with_archetype(
124
+ self,
125
+ sample_trader_trajectory: TrajectoryFixture,
126
+ ):
127
+ """Test inserting a trajectory with archetype."""
128
+ traj_id = f"{self.test_prefix}-trader-001"
129
+
130
+ self._insert_trajectory(
131
+ trajectory_id=traj_id,
132
+ agent_id="test-agent",
133
+ archetype="trader",
134
+ window_id=f"{self.test_prefix}-window-1",
135
+ steps=sample_trader_trajectory.steps,
136
+ final_pnl=sample_trader_trajectory.final_pnl,
137
+ episode_length=sample_trader_trajectory.episode_length,
138
+ )
139
+
140
+ # Verify insertion
141
+ cur = self.conn.cursor()
142
+ cur.execute(
143
+ 'SELECT "archetype" FROM trajectories WHERE "trajectoryId" = %s',
144
+ (traj_id,)
145
+ )
146
+ row = cur.fetchone()
147
+ cur.close()
148
+
149
+ assert row is not None
150
+ assert row[0] == "trader"
151
+
152
+ def test_query_trajectories_by_archetype(self):
153
+ """Test querying trajectories filtered by archetype."""
154
+ window_id = f"{self.test_prefix}-window-multi"
155
+
156
+ # Insert multiple archetypes
157
+ for i, archetype in enumerate(["trader", "degen", "scammer"]):
158
+ self._insert_trajectory(
159
+ trajectory_id=f"{self.test_prefix}-{archetype}-{i}",
160
+ agent_id=f"agent-{archetype}",
161
+ archetype=archetype,
162
+ window_id=window_id,
163
+ steps=[],
164
+ final_pnl=100.0 * (i + 1),
165
+ episode_length=3,
166
+ )
167
+
168
+ # Query traders only
169
+ cur = self.conn.cursor()
170
+ cur.execute(
171
+ '''
172
+ SELECT "trajectoryId", "archetype" FROM trajectories
173
+ WHERE "windowId" = %s AND "archetype" = %s
174
+ ''',
175
+ (window_id, "trader")
176
+ )
177
+ rows = cur.fetchall()
178
+ cur.close()
179
+
180
+ assert len(rows) == 1
181
+ assert rows[0][1] == "trader"
182
+
183
+ def test_query_trajectories_by_window_with_archetypes(self):
184
+ """Test querying all trajectories in window with their archetypes."""
185
+ window_id = f"{self.test_prefix}-window-group"
186
+ archetypes = ["trader", "degen", "scammer", "social-butterfly"]
187
+
188
+ # Insert multiple archetypes
189
+ for i, archetype in enumerate(archetypes):
190
+ self._insert_trajectory(
191
+ trajectory_id=f"{self.test_prefix}-group-{i}",
192
+ agent_id=f"agent-{archetype}",
193
+ archetype=archetype,
194
+ window_id=window_id,
195
+ steps=[],
196
+ final_pnl=100.0 * (i + 1),
197
+ episode_length=3,
198
+ )
199
+
200
+ # Query all in window
201
+ cur = self.conn.cursor()
202
+ cur.execute(
203
+ '''
204
+ SELECT "trajectoryId", "archetype", "finalPnL" FROM trajectories
205
+ WHERE "windowId" = %s AND "isTrainingData" = true
206
+ ORDER BY "archetype"
207
+ ''',
208
+ (window_id,)
209
+ )
210
+ rows = cur.fetchall()
211
+ cur.close()
212
+
213
+ assert len(rows) == 4
214
+ db_archetypes = {row[1] for row in rows}
215
+ assert db_archetypes == set(archetypes)
216
+
217
+ def test_null_archetype_defaults_to_default(self):
218
+ """Test that NULL archetype is handled correctly."""
219
+ traj_id = f"{self.test_prefix}-null-arch"
220
+
221
+ # Insert with NULL archetype
222
+ cur = self.conn.cursor()
223
+ cur.execute(
224
+ '''
225
+ INSERT INTO trajectories (
226
+ "id", "trajectoryId", "agentId", "archetype", "windowId",
227
+ "stepsJson", "rewardComponentsJson", "metricsJson", "metadataJson",
228
+ "finalPnL", "episodeLength", "totalReward",
229
+ "finalStatus", "isTrainingData", "startTime", "endTime",
230
+ "durationMs", "createdAt", "updatedAt"
231
+ ) VALUES (
232
+ %s, %s, %s, NULL, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
233
+ )
234
+ ''',
235
+ (
236
+ traj_id, traj_id, "test-agent", f"{self.test_prefix}-window-null",
237
+ "[]", "{}", "{}", "{}", 100.0, 3, 0.5, "completed", True,
238
+ datetime.now(), datetime.now(), 5000, datetime.now(), datetime.now(),
239
+ )
240
+ )
241
+ self.conn.commit()
242
+ cur.close()
243
+
244
+ # Query and normalize
245
+ cur = self.conn.cursor()
246
+ cur.execute(
247
+ 'SELECT "archetype" FROM trajectories WHERE "trajectoryId" = %s',
248
+ (traj_id,)
249
+ )
250
+ row = cur.fetchone()
251
+ cur.close()
252
+
253
+ archetype = row[0]
254
+ normalized = normalize_archetype(archetype)
255
+ assert normalized == "default"
256
+
257
+ def test_archetype_case_normalization_in_scoring(self):
258
+ """Test that archetype case variations are normalized correctly."""
259
+ test_cases = [
260
+ ("TRADER", "trader"),
261
+ ("Social_Butterfly", "social-butterfly"),
262
+ ("goody_twoshoes", "goody-twoshoes"),
263
+ ]
264
+
265
+ for db_value, expected_normalized in test_cases:
266
+ normalized = normalize_archetype(db_value)
267
+ assert normalized == expected_normalized
268
+ assert has_custom_rubric(normalized)
269
+
270
+
271
+ class TestDatabaseScoring:
272
+ """Test scoring trajectories from database."""
273
+
274
+ @pytest.fixture(autouse=True)
275
+ def setup_db(self, database_url: str):
276
+ """Setup database connection and cleanup."""
277
+ import psycopg2
278
+ self.conn = psycopg2.connect(database_url)
279
+ self.test_prefix = f"test-{uuid.uuid4().hex[:8]}"
280
+ yield
281
+ cur = self.conn.cursor()
282
+ cur.execute(
283
+ 'DELETE FROM trajectories WHERE "trajectoryId" LIKE %s',
284
+ (f"{self.test_prefix}%",)
285
+ )
286
+ self.conn.commit()
287
+ cur.close()
288
+ self.conn.close()
289
+
290
+ def _insert_and_fetch_trajectory(
291
+ self,
292
+ archetype: str,
293
+ final_pnl: float,
294
+ steps: List[Dict],
295
+ ) -> Dict:
296
+ """Insert trajectory and return fetched data."""
297
+ traj_id = f"{self.test_prefix}-{archetype}-{uuid.uuid4().hex[:4]}"
298
+ window_id = f"{self.test_prefix}-scoring-window"
299
+
300
+ cur = self.conn.cursor()
301
+ cur.execute(
302
+ '''
303
+ INSERT INTO trajectories (
304
+ "id", "trajectoryId", "agentId", "archetype", "windowId",
305
+ "stepsJson", "rewardComponentsJson", "metricsJson", "metadataJson",
306
+ "finalPnL", "episodeLength", "totalReward",
307
+ "finalStatus", "isTrainingData", "startTime", "endTime",
308
+ "durationMs", "createdAt", "updatedAt"
309
+ ) VALUES (
310
+ %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
311
+ )
312
+ RETURNING "trajectoryId", "archetype", "stepsJson", "finalPnL", "episodeLength"
313
+ ''',
314
+ (
315
+ traj_id, traj_id, f"agent-{archetype}", archetype, window_id,
316
+ json.dumps(steps), "{}", "{}", "{}", final_pnl, len(steps), 0.5, "completed", True,
317
+ datetime.now(), datetime.now(), 5000, datetime.now(), datetime.now(),
318
+ )
319
+ )
320
+ row = cur.fetchone()
321
+ self.conn.commit()
322
+ cur.close()
323
+
324
+ return {
325
+ "trajectory_id": row[0],
326
+ "archetype": row[1],
327
+ "steps": json.loads(row[2]),
328
+ "final_pnl": float(row[3]),
329
+ "episode_length": row[4],
330
+ }
331
+
332
+ def test_score_trader_from_database(
333
+ self,
334
+ sample_trader_trajectory: TrajectoryFixture,
335
+ ):
336
+ """Test scoring a trader trajectory from database."""
337
+ traj_data = self._insert_and_fetch_trajectory(
338
+ archetype="trader",
339
+ final_pnl=150.0,
340
+ steps=sample_trader_trajectory.steps,
341
+ )
342
+
343
+ behavior = BehaviorMetrics(
344
+ trades_executed=3,
345
+ total_pnl=traj_data["final_pnl"],
346
+ episode_length=traj_data["episode_length"],
347
+ )
348
+
349
+ inputs = TrajectoryRewardInputs(
350
+ final_pnl=traj_data["final_pnl"],
351
+ starting_balance=10000.0,
352
+ end_balance=10000.0 + traj_data["final_pnl"],
353
+ format_score=0.8,
354
+ reasoning_score=0.75,
355
+ )
356
+
357
+ score = archetype_composite_reward(
358
+ inputs,
359
+ normalize_archetype(traj_data["archetype"]),
360
+ behavior
361
+ )
362
+
363
+ assert 0.0 <= score <= 1.0
364
+ assert score > 0.3 # Profitable trader should score reasonably
365
+
366
+ def test_score_multiple_archetypes_from_database(self):
367
+ """Test scoring multiple archetype trajectories from database."""
368
+ archetypes_pnl = [
369
+ ("trader", 200.0),
370
+ ("degen", -500.0),
371
+ ("scammer", 300.0),
372
+ ("social-butterfly", 20.0),
373
+ ]
374
+
375
+ scores = {}
376
+ for archetype, pnl in archetypes_pnl:
377
+ traj_data = self._insert_and_fetch_trajectory(
378
+ archetype=archetype,
379
+ final_pnl=pnl,
380
+ steps=[],
381
+ )
382
+
383
+ behavior = BehaviorMetrics(
384
+ trades_executed=5,
385
+ total_pnl=pnl,
386
+ episode_length=5,
387
+ )
388
+
389
+ inputs = TrajectoryRewardInputs(
390
+ final_pnl=pnl,
391
+ starting_balance=10000.0,
392
+ end_balance=10000.0 + pnl,
393
+ format_score=0.7,
394
+ reasoning_score=0.7,
395
+ )
396
+
397
+ score = archetype_composite_reward(
398
+ inputs,
399
+ normalize_archetype(archetype),
400
+ behavior
401
+ )
402
+ scores[archetype] = score
403
+
404
+ # All scores should be valid
405
+ for archetype, score in scores.items():
406
+ assert 0.0 <= score <= 1.0, f"{archetype} has invalid score: {score}"
407
+
408
+ def test_grpo_group_formation_from_database(self):
409
+ """Test forming GRPO groups from database trajectories."""
410
+ window_id = f"{self.test_prefix}-grpo-window"
411
+ archetypes = ["trader", "degen", "scammer"]
412
+
413
+ # Insert multiple trajectories to same window
414
+ for i, archetype in enumerate(archetypes):
415
+ cur = self.conn.cursor()
416
+ traj_id = f"{self.test_prefix}-grpo-{i}"
417
+ cur.execute(
418
+ '''
419
+ INSERT INTO trajectories (
420
+ "id", "trajectoryId", "agentId", "archetype", "windowId",
421
+ "stepsJson", "rewardComponentsJson", "metricsJson", "metadataJson",
422
+ "finalPnL", "episodeLength", "totalReward",
423
+ "finalStatus", "isTrainingData", "startTime", "endTime",
424
+ "durationMs", "createdAt", "updatedAt"
425
+ ) VALUES (
426
+ %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
427
+ )
428
+ ''',
429
+ (
430
+ traj_id, traj_id, f"agent-{i}", archetype, window_id,
431
+ "[]", "{}", "{}", "{}", 100.0 * (i + 1), 3, 0.5, "completed", True,
432
+ datetime.now(), datetime.now(), 5000, datetime.now(), datetime.now(),
433
+ )
434
+ )
435
+ self.conn.commit()
436
+ cur.close()
437
+
438
+ # Query group
439
+ cur = self.conn.cursor()
440
+ cur.execute(
441
+ '''
442
+ SELECT "trajectoryId", "archetype", "finalPnL", "stepsJson", "episodeLength"
443
+ FROM trajectories
444
+ WHERE "windowId" = %s AND "isTrainingData" = true
445
+ ''',
446
+ (window_id,)
447
+ )
448
+ rows = cur.fetchall()
449
+ cur.close()
450
+
451
+ assert len(rows) >= 2 # GRPO requires at least 2
452
+
453
+ # Score all trajectories
454
+ scores = []
455
+ for row in rows:
456
+ archetype = normalize_archetype(row[1])
457
+ pnl = float(row[2])
458
+
459
+ behavior = BehaviorMetrics(trades_executed=3, total_pnl=pnl)
460
+ inputs = TrajectoryRewardInputs(
461
+ final_pnl=pnl,
462
+ starting_balance=10000.0,
463
+ end_balance=10000.0 + pnl,
464
+ format_score=0.7,
465
+ reasoning_score=0.7,
466
+ )
467
+
468
+ score = archetype_composite_reward(inputs, archetype, behavior)
469
+ scores.append(score)
470
+
471
+ # Center for GRPO
472
+ mean_score = sum(scores) / len(scores)
473
+ centered = [s - mean_score for s in scores]
474
+ centered_mean = sum(centered) / len(centered)
475
+
476
+ assert abs(centered_mean) < 0.01
477
+
478
+
479
+ class TestEndToEndDatabasePipeline:
480
+ """Test complete database pipeline end-to-end."""
481
+
482
+ @pytest.fixture(autouse=True)
483
+ def setup_db(self, database_url: str):
484
+ """Setup database connection and cleanup."""
485
+ import psycopg2
486
+ self.conn = psycopg2.connect(database_url)
487
+ self.test_prefix = f"test-{uuid.uuid4().hex[:8]}"
488
+ yield
489
+ cur = self.conn.cursor()
490
+ cur.execute(
491
+ 'DELETE FROM trajectories WHERE "trajectoryId" LIKE %s',
492
+ (f"{self.test_prefix}%",)
493
+ )
494
+ self.conn.commit()
495
+ cur.close()
496
+ self.conn.close()
497
+
498
+ def test_full_database_pipeline(
499
+ self,
500
+ trajectory_group: List[TrajectoryFixture],
501
+ ):
502
+ """Test full pipeline: insert → query → score → center."""
503
+ window_id = f"{self.test_prefix}-full-pipeline"
504
+
505
+ # Step 1: Insert trajectories
506
+ cur = self.conn.cursor()
507
+ for traj in trajectory_group:
508
+ traj_id = f"{self.test_prefix}-{traj.archetype}"
509
+ cur.execute(
510
+ '''
511
+ INSERT INTO trajectories (
512
+ "id", "trajectoryId", "agentId", "archetype", "windowId",
513
+ "stepsJson", "rewardComponentsJson", "metricsJson", "metadataJson",
514
+ "finalPnL", "episodeLength", "totalReward",
515
+ "finalStatus", "isTrainingData", "startTime", "endTime",
516
+ "durationMs", "createdAt", "updatedAt"
517
+ ) VALUES (
518
+ %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
519
+ )
520
+ ''',
521
+ (
522
+ traj_id, traj_id, traj.agent_id, traj.archetype, window_id,
523
+ json.dumps(traj.steps), "{}", "{}", "{}", traj.final_pnl, traj.episode_length,
524
+ traj.total_reward, "completed", True,
525
+ datetime.now(), datetime.now(), 5000, datetime.now(), datetime.now(),
526
+ )
527
+ )
528
+ self.conn.commit()
529
+ cur.close()
530
+
531
+ # Step 2: Query trajectories
532
+ cur = self.conn.cursor()
533
+ cur.execute(
534
+ '''
535
+ SELECT "trajectoryId", "archetype", "stepsJson", "finalPnL", "episodeLength"
536
+ FROM trajectories
537
+ WHERE "windowId" = %s AND "isTrainingData" = true
538
+ ''',
539
+ (window_id,)
540
+ )
541
+ rows = cur.fetchall()
542
+ cur.close()
543
+
544
+ assert len(rows) == 3
545
+
546
+ # Step 3: Score each trajectory
547
+ scored_trajectories = []
548
+ for row in rows:
549
+ traj_id, archetype, steps_json, pnl, episode_length = row
550
+ archetype_norm = normalize_archetype(archetype)
551
+ steps = json.loads(steps_json)
552
+
553
+ behavior = BehaviorMetrics(
554
+ trades_executed=len([s for s in steps if s.get("action", {}).get("actionType") != "hold"]),
555
+ total_pnl=float(pnl),
556
+ episode_length=episode_length,
557
+ )
558
+
559
+ inputs = TrajectoryRewardInputs(
560
+ final_pnl=float(pnl),
561
+ starting_balance=10000.0,
562
+ end_balance=10000.0 + float(pnl),
563
+ format_score=0.7,
564
+ reasoning_score=0.7,
565
+ )
566
+
567
+ score = archetype_composite_reward(inputs, archetype_norm, behavior)
568
+ scored_trajectories.append({
569
+ "trajectory_id": traj_id,
570
+ "archetype": archetype_norm,
571
+ "pnl": float(pnl),
572
+ "score": score,
573
+ })
574
+
575
+ # Step 4: Center scores for GRPO
576
+ mean_score = sum(t["score"] for t in scored_trajectories) / len(scored_trajectories)
577
+ for t in scored_trajectories:
578
+ t["centered_score"] = t["score"] - mean_score
579
+
580
+ # Verify results
581
+ centered_mean = sum(t["centered_score"] for t in scored_trajectories) / len(scored_trajectories)
582
+ assert abs(centered_mean) < 0.01
583
+
584
+ # Verify trader scores higher than degen (positive PnL vs negative)
585
+ trader = next(t for t in scored_trajectories if t["archetype"] == "trader")
586
+ degen = next(t for t in scored_trajectories if t["archetype"] == "degen")
587
+ assert trader["score"] > degen["score"]
588
+
589
+ def test_pipeline_with_all_archetypes(self):
590
+ """Test pipeline handles all valid archetypes."""
591
+ window_id = f"{self.test_prefix}-all-archetypes"
592
+ archetypes = get_available_archetypes()
593
+
594
+ # Insert one trajectory per archetype
595
+ cur = self.conn.cursor()
596
+ for i, archetype in enumerate(archetypes):
597
+ traj_id = f"{self.test_prefix}-all-{i}"
598
+ cur.execute(
599
+ '''
600
+ INSERT INTO trajectories (
601
+ "id", "trajectoryId", "agentId", "archetype", "windowId",
602
+ "stepsJson", "rewardComponentsJson", "metricsJson", "metadataJson",
603
+ "finalPnL", "episodeLength", "totalReward",
604
+ "finalStatus", "isTrainingData", "startTime", "endTime",
605
+ "durationMs", "createdAt", "updatedAt"
606
+ ) VALUES (
607
+ %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
608
+ )
609
+ ''',
610
+ (
611
+ traj_id, traj_id, f"agent-{archetype}", archetype, window_id,
612
+ "[]", "{}", "{}", "{}", 100.0 + i * 10, 3, 0.5, "completed", True,
613
+ datetime.now(), datetime.now(), 5000, datetime.now(), datetime.now(),
614
+ )
615
+ )
616
+ self.conn.commit()
617
+ cur.close()
618
+
619
+ # Query and score all
620
+ cur = self.conn.cursor()
621
+ cur.execute(
622
+ '''
623
+ SELECT "archetype", "finalPnL" FROM trajectories
624
+ WHERE "windowId" = %s
625
+ ''',
626
+ (window_id,)
627
+ )
628
+ rows = cur.fetchall()
629
+ cur.close()
630
+
631
+ assert len(rows) == len(archetypes)
632
+
633
+ # Score each
634
+ for archetype, pnl in rows:
635
+ normalized = normalize_archetype(archetype)
636
+ assert has_custom_rubric(normalized), f"{archetype} should have rubric"
637
+
638
+ behavior = BehaviorMetrics(trades_executed=3, total_pnl=float(pnl))
639
+ inputs = TrajectoryRewardInputs(
640
+ final_pnl=float(pnl),
641
+ starting_balance=10000.0,
642
+ end_balance=10000.0 + float(pnl),
643
+ format_score=0.7,
644
+ reasoning_score=0.7,
645
+ )
646
+
647
+ score = archetype_composite_reward(inputs, normalized, behavior)
648
+ assert 0.0 <= score <= 1.0
649
+