@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,528 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Babylon Local Training Script - Unified Mac (MLX) + GTX (CUDA) Support
|
|
4
|
+
|
|
5
|
+
This script provides training using REAL data from the database OR local JSON files.
|
|
6
|
+
Only trajectories with actual LLM calls are used.
|
|
7
|
+
|
|
8
|
+
Supports:
|
|
9
|
+
- Apple Silicon (MLX) - LoRA fine-tuning
|
|
10
|
+
- NVIDIA GPU (PyTorch/CUDA) - Full or LoRA fine-tuning
|
|
11
|
+
- CPU fallback (slow but works)
|
|
12
|
+
|
|
13
|
+
Usage:
|
|
14
|
+
# Mac with MLX from Postgres Database
|
|
15
|
+
python scripts/train_local.py --backend mlx --model mlx-community/Qwen2.5-1.5B-Instruct-4bit
|
|
16
|
+
|
|
17
|
+
# Mac with MLX from local JSON files
|
|
18
|
+
python scripts/train_local.py --backend mlx --model mlx-community/Qwen2.5-1.5B-Instruct-4bit --source-dir ../engine/training-data-output/trajectories
|
|
19
|
+
|
|
20
|
+
# GTX/CUDA machine from Postgres Database
|
|
21
|
+
python scripts/train_local.py --backend cuda --model Qwen/Qwen2.5-1.5B-Instruct
|
|
22
|
+
|
|
23
|
+
# GTX/CUDA machine from local JSON files
|
|
24
|
+
python scripts/train_local.py --backend cuda --model Qwen/Qwen2.5-1.5B-Instruct --source-dir ../engine/training-data-output/trajectories
|
|
25
|
+
|
|
26
|
+
Small model recommendations for consumer hardware:
|
|
27
|
+
Mac M1/M2 (8GB): mlx-community/Qwen2.5-0.5B-Instruct-4bit
|
|
28
|
+
Mac M1/M2 (16GB): mlx-community/Qwen2.5-1.5B-Instruct-4bit
|
|
29
|
+
GTX 3060 (12GB): Qwen/Qwen2.5-1.5B-Instruct
|
|
30
|
+
GTX 3080 (10GB): Qwen/Qwen2.5-1.5B-Instruct
|
|
31
|
+
GTX 4090 (24GB): Qwen/Qwen2.5-3B-Instruct
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
import os
|
|
35
|
+
import sys
|
|
36
|
+
from pathlib import Path
|
|
37
|
+
|
|
38
|
+
# Add src to path
|
|
39
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
import argparse
|
|
44
|
+
import asyncio
|
|
45
|
+
import json
|
|
46
|
+
import logging
|
|
47
|
+
from datetime import datetime, timezone
|
|
48
|
+
from typing import Literal, List
|
|
49
|
+
from dotenv import load_dotenv
|
|
50
|
+
|
|
51
|
+
from src.models import BabylonTrajectory
|
|
52
|
+
from src.data_bridge.reader import JsonTrajectoryReader, PostgresTrajectoryReader, validate_llm_calls
|
|
53
|
+
|
|
54
|
+
# Load environment
|
|
55
|
+
env_path = Path(__file__).parent.parent.parent.parent.parent / ".env"
|
|
56
|
+
if env_path.exists():
|
|
57
|
+
load_dotenv(env_path)
|
|
58
|
+
|
|
59
|
+
logging.basicConfig(
|
|
60
|
+
level=logging.INFO,
|
|
61
|
+
format='%(asctime)s [%(levelname)s] %(message)s'
|
|
62
|
+
)
|
|
63
|
+
logger = logging.getLogger(__name__)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# =============================================================================
|
|
67
|
+
# Backend Detection
|
|
68
|
+
# =============================================================================
|
|
69
|
+
|
|
70
|
+
def detect_backend() -> Literal["mlx", "cuda", "cpu"]:
|
|
71
|
+
"""Auto-detect the best available backend."""
|
|
72
|
+
# Check for MLX (Apple Silicon)
|
|
73
|
+
try:
|
|
74
|
+
import mlx.core # type: ignore
|
|
75
|
+
logger.info("MLX backend available (Apple Silicon)")
|
|
76
|
+
return "mlx"
|
|
77
|
+
except ImportError:
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
# Check for CUDA
|
|
81
|
+
try:
|
|
82
|
+
import torch
|
|
83
|
+
if torch.cuda.is_available():
|
|
84
|
+
logger.info(
|
|
85
|
+
f"CUDA backend available: {torch.cuda.get_device_name(0)}")
|
|
86
|
+
return "cuda"
|
|
87
|
+
except ImportError:
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
logger.warning("No GPU backend available, falling back to CPU (slow)")
|
|
91
|
+
return "cpu"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# =============================================================================
|
|
95
|
+
# Data Loading
|
|
96
|
+
# =============================================================================
|
|
97
|
+
|
|
98
|
+
async def load_postgres_training_data(
|
|
99
|
+
database_url: str,
|
|
100
|
+
min_actions: int,
|
|
101
|
+
lookback_hours: int,
|
|
102
|
+
max_trajectories: int,
|
|
103
|
+
) -> List[BabylonTrajectory]:
|
|
104
|
+
"""Load REAL training data from the database and parse into Pydantic models."""
|
|
105
|
+
logger.info("Loading real training data from database...")
|
|
106
|
+
|
|
107
|
+
trajectories: List[BabylonTrajectory] = []
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
async with PostgresTrajectoryReader(database_url) as reader:
|
|
111
|
+
windows = await reader.get_window_ids(lookback_hours=lookback_hours)
|
|
112
|
+
if not windows:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
"No trajectory windows found in database. Generate data first.")
|
|
115
|
+
|
|
116
|
+
logger.info(f"Found {len(windows)} trajectory windows")
|
|
117
|
+
|
|
118
|
+
for window_id in windows:
|
|
119
|
+
if len(trajectories) >= max_trajectories:
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
window_trajectories = await reader.get_trajectories_by_window(
|
|
123
|
+
window_id, min_actions=min_actions, validate=True
|
|
124
|
+
)
|
|
125
|
+
for traj_row in window_trajectories:
|
|
126
|
+
try:
|
|
127
|
+
steps = json.loads(traj_row.steps_json)
|
|
128
|
+
# Convert TrajectoryRow object to a dict for Pydantic validation
|
|
129
|
+
traj_data = {
|
|
130
|
+
"id": traj_row.trajectory_id,
|
|
131
|
+
"trajectory_id": traj_row.trajectory_id,
|
|
132
|
+
"agent_id": traj_row.agent_id,
|
|
133
|
+
"window_id": traj_row.window_id,
|
|
134
|
+
"steps": steps,
|
|
135
|
+
"total_reward": traj_row.total_reward,
|
|
136
|
+
"episode_length": traj_row.episode_length,
|
|
137
|
+
"final_status": traj_row.final_status,
|
|
138
|
+
"final_pnl": traj_row.final_pnl,
|
|
139
|
+
"trades_executed": traj_row.trades_executed,
|
|
140
|
+
"archetype": traj_row.archetype,
|
|
141
|
+
}
|
|
142
|
+
traj_model = BabylonTrajectory.model_validate(
|
|
143
|
+
traj_data)
|
|
144
|
+
trajectories.append(traj_model)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
logger.warning(
|
|
147
|
+
f"Skipping DB trajectory {traj_row.trajectory_id} due to parsing error: {e}")
|
|
148
|
+
|
|
149
|
+
except Exception as e:
|
|
150
|
+
logger.error(f"Failed to load from database: {e}")
|
|
151
|
+
logger.error(
|
|
152
|
+
"Please ensure the database is running and DATABASE_URL is correct.")
|
|
153
|
+
sys.exit(1)
|
|
154
|
+
|
|
155
|
+
if len(trajectories) < 10:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"Insufficient training data: only {len(trajectories)} valid trajectories found.")
|
|
158
|
+
|
|
159
|
+
logger.info(f"Loaded {len(trajectories)} real trajectories from DB")
|
|
160
|
+
return trajectories
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def load_json_training_data(source_dir: str, max_trajectories: int) -> List[BabylonTrajectory]:
|
|
164
|
+
"""Loads training data from a directory of JSON files."""
|
|
165
|
+
logger.info(f"Loading training data from local directory: {source_dir}")
|
|
166
|
+
try:
|
|
167
|
+
reader = JsonTrajectoryReader(source_dir)
|
|
168
|
+
all_trajectories: List[BabylonTrajectory] = []
|
|
169
|
+
for window_id in reader.get_window_ids():
|
|
170
|
+
if len(all_trajectories) >= max_trajectories:
|
|
171
|
+
break
|
|
172
|
+
for traj_data in reader.get_trajectories_by_window(window_id):
|
|
173
|
+
try:
|
|
174
|
+
# Handle the nested `trajectory` key and `stepsJson` string format
|
|
175
|
+
# from the TypeScript simulation engine.
|
|
176
|
+
if 'trajectory' in traj_data:
|
|
177
|
+
traj_data = traj_data['trajectory']
|
|
178
|
+
if 'stepsJson' in traj_data and isinstance(traj_data['stepsJson'], str):
|
|
179
|
+
traj_data['steps'] = json.loads(traj_data['stepsJson'])
|
|
180
|
+
|
|
181
|
+
is_valid, issues = validate_llm_calls(
|
|
182
|
+
traj_data.get('steps', []))
|
|
183
|
+
if not is_valid:
|
|
184
|
+
logger.debug(
|
|
185
|
+
f"Skipping invalid JSON trajectory {traj_data.get('trajectoryId')}: {issues}")
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
# Ensure 'id' field is present for Pydantic model validation
|
|
189
|
+
if 'id' not in traj_data:
|
|
190
|
+
traj_data['id'] = traj_data.get(
|
|
191
|
+
'trajectory_id', 'id_missing')
|
|
192
|
+
|
|
193
|
+
all_trajectories.append(
|
|
194
|
+
BabylonTrajectory.model_validate(traj_data))
|
|
195
|
+
except Exception as e:
|
|
196
|
+
logger.warning(
|
|
197
|
+
f"Skipping invalid JSON trajectory {traj_data.get('trajectoryId')}: {e}")
|
|
198
|
+
|
|
199
|
+
if len(all_trajectories) == 0:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
"Insufficient training data: 0 valid trajectories were loaded. Check validation logs with DEBUG level.")
|
|
202
|
+
elif len(all_trajectories) < 10:
|
|
203
|
+
logger.warning(
|
|
204
|
+
f"Low training data: only {len(all_trajectories)} valid trajectories found.")
|
|
205
|
+
|
|
206
|
+
logger.info(
|
|
207
|
+
f"Loaded {len(all_trajectories)} valid trajectories from JSON files.")
|
|
208
|
+
return all_trajectories
|
|
209
|
+
except (FileNotFoundError, ValueError) as e:
|
|
210
|
+
logger.error(f"Error loading JSON data: {e}")
|
|
211
|
+
sys.exit(1)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def trajectories_to_training_samples(trajectories: List[BabylonTrajectory]) -> list[dict]:
|
|
215
|
+
"""
|
|
216
|
+
Convert a list of BabylonTrajectory objects to the training sample format.
|
|
217
|
+
|
|
218
|
+
Each LLM call within a trajectory is extracted into a separate sample
|
|
219
|
+
containing a list of messages (system, user, assistant).
|
|
220
|
+
"""
|
|
221
|
+
samples = []
|
|
222
|
+
for traj in trajectories:
|
|
223
|
+
for step in traj.steps:
|
|
224
|
+
if not step.llm_calls:
|
|
225
|
+
continue
|
|
226
|
+
for llm_call in step.llm_calls:
|
|
227
|
+
# Basic quality filter for the LLM call
|
|
228
|
+
if not llm_call.response or len(llm_call.response) < 20:
|
|
229
|
+
continue
|
|
230
|
+
|
|
231
|
+
messages = []
|
|
232
|
+
if llm_call.system_prompt:
|
|
233
|
+
messages.append(
|
|
234
|
+
{"role": "system", "content": llm_call.system_prompt})
|
|
235
|
+
if llm_call.user_prompt:
|
|
236
|
+
messages.append(
|
|
237
|
+
{"role": "user", "content": llm_call.user_prompt})
|
|
238
|
+
messages.append(
|
|
239
|
+
{"role": "assistant", "content": llm_call.response})
|
|
240
|
+
|
|
241
|
+
if len(messages) >= 2:
|
|
242
|
+
samples.append({"messages": messages})
|
|
243
|
+
|
|
244
|
+
logger.info(
|
|
245
|
+
f"Converted {len(trajectories)} trajectories to {len(samples)} training samples")
|
|
246
|
+
return samples
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
# =============================================================================
|
|
250
|
+
# Training Backends
|
|
251
|
+
# =============================================================================
|
|
252
|
+
|
|
253
|
+
def train_mlx(
|
|
254
|
+
samples: list[dict], model_name: str, output_dir: str,
|
|
255
|
+
num_iters: int, batch_size: int, learning_rate: float
|
|
256
|
+
) -> str:
|
|
257
|
+
"""Train using MLX LoRA on Apple Silicon."""
|
|
258
|
+
import subprocess
|
|
259
|
+
import random
|
|
260
|
+
|
|
261
|
+
logger.info("=" * 60 + "\nMLX LORA TRAINING\n" + "=" * 60)
|
|
262
|
+
data_dir = os.path.join(output_dir, "training_data")
|
|
263
|
+
os.makedirs(data_dir, exist_ok=True)
|
|
264
|
+
|
|
265
|
+
random.shuffle(samples)
|
|
266
|
+
split_idx = int(len(samples) * 0.9)
|
|
267
|
+
train_samples, valid_samples = samples[:split_idx], samples[split_idx:]
|
|
268
|
+
|
|
269
|
+
with open(os.path.join(data_dir, "train.jsonl"), 'w') as f:
|
|
270
|
+
for s in train_samples:
|
|
271
|
+
f.write(json.dumps(s) + "\n")
|
|
272
|
+
with open(os.path.join(data_dir, "valid.jsonl"), 'w') as f:
|
|
273
|
+
for s in valid_samples:
|
|
274
|
+
f.write(json.dumps(s) + "\n")
|
|
275
|
+
|
|
276
|
+
adapter_path = os.path.join(output_dir, "adapters")
|
|
277
|
+
import mlx_lm # type: ignore
|
|
278
|
+
cmd = [
|
|
279
|
+
sys.executable, "-m", "mlx_lm", "lora", "--model", model_name, "--train",
|
|
280
|
+
"--data", data_dir, "--adapter-path", adapter_path, "--batch-size", str(
|
|
281
|
+
batch_size),
|
|
282
|
+
"--iters", str(num_iters), "--learning-rate", str(learning_rate),
|
|
283
|
+
"--steps-per-report", "10", "--steps-per-eval", "25", "--val-batches", "5",
|
|
284
|
+
"--max-seq-length", "1024", "--num-layers", "8", "--mask-prompt",
|
|
285
|
+
]
|
|
286
|
+
logger.info(f"Command: {' '.join(cmd)}")
|
|
287
|
+
subprocess.run(cmd, check=True)
|
|
288
|
+
return adapter_path
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def train_cuda(
|
|
292
|
+
samples: list[dict], model_name: str, output_dir: str,
|
|
293
|
+
epochs: int, batch_size: int, learning_rate: float, use_lora: bool
|
|
294
|
+
) -> str:
|
|
295
|
+
"""Train using PyTorch/CUDA on NVIDIA GPU."""
|
|
296
|
+
import torch
|
|
297
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
|
|
298
|
+
from datasets import Dataset
|
|
299
|
+
|
|
300
|
+
logger.info("=" * 60 + "\nCUDA/PYTORCH TRAINING\n" + "=" * 60)
|
|
301
|
+
logger.info(
|
|
302
|
+
f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
|
303
|
+
|
|
304
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
305
|
+
model_name, trust_remote_code=True)
|
|
306
|
+
if tokenizer.pad_token is None:
|
|
307
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
308
|
+
|
|
309
|
+
formatted = [{"text": tokenizer.apply_chat_template(
|
|
310
|
+
s['messages'], tokenize=False, add_generation_prompt=False)} for s in samples if s.get("messages")]
|
|
311
|
+
dataset = Dataset.from_list(formatted)
|
|
312
|
+
|
|
313
|
+
def tokenize_fn(examples):
|
|
314
|
+
# Using a shorter sequence length to prevent CUDA out-of-memory errors
|
|
315
|
+
# on consumer GPUs. The memory usage scales quadratically with this value.
|
|
316
|
+
return tokenizer(
|
|
317
|
+
examples["text"],
|
|
318
|
+
truncation=True,
|
|
319
|
+
max_length=1024, # Reduced from 2048 to fit in ~12GB VRAM
|
|
320
|
+
padding="max_length",
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
|
|
324
|
+
|
|
325
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
326
|
+
model_name, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto")
|
|
327
|
+
|
|
328
|
+
if use_lora:
|
|
329
|
+
from peft import LoraConfig, get_peft_model, TaskType
|
|
330
|
+
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=16, lora_alpha=32,
|
|
331
|
+
lora_dropout=0.1, target_modules=["q_proj", "v_proj", "k_proj", "o_proj"])
|
|
332
|
+
model = get_peft_model(model, lora_config)
|
|
333
|
+
model.print_trainable_parameters()
|
|
334
|
+
|
|
335
|
+
# Optimized training arguments for consumer GPUs (~12GB VRAM)
|
|
336
|
+
training_args = TrainingArguments(
|
|
337
|
+
output_dir=output_dir,
|
|
338
|
+
num_train_epochs=epochs,
|
|
339
|
+
# Smallest possible batch size to save memory
|
|
340
|
+
per_device_train_batch_size=1,
|
|
341
|
+
gradient_accumulation_steps=8, # Compensate for small batch size
|
|
342
|
+
learning_rate=learning_rate,
|
|
343
|
+
warmup_steps=100,
|
|
344
|
+
logging_steps=10,
|
|
345
|
+
save_steps=500,
|
|
346
|
+
save_total_limit=2,
|
|
347
|
+
fp16=True,
|
|
348
|
+
report_to="none",
|
|
349
|
+
remove_unused_columns=False
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
trainer = Trainer(model=model, args=training_args, train_dataset=tokenized,
|
|
353
|
+
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False))
|
|
354
|
+
|
|
355
|
+
trainer.train()
|
|
356
|
+
trainer.save_model(output_dir)
|
|
357
|
+
return output_dir
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def train_cpu(samples: list[dict], model_name: str, output_dir: str, epochs: int, batch_size: int, learning_rate: float) -> str:
|
|
361
|
+
"""Train using CPU (slow fallback)."""
|
|
362
|
+
logger.warning("=" * 60 + "\nCPU TRAINING (VERY SLOW)\n" + "=" * 60)
|
|
363
|
+
# Using the CUDA function is fine here, as transformers will default to CPU if no GPU is found.
|
|
364
|
+
# We force a smaller model to make it feasible.
|
|
365
|
+
return train_cuda(samples, "Qwen/Qwen2.5-0.5B-Instruct", output_dir, epochs, batch_size, learning_rate, use_lora=False)
|
|
366
|
+
|
|
367
|
+
# =============================================================================
|
|
368
|
+
# Validation
|
|
369
|
+
# =============================================================================
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def validate_trained_model(model_path: str, backend: Literal["mlx", "cuda", "cpu"], base_model: str | None = None) -> bool:
|
|
373
|
+
"""Validate the trained model by generating a test response."""
|
|
374
|
+
logger.info("=" * 60 + "\nVALIDATING TRAINED MODEL\n" + "=" * 60)
|
|
375
|
+
test_prompt = """You are a trading agent in Babylon prediction markets.
|
|
376
|
+
|
|
377
|
+
Current State:
|
|
378
|
+
- Balance: $10,000
|
|
379
|
+
- P&L: $250
|
|
380
|
+
- Positions: 2 open
|
|
381
|
+
|
|
382
|
+
Market Update:
|
|
383
|
+
- BTC prediction market at 68% probability
|
|
384
|
+
- Recent news: Fed announces rate cut consideration
|
|
385
|
+
|
|
386
|
+
Analyze this market update and explain your trading decision."""
|
|
387
|
+
|
|
388
|
+
try:
|
|
389
|
+
if backend == "mlx":
|
|
390
|
+
from mlx_lm import load, generate # type: ignore
|
|
391
|
+
model, tokenizer = load(base_model, adapter_path=model_path)
|
|
392
|
+
messages = [{"role": "user", "content": test_prompt}]
|
|
393
|
+
prompt = tokenizer.apply_chat_template(
|
|
394
|
+
messages, tokenize=False, add_generation_prompt=True)
|
|
395
|
+
response = generate(model, tokenizer, prompt=prompt,
|
|
396
|
+
max_tokens=200, verbose=False)
|
|
397
|
+
else:
|
|
398
|
+
import torch
|
|
399
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
400
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
401
|
+
model_path, trust_remote_code=True)
|
|
402
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
403
|
+
model_path,
|
|
404
|
+
torch_dtype=torch.float16 if backend == "cuda" else torch.float32,
|
|
405
|
+
device_map="auto" if backend == "cuda" else None,
|
|
406
|
+
trust_remote_code=True,
|
|
407
|
+
)
|
|
408
|
+
messages = [{"role": "user", "content": test_prompt}]
|
|
409
|
+
prompt = tokenizer.apply_chat_template(
|
|
410
|
+
messages, tokenize=False, add_generation_prompt=True)
|
|
411
|
+
inputs = tokenizer(prompt, return_tensors="pt")
|
|
412
|
+
if backend == "cuda":
|
|
413
|
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
|
414
|
+
outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.7,
|
|
415
|
+
do_sample=True, pad_token_id=tokenizer.eos_token_id)
|
|
416
|
+
response = tokenizer.decode(
|
|
417
|
+
outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
|
418
|
+
|
|
419
|
+
logger.info("Test Response:\n" + "-" * 40 +
|
|
420
|
+
f"\n{response[:500]}...\n" + "-" * 40)
|
|
421
|
+
|
|
422
|
+
if len(response) < 50:
|
|
423
|
+
logger.error("Response too short - model may not be working")
|
|
424
|
+
return False
|
|
425
|
+
|
|
426
|
+
logger.info("✅ Model validation passed!")
|
|
427
|
+
return True
|
|
428
|
+
|
|
429
|
+
except Exception as e:
|
|
430
|
+
logger.error(f"Model validation failed: {e}", exc_info=True)
|
|
431
|
+
return False
|
|
432
|
+
|
|
433
|
+
# =============================================================================
|
|
434
|
+
# Main
|
|
435
|
+
# =============================================================================
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
async def main_async(args):
|
|
439
|
+
"""Main async training function."""
|
|
440
|
+
backend = args.backend or detect_backend()
|
|
441
|
+
model_name = args.model or (
|
|
442
|
+
"mlx-community/Qwen2.5-1.5B-Instruct-4bit" if backend == "mlx" else "Qwen/Qwen2.5-1.5B-Instruct")
|
|
443
|
+
logger.info(f"Using backend: {backend}, Model: {model_name}")
|
|
444
|
+
os.makedirs(args.output, exist_ok=True)
|
|
445
|
+
|
|
446
|
+
try:
|
|
447
|
+
# Main logic to select data source based on CLI arguments
|
|
448
|
+
if args.source_dir:
|
|
449
|
+
trajectories = load_json_training_data(
|
|
450
|
+
args.source_dir, args.max_trajectories)
|
|
451
|
+
else:
|
|
452
|
+
database_url = args.database_url or os.getenv("DATABASE_URL")
|
|
453
|
+
if not database_url:
|
|
454
|
+
logger.error(
|
|
455
|
+
"DATABASE_URL not set and --source-dir not provided. Exiting.")
|
|
456
|
+
return 1
|
|
457
|
+
trajectories = await load_postgres_training_data(database_url, args.min_actions, args.lookback_hours, args.max_trajectories)
|
|
458
|
+
except (ValueError, FileNotFoundError) as e:
|
|
459
|
+
logger.error(f"Failed to load data: {e}")
|
|
460
|
+
return 1
|
|
461
|
+
|
|
462
|
+
samples = trajectories_to_training_samples(trajectories)
|
|
463
|
+
if len(samples) < 10:
|
|
464
|
+
logger.error(
|
|
465
|
+
f"Not enough valid training samples found: {len(samples)}")
|
|
466
|
+
return 1
|
|
467
|
+
|
|
468
|
+
model_path, base_model = "", None
|
|
469
|
+
try:
|
|
470
|
+
if backend == "mlx":
|
|
471
|
+
model_path, base_model = train_mlx(
|
|
472
|
+
samples, model_name, args.output, args.iters, args.batch_size, args.lr), model_name
|
|
473
|
+
elif backend == "cuda":
|
|
474
|
+
model_path = train_cuda(
|
|
475
|
+
samples, model_name, args.output, args.epochs, args.batch_size, args.lr, args.lora)
|
|
476
|
+
else: # cpu
|
|
477
|
+
model_path = train_cpu(
|
|
478
|
+
samples, model_name, args.output, args.epochs, args.batch_size, args.lr)
|
|
479
|
+
except Exception as e:
|
|
480
|
+
logger.error(f"Training process failed: {e}", exc_info=True)
|
|
481
|
+
return 1
|
|
482
|
+
|
|
483
|
+
if args.validate and model_path:
|
|
484
|
+
validate_trained_model(model_path, backend, base_model)
|
|
485
|
+
|
|
486
|
+
logger.info("\n" + "="*60 + "\nTRAINING COMPLETE\n" +
|
|
487
|
+
f" Model/adapter saved to: {model_path}\n" + "="*60)
|
|
488
|
+
return 0
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def main():
|
|
492
|
+
parser = argparse.ArgumentParser(
|
|
493
|
+
description="Babylon Local Training", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
494
|
+
|
|
495
|
+
parser.add_argument(
|
|
496
|
+
"--source-dir", help="Directory with local JSON trajectory files for offline training.")
|
|
497
|
+
parser.add_argument(
|
|
498
|
+
"--database-url", help="Database URL (used if --source-dir is not provided).")
|
|
499
|
+
parser.add_argument("--backend", choices=["mlx", "cuda", "cpu"],
|
|
500
|
+
help="Training backend (auto-detected if not specified)")
|
|
501
|
+
parser.add_argument(
|
|
502
|
+
"--model", help="Model to train (default depends on backend)")
|
|
503
|
+
parser.add_argument("--min-actions", type=int, default=3,
|
|
504
|
+
help="Minimum actions per trajectory (DB source)")
|
|
505
|
+
parser.add_argument("--lookback-hours", type=int, default=168,
|
|
506
|
+
help="Hours to look back for trajectories (DB source)")
|
|
507
|
+
parser.add_argument("--max-trajectories", type=int,
|
|
508
|
+
default=500, help="Maximum trajectories to load")
|
|
509
|
+
parser.add_argument(
|
|
510
|
+
"--output", default="./trained_models/local", help="Output directory")
|
|
511
|
+
parser.add_argument("--iters", type=int, default=100,
|
|
512
|
+
help="Training iterations (MLX)")
|
|
513
|
+
parser.add_argument("--epochs", type=int, default=3,
|
|
514
|
+
help="Training epochs (CUDA/CPU)")
|
|
515
|
+
parser.add_argument("--batch-size", type=int, default=2,
|
|
516
|
+
help="Batch size (Note: CUDA uses a fixed batch size of 1 for memory optimization)")
|
|
517
|
+
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
|
|
518
|
+
parser.add_argument("--lora", action=argparse.BooleanOptionalAction,
|
|
519
|
+
default=True, help="Use LoRA (CUDA only)")
|
|
520
|
+
parser.add_argument("--validate", action=argparse.BooleanOptionalAction,
|
|
521
|
+
default=True, help="Validate trained model")
|
|
522
|
+
|
|
523
|
+
args = parser.parse_args()
|
|
524
|
+
return asyncio.run(main_async(args))
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
if __name__ == "__main__":
|
|
528
|
+
sys.exit(main())
|
package/python/setup.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Setup file for ElizaOS RL training with Atropos."""
|
|
2
|
+
|
|
3
|
+
from setuptools import setup, find_packages
|
|
4
|
+
|
|
5
|
+
setup(
|
|
6
|
+
name="elizaos-training",
|
|
7
|
+
version="1.0.0",
|
|
8
|
+
packages=find_packages(where="src"),
|
|
9
|
+
package_dir={"": "src"},
|
|
10
|
+
python_requires=">=3.10",
|
|
11
|
+
install_requires=[
|
|
12
|
+
"atroposlib>=0.3.0",
|
|
13
|
+
"asyncpg>=0.29.0",
|
|
14
|
+
"python-dotenv>=1.0.0",
|
|
15
|
+
"pydantic>=2.0.0",
|
|
16
|
+
"openai>=1.0.0",
|
|
17
|
+
"torch>=2.1.0",
|
|
18
|
+
"transformers>=4.36.0",
|
|
19
|
+
],
|
|
20
|
+
)
|