@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,633 @@
1
+ """
2
+ Shared Quality Utilities
3
+
4
+ Common quality scoring and validation functions used across the training pipeline.
5
+ Extracted to avoid duplication between rollout_generator and fast_simulator.
6
+
7
+ ENHANCED v3:
8
+ - Archetype-specific scoring weights
9
+ - Reasoning-action alignment validation with Financial Literacy
10
+ - XML Structure validation
11
+ - Coherence heuristics
12
+ - Curriculum learning support
13
+ """
14
+
15
+ import re
16
+ import json
17
+ from dataclasses import dataclass
18
+ from datetime import datetime, timezone
19
+ from typing import TYPE_CHECKING, Literal
20
+
21
+ from ..models import (
22
+ BabylonTrajectory,
23
+ TrajectoryStep,
24
+ Action,
25
+ EnvironmentState,
26
+ )
27
+
28
+ if TYPE_CHECKING:
29
+ from .rollout_generator import AgentTickData
30
+
31
+ # Archetype-specific quality weights
32
+ ARCHETYPE_WEIGHTS: dict[str, dict[str, float]] = {
33
+ # Research-heavy archetypes prioritize reasoning
34
+ "researcher": {"llm_calls": 0.3, "reasoning": 0.45, "action": 0.15, "feedback": 0.1},
35
+ "information-trader": {"llm_calls": 0.3, "reasoning": 0.4, "action": 0.2, "feedback": 0.1},
36
+ "super-predictor": {"llm_calls": 0.3, "reasoning": 0.4, "action": 0.2, "feedback": 0.1},
37
+
38
+ # Action-heavy archetypes prioritize execution
39
+ "trader": {"llm_calls": 0.3, "reasoning": 0.2, "action": 0.4, "feedback": 0.1},
40
+ "degen": {"llm_calls": 0.2, "reasoning": 0.15, "action": 0.55, "feedback": 0.1},
41
+ "perps-trader": {"llm_calls": 0.25, "reasoning": 0.2, "action": 0.45, "feedback": 0.1},
42
+
43
+ # Social archetypes prioritize engagement (response quality)
44
+ "social-butterfly": {"llm_calls": 0.35, "reasoning": 0.25, "action": 0.25, "feedback": 0.15},
45
+ "ass-kisser": {"llm_calls": 0.35, "reasoning": 0.3, "action": 0.2, "feedback": 0.15},
46
+ "goody-twoshoes": {"llm_calls": 0.35, "reasoning": 0.3, "action": 0.2, "feedback": 0.15},
47
+
48
+ # Deceptive archetypes prioritize reasoning (planning deception)
49
+ "scammer": {"llm_calls": 0.25, "reasoning": 0.4, "action": 0.25, "feedback": 0.1},
50
+ "liar": {"llm_calls": 0.25, "reasoning": 0.4, "action": 0.25, "feedback": 0.1},
51
+
52
+ # Balanced
53
+ "infosec": {"llm_calls": 0.3, "reasoning": 0.3, "action": 0.3, "feedback": 0.1},
54
+
55
+ # Default
56
+ "default": {"llm_calls": 0.4, "reasoning": 0.3, "action": 0.2, "feedback": 0.1},
57
+ }
58
+
59
+
60
+ def validate_xml_structure(response: str) -> float:
61
+ """
62
+ Validate that the response contains valid decision XML tags.
63
+
64
+ Criteria:
65
+ 1. Must contain <decisions> and </decisions> tags.
66
+ 2. Must contain at least one <decision> tag.
67
+ 3. Attributes 'amount' and 'ticker' (or 'marketId') should be present.
68
+
69
+ Returns:
70
+ +0.5 for valid syntax and attributes
71
+ -1.0 for broken XML or missing tags
72
+ -0.2 for missing attributes in otherwise valid tags
73
+ """
74
+ if not response:
75
+ return -1.0
76
+
77
+ # Check for wrapping tags
78
+ if "<decisions>" not in response or "</decisions>" not in response:
79
+ return -1.0
80
+
81
+ # Check for inner tags
82
+ if "<decision" not in response:
83
+ return -0.5 # Has wrappers but no decision?
84
+
85
+ # Check for critical attributes (simple heuristic regex to handle both quote styles)
86
+ has_ticker = re.search(
87
+ r'ticker="[^"]+"', response) or re.search(r"ticker='[^']+'", response)
88
+ has_market = re.search(
89
+ r'marketId="[^"]+"', response) or re.search(r"marketId='[^']+'", response)
90
+ has_amount = re.search(
91
+ r'amount="[^"]+"', response) or re.search(r"amount='[^']+'", response)
92
+
93
+ # Need either ticker OR marketId, AND amount
94
+ if (not has_ticker and not has_market) or not has_amount:
95
+ return -0.2 # Penalty for partial hallucination / missing args
96
+
97
+ return 0.5
98
+
99
+
100
+ def check_reasoning_action_alignment(
101
+ reasoning_text: str,
102
+ action: Action | None,
103
+ ) -> float:
104
+ """
105
+ Check if reasoning aligns with action taken, including Financial Literacy check.
106
+
107
+ Components:
108
+ 1. Directional Alignment (Up/Buy vs Down/Sell)
109
+ 2. Financial Literacy Bonus (referencing Exposure or PnL)
110
+
111
+ Returns:
112
+ Score between 0.0 and 1.0
113
+ """
114
+ if not action or not reasoning_text:
115
+ return 0.5 # Neutral if we can't check
116
+
117
+ reasoning_lower = reasoning_text.lower()
118
+ action_type = action.action_type.lower()
119
+
120
+ score = 0.5
121
+
122
+ # --- 1. Financial Literacy Check ---
123
+ literacy_bonus = 0.0
124
+ if "exposure" in reasoning_lower:
125
+ literacy_bonus += 0.15
126
+ if "pnl" in reasoning_lower or "profit" in reasoning_lower or "loss" in reasoning_lower:
127
+ literacy_bonus += 0.15
128
+
129
+ # --- 2. Directional Alignment ---
130
+ # Sentiment indicators
131
+ bullish_words = ["bullish", "buy", "long",
132
+ "upward", "positive", "opportunity", "moon"]
133
+ bearish_words = ["bearish", "sell", "short",
134
+ "downward", "negative", "avoid", "dump"]
135
+ wait_words = ["wait", "hold", "unclear",
136
+ "uncertain", "need more data", "observing"]
137
+
138
+ # Count sentiment
139
+ bullish_score = sum(1 for w in bullish_words if w in reasoning_lower)
140
+ bearish_score = sum(1 for w in bearish_words if w in reasoning_lower)
141
+ wait_score = sum(1 for w in wait_words if w in reasoning_lower)
142
+
143
+ # Check alignment
144
+ is_buy = action_type in ["buy", "buy_prediction", "open_perp", "long"]
145
+ is_sell = action_type in ["sell", "sell_prediction", "close_perp", "short"]
146
+ is_wait = action_type in ["wait", "hold"]
147
+
148
+ if is_buy:
149
+ if bullish_score > bearish_score:
150
+ score = 0.7 # Aligned
151
+ elif bearish_score > bullish_score:
152
+ score = 0.0 # Misaligned (Hallucination penalty)
153
+ else:
154
+ score = 0.4
155
+ elif is_sell:
156
+ if bearish_score > bullish_score:
157
+ score = 0.7 # Aligned
158
+ elif bullish_score > bearish_score:
159
+ score = 0.0 # Misaligned (Hallucination penalty)
160
+ else:
161
+ score = 0.4
162
+ elif is_wait:
163
+ if wait_score > 0:
164
+ score = 0.7
165
+ else:
166
+ score = 0.5
167
+
168
+ # Cap total at 1.0
169
+ return min(1.0, score + literacy_bonus)
170
+
171
+
172
+ def check_reasoning_coherence(reasoning_text: str) -> float:
173
+ """
174
+ Check reasoning coherence using simple heuristics (0-1 score).
175
+ """
176
+ if not reasoning_text or len(reasoning_text) < 20:
177
+ return 0.1
178
+
179
+ score = 0.0
180
+ text = reasoning_text
181
+
182
+ # Check for structure (numbered lists, bullet points)
183
+ if re.search(r'(\d+[\.\):]|\-|\*|\•)', text):
184
+ score += 0.25
185
+
186
+ # Check for conclusion markers
187
+ conclusion_markers = [
188
+ "therefore", "conclusion", "decision", "recommend",
189
+ "suggest", "final", "result", "action:", "execute"
190
+ ]
191
+ if any(marker in text.lower() for marker in conclusion_markers):
192
+ score += 0.25
193
+
194
+ # Check sentence count (2-10 sentences is ideal)
195
+ sentences = text.split('. ')
196
+ if 2 <= len(sentences) <= 10:
197
+ score += 0.2
198
+ elif len(sentences) > 10:
199
+ score += 0.1 # Too verbose
200
+
201
+ # Check for repetitive patterns (bad quality indicator)
202
+ words = text.lower().split()
203
+ if len(words) > 10:
204
+ unique_ratio = len(set(words)) / len(words)
205
+ if unique_ratio > 0.4:
206
+ score += 0.15 # Good vocabulary diversity
207
+ else:
208
+ score -= 0.1 # Repetitive
209
+ else:
210
+ score += 0.1
211
+
212
+ # Check for numeric analysis (prices, percentages)
213
+ if re.search(r'\$?\d+(?:\.\d+)?(?:%|k|K|M)?', text):
214
+ score += 0.15 # Contains quantitative analysis
215
+
216
+ return min(max(score, 0.0), 1.0)
217
+
218
+
219
+ def calculate_detailed_tick_quality(
220
+ llm_calls: list,
221
+ action: Action | None,
222
+ feedback: dict | None,
223
+ archetype: str | None = None,
224
+ ) -> tuple[float, float]:
225
+ """
226
+ Calculate detailed quality scores.
227
+ Returns: (format_score, reasoning_score)
228
+ """
229
+ format_score = 0.0
230
+ reasoning_score = 0.0
231
+
232
+ # 1. Format Score (XML)
233
+ if llm_calls:
234
+ last_call = llm_calls[-1]
235
+ if last_call.response:
236
+ format_score = validate_xml_structure(last_call.response)
237
+
238
+ # 2. Reasoning Score
239
+ reasoning_texts = []
240
+ for call in llm_calls:
241
+ if call.reasoning:
242
+ reasoning_texts.append(call.reasoning)
243
+ if call.response:
244
+ reasoning_texts.append(call.response)
245
+
246
+ if action and action.reasoning:
247
+ reasoning_texts.append(action.reasoning)
248
+
249
+ full_reasoning = " ".join(reasoning_texts)
250
+
251
+ if full_reasoning:
252
+ reasoning_score = check_reasoning_action_alignment(
253
+ full_reasoning, action)
254
+ # Coherence boost
255
+ reasoning_score += check_reasoning_coherence(full_reasoning) * 0.2
256
+
257
+ return format_score, min(1.0, reasoning_score)
258
+
259
+
260
+ def calculate_tick_quality_score(
261
+ llm_calls: list,
262
+ action: Action | None,
263
+ feedback: dict | None,
264
+ archetype: str | None = None,
265
+ ) -> float:
266
+ """
267
+ Calculate quality score for a single tick (0-1).
268
+ Legacy wrapper that returns a single float to maintain API compatibility.
269
+ """
270
+ weights = ARCHETYPE_WEIGHTS.get(
271
+ archetype or "default", ARCHETYPE_WEIGHTS["default"])
272
+
273
+ # Get detailed scores
274
+ fmt, rsn = calculate_detailed_tick_quality(
275
+ llm_calls, action, feedback, archetype)
276
+
277
+ # Calculate action score separately as before
278
+ action_score = 0.0
279
+ if action:
280
+ if action.success:
281
+ action_score = 1.0
282
+ elif action.error:
283
+ action_score = 0.25
284
+ else:
285
+ action_score = 0.5
286
+
287
+ feedback_score = 0.0
288
+ if feedback:
289
+ feedback_score = 1.0
290
+
291
+ # Combine using legacy weights logic plus new components
292
+ # We map format (-1 to 0.5) to a 0-1 range for the legacy score roughly:
293
+ # 0.5 -> 1.0, -1.0 -> 0.0
294
+ normalized_format = (fmt + 1.0) / 1.5
295
+
296
+ total_score = (
297
+ normalized_format * weights["llm_calls"] +
298
+ rsn * weights["reasoning"] +
299
+ action_score * weights["action"] +
300
+ feedback_score * weights["feedback"]
301
+ )
302
+
303
+ return min(1.0, max(0.0, total_score))
304
+
305
+
306
+ CurriculumLevel = Literal["easy", "medium", "hard"]
307
+
308
+
309
+ @dataclass
310
+ class TrajectoryDifficulty:
311
+ """Trajectory difficulty assessment for curriculum learning"""
312
+ level: CurriculumLevel
313
+ score: float # 0-1, higher = harder
314
+ reasons: list[str]
315
+
316
+
317
+ def calculate_trajectory_quality_score(
318
+ ticks: list["AgentTickData"],
319
+ archetype: str | None = None,
320
+ ) -> float:
321
+ """
322
+ Calculate overall quality score for a trajectory (0-1).
323
+
324
+ Args:
325
+ ticks: List of tick data
326
+ archetype: Agent archetype for weight customization
327
+ """
328
+ if not ticks:
329
+ return 0.0
330
+
331
+ scores = [
332
+ calculate_tick_quality_score(
333
+ tick.llm_calls,
334
+ tick.action,
335
+ tick.feedback,
336
+ archetype=archetype,
337
+ )
338
+ for tick in ticks
339
+ ]
340
+
341
+ return sum(scores) / len(scores)
342
+
343
+
344
+ def assess_trajectory_difficulty(
345
+ ticks: list["AgentTickData"],
346
+ ) -> TrajectoryDifficulty:
347
+ """
348
+ Assess difficulty of a trajectory for curriculum learning.
349
+
350
+ Difficulty factors:
351
+ - Number of market changes
352
+ - Action complexity (leverage, size)
353
+ - Decision reversals
354
+ - Length of reasoning required
355
+ """
356
+ reasons = []
357
+ difficulty_score = 0.0
358
+
359
+ if not ticks:
360
+ return TrajectoryDifficulty(level="easy", score=0.0, reasons=["Empty trajectory"])
361
+
362
+ # Factor 1: Trajectory length (longer = harder)
363
+ if len(ticks) > 20:
364
+ difficulty_score += 0.2
365
+ reasons.append(f"Long trajectory ({len(ticks)} ticks)")
366
+ elif len(ticks) > 10:
367
+ difficulty_score += 0.1
368
+
369
+ # Factor 2: Action diversity (more diverse = harder)
370
+ action_types = set()
371
+ for tick in ticks:
372
+ if tick.action:
373
+ action_types.add(tick.action.action_type)
374
+
375
+ if len(action_types) >= 4:
376
+ difficulty_score += 0.2
377
+ reasons.append(f"High action diversity ({len(action_types)} types)")
378
+ elif len(action_types) >= 2:
379
+ difficulty_score += 0.1
380
+
381
+ # Factor 3: Complex parameters (leverage, large sizes)
382
+ complex_actions = 0
383
+ for tick in ticks:
384
+ if tick.action and tick.action.parameters:
385
+ params = tick.action.parameters
386
+ # Explicitly cast to string then float to satisfy Pylance
387
+ try:
388
+ leverage = float(str(params.get("leverage", 1)))
389
+ if leverage > 1:
390
+ complex_actions += 1
391
+ except (ValueError, TypeError):
392
+ pass
393
+
394
+ try:
395
+ amount = float(str(params.get("amount", 0)))
396
+ if amount > 1000:
397
+ complex_actions += 1
398
+ except (ValueError, TypeError):
399
+ pass
400
+
401
+ if complex_actions >= 3:
402
+ difficulty_score += 0.2
403
+ reasons.append(f"Complex action parameters ({complex_actions})")
404
+ elif complex_actions >= 1:
405
+ difficulty_score += 0.1
406
+
407
+ # Factor 4: Decision reversals (buy -> sell in short time)
408
+ reversals = 0
409
+ prev_action = None
410
+ for tick in ticks:
411
+ if tick.action:
412
+ curr = tick.action.action_type
413
+ if prev_action:
414
+ if (prev_action in ["buy", "long"] and curr in ["sell", "short"]) or \
415
+ (prev_action in ["sell", "short"] and curr in ["buy", "long"]):
416
+ reversals += 1
417
+ prev_action = curr
418
+
419
+ if reversals >= 2:
420
+ difficulty_score += 0.2
421
+ reasons.append(f"Multiple reversals ({reversals})")
422
+ elif reversals >= 1:
423
+ difficulty_score += 0.1
424
+
425
+ # Factor 5: Reasoning depth required
426
+ total_reasoning_len = sum(
427
+ sum(len(c.reasoning or "") for c in tick.llm_calls) +
428
+ len((tick.action.reasoning or "") if tick.action else "")
429
+ for tick in ticks
430
+ )
431
+
432
+ avg_reasoning = total_reasoning_len / len(ticks) if ticks else 0
433
+ if avg_reasoning > 200:
434
+ difficulty_score += 0.2
435
+ reasons.append(
436
+ f"Deep reasoning required (avg {avg_reasoning:.0f} chars)")
437
+ elif avg_reasoning > 100:
438
+ difficulty_score += 0.1
439
+
440
+ # Normalize and categorize
441
+ difficulty_score = min(difficulty_score, 1.0)
442
+
443
+ if difficulty_score >= 0.6:
444
+ level: CurriculumLevel = "hard"
445
+ elif difficulty_score >= 0.3:
446
+ level = "medium"
447
+ else:
448
+ level = "easy"
449
+
450
+ return TrajectoryDifficulty(
451
+ level=level,
452
+ score=difficulty_score,
453
+ reasons=reasons if reasons else ["Standard complexity"],
454
+ )
455
+
456
+
457
+ def build_trajectory_from_ticks(
458
+ trajectory_id: str,
459
+ agent_id: str,
460
+ ticks: list["AgentTickData"],
461
+ min_steps: int = 1,
462
+ ) -> BabylonTrajectory | None:
463
+ """
464
+ Build a BabylonTrajectory from tick data.
465
+
466
+ Args:
467
+ trajectory_id: Unique trajectory ID
468
+ agent_id: Agent ID
469
+ ticks: List of AgentTickData
470
+ min_steps: Minimum steps required (returns None if fewer)
471
+
472
+ Returns:
473
+ BabylonTrajectory or None if insufficient data
474
+ """
475
+ if len(ticks) < min_steps:
476
+ return None
477
+
478
+ steps = []
479
+ for tick in ticks:
480
+ step = TrajectoryStep(
481
+ step_number=tick.tick_number,
482
+ timestamp=tick.timestamp,
483
+ environment_state=tick.environment_state,
484
+ provider_accesses=[],
485
+ llm_calls=tick.llm_calls,
486
+ action=tick.action or Action(
487
+ action_type="wait",
488
+ parameters={},
489
+ success=True,
490
+ ),
491
+ reward=tick.reward,
492
+ )
493
+ steps.append(step)
494
+
495
+ # Calculate final metrics
496
+ final_pnl = ticks[-1].environment_state.agent_pnl if ticks else 0.0
497
+ final_balance = ticks[-1].environment_state.agent_balance if ticks else 10000.0
498
+ total_reward = sum(t.reward for t in ticks)
499
+
500
+ # Count trades and posts
501
+ trades_executed = sum(
502
+ 1 for t in ticks
503
+ if t.action and t.action.action_type in [
504
+ "buy", "sell", "buy_prediction", "sell_prediction",
505
+ "open_perp", "close_perp"
506
+ ]
507
+ )
508
+ posts_created = sum(
509
+ 1 for t in ticks
510
+ if t.action and t.action.action_type in ["create_post", "post"]
511
+ )
512
+
513
+ now = datetime.now(timezone.utc)
514
+
515
+ return BabylonTrajectory(
516
+ id=trajectory_id,
517
+ trajectory_id=trajectory_id,
518
+ agent_id=agent_id,
519
+ window_id=now.strftime("%Y-%m-%dT%H:00"),
520
+ start_time=datetime.fromtimestamp(
521
+ ticks[0].timestamp / 1000, tz=timezone.utc),
522
+ end_time=datetime.fromtimestamp(
523
+ ticks[-1].timestamp / 1000, tz=timezone.utc),
524
+ duration_ms=ticks[-1].timestamp - ticks[0].timestamp,
525
+ steps=steps,
526
+ total_reward=total_reward,
527
+ final_pnl=final_pnl,
528
+ final_balance=final_balance,
529
+ trades_executed=trades_executed,
530
+ posts_created=posts_created,
531
+ episode_length=len(steps),
532
+ final_status="completed",
533
+ )
534
+
535
+
536
+ def state_to_observation(game_state: dict) -> dict:
537
+ """Convert game state to agent observation"""
538
+ return {
539
+ "tick": game_state.get("tick", 0),
540
+ "time": game_state.get("currentTime", 0),
541
+ "markets": game_state.get("predictionMarkets", []),
542
+ "perpetuals": game_state.get("perpetualMarkets", []),
543
+ "news": game_state.get("news", [])[:5], # Limit for speed
544
+ "posts": game_state.get("socialFeed", [])[:10],
545
+ }
546
+
547
+
548
+ def state_to_env_state(game_state: dict, agent_id: str) -> EnvironmentState:
549
+ """Extract environment state for an agent from game state"""
550
+ # Find agent's portfolio
551
+ portfolio = {}
552
+ for p in game_state.get("portfolios", []):
553
+ if p.get("agentId") == agent_id:
554
+ portfolio = p
555
+ break
556
+
557
+ return EnvironmentState(
558
+ agent_balance=portfolio.get("balance", 10000.0),
559
+ agentPnL=portfolio.get("pnl", 0.0),
560
+ open_positions=portfolio.get(
561
+ "positionCount", portfolio.get("positions", 0)),
562
+ active_markets=len(game_state.get("predictionMarkets", [])),
563
+ )
564
+
565
+
566
+ @dataclass
567
+ class ValidationResult:
568
+ """Result of rollout validation"""
569
+ is_valid: bool
570
+ issues: list[str]
571
+ quality_score: float
572
+
573
+ @property
574
+ def issue_count(self) -> int:
575
+ return len(self.issues)
576
+
577
+
578
+ def validate_trajectory_quality(
579
+ ticks: list["AgentTickData"],
580
+ min_ticks: int = 5,
581
+ min_llm_calls_per_tick: float = 0.8, # 80% of ticks should have LLM calls
582
+ min_quality_score: float = 0.5,
583
+ ) -> ValidationResult:
584
+ """
585
+ Validate trajectory meets quality requirements for training.
586
+
587
+ Args:
588
+ ticks: List of tick data
589
+ min_ticks: Minimum number of ticks required
590
+ min_llm_calls_per_tick: Minimum fraction of ticks with LLM calls
591
+ min_quality_score: Minimum quality score threshold
592
+
593
+ Returns:
594
+ ValidationResult with validity, issues, and score
595
+ """
596
+ issues: list[str] = []
597
+
598
+ # Check tick count
599
+ if len(ticks) < min_ticks:
600
+ issues.append(f"Too few ticks: {len(ticks)} < {min_ticks}")
601
+
602
+ if not ticks:
603
+ return ValidationResult(is_valid=False, issues=issues, quality_score=0.0)
604
+
605
+ # Check LLM call coverage
606
+ ticks_with_calls = sum(1 for t in ticks if t.llm_calls)
607
+ call_coverage = ticks_with_calls / len(ticks)
608
+ if call_coverage < min_llm_calls_per_tick:
609
+ issues.append(
610
+ f"Low LLM call coverage: {call_coverage:.1%} < {min_llm_calls_per_tick:.1%}")
611
+
612
+ # Check for empty LLM calls
613
+ empty_calls = 0
614
+ for tick in ticks:
615
+ for call in tick.llm_calls:
616
+ if not call.user_prompt or not call.response:
617
+ empty_calls += 1
618
+
619
+ if empty_calls > 0:
620
+ issues.append(f"{empty_calls} LLM calls with empty prompt/response")
621
+
622
+ # Calculate quality score
623
+ quality_score = calculate_trajectory_quality_score(ticks)
624
+
625
+ if quality_score < min_quality_score:
626
+ issues.append(
627
+ f"Quality score too low: {quality_score:.2f} < {min_quality_score}")
628
+
629
+ return ValidationResult(
630
+ is_valid=len(issues) == 0,
631
+ issues=issues,
632
+ quality_score=quality_score,
633
+ )