@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,393 @@
1
+ """
2
+ Babylon Trajectory Reader
3
+
4
+ Reads trajectories from PostgreSQL database or local JSON files for training.
5
+ Validates LLM call quality to ensure training data authenticity.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ from dataclasses import dataclass
11
+ from typing import Optional, List, Dict
12
+ from pathlib import Path
13
+ import logging
14
+
15
+ # Handle optional psycopg2 import for JSON-only workflows.
16
+ try:
17
+ import psycopg2
18
+ except ImportError:
19
+ psycopg2 = None
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class TrajectoryRow:
26
+ """Raw trajectory data from database. Used by PostgresTrajectoryReader."""
27
+
28
+ trajectory_id: str
29
+ agent_id: str
30
+ window_id: str
31
+ steps_json: str
32
+ metrics_json: str
33
+ metadata_json: str
34
+ total_reward: float
35
+ episode_length: int
36
+ final_status: str
37
+ final_pnl: Optional[float]
38
+ trades_executed: Optional[int]
39
+ ai_judge_reward: Optional[float]
40
+ archetype: Optional[str]
41
+
42
+
43
+ def get_connection():
44
+ """
45
+ Get PostgreSQL connection from environment.
46
+
47
+ Returns:
48
+ psycopg2 connection
49
+
50
+ Raises:
51
+ ValueError: If DATABASE_URL not set
52
+ ImportError: If psycopg2 is not installed
53
+ """
54
+ if psycopg2 is None:
55
+ raise ImportError(
56
+ "psycopg2 is not installed. Please install it with 'pip install psycopg2-binary'")
57
+ database_url = os.environ.get("DATABASE_URL")
58
+ if not database_url:
59
+ raise ValueError("DATABASE_URL environment variable required")
60
+ return psycopg2.connect(database_url)
61
+
62
+
63
+ def validate_llm_calls(steps: list, min_steps_with_llm: int = 3) -> tuple[bool, list[str]]:
64
+ """
65
+ Validate trajectory steps contain real LLM calls.
66
+
67
+ Training data MUST have actual LLM calls with real prompts and responses.
68
+ Synthetic or placeholder data will cause training failures.
69
+
70
+ Args:
71
+ steps: List of trajectory steps
72
+ min_steps_with_llm: Minimum steps with valid LLM calls
73
+
74
+ Returns:
75
+ Tuple of (is_valid, list of issue descriptions)
76
+ """
77
+ issues: list[str] = []
78
+ steps_with_llm = 0
79
+
80
+ if not steps:
81
+ issues.append("Trajectory has no steps.")
82
+ return False, issues
83
+
84
+ for i, step in enumerate(steps):
85
+ llm_calls = step.get("llmCalls") or step.get("llm_calls") or []
86
+ if not llm_calls:
87
+ continue
88
+
89
+ valid_calls_in_step = 0
90
+ for call_idx, call in enumerate(llm_calls):
91
+ system_prompt = call.get("systemPrompt") or call.get(
92
+ "system_prompt") or ""
93
+ user_prompt = call.get("userPrompt") or call.get(
94
+ "user_prompt") or ""
95
+ response = call.get("response") or ""
96
+
97
+ call_issues = []
98
+ if len(system_prompt) < 20:
99
+ call_issues.append("system_prompt too short")
100
+ if len(user_prompt) < 20:
101
+ call_issues.append("user_prompt too short")
102
+ if len(response) < 20:
103
+ call_issues.append("response too short")
104
+
105
+ if not call_issues:
106
+ valid_calls_in_step += 1
107
+ else:
108
+ issues.append(
109
+ f"Step {i}, Call {call_idx}: " + ", ".join(call_issues))
110
+
111
+ if valid_calls_in_step > 0:
112
+ steps_with_llm += 1
113
+
114
+ if steps_with_llm < min_steps_with_llm:
115
+ issues.append(
116
+ f"Only {steps_with_llm}/{len(steps)} steps have valid LLM calls (need at least {min_steps_with_llm})")
117
+
118
+ return len(issues) == 0, issues
119
+
120
+
121
+ class PostgresTrajectoryReader:
122
+ """Reads Babylon trajectories from a PostgreSQL database."""
123
+
124
+ def __init__(self, database_url: str):
125
+ if psycopg2 is None:
126
+ raise ImportError(
127
+ "psycopg2 is not installed for PostgresTrajectoryReader. Please install it with 'pip install psycopg2-binary'")
128
+ if not database_url:
129
+ raise ValueError(
130
+ "DATABASE_URL must be provided for PostgresTrajectoryReader")
131
+ self.db_url = database_url
132
+ self.conn = None
133
+
134
+ async def __aenter__(self):
135
+ """Connect to the database upon entering the async context."""
136
+ # Check to satisfy Pylance's static analysis
137
+ if psycopg2 is None:
138
+ raise ImportError(
139
+ "psycopg2 is not installed, cannot connect to database.")
140
+ self.conn = psycopg2.connect(self.db_url)
141
+ return self
142
+
143
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
144
+ """Close the database connection upon exiting the context."""
145
+ if self.conn:
146
+ self.conn.close()
147
+
148
+ async def get_window_ids(self, limit: int = 100, only_scored: bool = True, lookback_hours: int = 168, min_agents: int = 1) -> list[str]:
149
+ if not self.conn:
150
+ raise ConnectionError("Database not connected.")
151
+ with self.conn.cursor() as cur:
152
+ query = """
153
+ SELECT DISTINCT "windowId" FROM trajectories
154
+ WHERE "isTrainingData" = true AND "createdAt" > NOW() - INTERVAL '%s hours'
155
+ """
156
+ params = [lookback_hours]
157
+ if only_scored:
158
+ query += ' AND "aiJudgeReward" IS NOT NULL'
159
+ query += ' ORDER BY "windowId" DESC LIMIT %s'
160
+ params.append(limit)
161
+ cur.execute(query, tuple(params))
162
+ return [row[0] for row in cur.fetchall() if row[0]]
163
+
164
+ async def get_trajectories_by_window(
165
+ self, window_id: str, min_score: Optional[float] = None,
166
+ validate: bool = True, min_actions: int = 1
167
+ ) -> list[TrajectoryRow]:
168
+ if not self.conn:
169
+ raise ConnectionError("Database not connected.")
170
+ with self.conn.cursor() as cur:
171
+ query = """
172
+ SELECT "trajectoryId", "agentId", "windowId", "stepsJson", "metricsJson", "metadataJson",
173
+ "totalReward", "episodeLength", "finalStatus", "finalPnL", "tradesExecuted",
174
+ "aiJudgeReward", "archetype"
175
+ FROM trajectories WHERE "windowId" = %s AND "isTrainingData" = true AND "episodeLength" >= %s
176
+ """
177
+ params: list = [window_id, min_actions]
178
+ if min_score is not None:
179
+ query += ' AND "aiJudgeReward" >= %s'
180
+ params.append(min_score)
181
+ cur.execute(query, tuple(params))
182
+ rows = cur.fetchall()
183
+
184
+ results = []
185
+ for row in rows:
186
+ trajectory = TrajectoryRow(
187
+ trajectory_id=row[0], agent_id=row[1], window_id=row[2], steps_json=row[3],
188
+ metrics_json=row[4], metadata_json=row[5], total_reward=float(
189
+ row[6] or 0.0),
190
+ episode_length=int(row[7] or 0), final_status=row[8] or "unknown",
191
+ final_pnl=float(row[9]) if row[9] else None, trades_executed=int(row[10]) if row[10] else None,
192
+ ai_judge_reward=float(row[11]) if row[11] else None, archetype=row[12],
193
+ )
194
+ if validate:
195
+ try:
196
+ steps = json.loads(trajectory.steps_json)
197
+ is_valid, issues = validate_llm_calls(steps)
198
+ if not is_valid:
199
+ logger.debug(
200
+ f"Skipping DB trajectory {trajectory.trajectory_id}: {issues}")
201
+ continue
202
+ except (json.JSONDecodeError, TypeError):
203
+ logger.warning(
204
+ f"Could not parse steps_json for trajectory {trajectory.trajectory_id}")
205
+ continue
206
+ results.append(trajectory)
207
+ return results
208
+
209
+
210
+ class JsonTrajectoryReader:
211
+ """Reads Babylon trajectories from a local directory of JSON files."""
212
+
213
+ def __init__(self, directory_path: str):
214
+ self._directory = Path(directory_path)
215
+ self._trajectories_by_window: Dict[str, List[Dict]] = {}
216
+
217
+ if not self._directory.is_dir():
218
+ raise FileNotFoundError(
219
+ f"Source directory not found: {self._directory.resolve()}")
220
+
221
+ self._scan_files()
222
+ logger.info(
223
+ f"Found {len(self._trajectories_by_window)} windows in {self._directory}")
224
+
225
+ def _scan_files(self):
226
+ file_count = 0
227
+ for file_path in self._directory.glob("*.json"):
228
+ file_count += 1
229
+ try:
230
+ with file_path.open('r', encoding='utf-8') as f:
231
+ data = json.load(f)
232
+ trajectory_data = data.get('trajectory', data)
233
+ window_id = trajectory_data.get("windowId", "default_window")
234
+ if window_id not in self._trajectories_by_window:
235
+ self._trajectories_by_window[window_id] = []
236
+ self._trajectories_by_window[window_id].append(trajectory_data)
237
+ except (json.JSONDecodeError, KeyError, TypeError) as e:
238
+ logger.warning(f"Skipping invalid JSON file {file_path}: {e}")
239
+
240
+ if file_count == 0:
241
+ logger.warning(
242
+ f"No JSON files found in directory: {self._directory}")
243
+
244
+ def get_window_ids(self) -> List[str]:
245
+ return list(self._trajectories_by_window.keys())
246
+
247
+ def get_trajectories_by_window(self, window_id: str) -> List[Dict]:
248
+ return self._trajectories_by_window.get(window_id, [])
249
+
250
+
251
+ def get_window_ids(limit: int = 100, only_scored: bool = True) -> list[str]:
252
+ """
253
+ Get distinct window IDs with training data.
254
+
255
+ Args:
256
+ limit: Maximum windows to return
257
+ only_scored: Only return windows with scored trajectories
258
+
259
+ Returns:
260
+ List of window IDs
261
+ """
262
+ conn = get_connection()
263
+ cur = conn.cursor()
264
+ query = 'SELECT DISTINCT "windowId" FROM trajectories WHERE "isTrainingData" = true'
265
+ if only_scored:
266
+ query += ' AND "aiJudgeReward" IS NOT NULL'
267
+ query += ' ORDER BY "windowId" DESC LIMIT %s'
268
+ cur.execute(query, (limit,))
269
+ rows = cur.fetchall()
270
+ cur.close()
271
+ conn.close()
272
+ return [row[0] for row in rows if row[0]]
273
+
274
+
275
+ def get_trajectories_by_window(
276
+ window_id: str,
277
+ min_score: Optional[float] = None,
278
+ validate: bool = True,
279
+ ) -> list[TrajectoryRow]:
280
+ """
281
+ Get trajectories for a specific window.
282
+
283
+ Args:
284
+ window_id: Window ID to query
285
+ min_score: Optional minimum AI judge score
286
+ validate: Whether to validate LLM calls
287
+
288
+ Returns:
289
+ List of trajectory rows
290
+ """
291
+ conn = get_connection()
292
+ cur = conn.cursor()
293
+ query = """
294
+ SELECT "trajectoryId", "agentId", "windowId", "stepsJson", "metricsJson", "metadataJson",
295
+ "totalReward", "episodeLength", "finalStatus", "finalPnL", "tradesExecuted",
296
+ "aiJudgeReward", "archetype"
297
+ FROM trajectories WHERE "windowId" = %s AND "isTrainingData" = true
298
+ """
299
+ params: list = [window_id]
300
+ if min_score is not None:
301
+ query += ' AND "aiJudgeReward" >= %s'
302
+ params.append(min_score)
303
+ cur.execute(query, params)
304
+ rows = cur.fetchall()
305
+ cur.close()
306
+ conn.close()
307
+ results: list[TrajectoryRow] = []
308
+ for row in rows:
309
+ trajectory = TrajectoryRow(
310
+ trajectory_id=row[0], agent_id=row[1], window_id=row[2], steps_json=row[3],
311
+ metrics_json=row[4], metadata_json=row[5], total_reward=float(
312
+ row[6] or 0.0),
313
+ episode_length=int(row[7] or 0), final_status=row[8] or "unknown",
314
+ final_pnl=float(row[9]) if row[9] else None, trades_executed=int(row[10]) if row[10] else None,
315
+ ai_judge_reward=float(row[11]) if row[11] else None, archetype=row[12],
316
+ )
317
+ if validate:
318
+ steps = json.loads(trajectory.steps_json)
319
+ is_valid, _ = validate_llm_calls(steps)
320
+ if not is_valid:
321
+ continue
322
+ results.append(trajectory)
323
+ return results
324
+
325
+
326
+ def get_all_training_trajectories(
327
+ limit: int = 1000,
328
+ min_score: Optional[float] = None,
329
+ archetype: Optional[str] = None,
330
+ ) -> list[TrajectoryRow]:
331
+ """
332
+ Get all training trajectories.
333
+
334
+ Args:
335
+ limit: Maximum trajectories to return
336
+ min_score: Optional minimum AI judge score
337
+ archetype: Optional filter by archetype
338
+
339
+ Returns:
340
+ List of trajectory rows
341
+ """
342
+ conn = get_connection()
343
+ cur = conn.cursor()
344
+ query = """
345
+ SELECT "trajectoryId", "agentId", "windowId", "stepsJson", "metricsJson", "metadataJson",
346
+ "totalReward", "episodeLength", "finalStatus", "finalPnL", "tradesExecuted",
347
+ "aiJudgeReward", "archetype"
348
+ FROM trajectories WHERE "isTrainingData" = true
349
+ """
350
+ params: list = []
351
+ if min_score is not None:
352
+ query += ' AND "aiJudgeReward" >= %s'
353
+ params.append(min_score)
354
+ if archetype is not None:
355
+ query += ' AND "archetype" = %s'
356
+ params.append(archetype)
357
+ query += ' ORDER BY "createdAt" DESC LIMIT %s'
358
+ params.append(limit)
359
+ cur.execute(query, params)
360
+ rows = cur.fetchall()
361
+ cur.close()
362
+ conn.close()
363
+ return [TrajectoryRow(
364
+ trajectory_id=r[0], agent_id=r[1], window_id=r[2], steps_json=r[3],
365
+ metrics_json=r[4], metadata_json=r[5], total_reward=float(r[6] or 0.0),
366
+ episode_length=int(r[7] or 0), final_status=r[8] or "unknown",
367
+ final_pnl=float(r[9]) if r[9] else None, trades_executed=int(r[10]) if r[10] else None,
368
+ ai_judge_reward=float(r[11]) if r[11] else None, archetype=r[12],
369
+ ) for r in rows]
370
+
371
+
372
+ def get_trajectory_stats() -> dict:
373
+ conn = get_connection()
374
+ cur = conn.cursor()
375
+ cur.execute("""
376
+ SELECT COUNT(*), COUNT("aiJudgeReward"), AVG("aiJudgeReward"),
377
+ MIN("aiJudgeReward"), MAX("aiJudgeReward"), COUNT(DISTINCT "archetype")
378
+ FROM trajectories WHERE "isTrainingData" = true
379
+ """)
380
+ row = cur.fetchone()
381
+ cur.close()
382
+ conn.close()
383
+
384
+ if row is None:
385
+ return {"total": 0, "scored": 0, "avg_score": 0.0, "min_score": 0.0, "max_score": 0.0, "archetypes": 0}
386
+
387
+ return {
388
+ "total": row[0] or 0, "scored": row[1] or 0,
389
+ "avg_score": float(row[2]) if row[2] else 0.0,
390
+ "min_score": float(row[3]) if row[3] else 0.0,
391
+ "max_score": float(row[4]) if row[4] else 0.0,
392
+ "archetypes": row[5] or 0,
393
+ }
@@ -0,0 +1,283 @@
1
+ """
2
+ Shared type definitions for RL training.
3
+ Strong, validated types.
4
+ """
5
+
6
+ from typing import Dict, List, Literal
7
+ from pydantic import BaseModel, ConfigDict, Field
8
+ from datetime import datetime
9
+ from pydantic.alias_generators import to_camel
10
+
11
+ # Type alias for JSON-serializable values
12
+ JsonDict = Dict[str, object]
13
+
14
+ # Type alias for chat messages with known structure
15
+ ChatMessage = Dict[str, str] # {"role": str, "content": str}
16
+
17
+ # Base config for camelCase conversion, to be used by all models
18
+ camel_case_config = ConfigDict(
19
+ alias_generator=to_camel,
20
+ populate_by_name=True,
21
+ )
22
+
23
+
24
+ class EnvironmentState(BaseModel):
25
+ """Environment state at a given point"""
26
+ model_config = camel_case_config
27
+
28
+ agent_balance: float
29
+ # Explicit alias for the 'agentPnL' field from the JSON data
30
+ agent_pnl: float = Field(..., alias='agentPnL')
31
+ open_positions: int
32
+ active_markets: int = 0
33
+
34
+
35
+ class ProviderAccess(BaseModel):
36
+ """Data accessed from a provider"""
37
+ # Combines camelCase conversion with allowing extra fields
38
+ model_config = ConfigDict(
39
+ alias_generator=to_camel,
40
+ populate_by_name=True,
41
+ extra="allow"
42
+ )
43
+
44
+ provider_name: str
45
+ data: JsonDict
46
+ purpose: str
47
+
48
+
49
+ class LLMCall(BaseModel):
50
+ """
51
+ Single LLM call record.
52
+ Matches the TypeScript LLMCall interface in plugin-trajectory-logger/types.ts
53
+ """
54
+ model_config = camel_case_config
55
+
56
+ model: str
57
+ model_version: str | None = None
58
+ system_prompt: str
59
+ user_prompt: str
60
+ response: str
61
+ reasoning: str | None = None
62
+ temperature: float
63
+ max_tokens: int
64
+ latency_ms: int | None = None
65
+ prompt_tokens: int | None = None
66
+ completion_tokens: int | None = None
67
+ purpose: Literal['action', 'reasoning', 'evaluation', 'response', 'other']
68
+ action_type: str | None = None
69
+
70
+
71
+ class Action(BaseModel):
72
+ """Action taken by agent"""
73
+ # Combines camelCase conversion with allowing extra fields
74
+ model_config = ConfigDict(
75
+ alias_generator=to_camel,
76
+ populate_by_name=True,
77
+ extra="allow"
78
+ )
79
+
80
+ action_type: str
81
+ parameters: JsonDict
82
+ success: bool
83
+ result: JsonDict | None = None
84
+ error: str | None = None
85
+ reasoning: str | None = None
86
+
87
+
88
+ class TrajectoryStep(BaseModel):
89
+ """Single step in a trajectory"""
90
+ model_config = camel_case_config
91
+
92
+ step_number: int
93
+ timestamp: int
94
+ environment_state: EnvironmentState
95
+ provider_accesses: List[ProviderAccess] = Field(default_factory=list)
96
+ llm_calls: List[LLMCall] = Field(default_factory=list)
97
+ action: Action | None = None
98
+ reward: float = 0.0
99
+
100
+
101
+ class TrainingTrajectory(BaseModel):
102
+ """Complete trajectory from database"""
103
+ # Combines camelCase conversion with mutability
104
+ model_config = ConfigDict(
105
+ alias_generator=to_camel,
106
+ populate_by_name=True,
107
+ frozen=False
108
+ )
109
+
110
+ trajectory_id: str
111
+ agent_id: str
112
+
113
+ id: str = ""
114
+ window_id: str = "default"
115
+ start_time: datetime | None = None
116
+ end_time: datetime | None = None
117
+ duration_ms: int = 0
118
+ scenario_id: str | None = None
119
+ episode_id: str | None = None
120
+ steps: List[TrajectoryStep] = Field(default_factory=list)
121
+ total_reward: float = 0.0
122
+ final_pnl: float = 0.0
123
+ final_balance: float | None = None
124
+ trades_executed: int = 0
125
+ successful_trades: int = 0
126
+ failed_trades: int = 0
127
+ posts_created: int = 0
128
+ provider_accesses: int = 0
129
+ episode_length: int = 0
130
+ final_status: str = "completed"
131
+ archetype: str | None = None
132
+
133
+
134
+ class StockOutcome(BaseModel):
135
+ """Market outcome for a stock"""
136
+ model_config = camel_case_config
137
+ ticker: str
138
+ start_price: float
139
+ end_price: float
140
+ change_percent: float
141
+ sentiment: Literal['BULLISH', 'BEARISH', 'NEUTRAL'] | None = None
142
+ news_events: List[str] = Field(default_factory=list)
143
+
144
+
145
+ class PredictionOutcome(BaseModel):
146
+ """Outcome for a prediction market"""
147
+ model_config = camel_case_config
148
+ market_id: str
149
+ question: str
150
+ outcome: Literal['YES', 'NO', 'UNRESOLVED']
151
+ final_probability: float
152
+
153
+
154
+ class MarketOutcomes(BaseModel):
155
+ """All market outcomes for a window"""
156
+ model_config = camel_case_config
157
+ window_id: str
158
+ window_start: datetime
159
+ window_end: datetime
160
+ stocks: dict[str, StockOutcome] = Field(default_factory=dict)
161
+ predictions: dict[str, PredictionOutcome] = Field(default_factory=dict)
162
+ overall_trend: Literal['BULLISH', 'BEARISH', 'NEUTRAL'] | None = None
163
+ volatility: Literal['HIGH', 'MEDIUM', 'LOW'] | None = None
164
+
165
+
166
+ class WindowStatistics(BaseModel):
167
+ """Statistics for a training window"""
168
+ model_config = camel_case_config
169
+ window_id: str
170
+ agent_count: int
171
+ trajectory_count: int
172
+ total_actions: int
173
+ avg_pnl: float
174
+ min_pnl: float
175
+ max_pnl: float
176
+ start_time: datetime
177
+ end_time: datetime
178
+
179
+
180
+ class TrainingBatchSummary(BaseModel):
181
+ """Summary of a training batch"""
182
+ model_config = camel_case_config
183
+ windows: int
184
+ total_trajectories: int
185
+ avg_trajectories_per_window: float
186
+ score_min: float
187
+ score_max: float
188
+ score_avg: float
189
+ pnl_min: float
190
+ pnl_max: float
191
+ pnl_avg: float
192
+
193
+
194
+ # =============================================================================
195
+ # Atropos-compatible types
196
+ # =============================================================================
197
+
198
+
199
+ class AtroposScoredItem(BaseModel):
200
+ """Single scored item for Atropos training"""
201
+ model_config = camel_case_config
202
+ tokens: List[int]
203
+ masks: List[int]
204
+ score: float
205
+ logprobs: List[float] = Field(default_factory=list)
206
+ messages: List[ChatMessage] = Field(default_factory=list)
207
+
208
+
209
+ class AtroposScoredGroup(BaseModel):
210
+ """Group of scored items for Atropos GRPO training"""
211
+ model_config = camel_case_config
212
+ tokens: List[List[int]]
213
+ masks: List[List[int]]
214
+ scores: List[float]
215
+ inference_logprobs: List[List[float]] = Field(default_factory=list)
216
+ messages: List[List[ChatMessage]] = Field(default_factory=list)
217
+ env_id: int | None = None
218
+
219
+ @property
220
+ def group_size(self) -> int:
221
+ return len(self.tokens)
222
+
223
+
224
+ class TrajectoryGroup(BaseModel):
225
+ """Group of trajectories for relative comparison"""
226
+ model_config = camel_case_config
227
+ group_key: str
228
+ window_id: str
229
+ scenario_id: str | None = None
230
+ trajectories: List[TrainingTrajectory]
231
+
232
+ @property
233
+ def size(self) -> int:
234
+ return len(self.trajectories)
235
+
236
+ def get_pnl_stats(self) -> dict:
237
+ """Get P&L statistics for the group"""
238
+ pnls = [t.final_pnl for t in self.trajectories]
239
+ return {
240
+ "min": min(pnls) if pnls else 0,
241
+ "max": max(pnls) if pnls else 0,
242
+ "mean": sum(pnls) / len(pnls) if pnls else 0,
243
+ }
244
+
245
+
246
+ class JudgeScore(BaseModel):
247
+ """Score from LLM judge for a trajectory"""
248
+ model_config = camel_case_config
249
+ trajectory_id: str
250
+ score: float = Field(ge=0.0, le=1.0)
251
+ explanation: str
252
+ confidence: float = Field(default=1.0, ge=0.0, le=1.0)
253
+
254
+
255
+ class JudgeResponse(BaseModel):
256
+ """Response from LLM judge for a group of trajectories"""
257
+ model_config = camel_case_config
258
+ reasoning: str
259
+ scores: List[JudgeScore]
260
+
261
+ def get_score_for(self, trajectory_id: str) -> float | None:
262
+ """Get score for a specific trajectory"""
263
+ for score in self.scores:
264
+ if score.trajectory_id == trajectory_id:
265
+ return score.score
266
+ return None
267
+
268
+
269
+ # Backward compatibility alias while downstream code migrates.
270
+ BabylonTrajectory = TrainingTrajectory
271
+
272
+
273
+ class TrainingMetrics(BaseModel):
274
+ """Metrics from a training step"""
275
+ model_config = camel_case_config
276
+ step: int
277
+ loss: float
278
+ grad_norm: float
279
+ learning_rate: float
280
+ pos_logp: float = 0.0
281
+ neg_logp: float = 0.0
282
+ num_samples: int = 0
283
+ timestamp: datetime = Field(default_factory=datetime.now)