@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,914 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
ElizaOS RL Training - Full Pipeline Runner
|
|
4
|
+
|
|
5
|
+
This script orchestrates the complete RLAIF training pipeline:
|
|
6
|
+
1. Validates environment and prerequisites
|
|
7
|
+
2. Starts background services (Atropos API, vLLM)
|
|
8
|
+
3. Starts the RLAIF environment
|
|
9
|
+
4. Runs the GRPO trainer with optional W&B logging
|
|
10
|
+
|
|
11
|
+
Usage:
|
|
12
|
+
# Use a GPU profile (recommended - auto-configures for your hardware)
|
|
13
|
+
python scripts/run_training.py --profile 12gb --steps 100
|
|
14
|
+
python scripts/run_training.py --profile 24gb --steps 100
|
|
15
|
+
|
|
16
|
+
# List available profiles
|
|
17
|
+
python scripts/run_training.py --list-profiles
|
|
18
|
+
|
|
19
|
+
# Manual configuration (override profile or use without profile)
|
|
20
|
+
python scripts/run_training.py --model Qwen/Qwen2.5-0.5B-Instruct --vllm-gpu-memory 0.25 --steps 100
|
|
21
|
+
|
|
22
|
+
# Resume from checkpoint
|
|
23
|
+
python scripts/run_training.py --profile 12gb --resume ./trained_models/step_50
|
|
24
|
+
|
|
25
|
+
# Disable W&B
|
|
26
|
+
python scripts/run_training.py --profile 12gb --steps 100 --no-wandb
|
|
27
|
+
|
|
28
|
+
GPU Profiles (config/profiles/*.json):
|
|
29
|
+
12gb - RTX 3060/4070 (0.5B model, 25% vLLM memory)
|
|
30
|
+
16gb - RTX 4080/A4000 (1.5B model, 35% vLLM memory)
|
|
31
|
+
24gb - RTX 4090/A5000 (3B model, 40% vLLM memory)
|
|
32
|
+
48gb - A40/A6000 (7B model, 45% vLLM memory)
|
|
33
|
+
|
|
34
|
+
Or run components separately:
|
|
35
|
+
Terminal 1: run-api
|
|
36
|
+
Terminal 2: python -m src.training.rlaif_env serve --slurm false
|
|
37
|
+
Terminal 3: python -m src.training.atropos_trainer --steps 100
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
import argparse
|
|
41
|
+
import json
|
|
42
|
+
import logging
|
|
43
|
+
import os
|
|
44
|
+
import signal
|
|
45
|
+
import subprocess
|
|
46
|
+
import sys
|
|
47
|
+
import time
|
|
48
|
+
from pathlib import Path
|
|
49
|
+
from typing import Optional
|
|
50
|
+
|
|
51
|
+
# Add src to path
|
|
52
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
53
|
+
|
|
54
|
+
from dotenv import load_dotenv
|
|
55
|
+
|
|
56
|
+
# Load environment
|
|
57
|
+
load_dotenv()
|
|
58
|
+
|
|
59
|
+
logging.basicConfig(
|
|
60
|
+
level=logging.INFO,
|
|
61
|
+
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s'
|
|
62
|
+
)
|
|
63
|
+
logger = logging.getLogger(__name__)
|
|
64
|
+
|
|
65
|
+
# Profile directory
|
|
66
|
+
PROFILES_DIR = Path(__file__).parent.parent / "config" / "profiles"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_available_profiles() -> list[str]:
|
|
70
|
+
"""Get list of available GPU profiles."""
|
|
71
|
+
if not PROFILES_DIR.exists():
|
|
72
|
+
return []
|
|
73
|
+
return [p.stem for p in PROFILES_DIR.glob("*.json")]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def load_profile(profile_name: str) -> dict:
|
|
77
|
+
"""Load a GPU profile by name."""
|
|
78
|
+
profile_path = PROFILES_DIR / f"{profile_name}.json"
|
|
79
|
+
if not profile_path.exists():
|
|
80
|
+
available = get_available_profiles()
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"Profile '{profile_name}' not found. "
|
|
83
|
+
f"Available: {', '.join(available) or 'none'}"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
with open(profile_path) as f:
|
|
87
|
+
profile = json.load(f)
|
|
88
|
+
|
|
89
|
+
logger.info(f"Loaded profile: {profile.get('name', profile_name)}")
|
|
90
|
+
if profile.get('notes'):
|
|
91
|
+
logger.info(f" Note: {profile['notes']}")
|
|
92
|
+
|
|
93
|
+
return profile
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def list_profiles() -> None:
|
|
97
|
+
"""Print available profiles and exit."""
|
|
98
|
+
print("\nAvailable GPU Profiles:")
|
|
99
|
+
print("=" * 60)
|
|
100
|
+
|
|
101
|
+
for profile_name in sorted(get_available_profiles()):
|
|
102
|
+
try:
|
|
103
|
+
profile = load_profile(profile_name)
|
|
104
|
+
print(f"\n --profile {profile_name}")
|
|
105
|
+
print(f" {profile.get('name', 'Unnamed')}")
|
|
106
|
+
print(f" Model: {profile.get('model', 'default')}")
|
|
107
|
+
print(f" vLLM Memory: {profile.get('vllm_gpu_memory', 0.45) * 100:.0f}%")
|
|
108
|
+
if profile.get('notes'):
|
|
109
|
+
print(f" Note: {profile['notes']}")
|
|
110
|
+
except Exception as e:
|
|
111
|
+
print(f"\n --profile {profile_name}")
|
|
112
|
+
print(f" Error loading: {e}")
|
|
113
|
+
|
|
114
|
+
print()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def validate_environment() -> list[str]:
|
|
118
|
+
"""
|
|
119
|
+
Validate that all required environment variables and dependencies are present.
|
|
120
|
+
|
|
121
|
+
Returns a list of error messages for missing requirements.
|
|
122
|
+
"""
|
|
123
|
+
errors = []
|
|
124
|
+
|
|
125
|
+
# Check DATABASE_URL
|
|
126
|
+
if not os.getenv("DATABASE_URL"):
|
|
127
|
+
errors.append(
|
|
128
|
+
"DATABASE_URL not set. Required for loading training trajectories.\n"
|
|
129
|
+
" Set in .env or export DATABASE_URL=postgresql://..."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Check OPENAI_API_KEY (for RLAIF judge)
|
|
133
|
+
if not os.getenv("OPENAI_API_KEY"):
|
|
134
|
+
errors.append(
|
|
135
|
+
"OPENAI_API_KEY not set. Required for RLAIF judge scoring.\n"
|
|
136
|
+
" Set in .env or export OPENAI_API_KEY=sk-..."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Check for run-api command (Atropos)
|
|
140
|
+
import shutil
|
|
141
|
+
if not shutil.which("run-api"):
|
|
142
|
+
errors.append(
|
|
143
|
+
"Atropos API not found. Install with: pip install atroposlib"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Check for PyTorch and CUDA
|
|
147
|
+
try:
|
|
148
|
+
import torch
|
|
149
|
+
if not torch.cuda.is_available():
|
|
150
|
+
errors.append(
|
|
151
|
+
"CUDA not available. GPU is recommended for training.\n"
|
|
152
|
+
" For CPU-only (slow), use --skip-vllm and provide external inference."
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
gpu_name = torch.cuda.get_device_name(0)
|
|
156
|
+
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
|
|
157
|
+
logger.info(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")
|
|
158
|
+
except ImportError:
|
|
159
|
+
errors.append("PyTorch not installed. Install with: pip install torch")
|
|
160
|
+
|
|
161
|
+
return errors
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class TrainingOrchestrator:
|
|
165
|
+
"""
|
|
166
|
+
Orchestrates the complete training pipeline.
|
|
167
|
+
|
|
168
|
+
Manages:
|
|
169
|
+
- Service lifecycle (Atropos API, vLLM)
|
|
170
|
+
- Environment server
|
|
171
|
+
- GRPO trainer
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
model_name: str = "Qwen/Qwen2.5-3B-Instruct",
|
|
177
|
+
base_model: Optional[str] = None,
|
|
178
|
+
dataset_input: Optional[str] = None,
|
|
179
|
+
scoring_mode: str = "deterministic",
|
|
180
|
+
training_steps: int = 100,
|
|
181
|
+
batch_size: int = 4,
|
|
182
|
+
learning_rate: float = 1e-5,
|
|
183
|
+
min_learning_rate: float = 1e-7,
|
|
184
|
+
lr_scheduler: str = "cosine",
|
|
185
|
+
warmup_steps: int = 10,
|
|
186
|
+
api_port: int = 8000,
|
|
187
|
+
vllm_host: str = "127.0.0.1",
|
|
188
|
+
vllm_port: int = 9001,
|
|
189
|
+
vllm_gpu_memory: float = 0.45,
|
|
190
|
+
save_path: str = "./trained_models",
|
|
191
|
+
save_every: int = 5,
|
|
192
|
+
keep_checkpoints: int = 3,
|
|
193
|
+
resume_from: Optional[str] = None,
|
|
194
|
+
use_wandb: bool = True,
|
|
195
|
+
wandb_project: str = "eliza-training",
|
|
196
|
+
wandb_entity: Optional[str] = None,
|
|
197
|
+
wandb_run_name: Optional[str] = None,
|
|
198
|
+
skip_services: bool = False,
|
|
199
|
+
log_dir: str = "./logs",
|
|
200
|
+
# Phase 3: Online training parameters
|
|
201
|
+
mode: str = "offline",
|
|
202
|
+
bridge_url: str = "http://localhost:3001",
|
|
203
|
+
hybrid_online_ratio: float = 0.2,
|
|
204
|
+
# Phase 4: Cloud/Multi-GPU parameters
|
|
205
|
+
tensor_parallel_size: int = 1,
|
|
206
|
+
use_flash_attention: bool = False,
|
|
207
|
+
vllm_gpu: Optional[str] = None, # Explicit GPU assignment for vLLM
|
|
208
|
+
training_gpu: Optional[str] = None, # Explicit GPU assignment for training
|
|
209
|
+
):
|
|
210
|
+
self.model_name = model_name
|
|
211
|
+
self.base_model = base_model
|
|
212
|
+
self.dataset_input = dataset_input
|
|
213
|
+
self.scoring_mode = scoring_mode
|
|
214
|
+
self.training_steps = training_steps
|
|
215
|
+
self.batch_size = batch_size
|
|
216
|
+
self.learning_rate = learning_rate
|
|
217
|
+
self.min_learning_rate = min_learning_rate
|
|
218
|
+
self.lr_scheduler = lr_scheduler
|
|
219
|
+
self.warmup_steps = warmup_steps
|
|
220
|
+
self.api_port = api_port
|
|
221
|
+
self.vllm_host = vllm_host
|
|
222
|
+
self.vllm_port = vllm_port
|
|
223
|
+
self.vllm_gpu_memory = vllm_gpu_memory
|
|
224
|
+
self.save_path = save_path
|
|
225
|
+
self.save_every = save_every
|
|
226
|
+
self.keep_checkpoints = keep_checkpoints
|
|
227
|
+
self.resume_from = resume_from
|
|
228
|
+
self.use_wandb = use_wandb
|
|
229
|
+
self.wandb_project = wandb_project
|
|
230
|
+
self.wandb_entity = wandb_entity
|
|
231
|
+
self.wandb_run_name = wandb_run_name
|
|
232
|
+
self.skip_services = skip_services
|
|
233
|
+
self.log_dir = Path(log_dir)
|
|
234
|
+
# Phase 3: Online training
|
|
235
|
+
self.mode = mode
|
|
236
|
+
self.bridge_url = bridge_url
|
|
237
|
+
self.hybrid_online_ratio = hybrid_online_ratio
|
|
238
|
+
# Phase 4: Cloud/Multi-GPU
|
|
239
|
+
self.tensor_parallel_size = tensor_parallel_size
|
|
240
|
+
self.use_flash_attention = use_flash_attention
|
|
241
|
+
self.vllm_gpu = vllm_gpu
|
|
242
|
+
self.training_gpu = training_gpu
|
|
243
|
+
|
|
244
|
+
self.env_process: Optional[subprocess.Popen] = None
|
|
245
|
+
self.trainer_process: Optional[subprocess.Popen] = None
|
|
246
|
+
self._service_manager = None
|
|
247
|
+
self._shutdown_requested = False
|
|
248
|
+
self._log_handles: list = [] # Track open file handles
|
|
249
|
+
|
|
250
|
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
251
|
+
|
|
252
|
+
signal.signal(signal.SIGINT, self._signal_handler)
|
|
253
|
+
signal.signal(signal.SIGTERM, self._signal_handler)
|
|
254
|
+
|
|
255
|
+
def _signal_handler(self, signum, frame):
|
|
256
|
+
"""Handle shutdown signals"""
|
|
257
|
+
if self._shutdown_requested:
|
|
258
|
+
logger.warning("Forced shutdown, exiting immediately")
|
|
259
|
+
sys.exit(1)
|
|
260
|
+
|
|
261
|
+
logger.info("Received shutdown signal, cleaning up...")
|
|
262
|
+
self._shutdown_requested = True
|
|
263
|
+
self.cleanup()
|
|
264
|
+
sys.exit(0)
|
|
265
|
+
|
|
266
|
+
def cleanup(self):
|
|
267
|
+
"""Clean up all subprocesses and services"""
|
|
268
|
+
self._stop_process(self.trainer_process, "trainer")
|
|
269
|
+
self._stop_process(self.env_process, "environment")
|
|
270
|
+
|
|
271
|
+
if self._service_manager:
|
|
272
|
+
self._service_manager.stop_all()
|
|
273
|
+
|
|
274
|
+
for handle in self._log_handles:
|
|
275
|
+
handle.close()
|
|
276
|
+
self._log_handles.clear()
|
|
277
|
+
|
|
278
|
+
def _stop_process(self, proc: Optional[subprocess.Popen], name: str, timeout: int = 10) -> None:
|
|
279
|
+
"""Stop a subprocess gracefully"""
|
|
280
|
+
if not proc:
|
|
281
|
+
return
|
|
282
|
+
|
|
283
|
+
logger.info(f"Stopping {name}...")
|
|
284
|
+
proc.terminate()
|
|
285
|
+
|
|
286
|
+
deadline = time.time() + timeout
|
|
287
|
+
while proc.poll() is None and time.time() < deadline:
|
|
288
|
+
time.sleep(0.5)
|
|
289
|
+
|
|
290
|
+
if proc.poll() is None:
|
|
291
|
+
proc.kill()
|
|
292
|
+
proc.wait()
|
|
293
|
+
|
|
294
|
+
def start_services(self) -> bool:
|
|
295
|
+
"""Start background services using ServiceManager"""
|
|
296
|
+
if self.skip_services:
|
|
297
|
+
logger.info("Skipping service startup (--skip-services)")
|
|
298
|
+
return True
|
|
299
|
+
|
|
300
|
+
from src.training.service_manager import ServiceManager, ServiceConfig
|
|
301
|
+
|
|
302
|
+
config = ServiceConfig(
|
|
303
|
+
atropos_port=self.api_port,
|
|
304
|
+
vllm_port=self.vllm_port,
|
|
305
|
+
model_name=self.model_name,
|
|
306
|
+
vllm_gpu_memory_utilization=self.vllm_gpu_memory,
|
|
307
|
+
log_dir=str(self.log_dir / "services"),
|
|
308
|
+
# Phase 4: Multi-GPU support
|
|
309
|
+
tensor_parallel_size=self.tensor_parallel_size,
|
|
310
|
+
use_flash_attention=self.use_flash_attention,
|
|
311
|
+
vllm_gpu=self.vllm_gpu,
|
|
312
|
+
training_gpu=self.training_gpu,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
self._service_manager = ServiceManager(config)
|
|
316
|
+
|
|
317
|
+
if not self._service_manager.start_all():
|
|
318
|
+
return False
|
|
319
|
+
|
|
320
|
+
if not self._service_manager.wait_for_ready():
|
|
321
|
+
logger.error("Services failed to become ready")
|
|
322
|
+
return False
|
|
323
|
+
|
|
324
|
+
return True
|
|
325
|
+
|
|
326
|
+
def check_bridge_health(self) -> bool:
|
|
327
|
+
"""Check if simulation bridge is running and healthy"""
|
|
328
|
+
import urllib.request
|
|
329
|
+
import urllib.error
|
|
330
|
+
|
|
331
|
+
logger.info(f"Checking simulation bridge at {self.bridge_url}...")
|
|
332
|
+
|
|
333
|
+
health_url = f"{self.bridge_url}/health"
|
|
334
|
+
for attempt in range(3):
|
|
335
|
+
try:
|
|
336
|
+
req = urllib.request.Request(health_url, method='GET')
|
|
337
|
+
with urllib.request.urlopen(req, timeout=5) as resp:
|
|
338
|
+
if resp.status == 200:
|
|
339
|
+
logger.info("Simulation bridge is healthy ✓")
|
|
340
|
+
return True
|
|
341
|
+
except urllib.error.URLError as e:
|
|
342
|
+
if attempt < 2:
|
|
343
|
+
logger.warning(f"Bridge not ready (attempt {attempt + 1}/3): {e}")
|
|
344
|
+
time.sleep(2)
|
|
345
|
+
else:
|
|
346
|
+
logger.error(f"Simulation bridge not available at {self.bridge_url}")
|
|
347
|
+
logger.error("Start it with: make bridge-server")
|
|
348
|
+
return False
|
|
349
|
+
except Exception as e:
|
|
350
|
+
logger.error(f"Bridge health check failed: {e}")
|
|
351
|
+
return False
|
|
352
|
+
|
|
353
|
+
return False
|
|
354
|
+
|
|
355
|
+
def start_environment(self) -> bool:
|
|
356
|
+
"""Start RLAIF environment (offline mode)"""
|
|
357
|
+
logger.info("Starting RLAIF environment (offline mode)...")
|
|
358
|
+
|
|
359
|
+
env_cmd = [
|
|
360
|
+
sys.executable, "-m", "src.training.rlaif_env", "serve",
|
|
361
|
+
"--slurm", "false",
|
|
362
|
+
"--env.tokenizer_name", self.model_name,
|
|
363
|
+
"--env.scoring_mode", self.scoring_mode,
|
|
364
|
+
"--env.rollout_server_url", f"http://localhost:{self.api_port}",
|
|
365
|
+
"--openai.model_name", self.model_name,
|
|
366
|
+
"--openai.base_url", f"http://{self.vllm_host}:{self.vllm_port}/v1",
|
|
367
|
+
]
|
|
368
|
+
|
|
369
|
+
if not self.use_wandb:
|
|
370
|
+
env_cmd.extend(["--env.use_wandb", "false"])
|
|
371
|
+
|
|
372
|
+
log_file = self.log_dir / "environment.log"
|
|
373
|
+
log_handle = open(log_file, "w")
|
|
374
|
+
self._log_handles.append(log_handle)
|
|
375
|
+
|
|
376
|
+
self.env_process = subprocess.Popen(
|
|
377
|
+
env_cmd,
|
|
378
|
+
cwd=str(Path(__file__).parent.parent),
|
|
379
|
+
stdout=log_handle,
|
|
380
|
+
stderr=subprocess.STDOUT,
|
|
381
|
+
env=os.environ.copy(), # Pass environment variables including DATABASE_URL
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
time.sleep(5) # Wait for environment to initialize
|
|
385
|
+
|
|
386
|
+
if self.env_process.poll() is not None:
|
|
387
|
+
logger.error(f"Environment failed to start (exit code: {self.env_process.returncode})")
|
|
388
|
+
logger.error(f"Check logs at: {log_file}")
|
|
389
|
+
return False
|
|
390
|
+
|
|
391
|
+
logger.info(f"Environment started (PID: {self.env_process.pid}), logs: {log_file}")
|
|
392
|
+
return True
|
|
393
|
+
|
|
394
|
+
def start_online_environment(self) -> bool:
|
|
395
|
+
"""Start online environment (online mode with simulation bridge)"""
|
|
396
|
+
logger.info("Starting online environment (online mode)...")
|
|
397
|
+
|
|
398
|
+
env_cmd = [
|
|
399
|
+
sys.executable, "-m", "src.training.online_env", "serve",
|
|
400
|
+
"--slurm", "false",
|
|
401
|
+
"--env.tokenizer_name", self.model_name,
|
|
402
|
+
"--env.rollout_server_url", f"http://localhost:{self.api_port}",
|
|
403
|
+
"--openai.model_name", self.model_name,
|
|
404
|
+
"--openai.base_url", f"http://{self.vllm_host}:{self.vllm_port}/v1",
|
|
405
|
+
# Online-specific settings
|
|
406
|
+
"--env.use_simulation_bridge", "true",
|
|
407
|
+
"--env.simulation_bridge_url", self.bridge_url,
|
|
408
|
+
]
|
|
409
|
+
|
|
410
|
+
if not self.use_wandb:
|
|
411
|
+
env_cmd.extend(["--env.use_wandb", "false"])
|
|
412
|
+
|
|
413
|
+
log_file = self.log_dir / "online_environment.log"
|
|
414
|
+
log_handle = open(log_file, "w")
|
|
415
|
+
self._log_handles.append(log_handle)
|
|
416
|
+
|
|
417
|
+
# Set environment variables for bridge
|
|
418
|
+
env_vars = os.environ.copy()
|
|
419
|
+
env_vars["USE_SIMULATION_BRIDGE"] = "1"
|
|
420
|
+
env_vars["SIMULATION_BRIDGE_URL"] = self.bridge_url
|
|
421
|
+
|
|
422
|
+
self.env_process = subprocess.Popen(
|
|
423
|
+
env_cmd,
|
|
424
|
+
cwd=str(Path(__file__).parent.parent),
|
|
425
|
+
stdout=log_handle,
|
|
426
|
+
stderr=subprocess.STDOUT,
|
|
427
|
+
env=env_vars,
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
time.sleep(5) # Wait for environment to initialize
|
|
431
|
+
|
|
432
|
+
if self.env_process.poll() is not None:
|
|
433
|
+
logger.error(f"Online environment failed to start (exit code: {self.env_process.returncode})")
|
|
434
|
+
logger.error(f"Check logs at: {log_file}")
|
|
435
|
+
return False
|
|
436
|
+
|
|
437
|
+
logger.info(f"Online environment started (PID: {self.env_process.pid}), logs: {log_file}")
|
|
438
|
+
return True
|
|
439
|
+
|
|
440
|
+
def start_hybrid_environment(self) -> bool:
|
|
441
|
+
"""Start hybrid environment (mix of offline and online)"""
|
|
442
|
+
logger.info(f"Starting hybrid environment (online ratio: {self.hybrid_online_ratio:.0%})...")
|
|
443
|
+
|
|
444
|
+
env_cmd = [
|
|
445
|
+
sys.executable, "-m", "src.training.hybrid_env", "serve",
|
|
446
|
+
"--slurm", "false",
|
|
447
|
+
"--env.tokenizer_name", self.model_name,
|
|
448
|
+
"--env.rollout_server_url", f"http://localhost:{self.api_port}",
|
|
449
|
+
"--openai.model_name", self.model_name,
|
|
450
|
+
"--openai.base_url", f"http://{self.vllm_host}:{self.vllm_port}/v1",
|
|
451
|
+
# Hybrid-specific settings
|
|
452
|
+
"--env.use_simulation_bridge", "true",
|
|
453
|
+
"--env.simulation_bridge_url", self.bridge_url,
|
|
454
|
+
"--env.online_ratio", str(self.hybrid_online_ratio),
|
|
455
|
+
]
|
|
456
|
+
|
|
457
|
+
if not self.use_wandb:
|
|
458
|
+
env_cmd.extend(["--env.use_wandb", "false"])
|
|
459
|
+
|
|
460
|
+
log_file = self.log_dir / "hybrid_environment.log"
|
|
461
|
+
log_handle = open(log_file, "w")
|
|
462
|
+
self._log_handles.append(log_handle)
|
|
463
|
+
|
|
464
|
+
# Set environment variables
|
|
465
|
+
env_vars = os.environ.copy()
|
|
466
|
+
env_vars["USE_SIMULATION_BRIDGE"] = "1"
|
|
467
|
+
env_vars["SIMULATION_BRIDGE_URL"] = self.bridge_url
|
|
468
|
+
env_vars["HYBRID_ONLINE_RATIO"] = str(self.hybrid_online_ratio)
|
|
469
|
+
|
|
470
|
+
self.env_process = subprocess.Popen(
|
|
471
|
+
env_cmd,
|
|
472
|
+
cwd=str(Path(__file__).parent.parent),
|
|
473
|
+
stdout=log_handle,
|
|
474
|
+
stderr=subprocess.STDOUT,
|
|
475
|
+
env=env_vars,
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
time.sleep(5) # Wait for environment to initialize
|
|
479
|
+
|
|
480
|
+
if self.env_process.poll() is not None:
|
|
481
|
+
logger.error(f"Hybrid environment failed to start (exit code: {self.env_process.returncode})")
|
|
482
|
+
logger.error(f"Check logs at: {log_file}")
|
|
483
|
+
return False
|
|
484
|
+
|
|
485
|
+
logger.info(f"Hybrid environment started (PID: {self.env_process.pid}), logs: {log_file}")
|
|
486
|
+
return True
|
|
487
|
+
|
|
488
|
+
def start_trainer(self) -> bool:
|
|
489
|
+
"""Start GRPO trainer"""
|
|
490
|
+
logger.info("Starting GRPO trainer...")
|
|
491
|
+
|
|
492
|
+
trainer_cmd = [
|
|
493
|
+
sys.executable, "-m", "src.training.atropos_trainer",
|
|
494
|
+
"--model", self.model_name,
|
|
495
|
+
"--scoring-mode", self.scoring_mode,
|
|
496
|
+
"--steps", str(self.training_steps),
|
|
497
|
+
"--batch-size", str(self.batch_size),
|
|
498
|
+
"--lr", str(self.learning_rate),
|
|
499
|
+
"--min-lr", str(self.min_learning_rate),
|
|
500
|
+
"--lr-scheduler", self.lr_scheduler,
|
|
501
|
+
"--warmup-steps", str(self.warmup_steps),
|
|
502
|
+
"--api-url", f"http://localhost:{self.api_port}",
|
|
503
|
+
"--vllm-host", self.vllm_host,
|
|
504
|
+
"--vllm-port", str(self.vllm_port),
|
|
505
|
+
"--vllm-gpu-utilization", str(self.vllm_gpu_memory),
|
|
506
|
+
"--save-path", self.save_path,
|
|
507
|
+
"--save-every", str(self.save_every),
|
|
508
|
+
"--keep-checkpoints", str(self.keep_checkpoints),
|
|
509
|
+
"--log-file", str(self.log_dir / "training_metrics.jsonl"),
|
|
510
|
+
"--wandb-project", self.wandb_project,
|
|
511
|
+
"--skip-vllm", # vLLM already started by ServiceManager
|
|
512
|
+
]
|
|
513
|
+
|
|
514
|
+
if self.base_model:
|
|
515
|
+
trainer_cmd.extend(["--base-model", self.base_model])
|
|
516
|
+
if self.dataset_input:
|
|
517
|
+
trainer_cmd.extend(["--dataset-input", self.dataset_input])
|
|
518
|
+
if self.resume_from:
|
|
519
|
+
trainer_cmd.extend(["--resume", self.resume_from])
|
|
520
|
+
if not self.use_wandb:
|
|
521
|
+
trainer_cmd.append("--no-wandb")
|
|
522
|
+
if self.wandb_entity:
|
|
523
|
+
trainer_cmd.extend(["--wandb-entity", self.wandb_entity])
|
|
524
|
+
if self.wandb_run_name:
|
|
525
|
+
trainer_cmd.extend(["--wandb-run-name", self.wandb_run_name])
|
|
526
|
+
|
|
527
|
+
# Set up environment with GPU assignment for training
|
|
528
|
+
env = os.environ.copy()
|
|
529
|
+
if self.training_gpu:
|
|
530
|
+
env["CUDA_VISIBLE_DEVICES"] = self.training_gpu
|
|
531
|
+
logger.info(f"Training GPU (explicit): {self.training_gpu}")
|
|
532
|
+
|
|
533
|
+
# Pipe stdout for streaming to console
|
|
534
|
+
self.trainer_process = subprocess.Popen(
|
|
535
|
+
trainer_cmd,
|
|
536
|
+
cwd=str(Path(__file__).parent.parent),
|
|
537
|
+
stdout=subprocess.PIPE,
|
|
538
|
+
stderr=subprocess.STDOUT,
|
|
539
|
+
env=env,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
logger.info(f"Trainer started (PID: {self.trainer_process.pid})")
|
|
543
|
+
return True
|
|
544
|
+
|
|
545
|
+
def run(self) -> int:
|
|
546
|
+
"""Run the complete training pipeline"""
|
|
547
|
+
self._log_config()
|
|
548
|
+
start_time = time.time()
|
|
549
|
+
|
|
550
|
+
try:
|
|
551
|
+
# Step 1: Start services
|
|
552
|
+
if not self.start_services():
|
|
553
|
+
logger.error("Failed to start services")
|
|
554
|
+
return 1
|
|
555
|
+
|
|
556
|
+
# Step 2: For online/hybrid modes, check bridge health
|
|
557
|
+
if self.mode in ("online", "hybrid"):
|
|
558
|
+
if not self.check_bridge_health():
|
|
559
|
+
logger.error("Simulation bridge not available")
|
|
560
|
+
logger.error("Start it with: make bridge-server")
|
|
561
|
+
return 1
|
|
562
|
+
|
|
563
|
+
# Step 3: Start appropriate environment based on mode
|
|
564
|
+
env_starter = {
|
|
565
|
+
"offline": self.start_environment,
|
|
566
|
+
"online": self.start_online_environment,
|
|
567
|
+
"hybrid": self.start_hybrid_environment,
|
|
568
|
+
}.get(self.mode, self.start_environment)
|
|
569
|
+
|
|
570
|
+
if not env_starter():
|
|
571
|
+
logger.error(f"Failed to start {self.mode} environment")
|
|
572
|
+
return 1
|
|
573
|
+
|
|
574
|
+
# Step 4: Start trainer
|
|
575
|
+
if not self.start_trainer():
|
|
576
|
+
logger.error("Failed to start trainer")
|
|
577
|
+
return 1
|
|
578
|
+
|
|
579
|
+
return_code = self._stream_trainer_output()
|
|
580
|
+
elapsed = time.time() - start_time
|
|
581
|
+
|
|
582
|
+
if return_code == 0:
|
|
583
|
+
logger.info("\n" + "=" * 70)
|
|
584
|
+
logger.info("TRAINING COMPLETED SUCCESSFULLY")
|
|
585
|
+
logger.info(f"Mode: {self.mode.upper()}")
|
|
586
|
+
logger.info(f"Total time: {elapsed:.1f}s ({elapsed/60:.1f} minutes)")
|
|
587
|
+
logger.info(f"Model saved to: {self.save_path}")
|
|
588
|
+
logger.info("=" * 70)
|
|
589
|
+
else:
|
|
590
|
+
logger.error(f"Training failed with return code: {return_code}")
|
|
591
|
+
logger.error(f"Check logs at: {self.log_dir}")
|
|
592
|
+
|
|
593
|
+
return return_code
|
|
594
|
+
finally:
|
|
595
|
+
self.cleanup()
|
|
596
|
+
|
|
597
|
+
def _log_config(self):
|
|
598
|
+
"""Log training configuration"""
|
|
599
|
+
logger.info("=" * 70)
|
|
600
|
+
logger.info("ELIZAOS RL TRAINING PIPELINE")
|
|
601
|
+
logger.info("=" * 70)
|
|
602
|
+
logger.info(f"Mode: {self.mode.upper()}")
|
|
603
|
+
if self.mode in ("online", "hybrid"):
|
|
604
|
+
logger.info(f"Bridge URL: {self.bridge_url}")
|
|
605
|
+
if self.mode == "hybrid":
|
|
606
|
+
logger.info(f"Online ratio: {self.hybrid_online_ratio:.0%}")
|
|
607
|
+
logger.info(f"Model: {self.model_name}")
|
|
608
|
+
logger.info(f"Steps: {self.training_steps}")
|
|
609
|
+
logger.info(f"Batch size: {self.batch_size}")
|
|
610
|
+
logger.info(f"Learning rate: {self.learning_rate} (scheduler: {self.lr_scheduler})")
|
|
611
|
+
logger.info(f"Save path: {self.save_path}")
|
|
612
|
+
logger.info(f"W&B: {'enabled' if self.use_wandb else 'disabled'}")
|
|
613
|
+
if self.resume_from:
|
|
614
|
+
logger.info(f"Resuming from: {self.resume_from}")
|
|
615
|
+
logger.info("=" * 70)
|
|
616
|
+
|
|
617
|
+
def _stream_trainer_output(self) -> int:
|
|
618
|
+
"""Stream trainer output to console and log file"""
|
|
619
|
+
logger.info("\n" + "-" * 70)
|
|
620
|
+
logger.info("TRAINING IN PROGRESS")
|
|
621
|
+
logger.info("-" * 70 + "\n")
|
|
622
|
+
|
|
623
|
+
log_file = self.log_dir / "trainer.log"
|
|
624
|
+
|
|
625
|
+
assert self.trainer_process is not None
|
|
626
|
+
assert self.trainer_process.stdout is not None
|
|
627
|
+
|
|
628
|
+
with open(log_file, "w") as log_handle:
|
|
629
|
+
for line in iter(self.trainer_process.stdout.readline, b''):
|
|
630
|
+
decoded = line.decode('utf-8', errors='replace')
|
|
631
|
+
print(decoded, end='')
|
|
632
|
+
log_handle.write(decoded)
|
|
633
|
+
log_handle.flush()
|
|
634
|
+
|
|
635
|
+
return self.trainer_process.wait()
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def main():
|
|
639
|
+
parser = argparse.ArgumentParser(
|
|
640
|
+
description="ElizaOS RL Training Pipeline",
|
|
641
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
# Profile settings (applied first, can be overridden by explicit args)
|
|
645
|
+
parser.add_argument(
|
|
646
|
+
"--profile",
|
|
647
|
+
choices=get_available_profiles() or None,
|
|
648
|
+
help="GPU profile to use (e.g., 12gb, 24gb). See --list-profiles"
|
|
649
|
+
)
|
|
650
|
+
parser.add_argument(
|
|
651
|
+
"--list-profiles",
|
|
652
|
+
action="store_true",
|
|
653
|
+
help="List available GPU profiles and exit"
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
# Model settings
|
|
657
|
+
parser.add_argument(
|
|
658
|
+
"--model",
|
|
659
|
+
default=None, # Will use profile default or fallback
|
|
660
|
+
help="Model to train (default: from profile or Qwen2.5-3B-Instruct)"
|
|
661
|
+
)
|
|
662
|
+
parser.add_argument(
|
|
663
|
+
"--base-model",
|
|
664
|
+
default=None,
|
|
665
|
+
help="Optional base model alias passed to trainer"
|
|
666
|
+
)
|
|
667
|
+
parser.add_argument(
|
|
668
|
+
"--dataset-input",
|
|
669
|
+
default=None,
|
|
670
|
+
help="Optional dataset input path passed to trainer"
|
|
671
|
+
)
|
|
672
|
+
parser.add_argument(
|
|
673
|
+
"--scoring-mode",
|
|
674
|
+
choices=["deterministic", "llm_judge"],
|
|
675
|
+
default="deterministic",
|
|
676
|
+
help="Scoring mode used by environment/trainer pipeline"
|
|
677
|
+
)
|
|
678
|
+
parser.add_argument(
|
|
679
|
+
"--steps",
|
|
680
|
+
type=int,
|
|
681
|
+
default=100,
|
|
682
|
+
help="Number of training steps"
|
|
683
|
+
)
|
|
684
|
+
parser.add_argument(
|
|
685
|
+
"--batch-size",
|
|
686
|
+
type=int,
|
|
687
|
+
default=4,
|
|
688
|
+
help="Batch size"
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
# Learning rate settings
|
|
692
|
+
parser.add_argument(
|
|
693
|
+
"--lr",
|
|
694
|
+
type=float,
|
|
695
|
+
default=1e-5,
|
|
696
|
+
help="Initial learning rate"
|
|
697
|
+
)
|
|
698
|
+
parser.add_argument(
|
|
699
|
+
"--min-lr",
|
|
700
|
+
type=float,
|
|
701
|
+
default=1e-7,
|
|
702
|
+
help="Minimum learning rate"
|
|
703
|
+
)
|
|
704
|
+
parser.add_argument(
|
|
705
|
+
"--lr-scheduler",
|
|
706
|
+
choices=["constant", "linear", "cosine"],
|
|
707
|
+
default="cosine",
|
|
708
|
+
help="Learning rate scheduler"
|
|
709
|
+
)
|
|
710
|
+
parser.add_argument(
|
|
711
|
+
"--warmup-steps",
|
|
712
|
+
type=int,
|
|
713
|
+
default=10,
|
|
714
|
+
help="LR warmup steps"
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
# Service settings
|
|
718
|
+
parser.add_argument(
|
|
719
|
+
"--api-port",
|
|
720
|
+
type=int,
|
|
721
|
+
default=8000,
|
|
722
|
+
help="Atropos API server port"
|
|
723
|
+
)
|
|
724
|
+
parser.add_argument(
|
|
725
|
+
"--vllm-port",
|
|
726
|
+
type=int,
|
|
727
|
+
default=9001,
|
|
728
|
+
help="vLLM inference server port"
|
|
729
|
+
)
|
|
730
|
+
parser.add_argument(
|
|
731
|
+
"--vllm-host",
|
|
732
|
+
default="127.0.0.1",
|
|
733
|
+
help="vLLM inference host"
|
|
734
|
+
)
|
|
735
|
+
parser.add_argument(
|
|
736
|
+
"--vllm-gpu-memory",
|
|
737
|
+
type=float,
|
|
738
|
+
default=0.45,
|
|
739
|
+
help="GPU memory fraction for vLLM"
|
|
740
|
+
)
|
|
741
|
+
parser.add_argument(
|
|
742
|
+
"--skip-services",
|
|
743
|
+
action="store_true",
|
|
744
|
+
help="Skip starting services (assume already running)"
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
# Checkpoint settings
|
|
748
|
+
parser.add_argument(
|
|
749
|
+
"--save-path",
|
|
750
|
+
default="./trained_models",
|
|
751
|
+
help="Directory to save checkpoints"
|
|
752
|
+
)
|
|
753
|
+
parser.add_argument(
|
|
754
|
+
"--save-every",
|
|
755
|
+
type=int,
|
|
756
|
+
default=5,
|
|
757
|
+
help="Save checkpoint every N steps"
|
|
758
|
+
)
|
|
759
|
+
parser.add_argument(
|
|
760
|
+
"--keep-checkpoints",
|
|
761
|
+
type=int,
|
|
762
|
+
default=3,
|
|
763
|
+
help="Number of checkpoints to keep"
|
|
764
|
+
)
|
|
765
|
+
parser.add_argument(
|
|
766
|
+
"--resume",
|
|
767
|
+
help="Resume from checkpoint path"
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
# W&B settings
|
|
771
|
+
parser.add_argument(
|
|
772
|
+
"--wandb-project",
|
|
773
|
+
default="eliza-training",
|
|
774
|
+
help="W&B project name"
|
|
775
|
+
)
|
|
776
|
+
parser.add_argument(
|
|
777
|
+
"--wandb-entity",
|
|
778
|
+
help="W&B entity/team"
|
|
779
|
+
)
|
|
780
|
+
parser.add_argument(
|
|
781
|
+
"--wandb-run-name",
|
|
782
|
+
help="W&B run name"
|
|
783
|
+
)
|
|
784
|
+
parser.add_argument(
|
|
785
|
+
"--no-wandb",
|
|
786
|
+
action="store_true",
|
|
787
|
+
help="Disable W&B logging"
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
# Logging
|
|
791
|
+
parser.add_argument(
|
|
792
|
+
"--log-dir",
|
|
793
|
+
default="./logs",
|
|
794
|
+
help="Directory for log files"
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
# Validation
|
|
798
|
+
parser.add_argument(
|
|
799
|
+
"--skip-validation",
|
|
800
|
+
action="store_true",
|
|
801
|
+
help="Skip environment validation"
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
# Training Mode (Phase 3)
|
|
805
|
+
parser.add_argument(
|
|
806
|
+
"--mode",
|
|
807
|
+
choices=["offline", "online", "hybrid"],
|
|
808
|
+
default="offline",
|
|
809
|
+
help="Training mode: offline (DB trajectories), online (simulation bridge), hybrid (mix)"
|
|
810
|
+
)
|
|
811
|
+
parser.add_argument(
|
|
812
|
+
"--bridge-url",
|
|
813
|
+
default="http://localhost:3001",
|
|
814
|
+
help="Simulation bridge URL (for online/hybrid modes)"
|
|
815
|
+
)
|
|
816
|
+
parser.add_argument(
|
|
817
|
+
"--hybrid-online-ratio",
|
|
818
|
+
type=float,
|
|
819
|
+
default=0.2,
|
|
820
|
+
help="Ratio of online rollouts in hybrid mode (0.0-1.0)"
|
|
821
|
+
)
|
|
822
|
+
parser.add_argument(
|
|
823
|
+
"--online",
|
|
824
|
+
action="store_true",
|
|
825
|
+
help="Shorthand for --mode online"
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
args = parser.parse_args()
|
|
829
|
+
|
|
830
|
+
# Handle --online shorthand
|
|
831
|
+
if args.online:
|
|
832
|
+
args.mode = "online"
|
|
833
|
+
|
|
834
|
+
# Handle --list-profiles
|
|
835
|
+
if args.list_profiles:
|
|
836
|
+
list_profiles()
|
|
837
|
+
sys.exit(0)
|
|
838
|
+
|
|
839
|
+
# Apply profile defaults (can be overridden by explicit args)
|
|
840
|
+
profile = {}
|
|
841
|
+
if args.profile:
|
|
842
|
+
profile = load_profile(args.profile)
|
|
843
|
+
|
|
844
|
+
# Apply profile values as defaults for unset args
|
|
845
|
+
if args.model is None:
|
|
846
|
+
args.model = profile.get("model", "Qwen/Qwen2.5-3B-Instruct")
|
|
847
|
+
if args.batch_size == 4 and "batch_size" in profile: # 4 is the argparse default
|
|
848
|
+
args.batch_size = profile["batch_size"]
|
|
849
|
+
if args.vllm_gpu_memory == 0.45 and "vllm_gpu_memory" in profile: # 0.45 is the default
|
|
850
|
+
args.vllm_gpu_memory = profile["vllm_gpu_memory"]
|
|
851
|
+
|
|
852
|
+
# Phase 4: Read multi-GPU settings from profile
|
|
853
|
+
args.tensor_parallel_size = profile.get("tensor_parallel_size", 1)
|
|
854
|
+
args.use_flash_attention = profile.get("use_flash_attention", False)
|
|
855
|
+
args.vllm_gpu = profile.get("vllm_gpu") # Explicit GPU assignment for vLLM
|
|
856
|
+
args.training_gpu = profile.get("training_gpu") # Explicit GPU assignment for training
|
|
857
|
+
|
|
858
|
+
# Log effective settings
|
|
859
|
+
if args.profile:
|
|
860
|
+
tp_info = f", tp={args.tensor_parallel_size}" if args.tensor_parallel_size > 1 else ""
|
|
861
|
+
logger.info(f"Using profile '{args.profile}': model={args.model}, "
|
|
862
|
+
f"vllm_mem={args.vllm_gpu_memory:.0%}, batch={args.batch_size}{tp_info}")
|
|
863
|
+
|
|
864
|
+
# Validate environment
|
|
865
|
+
if not args.skip_validation:
|
|
866
|
+
errors = validate_environment()
|
|
867
|
+
if errors:
|
|
868
|
+
logger.error("Environment validation failed:")
|
|
869
|
+
for error in errors:
|
|
870
|
+
logger.error(f" • {error}")
|
|
871
|
+
logger.error("\nFix the above issues or use --skip-validation to bypass.")
|
|
872
|
+
sys.exit(1)
|
|
873
|
+
|
|
874
|
+
orchestrator = TrainingOrchestrator(
|
|
875
|
+
model_name=args.model,
|
|
876
|
+
base_model=args.base_model,
|
|
877
|
+
dataset_input=args.dataset_input,
|
|
878
|
+
scoring_mode=args.scoring_mode,
|
|
879
|
+
training_steps=args.steps,
|
|
880
|
+
batch_size=args.batch_size,
|
|
881
|
+
learning_rate=args.lr,
|
|
882
|
+
min_learning_rate=args.min_lr,
|
|
883
|
+
lr_scheduler=args.lr_scheduler,
|
|
884
|
+
warmup_steps=args.warmup_steps,
|
|
885
|
+
api_port=args.api_port,
|
|
886
|
+
vllm_host=args.vllm_host,
|
|
887
|
+
vllm_port=args.vllm_port,
|
|
888
|
+
vllm_gpu_memory=args.vllm_gpu_memory,
|
|
889
|
+
save_path=args.save_path,
|
|
890
|
+
save_every=args.save_every,
|
|
891
|
+
keep_checkpoints=args.keep_checkpoints,
|
|
892
|
+
resume_from=args.resume,
|
|
893
|
+
use_wandb=not args.no_wandb,
|
|
894
|
+
wandb_project=args.wandb_project,
|
|
895
|
+
wandb_entity=args.wandb_entity,
|
|
896
|
+
wandb_run_name=args.wandb_run_name,
|
|
897
|
+
skip_services=args.skip_services,
|
|
898
|
+
log_dir=args.log_dir,
|
|
899
|
+
# Phase 3: Online training
|
|
900
|
+
mode=args.mode,
|
|
901
|
+
bridge_url=args.bridge_url,
|
|
902
|
+
hybrid_online_ratio=args.hybrid_online_ratio,
|
|
903
|
+
# Phase 4: Cloud/Multi-GPU
|
|
904
|
+
tensor_parallel_size=args.tensor_parallel_size,
|
|
905
|
+
use_flash_attention=args.use_flash_attention,
|
|
906
|
+
vllm_gpu=args.vllm_gpu,
|
|
907
|
+
training_gpu=args.training_gpu,
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
sys.exit(orchestrator.run())
|
|
911
|
+
|
|
912
|
+
|
|
913
|
+
if __name__ == "__main__":
|
|
914
|
+
main()
|