@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,481 @@
1
+ """
2
+ Trajectory Schema Definitions
3
+
4
+ Provides strict schema validation for trajectories to ensure data integrity
5
+ between TypeScript trajectory generation and Python training pipeline.
6
+
7
+ This module catches schema drift early and provides clear error messages
8
+ when data doesn't match expectations.
9
+ """
10
+
11
+ from dataclasses import dataclass, field
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
+ import json
14
+ import logging
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ # ============================================================================
20
+ # Step Schemas
21
+ # ============================================================================
22
+
23
+ @dataclass
24
+ class EnvironmentStateSchema:
25
+ """Schema for step environment state"""
26
+ agent_balance: float = 0.0
27
+ agent_pnl: float = 0.0
28
+ agent_points: int = 0
29
+ open_positions: int = 0
30
+ timestamp: Optional[int] = None
31
+
32
+ # Optional reputation/influence fields
33
+ reputation_delta: Optional[int] = None
34
+ followers_gained: Optional[int] = None
35
+ positive_reactions: Optional[int] = None
36
+ information_spread: Optional[int] = None
37
+
38
+ @classmethod
39
+ def from_dict(cls, data: Dict[str, Any]) -> "EnvironmentStateSchema":
40
+ """Create from dictionary with field name normalization"""
41
+ return cls(
42
+ agent_balance=data.get("agentBalance", data.get("agent_balance", 0.0)),
43
+ agent_pnl=data.get("agentPnL", data.get("agent_pnl", 0.0)),
44
+ agent_points=data.get("agentPoints", data.get("agent_points", 0)),
45
+ open_positions=data.get("openPositions", data.get("open_positions", 0)),
46
+ timestamp=data.get("timestamp"),
47
+ reputation_delta=data.get("reputationDelta", data.get("reputation_delta")),
48
+ followers_gained=data.get("followersGained", data.get("followers_gained")),
49
+ positive_reactions=data.get("positiveReactions", data.get("positive_reactions")),
50
+ information_spread=data.get("informationSpread", data.get("information_spread")),
51
+ )
52
+
53
+
54
+ @dataclass
55
+ class ActionParametersSchema:
56
+ """Schema for action parameters"""
57
+ # Trading parameters
58
+ ticker: Optional[str] = None
59
+ amount: Optional[float] = None
60
+ leverage: Optional[float] = None
61
+ confidence: Optional[float] = None
62
+ market_id: Optional[str] = None
63
+
64
+ # Social parameters
65
+ target_user_id: Optional[str] = None
66
+ recipient_id: Optional[str] = None
67
+ message: Optional[str] = None
68
+
69
+ # Archetype (for batch recording mode)
70
+ archetype: Optional[str] = None
71
+
72
+ @classmethod
73
+ def from_dict(cls, data: Dict[str, Any]) -> "ActionParametersSchema":
74
+ """Create from dictionary"""
75
+ return cls(
76
+ ticker=data.get("ticker"),
77
+ amount=data.get("amount", data.get("size", data.get("quantity"))),
78
+ leverage=data.get("leverage"),
79
+ confidence=data.get("confidence"),
80
+ market_id=data.get("marketId", data.get("market")),
81
+ target_user_id=data.get("targetUserId"),
82
+ recipient_id=data.get("recipientId"),
83
+ message=data.get("message"),
84
+ archetype=data.get("archetype"),
85
+ )
86
+
87
+
88
+ @dataclass
89
+ class ActionResultSchema:
90
+ """Schema for action result"""
91
+ position_id: Optional[str] = None
92
+ pnl: Optional[float] = None
93
+ success: bool = True
94
+ error: Optional[str] = None
95
+ archetype: Optional[str] = None
96
+
97
+ # Prediction-specific
98
+ correct: Optional[bool] = None
99
+ prediction_correct: Optional[bool] = None
100
+
101
+ @classmethod
102
+ def from_dict(cls, data: Dict[str, Any]) -> "ActionResultSchema":
103
+ """Create from dictionary"""
104
+ return cls(
105
+ position_id=data.get("positionId"),
106
+ pnl=data.get("pnl"),
107
+ success=data.get("success", True),
108
+ error=data.get("error"),
109
+ archetype=data.get("archetype"),
110
+ correct=data.get("correct"),
111
+ prediction_correct=data.get("predictionCorrect"),
112
+ )
113
+
114
+
115
+ @dataclass
116
+ class ActionSchema:
117
+ """Schema for trajectory action"""
118
+ action_type: str
119
+ parameters: ActionParametersSchema = field(default_factory=ActionParametersSchema)
120
+ success: bool = True
121
+ result: ActionResultSchema = field(default_factory=ActionResultSchema)
122
+ reasoning: Optional[str] = None
123
+
124
+ @classmethod
125
+ def from_dict(cls, data: Dict[str, Any]) -> "ActionSchema":
126
+ """Create from dictionary with field name normalization"""
127
+ return cls(
128
+ action_type=data.get("actionType", data.get("action_type", "unknown")),
129
+ parameters=ActionParametersSchema.from_dict(data.get("parameters", {})),
130
+ success=data.get("success", True),
131
+ result=ActionResultSchema.from_dict(data.get("result", {})),
132
+ reasoning=data.get("reasoning"),
133
+ )
134
+
135
+
136
+ @dataclass
137
+ class LLMCallSchema:
138
+ """Schema for LLM call within a step"""
139
+ model: str
140
+ purpose: str = "action"
141
+ system_prompt: Optional[str] = None
142
+ user_prompt: Optional[str] = None
143
+ response: Optional[str] = None
144
+ reasoning: Optional[str] = None
145
+ temperature: float = 0.7
146
+ max_tokens: int = 1000
147
+ latency_ms: Optional[int] = None
148
+
149
+ @classmethod
150
+ def from_dict(cls, data: Dict[str, Any]) -> "LLMCallSchema":
151
+ """Create from dictionary with field name normalization"""
152
+ return cls(
153
+ model=data.get("model", "unknown"),
154
+ purpose=data.get("purpose", "action"),
155
+ system_prompt=data.get("systemPrompt", data.get("system_prompt")),
156
+ user_prompt=data.get("userPrompt", data.get("user_prompt")),
157
+ response=data.get("response"),
158
+ reasoning=data.get("reasoning"),
159
+ temperature=data.get("temperature", 0.7),
160
+ max_tokens=data.get("maxTokens", data.get("max_tokens", 1000)),
161
+ latency_ms=data.get("latencyMs", data.get("latency_ms")),
162
+ )
163
+
164
+
165
+ @dataclass
166
+ class StepSchema:
167
+ """Schema for a single trajectory step"""
168
+ step_number: int
169
+ timestamp: Optional[int] = None
170
+ environment_state: EnvironmentStateSchema = field(default_factory=EnvironmentStateSchema)
171
+ action: ActionSchema = field(default_factory=lambda: ActionSchema(action_type="unknown"))
172
+ llm_calls: List[LLMCallSchema] = field(default_factory=list)
173
+ reward: float = 0.0
174
+ observation: Optional[Dict[str, Any]] = None
175
+
176
+ @classmethod
177
+ def from_dict(cls, data: Dict[str, Any]) -> "StepSchema":
178
+ """Create from dictionary with field name normalization"""
179
+ llm_calls_raw = data.get("llmCalls", data.get("llm_calls", []))
180
+ llm_calls = [LLMCallSchema.from_dict(call) for call in llm_calls_raw]
181
+
182
+ return cls(
183
+ step_number=data.get("stepNumber", data.get("step_number", 0)),
184
+ timestamp=data.get("timestamp"),
185
+ environment_state=EnvironmentStateSchema.from_dict(
186
+ data.get("environmentState", data.get("environment_state", {}))
187
+ ),
188
+ action=ActionSchema.from_dict(data.get("action", {})),
189
+ llm_calls=llm_calls,
190
+ reward=data.get("reward", 0.0),
191
+ observation=data.get("observation"),
192
+ )
193
+
194
+
195
+ # ============================================================================
196
+ # Trajectory Schema
197
+ # ============================================================================
198
+
199
+ @dataclass
200
+ class TrajectorySchema:
201
+ """Schema for a complete trajectory"""
202
+ trajectory_id: str
203
+ agent_id: str
204
+ window_id: str
205
+ scenario_id: Optional[str] = None
206
+ archetype: Optional[str] = None
207
+ steps_json: str = "[]"
208
+ final_pnl: float = 0.0
209
+ final_balance: Optional[float] = None
210
+ episode_length: int = 0
211
+ total_reward: float = 0.0
212
+ trades_executed: int = 0
213
+ is_training_data: bool = True
214
+
215
+ @classmethod
216
+ def from_dict(cls, data: Dict[str, Any]) -> "TrajectorySchema":
217
+ """Create from dictionary with field name normalization"""
218
+ return cls(
219
+ trajectory_id=data.get("trajectoryId", data.get("trajectory_id", "")),
220
+ agent_id=data.get("agentId", data.get("agent_id", "")),
221
+ window_id=data.get("windowId", data.get("window_id", "")),
222
+ scenario_id=data.get("scenarioId", data.get("scenario_id")),
223
+ archetype=data.get("archetype"),
224
+ steps_json=data.get("stepsJson", data.get("steps_json", "[]")),
225
+ final_pnl=float(data.get("finalPnL", data.get("final_pnl", 0.0))),
226
+ final_balance=data.get("finalBalance", data.get("final_balance")),
227
+ episode_length=data.get("episodeLength", data.get("episode_length", 0)),
228
+ total_reward=float(data.get("totalReward", data.get("total_reward", 0.0))),
229
+ trades_executed=data.get("tradesExecuted", data.get("trades_executed", 0)),
230
+ is_training_data=data.get("isTrainingData", data.get("is_training_data", True)),
231
+ )
232
+
233
+ def get_steps(self) -> List[StepSchema]:
234
+ """Parse and return steps as StepSchema objects"""
235
+ try:
236
+ steps_raw = json.loads(self.steps_json)
237
+ return [StepSchema.from_dict(step) for step in steps_raw]
238
+ except json.JSONDecodeError:
239
+ return []
240
+
241
+ def extract_archetype_from_steps(self) -> Optional[str]:
242
+ """Extract archetype from step action parameters if not set at trajectory level"""
243
+ if self.archetype:
244
+ return self.archetype
245
+
246
+ steps = self.get_steps()
247
+ for step in steps:
248
+ if step.action.parameters.archetype:
249
+ return step.action.parameters.archetype
250
+ if step.action.result.archetype:
251
+ return step.action.result.archetype
252
+
253
+ return None
254
+
255
+
256
+ # ============================================================================
257
+ # Validation Functions
258
+ # ============================================================================
259
+
260
+ def validate_trajectory(data: Dict[str, Any]) -> Tuple[bool, List[str]]:
261
+ """
262
+ Validate trajectory data against schema.
263
+
264
+ Returns:
265
+ Tuple of (is_valid, list of error messages)
266
+ """
267
+ errors = []
268
+
269
+ # Required fields
270
+ required_fields = ["trajectoryId", "agentId", "windowId"]
271
+ for field_name in required_fields:
272
+ snake_case = _camel_to_snake(field_name)
273
+ if field_name not in data and snake_case not in data:
274
+ errors.append(f"Missing required field: {field_name}")
275
+
276
+ # Validate stepsJson if present
277
+ steps_json = data.get("stepsJson", data.get("steps_json", "[]"))
278
+ if steps_json:
279
+ try:
280
+ steps = json.loads(steps_json)
281
+ if not isinstance(steps, list):
282
+ errors.append(f"stepsJson must be an array, got {type(steps).__name__}")
283
+ elif len(steps) == 0:
284
+ errors.append("stepsJson is empty - trajectory has no steps")
285
+ else:
286
+ for i, step in enumerate(steps):
287
+ step_errors = _validate_step(step, i)
288
+ errors.extend(step_errors)
289
+ except json.JSONDecodeError as e:
290
+ errors.append(f"Invalid JSON in stepsJson: {e}")
291
+
292
+ # Validate numeric fields
293
+ pnl = data.get("finalPnL", data.get("final_pnl"))
294
+ if pnl is not None:
295
+ try:
296
+ float(pnl)
297
+ except (TypeError, ValueError):
298
+ errors.append(f"finalPnL must be a number, got {type(pnl).__name__}")
299
+
300
+ episode_length = data.get("episodeLength", data.get("episode_length"))
301
+ if episode_length is not None:
302
+ if not isinstance(episode_length, int) or episode_length < 0:
303
+ errors.append("episodeLength must be a non-negative integer")
304
+
305
+ return len(errors) == 0, errors
306
+
307
+
308
+ def _validate_step(step: Dict[str, Any], index: int) -> List[str]:
309
+ """Validate a single step"""
310
+ errors = []
311
+ prefix = f"Step {index}"
312
+
313
+ # stepNumber should exist
314
+ if "stepNumber" not in step and "step_number" not in step:
315
+ errors.append(f"{prefix}: missing stepNumber")
316
+
317
+ # action should exist and have actionType
318
+ action = step.get("action", {})
319
+ if not action:
320
+ errors.append(f"{prefix}: missing action")
321
+ elif "actionType" not in action and "action_type" not in action:
322
+ errors.append(f"{prefix}: action missing actionType")
323
+
324
+ # environmentState should exist
325
+ env_state = step.get("environmentState", step.get("environment_state"))
326
+ if not env_state:
327
+ errors.append(f"{prefix}: missing environmentState")
328
+
329
+ return errors
330
+
331
+
332
+ def validate_step(data: Dict[str, Any]) -> Tuple[bool, List[str]]:
333
+ """
334
+ Validate step data against schema.
335
+
336
+ Returns:
337
+ Tuple of (is_valid, list of error messages)
338
+ """
339
+ errors = _validate_step(data, 0)
340
+ return len(errors) == 0, errors
341
+
342
+
343
+ def validate_llm_call(data: Dict[str, Any]) -> Tuple[bool, List[str]]:
344
+ """
345
+ Validate LLM call data against schema.
346
+
347
+ Returns:
348
+ Tuple of (is_valid, list of error messages)
349
+ """
350
+ errors = []
351
+
352
+ if "model" not in data:
353
+ errors.append("LLM call missing 'model' field")
354
+
355
+ return len(errors) == 0, errors
356
+
357
+
358
+ def _camel_to_snake(name: str) -> str:
359
+ """Convert camelCase to snake_case"""
360
+ result = []
361
+ for i, char in enumerate(name):
362
+ if char.isupper() and i > 0:
363
+ result.append("_")
364
+ result.append(char.lower())
365
+ return "".join(result)
366
+
367
+
368
+ # ============================================================================
369
+ # Schema Comparison
370
+ # ============================================================================
371
+
372
+ def compare_trajectory_formats(
373
+ json_data: Dict[str, Any],
374
+ db_data: Dict[str, Any],
375
+ ) -> Tuple[bool, List[str]]:
376
+ """
377
+ Compare trajectory data from JSON and database formats.
378
+
379
+ Returns:
380
+ Tuple of (are_equivalent, list of difference descriptions)
381
+ """
382
+ differences = []
383
+
384
+ # Map of JSON field names to DB field names
385
+ field_mapping = {
386
+ "trajectoryId": "trajectoryId",
387
+ "agentId": "agentId",
388
+ "windowId": "windowId",
389
+ "scenarioId": "scenarioId",
390
+ "archetype": "archetype",
391
+ "stepsJson": "stepsJson",
392
+ "finalPnL": "finalPnL",
393
+ "episodeLength": "episodeLength",
394
+ }
395
+
396
+ for json_field, db_field in field_mapping.items():
397
+ json_val = json_data.get(json_field)
398
+ db_val = db_data.get(db_field)
399
+
400
+ if json_val != db_val:
401
+ # Special handling for numeric comparison
402
+ if isinstance(json_val, (int, float)) and isinstance(db_val, (int, float)):
403
+ if abs(float(json_val) - float(db_val)) < 0.001:
404
+ continue # Close enough
405
+
406
+ differences.append(f"{json_field}: JSON={json_val!r}, DB={db_val!r}")
407
+
408
+ return len(differences) == 0, differences
409
+
410
+
411
+ # ============================================================================
412
+ # Export Validation Results
413
+ # ============================================================================
414
+
415
+ @dataclass
416
+ class ValidationResult:
417
+ """Result of schema validation"""
418
+ is_valid: bool
419
+ errors: List[str]
420
+ warnings: List[str] = field(default_factory=list)
421
+
422
+ def __bool__(self) -> bool:
423
+ return self.is_valid
424
+
425
+
426
+ def validate_trajectory_file(file_path: str) -> ValidationResult:
427
+ """
428
+ Validate a trajectory JSON file.
429
+
430
+ Args:
431
+ file_path: Path to JSON file
432
+
433
+ Returns:
434
+ ValidationResult with validation status and any errors/warnings
435
+ """
436
+ errors = []
437
+ warnings = []
438
+
439
+ try:
440
+ with open(file_path, "r") as f:
441
+ data = json.load(f)
442
+ except json.JSONDecodeError as e:
443
+ return ValidationResult(False, [f"Invalid JSON: {e}"])
444
+ except FileNotFoundError:
445
+ return ValidationResult(False, [f"File not found: {file_path}"])
446
+
447
+ # Check for trajectory wrapper
448
+ if "trajectory" not in data:
449
+ warnings.append("Missing 'trajectory' wrapper - treating root as trajectory data")
450
+ traj_data = data
451
+ else:
452
+ traj_data = data["trajectory"]
453
+
454
+ # Validate trajectory
455
+ is_valid, traj_errors = validate_trajectory(traj_data)
456
+ errors.extend(traj_errors)
457
+
458
+ # Check for archetype
459
+ archetype = traj_data.get("archetype")
460
+ if not archetype:
461
+ # Try to find in steps
462
+ steps_json = traj_data.get("stepsJson", "[]")
463
+ try:
464
+ steps = json.loads(steps_json)
465
+ found_archetype = False
466
+ for step in steps:
467
+ params = step.get("action", {}).get("parameters", {})
468
+ if params.get("archetype"):
469
+ found_archetype = True
470
+ break
471
+ if not found_archetype:
472
+ warnings.append("No archetype found at trajectory or step level - will use 'default'")
473
+ except json.JSONDecodeError:
474
+ pass # Already caught above
475
+
476
+ return ValidationResult(
477
+ is_valid=len(errors) == 0,
478
+ errors=errors,
479
+ warnings=warnings,
480
+ )
481
+