@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
|
|
2
|
+
import mlx.core as mx
|
|
3
|
+
from mlx_lm import load, generate
|
|
4
|
+
import argparse
|
|
5
|
+
|
|
6
|
+
def main():
|
|
7
|
+
parser = argparse.ArgumentParser()
|
|
8
|
+
parser.add_argument("--model", type=str, default="mlx-community/Qwen2.5-1.5B-Instruct-4bit")
|
|
9
|
+
parser.add_argument("--adapter-path", type=str, default="trained_models/should_respond_sft/adapters")
|
|
10
|
+
parser.add_argument("--temp", type=float, default=1.0)
|
|
11
|
+
args = parser.parse_args()
|
|
12
|
+
|
|
13
|
+
print(f"Loading {args.model} with {args.adapter_path}")
|
|
14
|
+
model, tokenizer = load(args.model, adapter_path=args.adapter_path)
|
|
15
|
+
|
|
16
|
+
prompt = "<task>Decide on behalf of Eliza whether they should respond to the message, ignore it or stop the conversation.</task>\n\n<providers>\n[RECENT_MESSAGES]\nUser: I heard Eliza is helping\n</providers>\n\n<instructions>Decide if Eliza should respond to or interact with the conversation.\n\nIMPORTANT RULES FOR RESPONDING:\n- If YOUR name (Eliza) is directly mentioned → RESPOND\n- If someone uses a DIFFERENT name (not Eliza) → IGNORE (they're talking to someone else)\n- If you're actively participating in a conversation and the message continues that thread → RESPOND\n- If someone tells you to stop or be quiet → STOP\n- Otherwise → IGNORE\n\nThe key distinction is:\n- \"Talking TO Eliza\" (your name mentioned, replies to you, continuing your conversation) → RESPOND\n- \"Talking ABOUT Eliza\" or to someone else → IGNORE\n</instructions>\n\n<output>\nDo NOT include any thinking, reasoning, or <think> sections in your response.\nGo directly to the XML response format without any preamble or explanation.\n\nRespond using XML format like this:\n<response>\n <name>Eliza</name>\n <reasoning>Your reasoning here</reasoning>\n <action>RESPOND | IGNORE | STOP</action>\n</response>\n\nIMPORTANT: Your response must ONLY contain the <response></response> XML block above. Do not include any text, thinking, or reasoning before or after this XML block. Start your response immediately with <response> and end with </response>.\n</output>"
|
|
17
|
+
|
|
18
|
+
print("\n--- Gen 1 (temp={}) ---".format(args.temp))
|
|
19
|
+
from mlx_lm.sample_utils import make_sampler
|
|
20
|
+
sampler = make_sampler(temp=args.temp)
|
|
21
|
+
|
|
22
|
+
print(generate(model, tokenizer, prompt=prompt, max_tokens=50, verbose=True, sampler=sampler))
|
|
23
|
+
|
|
24
|
+
print("\n--- Gen 2 (temp={}) ---".format(args.temp))
|
|
25
|
+
sampler2 = make_sampler(temp=args.temp)
|
|
26
|
+
print(generate(model, tokenizer, prompt=prompt, max_tokens=50, verbose=True, sampler=sampler2))
|
|
27
|
+
|
|
28
|
+
if __name__ == "__main__":
|
|
29
|
+
main()
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Test The Judge (PR #4)
|
|
4
|
+
|
|
5
|
+
Loads trajectories and evaluates them using the new reward functions.
|
|
6
|
+
Verifies:
|
|
7
|
+
1. Financial Rewards (PnL, Risk)
|
|
8
|
+
2. Format Rewards (XML validation)
|
|
9
|
+
3. Reasoning Alignment (Financial Literacy)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import sys
|
|
13
|
+
import logging
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
# Add python directory to path
|
|
17
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
18
|
+
|
|
19
|
+
from src.data_bridge.reader import JsonTrajectoryReader
|
|
20
|
+
from src.models import BabylonTrajectory
|
|
21
|
+
from src.training.rewards import (
|
|
22
|
+
TrajectoryRewardInputs,
|
|
23
|
+
composite_reward,
|
|
24
|
+
calculate_pnl_reward,
|
|
25
|
+
calculate_risk_reward
|
|
26
|
+
)
|
|
27
|
+
from src.training.quality_utils import (
|
|
28
|
+
calculate_detailed_tick_quality,
|
|
29
|
+
validate_xml_structure
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
|
33
|
+
logger = logging.getLogger("TheJudge")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def evaluate_trajectory(traj: BabylonTrajectory):
|
|
37
|
+
print(f"\n--- Judging Trajectory: {traj.trajectory_id} ---")
|
|
38
|
+
|
|
39
|
+
# 1. Financials
|
|
40
|
+
# In your current JSON, you might need to calculate start/end from steps if not top-level
|
|
41
|
+
start_bal = 10000.0
|
|
42
|
+
end_bal = start_bal + traj.final_pnl
|
|
43
|
+
|
|
44
|
+
pnl_score = calculate_pnl_reward(start_bal, end_bal)
|
|
45
|
+
print(f"💰 Financials: PnL ${traj.final_pnl:.2f} -> Score: {pnl_score:.2f}")
|
|
46
|
+
|
|
47
|
+
# 2. Step-by-Step Analysis
|
|
48
|
+
total_format = 0.0
|
|
49
|
+
total_reasoning = 0.0
|
|
50
|
+
risk_penalties = 0
|
|
51
|
+
valid_steps = 0
|
|
52
|
+
|
|
53
|
+
for i, step in enumerate(traj.steps):
|
|
54
|
+
# Skip steps without LLM calls
|
|
55
|
+
if not step.llm_calls:
|
|
56
|
+
continue
|
|
57
|
+
|
|
58
|
+
valid_steps += 1
|
|
59
|
+
|
|
60
|
+
# Calculate Quality Scores
|
|
61
|
+
fmt, rsn = calculate_detailed_tick_quality(
|
|
62
|
+
step.llm_calls,
|
|
63
|
+
step.action,
|
|
64
|
+
None, # Feedback
|
|
65
|
+
"default"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Calculate Risk (Mocking exposure calculation for this test)
|
|
69
|
+
# Assuming open_positions count proxies for exposure roughly
|
|
70
|
+
exposure = min(1.0, step.environment_state.open_positions * 0.1)
|
|
71
|
+
action_type = step.action.action_type if step.action else "wait"
|
|
72
|
+
risk_penalty = calculate_risk_reward(exposure, action_type)
|
|
73
|
+
|
|
74
|
+
if risk_penalty < 0:
|
|
75
|
+
risk_penalties += 1
|
|
76
|
+
|
|
77
|
+
total_format += fmt
|
|
78
|
+
total_reasoning += rsn
|
|
79
|
+
|
|
80
|
+
# Log interesting steps (e.g., failed XML or high reasoning)
|
|
81
|
+
if fmt < 0:
|
|
82
|
+
print(f" ⚠️ Step {i} Bad XML: {fmt}")
|
|
83
|
+
if rsn > 0.6:
|
|
84
|
+
print(f" ✨ Step {i} Good Reasoning: {rsn:.2f}")
|
|
85
|
+
|
|
86
|
+
# Averages
|
|
87
|
+
avg_format = total_format / max(1, valid_steps)
|
|
88
|
+
avg_reasoning = total_reasoning / max(1, valid_steps)
|
|
89
|
+
|
|
90
|
+
print(
|
|
91
|
+
f"📝 Quality: Avg XML {avg_format:.2f} | Avg Reasoning {avg_reasoning:.2f}")
|
|
92
|
+
if risk_penalties > 0:
|
|
93
|
+
print(f"🚨 Risk: {risk_penalties} dangerous actions detected")
|
|
94
|
+
|
|
95
|
+
# 3. Final Composite Score
|
|
96
|
+
inputs = TrajectoryRewardInputs(
|
|
97
|
+
final_pnl=traj.final_pnl,
|
|
98
|
+
starting_balance=start_bal,
|
|
99
|
+
end_balance=end_bal,
|
|
100
|
+
format_score=avg_format,
|
|
101
|
+
reasoning_score=avg_reasoning,
|
|
102
|
+
risky_actions_count=risk_penalties
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
final_score = composite_reward(inputs)
|
|
106
|
+
|
|
107
|
+
verdict = "✅ PASSED" if final_score > 0 else "❌ FAILED"
|
|
108
|
+
print(f"⚖️ FINAL SCORE: {final_score:.4f} ({verdict})")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def main():
|
|
112
|
+
# Look for trajectory data in the training package output directory
|
|
113
|
+
source_dir = Path(__file__).parent.parent.parent / "training-data-output" / "trajectories"
|
|
114
|
+
if not source_dir.exists():
|
|
115
|
+
# Fallback to engine output if training output doesn't exist
|
|
116
|
+
source_dir = Path(__file__).parent.parent.parent.parent / "engine" / "training-data-output" / "trajectories"
|
|
117
|
+
|
|
118
|
+
# Validate that at least one path exists
|
|
119
|
+
if not source_dir.exists():
|
|
120
|
+
logger.error("No trajectory data found. Checked paths:")
|
|
121
|
+
logger.error(f" - {Path(__file__).parent.parent.parent / 'training-data-output' / 'trajectories'}")
|
|
122
|
+
logger.error(f" - {source_dir}")
|
|
123
|
+
logger.error("Run 'make tier4-generate' or 'bun run packages/engine/examples/generate-training-data.ts' first.")
|
|
124
|
+
sys.exit(1)
|
|
125
|
+
|
|
126
|
+
source_dir = str(source_dir)
|
|
127
|
+
try:
|
|
128
|
+
reader = JsonTrajectoryReader(source_dir)
|
|
129
|
+
window_ids = reader.get_window_ids()
|
|
130
|
+
|
|
131
|
+
count = 0
|
|
132
|
+
for window_id in window_ids:
|
|
133
|
+
raw_trajs = reader.get_trajectories_by_window(window_id)
|
|
134
|
+
for raw in raw_trajs:
|
|
135
|
+
if 'trajectory' in raw:
|
|
136
|
+
raw = raw['trajectory']
|
|
137
|
+
if isinstance(raw.get('stepsJson'), str):
|
|
138
|
+
import json
|
|
139
|
+
raw['steps'] = json.loads(raw['stepsJson'])
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
traj = BabylonTrajectory.model_validate(raw)
|
|
143
|
+
evaluate_trajectory(traj)
|
|
144
|
+
count += 1
|
|
145
|
+
if count >= 5:
|
|
146
|
+
return # Just test 5 for now
|
|
147
|
+
except Exception as e:
|
|
148
|
+
print(f"Skipping invalid: {e}")
|
|
149
|
+
|
|
150
|
+
except Exception as e:
|
|
151
|
+
logger.error(f"Error: {e}")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
if __name__ == "__main__":
|
|
155
|
+
main()
|
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
ElizaOS Training Pipeline - End-to-End Test
|
|
4
|
+
|
|
5
|
+
This script validates the complete training pipeline:
|
|
6
|
+
1. Database connectivity
|
|
7
|
+
2. Real trajectory data loading
|
|
8
|
+
3. Data conversion to training format
|
|
9
|
+
4. Backend availability (MLX/CUDA/CPU)
|
|
10
|
+
|
|
11
|
+
Run this BEFORE training to verify everything is set up correctly.
|
|
12
|
+
|
|
13
|
+
Usage:
|
|
14
|
+
python scripts/test_pipeline.py
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import logging
|
|
19
|
+
import os
|
|
20
|
+
import sys
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
# Add src to path
|
|
24
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
25
|
+
|
|
26
|
+
from dotenv import load_dotenv
|
|
27
|
+
|
|
28
|
+
# Load environment
|
|
29
|
+
env_path = Path(__file__).parent.parent.parent.parent.parent / ".env"
|
|
30
|
+
if env_path.exists():
|
|
31
|
+
load_dotenv(env_path)
|
|
32
|
+
|
|
33
|
+
logging.basicConfig(
|
|
34
|
+
level=logging.INFO,
|
|
35
|
+
format='%(asctime)s [%(levelname)s] %(message)s'
|
|
36
|
+
)
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TestResult:
|
|
41
|
+
def __init__(self, name: str):
|
|
42
|
+
self.name = name
|
|
43
|
+
self.passed = False
|
|
44
|
+
self.message = ""
|
|
45
|
+
self.details: dict = {}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
async def test_database_connection() -> TestResult:
|
|
49
|
+
"""Test database connectivity."""
|
|
50
|
+
result = TestResult("Database Connection")
|
|
51
|
+
|
|
52
|
+
database_url = os.getenv("DATABASE_URL", "")
|
|
53
|
+
if not database_url:
|
|
54
|
+
result.message = "DATABASE_URL not set"
|
|
55
|
+
return result
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
import asyncpg
|
|
59
|
+
pool = await asyncpg.create_pool(database_url, min_size=1, max_size=2)
|
|
60
|
+
|
|
61
|
+
# Test query
|
|
62
|
+
async with pool.acquire() as conn:
|
|
63
|
+
count = await conn.fetchval("SELECT COUNT(*) FROM trajectories")
|
|
64
|
+
|
|
65
|
+
await pool.close()
|
|
66
|
+
|
|
67
|
+
result.passed = True
|
|
68
|
+
result.message = f"Connected. Found {count} trajectories"
|
|
69
|
+
result.details["trajectory_count"] = count
|
|
70
|
+
|
|
71
|
+
except Exception as e:
|
|
72
|
+
result.message = f"Connection failed: {e}"
|
|
73
|
+
|
|
74
|
+
return result
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
async def test_trajectory_data() -> TestResult:
|
|
78
|
+
"""Test that real trajectory data exists."""
|
|
79
|
+
result = TestResult("Real Trajectory Data")
|
|
80
|
+
|
|
81
|
+
database_url = os.getenv("DATABASE_URL", "")
|
|
82
|
+
if not database_url:
|
|
83
|
+
result.message = "DATABASE_URL not set"
|
|
84
|
+
return result
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
from src.data_bridge import PostgresTrajectoryReader
|
|
88
|
+
|
|
89
|
+
async with PostgresTrajectoryReader(database_url) as reader:
|
|
90
|
+
windows = await reader.get_window_ids(min_agents=1, lookback_hours=168)
|
|
91
|
+
|
|
92
|
+
if not windows:
|
|
93
|
+
result.message = "No trajectory windows found"
|
|
94
|
+
return result
|
|
95
|
+
|
|
96
|
+
# Load trajectories from first window
|
|
97
|
+
trajectories = await reader.get_trajectories_by_window(
|
|
98
|
+
windows[0], min_actions=1
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Count those with LLM calls
|
|
102
|
+
with_llm_calls = 0
|
|
103
|
+
total_llm_calls = 0
|
|
104
|
+
|
|
105
|
+
for traj in trajectories:
|
|
106
|
+
has_calls = False
|
|
107
|
+
for step in traj.steps:
|
|
108
|
+
if step.llm_calls:
|
|
109
|
+
total_llm_calls += len(step.llm_calls)
|
|
110
|
+
has_calls = True
|
|
111
|
+
if has_calls:
|
|
112
|
+
with_llm_calls += 1
|
|
113
|
+
|
|
114
|
+
result.passed = with_llm_calls > 0
|
|
115
|
+
result.message = (
|
|
116
|
+
f"Found {len(windows)} windows, "
|
|
117
|
+
f"{len(trajectories)} trajectories in first window, "
|
|
118
|
+
f"{with_llm_calls} have LLM calls ({total_llm_calls} total calls)"
|
|
119
|
+
)
|
|
120
|
+
result.details = {
|
|
121
|
+
"windows": len(windows),
|
|
122
|
+
"trajectories": len(trajectories),
|
|
123
|
+
"with_llm_calls": with_llm_calls,
|
|
124
|
+
"total_llm_calls": total_llm_calls,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
except Exception as e:
|
|
128
|
+
result.message = f"Failed: {e}"
|
|
129
|
+
import traceback
|
|
130
|
+
traceback.print_exc()
|
|
131
|
+
|
|
132
|
+
return result
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
async def test_data_conversion() -> TestResult:
|
|
136
|
+
"""Test conversion of trajectories to training samples."""
|
|
137
|
+
result = TestResult("Data Conversion")
|
|
138
|
+
|
|
139
|
+
database_url = os.getenv("DATABASE_URL", "")
|
|
140
|
+
if not database_url:
|
|
141
|
+
result.message = "DATABASE_URL not set"
|
|
142
|
+
return result
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
from src.data_bridge import PostgresTrajectoryReader
|
|
146
|
+
|
|
147
|
+
async with PostgresTrajectoryReader(database_url) as reader:
|
|
148
|
+
windows = await reader.get_window_ids(min_agents=1, lookback_hours=168)
|
|
149
|
+
|
|
150
|
+
if not windows:
|
|
151
|
+
result.message = "No windows found"
|
|
152
|
+
return result
|
|
153
|
+
|
|
154
|
+
trajectories = await reader.get_trajectories_by_window(
|
|
155
|
+
windows[0], min_actions=1
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Convert to training samples
|
|
159
|
+
samples = []
|
|
160
|
+
for traj in trajectories:
|
|
161
|
+
for step in traj.steps:
|
|
162
|
+
if not step.llm_calls:
|
|
163
|
+
continue
|
|
164
|
+
|
|
165
|
+
for llm_call in step.llm_calls:
|
|
166
|
+
if not llm_call.response or len(llm_call.response) < 20:
|
|
167
|
+
continue
|
|
168
|
+
|
|
169
|
+
messages = []
|
|
170
|
+
if llm_call.system_prompt:
|
|
171
|
+
messages.append({"role": "system", "content": llm_call.system_prompt})
|
|
172
|
+
if llm_call.user_prompt:
|
|
173
|
+
messages.append({"role": "user", "content": llm_call.user_prompt})
|
|
174
|
+
messages.append({"role": "assistant", "content": llm_call.response})
|
|
175
|
+
|
|
176
|
+
if len(messages) >= 2:
|
|
177
|
+
samples.append({"messages": messages})
|
|
178
|
+
|
|
179
|
+
result.passed = len(samples) >= 10
|
|
180
|
+
result.message = f"Created {len(samples)} training samples"
|
|
181
|
+
result.details["samples"] = len(samples)
|
|
182
|
+
|
|
183
|
+
if len(samples) > 0:
|
|
184
|
+
# Show sample
|
|
185
|
+
sample = samples[0]
|
|
186
|
+
result.details["sample_preview"] = {
|
|
187
|
+
"roles": [m["role"] for m in sample["messages"]],
|
|
188
|
+
"lengths": [len(m["content"]) for m in sample["messages"]],
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
except Exception as e:
|
|
192
|
+
result.message = f"Failed: {e}"
|
|
193
|
+
import traceback
|
|
194
|
+
traceback.print_exc()
|
|
195
|
+
|
|
196
|
+
return result
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def test_mlx_backend() -> TestResult:
|
|
200
|
+
"""Test MLX backend availability."""
|
|
201
|
+
result = TestResult("MLX Backend")
|
|
202
|
+
|
|
203
|
+
try:
|
|
204
|
+
import mlx.core as mx
|
|
205
|
+
import mlx_lm
|
|
206
|
+
|
|
207
|
+
result.passed = True
|
|
208
|
+
result.message = f"MLX available (mlx-lm version: {mlx_lm.__version__})"
|
|
209
|
+
|
|
210
|
+
except ImportError as e:
|
|
211
|
+
result.message = f"MLX not available: {e}"
|
|
212
|
+
|
|
213
|
+
return result
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def test_cuda_backend() -> TestResult:
|
|
217
|
+
"""Test CUDA backend availability."""
|
|
218
|
+
result = TestResult("CUDA Backend")
|
|
219
|
+
|
|
220
|
+
try:
|
|
221
|
+
import torch
|
|
222
|
+
|
|
223
|
+
if torch.cuda.is_available():
|
|
224
|
+
device_name = torch.cuda.get_device_name(0)
|
|
225
|
+
vram = torch.cuda.get_device_properties(0).total_memory / 1e9
|
|
226
|
+
|
|
227
|
+
result.passed = True
|
|
228
|
+
result.message = f"CUDA available: {device_name} ({vram:.1f} GB)"
|
|
229
|
+
result.details = {
|
|
230
|
+
"device": device_name,
|
|
231
|
+
"vram_gb": vram,
|
|
232
|
+
}
|
|
233
|
+
else:
|
|
234
|
+
result.message = "PyTorch installed but CUDA not available"
|
|
235
|
+
|
|
236
|
+
except ImportError as e:
|
|
237
|
+
result.message = f"PyTorch not installed: {e}"
|
|
238
|
+
|
|
239
|
+
return result
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def test_transformers() -> TestResult:
|
|
243
|
+
"""Test transformers library."""
|
|
244
|
+
result = TestResult("Transformers Library")
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
import transformers
|
|
248
|
+
|
|
249
|
+
result.passed = True
|
|
250
|
+
result.message = f"transformers {transformers.__version__}"
|
|
251
|
+
|
|
252
|
+
except ImportError as e:
|
|
253
|
+
result.message = f"Not installed: {e}"
|
|
254
|
+
|
|
255
|
+
return result
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def test_environment_variables() -> TestResult:
|
|
259
|
+
"""Test required environment variables."""
|
|
260
|
+
result = TestResult("Environment Variables")
|
|
261
|
+
|
|
262
|
+
checks = {
|
|
263
|
+
"DATABASE_URL": bool(os.getenv("DATABASE_URL")),
|
|
264
|
+
"OPENAI_API_KEY": bool(os.getenv("OPENAI_API_KEY")),
|
|
265
|
+
"TINKER_API_KEY": bool(os.getenv("TINKER_API_KEY")),
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
required = ["DATABASE_URL"]
|
|
269
|
+
optional = ["OPENAI_API_KEY", "TINKER_API_KEY"]
|
|
270
|
+
|
|
271
|
+
missing_required = [k for k in required if not checks[k]]
|
|
272
|
+
missing_optional = [k for k in optional if not checks[k]]
|
|
273
|
+
|
|
274
|
+
result.passed = len(missing_required) == 0
|
|
275
|
+
|
|
276
|
+
if result.passed:
|
|
277
|
+
result.message = f"Required vars set. Optional missing: {', '.join(missing_optional) or 'none'}"
|
|
278
|
+
else:
|
|
279
|
+
result.message = f"Missing required: {', '.join(missing_required)}"
|
|
280
|
+
|
|
281
|
+
result.details = checks
|
|
282
|
+
return result
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
async def main():
|
|
286
|
+
"""Run all tests."""
|
|
287
|
+
print("=" * 70)
|
|
288
|
+
print(" ELIZAOS TRAINING PIPELINE - END-TO-END TEST")
|
|
289
|
+
print("=" * 70)
|
|
290
|
+
print()
|
|
291
|
+
|
|
292
|
+
# Run tests
|
|
293
|
+
tests = [
|
|
294
|
+
("Environment Variables", test_environment_variables()),
|
|
295
|
+
("Database Connection", await test_database_connection()),
|
|
296
|
+
("Real Trajectory Data", await test_trajectory_data()),
|
|
297
|
+
("Data Conversion", await test_data_conversion()),
|
|
298
|
+
("Transformers Library", test_transformers()),
|
|
299
|
+
("MLX Backend", test_mlx_backend()),
|
|
300
|
+
("CUDA Backend", test_cuda_backend()),
|
|
301
|
+
]
|
|
302
|
+
|
|
303
|
+
passed = 0
|
|
304
|
+
failed = 0
|
|
305
|
+
|
|
306
|
+
for name, result in tests:
|
|
307
|
+
status = "✅" if result.passed else "❌"
|
|
308
|
+
print(f"{status} {result.name}")
|
|
309
|
+
print(f" {result.message}")
|
|
310
|
+
if result.details:
|
|
311
|
+
for k, v in result.details.items():
|
|
312
|
+
if k != "sample_preview":
|
|
313
|
+
print(f" - {k}: {v}")
|
|
314
|
+
print()
|
|
315
|
+
|
|
316
|
+
if result.passed:
|
|
317
|
+
passed += 1
|
|
318
|
+
else:
|
|
319
|
+
failed += 1
|
|
320
|
+
|
|
321
|
+
# Summary
|
|
322
|
+
print("=" * 70)
|
|
323
|
+
print(f" RESULTS: {passed} passed, {failed} failed")
|
|
324
|
+
print("=" * 70)
|
|
325
|
+
|
|
326
|
+
# Required checks
|
|
327
|
+
required_tests = [
|
|
328
|
+
"Environment Variables",
|
|
329
|
+
"Database Connection",
|
|
330
|
+
"Real Trajectory Data",
|
|
331
|
+
"Data Conversion",
|
|
332
|
+
]
|
|
333
|
+
|
|
334
|
+
required_passed = all(
|
|
335
|
+
result.passed for name, result in tests
|
|
336
|
+
if result.name in required_tests
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
if required_passed:
|
|
340
|
+
print()
|
|
341
|
+
print("✅ All required checks passed!")
|
|
342
|
+
print()
|
|
343
|
+
print("Ready to train. Run:")
|
|
344
|
+
print(" python scripts/train_local.py")
|
|
345
|
+
print()
|
|
346
|
+
return 0
|
|
347
|
+
else:
|
|
348
|
+
print()
|
|
349
|
+
print("❌ Some required checks failed. Fix issues before training.")
|
|
350
|
+
print()
|
|
351
|
+
return 1
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
if __name__ == "__main__":
|
|
355
|
+
sys.exit(asyncio.run(main()))
|
|
356
|
+
|