@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,29 @@
1
+
2
+ import mlx.core as mx
3
+ from mlx_lm import load, generate
4
+ import argparse
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument("--model", type=str, default="mlx-community/Qwen2.5-1.5B-Instruct-4bit")
9
+ parser.add_argument("--adapter-path", type=str, default="trained_models/should_respond_sft/adapters")
10
+ parser.add_argument("--temp", type=float, default=1.0)
11
+ args = parser.parse_args()
12
+
13
+ print(f"Loading {args.model} with {args.adapter_path}")
14
+ model, tokenizer = load(args.model, adapter_path=args.adapter_path)
15
+
16
+ prompt = "<task>Decide on behalf of Eliza whether they should respond to the message, ignore it or stop the conversation.</task>\n\n<providers>\n[RECENT_MESSAGES]\nUser: I heard Eliza is helping\n</providers>\n\n<instructions>Decide if Eliza should respond to or interact with the conversation.\n\nIMPORTANT RULES FOR RESPONDING:\n- If YOUR name (Eliza) is directly mentioned → RESPOND\n- If someone uses a DIFFERENT name (not Eliza) → IGNORE (they're talking to someone else)\n- If you're actively participating in a conversation and the message continues that thread → RESPOND\n- If someone tells you to stop or be quiet → STOP\n- Otherwise → IGNORE\n\nThe key distinction is:\n- \"Talking TO Eliza\" (your name mentioned, replies to you, continuing your conversation) → RESPOND\n- \"Talking ABOUT Eliza\" or to someone else → IGNORE\n</instructions>\n\n<output>\nDo NOT include any thinking, reasoning, or <think> sections in your response.\nGo directly to the XML response format without any preamble or explanation.\n\nRespond using XML format like this:\n<response>\n <name>Eliza</name>\n <reasoning>Your reasoning here</reasoning>\n <action>RESPOND | IGNORE | STOP</action>\n</response>\n\nIMPORTANT: Your response must ONLY contain the <response></response> XML block above. Do not include any text, thinking, or reasoning before or after this XML block. Start your response immediately with <response> and end with </response>.\n</output>"
17
+
18
+ print("\n--- Gen 1 (temp={}) ---".format(args.temp))
19
+ from mlx_lm.sample_utils import make_sampler
20
+ sampler = make_sampler(temp=args.temp)
21
+
22
+ print(generate(model, tokenizer, prompt=prompt, max_tokens=50, verbose=True, sampler=sampler))
23
+
24
+ print("\n--- Gen 2 (temp={}) ---".format(args.temp))
25
+ sampler2 = make_sampler(temp=args.temp)
26
+ print(generate(model, tokenizer, prompt=prompt, max_tokens=50, verbose=True, sampler=sampler2))
27
+
28
+ if __name__ == "__main__":
29
+ main()
@@ -0,0 +1,155 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test The Judge (PR #4)
4
+
5
+ Loads trajectories and evaluates them using the new reward functions.
6
+ Verifies:
7
+ 1. Financial Rewards (PnL, Risk)
8
+ 2. Format Rewards (XML validation)
9
+ 3. Reasoning Alignment (Financial Literacy)
10
+ """
11
+
12
+ import sys
13
+ import logging
14
+ from pathlib import Path
15
+
16
+ # Add python directory to path
17
+ sys.path.insert(0, str(Path(__file__).parent.parent))
18
+
19
+ from src.data_bridge.reader import JsonTrajectoryReader
20
+ from src.models import BabylonTrajectory
21
+ from src.training.rewards import (
22
+ TrajectoryRewardInputs,
23
+ composite_reward,
24
+ calculate_pnl_reward,
25
+ calculate_risk_reward
26
+ )
27
+ from src.training.quality_utils import (
28
+ calculate_detailed_tick_quality,
29
+ validate_xml_structure
30
+ )
31
+
32
+ logging.basicConfig(level=logging.INFO, format='%(message)s')
33
+ logger = logging.getLogger("TheJudge")
34
+
35
+
36
+ def evaluate_trajectory(traj: BabylonTrajectory):
37
+ print(f"\n--- Judging Trajectory: {traj.trajectory_id} ---")
38
+
39
+ # 1. Financials
40
+ # In your current JSON, you might need to calculate start/end from steps if not top-level
41
+ start_bal = 10000.0
42
+ end_bal = start_bal + traj.final_pnl
43
+
44
+ pnl_score = calculate_pnl_reward(start_bal, end_bal)
45
+ print(f"💰 Financials: PnL ${traj.final_pnl:.2f} -> Score: {pnl_score:.2f}")
46
+
47
+ # 2. Step-by-Step Analysis
48
+ total_format = 0.0
49
+ total_reasoning = 0.0
50
+ risk_penalties = 0
51
+ valid_steps = 0
52
+
53
+ for i, step in enumerate(traj.steps):
54
+ # Skip steps without LLM calls
55
+ if not step.llm_calls:
56
+ continue
57
+
58
+ valid_steps += 1
59
+
60
+ # Calculate Quality Scores
61
+ fmt, rsn = calculate_detailed_tick_quality(
62
+ step.llm_calls,
63
+ step.action,
64
+ None, # Feedback
65
+ "default"
66
+ )
67
+
68
+ # Calculate Risk (Mocking exposure calculation for this test)
69
+ # Assuming open_positions count proxies for exposure roughly
70
+ exposure = min(1.0, step.environment_state.open_positions * 0.1)
71
+ action_type = step.action.action_type if step.action else "wait"
72
+ risk_penalty = calculate_risk_reward(exposure, action_type)
73
+
74
+ if risk_penalty < 0:
75
+ risk_penalties += 1
76
+
77
+ total_format += fmt
78
+ total_reasoning += rsn
79
+
80
+ # Log interesting steps (e.g., failed XML or high reasoning)
81
+ if fmt < 0:
82
+ print(f" ⚠️ Step {i} Bad XML: {fmt}")
83
+ if rsn > 0.6:
84
+ print(f" ✨ Step {i} Good Reasoning: {rsn:.2f}")
85
+
86
+ # Averages
87
+ avg_format = total_format / max(1, valid_steps)
88
+ avg_reasoning = total_reasoning / max(1, valid_steps)
89
+
90
+ print(
91
+ f"📝 Quality: Avg XML {avg_format:.2f} | Avg Reasoning {avg_reasoning:.2f}")
92
+ if risk_penalties > 0:
93
+ print(f"🚨 Risk: {risk_penalties} dangerous actions detected")
94
+
95
+ # 3. Final Composite Score
96
+ inputs = TrajectoryRewardInputs(
97
+ final_pnl=traj.final_pnl,
98
+ starting_balance=start_bal,
99
+ end_balance=end_bal,
100
+ format_score=avg_format,
101
+ reasoning_score=avg_reasoning,
102
+ risky_actions_count=risk_penalties
103
+ )
104
+
105
+ final_score = composite_reward(inputs)
106
+
107
+ verdict = "✅ PASSED" if final_score > 0 else "❌ FAILED"
108
+ print(f"⚖️ FINAL SCORE: {final_score:.4f} ({verdict})")
109
+
110
+
111
+ def main():
112
+ # Look for trajectory data in the training package output directory
113
+ source_dir = Path(__file__).parent.parent.parent / "training-data-output" / "trajectories"
114
+ if not source_dir.exists():
115
+ # Fallback to engine output if training output doesn't exist
116
+ source_dir = Path(__file__).parent.parent.parent.parent / "engine" / "training-data-output" / "trajectories"
117
+
118
+ # Validate that at least one path exists
119
+ if not source_dir.exists():
120
+ logger.error("No trajectory data found. Checked paths:")
121
+ logger.error(f" - {Path(__file__).parent.parent.parent / 'training-data-output' / 'trajectories'}")
122
+ logger.error(f" - {source_dir}")
123
+ logger.error("Run 'make tier4-generate' or 'bun run packages/engine/examples/generate-training-data.ts' first.")
124
+ sys.exit(1)
125
+
126
+ source_dir = str(source_dir)
127
+ try:
128
+ reader = JsonTrajectoryReader(source_dir)
129
+ window_ids = reader.get_window_ids()
130
+
131
+ count = 0
132
+ for window_id in window_ids:
133
+ raw_trajs = reader.get_trajectories_by_window(window_id)
134
+ for raw in raw_trajs:
135
+ if 'trajectory' in raw:
136
+ raw = raw['trajectory']
137
+ if isinstance(raw.get('stepsJson'), str):
138
+ import json
139
+ raw['steps'] = json.loads(raw['stepsJson'])
140
+
141
+ try:
142
+ traj = BabylonTrajectory.model_validate(raw)
143
+ evaluate_trajectory(traj)
144
+ count += 1
145
+ if count >= 5:
146
+ return # Just test 5 for now
147
+ except Exception as e:
148
+ print(f"Skipping invalid: {e}")
149
+
150
+ except Exception as e:
151
+ logger.error(f"Error: {e}")
152
+
153
+
154
+ if __name__ == "__main__":
155
+ main()
@@ -0,0 +1,356 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ ElizaOS Training Pipeline - End-to-End Test
4
+
5
+ This script validates the complete training pipeline:
6
+ 1. Database connectivity
7
+ 2. Real trajectory data loading
8
+ 3. Data conversion to training format
9
+ 4. Backend availability (MLX/CUDA/CPU)
10
+
11
+ Run this BEFORE training to verify everything is set up correctly.
12
+
13
+ Usage:
14
+ python scripts/test_pipeline.py
15
+ """
16
+
17
+ import asyncio
18
+ import logging
19
+ import os
20
+ import sys
21
+ from pathlib import Path
22
+
23
+ # Add src to path
24
+ sys.path.insert(0, str(Path(__file__).parent.parent))
25
+
26
+ from dotenv import load_dotenv
27
+
28
+ # Load environment
29
+ env_path = Path(__file__).parent.parent.parent.parent.parent / ".env"
30
+ if env_path.exists():
31
+ load_dotenv(env_path)
32
+
33
+ logging.basicConfig(
34
+ level=logging.INFO,
35
+ format='%(asctime)s [%(levelname)s] %(message)s'
36
+ )
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class TestResult:
41
+ def __init__(self, name: str):
42
+ self.name = name
43
+ self.passed = False
44
+ self.message = ""
45
+ self.details: dict = {}
46
+
47
+
48
+ async def test_database_connection() -> TestResult:
49
+ """Test database connectivity."""
50
+ result = TestResult("Database Connection")
51
+
52
+ database_url = os.getenv("DATABASE_URL", "")
53
+ if not database_url:
54
+ result.message = "DATABASE_URL not set"
55
+ return result
56
+
57
+ try:
58
+ import asyncpg
59
+ pool = await asyncpg.create_pool(database_url, min_size=1, max_size=2)
60
+
61
+ # Test query
62
+ async with pool.acquire() as conn:
63
+ count = await conn.fetchval("SELECT COUNT(*) FROM trajectories")
64
+
65
+ await pool.close()
66
+
67
+ result.passed = True
68
+ result.message = f"Connected. Found {count} trajectories"
69
+ result.details["trajectory_count"] = count
70
+
71
+ except Exception as e:
72
+ result.message = f"Connection failed: {e}"
73
+
74
+ return result
75
+
76
+
77
+ async def test_trajectory_data() -> TestResult:
78
+ """Test that real trajectory data exists."""
79
+ result = TestResult("Real Trajectory Data")
80
+
81
+ database_url = os.getenv("DATABASE_URL", "")
82
+ if not database_url:
83
+ result.message = "DATABASE_URL not set"
84
+ return result
85
+
86
+ try:
87
+ from src.data_bridge import PostgresTrajectoryReader
88
+
89
+ async with PostgresTrajectoryReader(database_url) as reader:
90
+ windows = await reader.get_window_ids(min_agents=1, lookback_hours=168)
91
+
92
+ if not windows:
93
+ result.message = "No trajectory windows found"
94
+ return result
95
+
96
+ # Load trajectories from first window
97
+ trajectories = await reader.get_trajectories_by_window(
98
+ windows[0], min_actions=1
99
+ )
100
+
101
+ # Count those with LLM calls
102
+ with_llm_calls = 0
103
+ total_llm_calls = 0
104
+
105
+ for traj in trajectories:
106
+ has_calls = False
107
+ for step in traj.steps:
108
+ if step.llm_calls:
109
+ total_llm_calls += len(step.llm_calls)
110
+ has_calls = True
111
+ if has_calls:
112
+ with_llm_calls += 1
113
+
114
+ result.passed = with_llm_calls > 0
115
+ result.message = (
116
+ f"Found {len(windows)} windows, "
117
+ f"{len(trajectories)} trajectories in first window, "
118
+ f"{with_llm_calls} have LLM calls ({total_llm_calls} total calls)"
119
+ )
120
+ result.details = {
121
+ "windows": len(windows),
122
+ "trajectories": len(trajectories),
123
+ "with_llm_calls": with_llm_calls,
124
+ "total_llm_calls": total_llm_calls,
125
+ }
126
+
127
+ except Exception as e:
128
+ result.message = f"Failed: {e}"
129
+ import traceback
130
+ traceback.print_exc()
131
+
132
+ return result
133
+
134
+
135
+ async def test_data_conversion() -> TestResult:
136
+ """Test conversion of trajectories to training samples."""
137
+ result = TestResult("Data Conversion")
138
+
139
+ database_url = os.getenv("DATABASE_URL", "")
140
+ if not database_url:
141
+ result.message = "DATABASE_URL not set"
142
+ return result
143
+
144
+ try:
145
+ from src.data_bridge import PostgresTrajectoryReader
146
+
147
+ async with PostgresTrajectoryReader(database_url) as reader:
148
+ windows = await reader.get_window_ids(min_agents=1, lookback_hours=168)
149
+
150
+ if not windows:
151
+ result.message = "No windows found"
152
+ return result
153
+
154
+ trajectories = await reader.get_trajectories_by_window(
155
+ windows[0], min_actions=1
156
+ )
157
+
158
+ # Convert to training samples
159
+ samples = []
160
+ for traj in trajectories:
161
+ for step in traj.steps:
162
+ if not step.llm_calls:
163
+ continue
164
+
165
+ for llm_call in step.llm_calls:
166
+ if not llm_call.response or len(llm_call.response) < 20:
167
+ continue
168
+
169
+ messages = []
170
+ if llm_call.system_prompt:
171
+ messages.append({"role": "system", "content": llm_call.system_prompt})
172
+ if llm_call.user_prompt:
173
+ messages.append({"role": "user", "content": llm_call.user_prompt})
174
+ messages.append({"role": "assistant", "content": llm_call.response})
175
+
176
+ if len(messages) >= 2:
177
+ samples.append({"messages": messages})
178
+
179
+ result.passed = len(samples) >= 10
180
+ result.message = f"Created {len(samples)} training samples"
181
+ result.details["samples"] = len(samples)
182
+
183
+ if len(samples) > 0:
184
+ # Show sample
185
+ sample = samples[0]
186
+ result.details["sample_preview"] = {
187
+ "roles": [m["role"] for m in sample["messages"]],
188
+ "lengths": [len(m["content"]) for m in sample["messages"]],
189
+ }
190
+
191
+ except Exception as e:
192
+ result.message = f"Failed: {e}"
193
+ import traceback
194
+ traceback.print_exc()
195
+
196
+ return result
197
+
198
+
199
+ def test_mlx_backend() -> TestResult:
200
+ """Test MLX backend availability."""
201
+ result = TestResult("MLX Backend")
202
+
203
+ try:
204
+ import mlx.core as mx
205
+ import mlx_lm
206
+
207
+ result.passed = True
208
+ result.message = f"MLX available (mlx-lm version: {mlx_lm.__version__})"
209
+
210
+ except ImportError as e:
211
+ result.message = f"MLX not available: {e}"
212
+
213
+ return result
214
+
215
+
216
+ def test_cuda_backend() -> TestResult:
217
+ """Test CUDA backend availability."""
218
+ result = TestResult("CUDA Backend")
219
+
220
+ try:
221
+ import torch
222
+
223
+ if torch.cuda.is_available():
224
+ device_name = torch.cuda.get_device_name(0)
225
+ vram = torch.cuda.get_device_properties(0).total_memory / 1e9
226
+
227
+ result.passed = True
228
+ result.message = f"CUDA available: {device_name} ({vram:.1f} GB)"
229
+ result.details = {
230
+ "device": device_name,
231
+ "vram_gb": vram,
232
+ }
233
+ else:
234
+ result.message = "PyTorch installed but CUDA not available"
235
+
236
+ except ImportError as e:
237
+ result.message = f"PyTorch not installed: {e}"
238
+
239
+ return result
240
+
241
+
242
+ def test_transformers() -> TestResult:
243
+ """Test transformers library."""
244
+ result = TestResult("Transformers Library")
245
+
246
+ try:
247
+ import transformers
248
+
249
+ result.passed = True
250
+ result.message = f"transformers {transformers.__version__}"
251
+
252
+ except ImportError as e:
253
+ result.message = f"Not installed: {e}"
254
+
255
+ return result
256
+
257
+
258
+ def test_environment_variables() -> TestResult:
259
+ """Test required environment variables."""
260
+ result = TestResult("Environment Variables")
261
+
262
+ checks = {
263
+ "DATABASE_URL": bool(os.getenv("DATABASE_URL")),
264
+ "OPENAI_API_KEY": bool(os.getenv("OPENAI_API_KEY")),
265
+ "TINKER_API_KEY": bool(os.getenv("TINKER_API_KEY")),
266
+ }
267
+
268
+ required = ["DATABASE_URL"]
269
+ optional = ["OPENAI_API_KEY", "TINKER_API_KEY"]
270
+
271
+ missing_required = [k for k in required if not checks[k]]
272
+ missing_optional = [k for k in optional if not checks[k]]
273
+
274
+ result.passed = len(missing_required) == 0
275
+
276
+ if result.passed:
277
+ result.message = f"Required vars set. Optional missing: {', '.join(missing_optional) or 'none'}"
278
+ else:
279
+ result.message = f"Missing required: {', '.join(missing_required)}"
280
+
281
+ result.details = checks
282
+ return result
283
+
284
+
285
+ async def main():
286
+ """Run all tests."""
287
+ print("=" * 70)
288
+ print(" ELIZAOS TRAINING PIPELINE - END-TO-END TEST")
289
+ print("=" * 70)
290
+ print()
291
+
292
+ # Run tests
293
+ tests = [
294
+ ("Environment Variables", test_environment_variables()),
295
+ ("Database Connection", await test_database_connection()),
296
+ ("Real Trajectory Data", await test_trajectory_data()),
297
+ ("Data Conversion", await test_data_conversion()),
298
+ ("Transformers Library", test_transformers()),
299
+ ("MLX Backend", test_mlx_backend()),
300
+ ("CUDA Backend", test_cuda_backend()),
301
+ ]
302
+
303
+ passed = 0
304
+ failed = 0
305
+
306
+ for name, result in tests:
307
+ status = "✅" if result.passed else "❌"
308
+ print(f"{status} {result.name}")
309
+ print(f" {result.message}")
310
+ if result.details:
311
+ for k, v in result.details.items():
312
+ if k != "sample_preview":
313
+ print(f" - {k}: {v}")
314
+ print()
315
+
316
+ if result.passed:
317
+ passed += 1
318
+ else:
319
+ failed += 1
320
+
321
+ # Summary
322
+ print("=" * 70)
323
+ print(f" RESULTS: {passed} passed, {failed} failed")
324
+ print("=" * 70)
325
+
326
+ # Required checks
327
+ required_tests = [
328
+ "Environment Variables",
329
+ "Database Connection",
330
+ "Real Trajectory Data",
331
+ "Data Conversion",
332
+ ]
333
+
334
+ required_passed = all(
335
+ result.passed for name, result in tests
336
+ if result.name in required_tests
337
+ )
338
+
339
+ if required_passed:
340
+ print()
341
+ print("✅ All required checks passed!")
342
+ print()
343
+ print("Ready to train. Run:")
344
+ print(" python scripts/train_local.py")
345
+ print()
346
+ return 0
347
+ else:
348
+ print()
349
+ print("❌ Some required checks failed. Fix issues before training.")
350
+ print()
351
+ return 1
352
+
353
+
354
+ if __name__ == "__main__":
355
+ sys.exit(asyncio.run(main()))
356
+