@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,914 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ ElizaOS RL Training - Full Pipeline Runner
4
+
5
+ This script orchestrates the complete RLAIF training pipeline:
6
+ 1. Validates environment and prerequisites
7
+ 2. Starts background services (Atropos API, vLLM)
8
+ 3. Starts the RLAIF environment
9
+ 4. Runs the GRPO trainer with optional W&B logging
10
+
11
+ Usage:
12
+ # Use a GPU profile (recommended - auto-configures for your hardware)
13
+ python scripts/run_training.py --profile 12gb --steps 100
14
+ python scripts/run_training.py --profile 24gb --steps 100
15
+
16
+ # List available profiles
17
+ python scripts/run_training.py --list-profiles
18
+
19
+ # Manual configuration (override profile or use without profile)
20
+ python scripts/run_training.py --model Qwen/Qwen2.5-0.5B-Instruct --vllm-gpu-memory 0.25 --steps 100
21
+
22
+ # Resume from checkpoint
23
+ python scripts/run_training.py --profile 12gb --resume ./trained_models/step_50
24
+
25
+ # Disable W&B
26
+ python scripts/run_training.py --profile 12gb --steps 100 --no-wandb
27
+
28
+ GPU Profiles (config/profiles/*.json):
29
+ 12gb - RTX 3060/4070 (0.5B model, 25% vLLM memory)
30
+ 16gb - RTX 4080/A4000 (1.5B model, 35% vLLM memory)
31
+ 24gb - RTX 4090/A5000 (3B model, 40% vLLM memory)
32
+ 48gb - A40/A6000 (7B model, 45% vLLM memory)
33
+
34
+ Or run components separately:
35
+ Terminal 1: run-api
36
+ Terminal 2: python -m src.training.rlaif_env serve --slurm false
37
+ Terminal 3: python -m src.training.atropos_trainer --steps 100
38
+ """
39
+
40
+ import argparse
41
+ import json
42
+ import logging
43
+ import os
44
+ import signal
45
+ import subprocess
46
+ import sys
47
+ import time
48
+ from pathlib import Path
49
+ from typing import Optional
50
+
51
+ # Add src to path
52
+ sys.path.insert(0, str(Path(__file__).parent.parent))
53
+
54
+ from dotenv import load_dotenv
55
+
56
+ # Load environment
57
+ load_dotenv()
58
+
59
+ logging.basicConfig(
60
+ level=logging.INFO,
61
+ format='%(asctime)s [%(levelname)s] %(name)s: %(message)s'
62
+ )
63
+ logger = logging.getLogger(__name__)
64
+
65
+ # Profile directory
66
+ PROFILES_DIR = Path(__file__).parent.parent / "config" / "profiles"
67
+
68
+
69
+ def get_available_profiles() -> list[str]:
70
+ """Get list of available GPU profiles."""
71
+ if not PROFILES_DIR.exists():
72
+ return []
73
+ return [p.stem for p in PROFILES_DIR.glob("*.json")]
74
+
75
+
76
+ def load_profile(profile_name: str) -> dict:
77
+ """Load a GPU profile by name."""
78
+ profile_path = PROFILES_DIR / f"{profile_name}.json"
79
+ if not profile_path.exists():
80
+ available = get_available_profiles()
81
+ raise ValueError(
82
+ f"Profile '{profile_name}' not found. "
83
+ f"Available: {', '.join(available) or 'none'}"
84
+ )
85
+
86
+ with open(profile_path) as f:
87
+ profile = json.load(f)
88
+
89
+ logger.info(f"Loaded profile: {profile.get('name', profile_name)}")
90
+ if profile.get('notes'):
91
+ logger.info(f" Note: {profile['notes']}")
92
+
93
+ return profile
94
+
95
+
96
+ def list_profiles() -> None:
97
+ """Print available profiles and exit."""
98
+ print("\nAvailable GPU Profiles:")
99
+ print("=" * 60)
100
+
101
+ for profile_name in sorted(get_available_profiles()):
102
+ try:
103
+ profile = load_profile(profile_name)
104
+ print(f"\n --profile {profile_name}")
105
+ print(f" {profile.get('name', 'Unnamed')}")
106
+ print(f" Model: {profile.get('model', 'default')}")
107
+ print(f" vLLM Memory: {profile.get('vllm_gpu_memory', 0.45) * 100:.0f}%")
108
+ if profile.get('notes'):
109
+ print(f" Note: {profile['notes']}")
110
+ except Exception as e:
111
+ print(f"\n --profile {profile_name}")
112
+ print(f" Error loading: {e}")
113
+
114
+ print()
115
+
116
+
117
+ def validate_environment() -> list[str]:
118
+ """
119
+ Validate that all required environment variables and dependencies are present.
120
+
121
+ Returns a list of error messages for missing requirements.
122
+ """
123
+ errors = []
124
+
125
+ # Check DATABASE_URL
126
+ if not os.getenv("DATABASE_URL"):
127
+ errors.append(
128
+ "DATABASE_URL not set. Required for loading training trajectories.\n"
129
+ " Set in .env or export DATABASE_URL=postgresql://..."
130
+ )
131
+
132
+ # Check OPENAI_API_KEY (for RLAIF judge)
133
+ if not os.getenv("OPENAI_API_KEY"):
134
+ errors.append(
135
+ "OPENAI_API_KEY not set. Required for RLAIF judge scoring.\n"
136
+ " Set in .env or export OPENAI_API_KEY=sk-..."
137
+ )
138
+
139
+ # Check for run-api command (Atropos)
140
+ import shutil
141
+ if not shutil.which("run-api"):
142
+ errors.append(
143
+ "Atropos API not found. Install with: pip install atroposlib"
144
+ )
145
+
146
+ # Check for PyTorch and CUDA
147
+ try:
148
+ import torch
149
+ if not torch.cuda.is_available():
150
+ errors.append(
151
+ "CUDA not available. GPU is recommended for training.\n"
152
+ " For CPU-only (slow), use --skip-vllm and provide external inference."
153
+ )
154
+ else:
155
+ gpu_name = torch.cuda.get_device_name(0)
156
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
157
+ logger.info(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")
158
+ except ImportError:
159
+ errors.append("PyTorch not installed. Install with: pip install torch")
160
+
161
+ return errors
162
+
163
+
164
+ class TrainingOrchestrator:
165
+ """
166
+ Orchestrates the complete training pipeline.
167
+
168
+ Manages:
169
+ - Service lifecycle (Atropos API, vLLM)
170
+ - Environment server
171
+ - GRPO trainer
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ model_name: str = "Qwen/Qwen2.5-3B-Instruct",
177
+ base_model: Optional[str] = None,
178
+ dataset_input: Optional[str] = None,
179
+ scoring_mode: str = "deterministic",
180
+ training_steps: int = 100,
181
+ batch_size: int = 4,
182
+ learning_rate: float = 1e-5,
183
+ min_learning_rate: float = 1e-7,
184
+ lr_scheduler: str = "cosine",
185
+ warmup_steps: int = 10,
186
+ api_port: int = 8000,
187
+ vllm_host: str = "127.0.0.1",
188
+ vllm_port: int = 9001,
189
+ vllm_gpu_memory: float = 0.45,
190
+ save_path: str = "./trained_models",
191
+ save_every: int = 5,
192
+ keep_checkpoints: int = 3,
193
+ resume_from: Optional[str] = None,
194
+ use_wandb: bool = True,
195
+ wandb_project: str = "eliza-training",
196
+ wandb_entity: Optional[str] = None,
197
+ wandb_run_name: Optional[str] = None,
198
+ skip_services: bool = False,
199
+ log_dir: str = "./logs",
200
+ # Phase 3: Online training parameters
201
+ mode: str = "offline",
202
+ bridge_url: str = "http://localhost:3001",
203
+ hybrid_online_ratio: float = 0.2,
204
+ # Phase 4: Cloud/Multi-GPU parameters
205
+ tensor_parallel_size: int = 1,
206
+ use_flash_attention: bool = False,
207
+ vllm_gpu: Optional[str] = None, # Explicit GPU assignment for vLLM
208
+ training_gpu: Optional[str] = None, # Explicit GPU assignment for training
209
+ ):
210
+ self.model_name = model_name
211
+ self.base_model = base_model
212
+ self.dataset_input = dataset_input
213
+ self.scoring_mode = scoring_mode
214
+ self.training_steps = training_steps
215
+ self.batch_size = batch_size
216
+ self.learning_rate = learning_rate
217
+ self.min_learning_rate = min_learning_rate
218
+ self.lr_scheduler = lr_scheduler
219
+ self.warmup_steps = warmup_steps
220
+ self.api_port = api_port
221
+ self.vllm_host = vllm_host
222
+ self.vllm_port = vllm_port
223
+ self.vllm_gpu_memory = vllm_gpu_memory
224
+ self.save_path = save_path
225
+ self.save_every = save_every
226
+ self.keep_checkpoints = keep_checkpoints
227
+ self.resume_from = resume_from
228
+ self.use_wandb = use_wandb
229
+ self.wandb_project = wandb_project
230
+ self.wandb_entity = wandb_entity
231
+ self.wandb_run_name = wandb_run_name
232
+ self.skip_services = skip_services
233
+ self.log_dir = Path(log_dir)
234
+ # Phase 3: Online training
235
+ self.mode = mode
236
+ self.bridge_url = bridge_url
237
+ self.hybrid_online_ratio = hybrid_online_ratio
238
+ # Phase 4: Cloud/Multi-GPU
239
+ self.tensor_parallel_size = tensor_parallel_size
240
+ self.use_flash_attention = use_flash_attention
241
+ self.vllm_gpu = vllm_gpu
242
+ self.training_gpu = training_gpu
243
+
244
+ self.env_process: Optional[subprocess.Popen] = None
245
+ self.trainer_process: Optional[subprocess.Popen] = None
246
+ self._service_manager = None
247
+ self._shutdown_requested = False
248
+ self._log_handles: list = [] # Track open file handles
249
+
250
+ self.log_dir.mkdir(parents=True, exist_ok=True)
251
+
252
+ signal.signal(signal.SIGINT, self._signal_handler)
253
+ signal.signal(signal.SIGTERM, self._signal_handler)
254
+
255
+ def _signal_handler(self, signum, frame):
256
+ """Handle shutdown signals"""
257
+ if self._shutdown_requested:
258
+ logger.warning("Forced shutdown, exiting immediately")
259
+ sys.exit(1)
260
+
261
+ logger.info("Received shutdown signal, cleaning up...")
262
+ self._shutdown_requested = True
263
+ self.cleanup()
264
+ sys.exit(0)
265
+
266
+ def cleanup(self):
267
+ """Clean up all subprocesses and services"""
268
+ self._stop_process(self.trainer_process, "trainer")
269
+ self._stop_process(self.env_process, "environment")
270
+
271
+ if self._service_manager:
272
+ self._service_manager.stop_all()
273
+
274
+ for handle in self._log_handles:
275
+ handle.close()
276
+ self._log_handles.clear()
277
+
278
+ def _stop_process(self, proc: Optional[subprocess.Popen], name: str, timeout: int = 10) -> None:
279
+ """Stop a subprocess gracefully"""
280
+ if not proc:
281
+ return
282
+
283
+ logger.info(f"Stopping {name}...")
284
+ proc.terminate()
285
+
286
+ deadline = time.time() + timeout
287
+ while proc.poll() is None and time.time() < deadline:
288
+ time.sleep(0.5)
289
+
290
+ if proc.poll() is None:
291
+ proc.kill()
292
+ proc.wait()
293
+
294
+ def start_services(self) -> bool:
295
+ """Start background services using ServiceManager"""
296
+ if self.skip_services:
297
+ logger.info("Skipping service startup (--skip-services)")
298
+ return True
299
+
300
+ from src.training.service_manager import ServiceManager, ServiceConfig
301
+
302
+ config = ServiceConfig(
303
+ atropos_port=self.api_port,
304
+ vllm_port=self.vllm_port,
305
+ model_name=self.model_name,
306
+ vllm_gpu_memory_utilization=self.vllm_gpu_memory,
307
+ log_dir=str(self.log_dir / "services"),
308
+ # Phase 4: Multi-GPU support
309
+ tensor_parallel_size=self.tensor_parallel_size,
310
+ use_flash_attention=self.use_flash_attention,
311
+ vllm_gpu=self.vllm_gpu,
312
+ training_gpu=self.training_gpu,
313
+ )
314
+
315
+ self._service_manager = ServiceManager(config)
316
+
317
+ if not self._service_manager.start_all():
318
+ return False
319
+
320
+ if not self._service_manager.wait_for_ready():
321
+ logger.error("Services failed to become ready")
322
+ return False
323
+
324
+ return True
325
+
326
+ def check_bridge_health(self) -> bool:
327
+ """Check if simulation bridge is running and healthy"""
328
+ import urllib.request
329
+ import urllib.error
330
+
331
+ logger.info(f"Checking simulation bridge at {self.bridge_url}...")
332
+
333
+ health_url = f"{self.bridge_url}/health"
334
+ for attempt in range(3):
335
+ try:
336
+ req = urllib.request.Request(health_url, method='GET')
337
+ with urllib.request.urlopen(req, timeout=5) as resp:
338
+ if resp.status == 200:
339
+ logger.info("Simulation bridge is healthy ✓")
340
+ return True
341
+ except urllib.error.URLError as e:
342
+ if attempt < 2:
343
+ logger.warning(f"Bridge not ready (attempt {attempt + 1}/3): {e}")
344
+ time.sleep(2)
345
+ else:
346
+ logger.error(f"Simulation bridge not available at {self.bridge_url}")
347
+ logger.error("Start it with: make bridge-server")
348
+ return False
349
+ except Exception as e:
350
+ logger.error(f"Bridge health check failed: {e}")
351
+ return False
352
+
353
+ return False
354
+
355
+ def start_environment(self) -> bool:
356
+ """Start RLAIF environment (offline mode)"""
357
+ logger.info("Starting RLAIF environment (offline mode)...")
358
+
359
+ env_cmd = [
360
+ sys.executable, "-m", "src.training.rlaif_env", "serve",
361
+ "--slurm", "false",
362
+ "--env.tokenizer_name", self.model_name,
363
+ "--env.scoring_mode", self.scoring_mode,
364
+ "--env.rollout_server_url", f"http://localhost:{self.api_port}",
365
+ "--openai.model_name", self.model_name,
366
+ "--openai.base_url", f"http://{self.vllm_host}:{self.vllm_port}/v1",
367
+ ]
368
+
369
+ if not self.use_wandb:
370
+ env_cmd.extend(["--env.use_wandb", "false"])
371
+
372
+ log_file = self.log_dir / "environment.log"
373
+ log_handle = open(log_file, "w")
374
+ self._log_handles.append(log_handle)
375
+
376
+ self.env_process = subprocess.Popen(
377
+ env_cmd,
378
+ cwd=str(Path(__file__).parent.parent),
379
+ stdout=log_handle,
380
+ stderr=subprocess.STDOUT,
381
+ env=os.environ.copy(), # Pass environment variables including DATABASE_URL
382
+ )
383
+
384
+ time.sleep(5) # Wait for environment to initialize
385
+
386
+ if self.env_process.poll() is not None:
387
+ logger.error(f"Environment failed to start (exit code: {self.env_process.returncode})")
388
+ logger.error(f"Check logs at: {log_file}")
389
+ return False
390
+
391
+ logger.info(f"Environment started (PID: {self.env_process.pid}), logs: {log_file}")
392
+ return True
393
+
394
+ def start_online_environment(self) -> bool:
395
+ """Start online environment (online mode with simulation bridge)"""
396
+ logger.info("Starting online environment (online mode)...")
397
+
398
+ env_cmd = [
399
+ sys.executable, "-m", "src.training.online_env", "serve",
400
+ "--slurm", "false",
401
+ "--env.tokenizer_name", self.model_name,
402
+ "--env.rollout_server_url", f"http://localhost:{self.api_port}",
403
+ "--openai.model_name", self.model_name,
404
+ "--openai.base_url", f"http://{self.vllm_host}:{self.vllm_port}/v1",
405
+ # Online-specific settings
406
+ "--env.use_simulation_bridge", "true",
407
+ "--env.simulation_bridge_url", self.bridge_url,
408
+ ]
409
+
410
+ if not self.use_wandb:
411
+ env_cmd.extend(["--env.use_wandb", "false"])
412
+
413
+ log_file = self.log_dir / "online_environment.log"
414
+ log_handle = open(log_file, "w")
415
+ self._log_handles.append(log_handle)
416
+
417
+ # Set environment variables for bridge
418
+ env_vars = os.environ.copy()
419
+ env_vars["USE_SIMULATION_BRIDGE"] = "1"
420
+ env_vars["SIMULATION_BRIDGE_URL"] = self.bridge_url
421
+
422
+ self.env_process = subprocess.Popen(
423
+ env_cmd,
424
+ cwd=str(Path(__file__).parent.parent),
425
+ stdout=log_handle,
426
+ stderr=subprocess.STDOUT,
427
+ env=env_vars,
428
+ )
429
+
430
+ time.sleep(5) # Wait for environment to initialize
431
+
432
+ if self.env_process.poll() is not None:
433
+ logger.error(f"Online environment failed to start (exit code: {self.env_process.returncode})")
434
+ logger.error(f"Check logs at: {log_file}")
435
+ return False
436
+
437
+ logger.info(f"Online environment started (PID: {self.env_process.pid}), logs: {log_file}")
438
+ return True
439
+
440
+ def start_hybrid_environment(self) -> bool:
441
+ """Start hybrid environment (mix of offline and online)"""
442
+ logger.info(f"Starting hybrid environment (online ratio: {self.hybrid_online_ratio:.0%})...")
443
+
444
+ env_cmd = [
445
+ sys.executable, "-m", "src.training.hybrid_env", "serve",
446
+ "--slurm", "false",
447
+ "--env.tokenizer_name", self.model_name,
448
+ "--env.rollout_server_url", f"http://localhost:{self.api_port}",
449
+ "--openai.model_name", self.model_name,
450
+ "--openai.base_url", f"http://{self.vllm_host}:{self.vllm_port}/v1",
451
+ # Hybrid-specific settings
452
+ "--env.use_simulation_bridge", "true",
453
+ "--env.simulation_bridge_url", self.bridge_url,
454
+ "--env.online_ratio", str(self.hybrid_online_ratio),
455
+ ]
456
+
457
+ if not self.use_wandb:
458
+ env_cmd.extend(["--env.use_wandb", "false"])
459
+
460
+ log_file = self.log_dir / "hybrid_environment.log"
461
+ log_handle = open(log_file, "w")
462
+ self._log_handles.append(log_handle)
463
+
464
+ # Set environment variables
465
+ env_vars = os.environ.copy()
466
+ env_vars["USE_SIMULATION_BRIDGE"] = "1"
467
+ env_vars["SIMULATION_BRIDGE_URL"] = self.bridge_url
468
+ env_vars["HYBRID_ONLINE_RATIO"] = str(self.hybrid_online_ratio)
469
+
470
+ self.env_process = subprocess.Popen(
471
+ env_cmd,
472
+ cwd=str(Path(__file__).parent.parent),
473
+ stdout=log_handle,
474
+ stderr=subprocess.STDOUT,
475
+ env=env_vars,
476
+ )
477
+
478
+ time.sleep(5) # Wait for environment to initialize
479
+
480
+ if self.env_process.poll() is not None:
481
+ logger.error(f"Hybrid environment failed to start (exit code: {self.env_process.returncode})")
482
+ logger.error(f"Check logs at: {log_file}")
483
+ return False
484
+
485
+ logger.info(f"Hybrid environment started (PID: {self.env_process.pid}), logs: {log_file}")
486
+ return True
487
+
488
+ def start_trainer(self) -> bool:
489
+ """Start GRPO trainer"""
490
+ logger.info("Starting GRPO trainer...")
491
+
492
+ trainer_cmd = [
493
+ sys.executable, "-m", "src.training.atropos_trainer",
494
+ "--model", self.model_name,
495
+ "--scoring-mode", self.scoring_mode,
496
+ "--steps", str(self.training_steps),
497
+ "--batch-size", str(self.batch_size),
498
+ "--lr", str(self.learning_rate),
499
+ "--min-lr", str(self.min_learning_rate),
500
+ "--lr-scheduler", self.lr_scheduler,
501
+ "--warmup-steps", str(self.warmup_steps),
502
+ "--api-url", f"http://localhost:{self.api_port}",
503
+ "--vllm-host", self.vllm_host,
504
+ "--vllm-port", str(self.vllm_port),
505
+ "--vllm-gpu-utilization", str(self.vllm_gpu_memory),
506
+ "--save-path", self.save_path,
507
+ "--save-every", str(self.save_every),
508
+ "--keep-checkpoints", str(self.keep_checkpoints),
509
+ "--log-file", str(self.log_dir / "training_metrics.jsonl"),
510
+ "--wandb-project", self.wandb_project,
511
+ "--skip-vllm", # vLLM already started by ServiceManager
512
+ ]
513
+
514
+ if self.base_model:
515
+ trainer_cmd.extend(["--base-model", self.base_model])
516
+ if self.dataset_input:
517
+ trainer_cmd.extend(["--dataset-input", self.dataset_input])
518
+ if self.resume_from:
519
+ trainer_cmd.extend(["--resume", self.resume_from])
520
+ if not self.use_wandb:
521
+ trainer_cmd.append("--no-wandb")
522
+ if self.wandb_entity:
523
+ trainer_cmd.extend(["--wandb-entity", self.wandb_entity])
524
+ if self.wandb_run_name:
525
+ trainer_cmd.extend(["--wandb-run-name", self.wandb_run_name])
526
+
527
+ # Set up environment with GPU assignment for training
528
+ env = os.environ.copy()
529
+ if self.training_gpu:
530
+ env["CUDA_VISIBLE_DEVICES"] = self.training_gpu
531
+ logger.info(f"Training GPU (explicit): {self.training_gpu}")
532
+
533
+ # Pipe stdout for streaming to console
534
+ self.trainer_process = subprocess.Popen(
535
+ trainer_cmd,
536
+ cwd=str(Path(__file__).parent.parent),
537
+ stdout=subprocess.PIPE,
538
+ stderr=subprocess.STDOUT,
539
+ env=env,
540
+ )
541
+
542
+ logger.info(f"Trainer started (PID: {self.trainer_process.pid})")
543
+ return True
544
+
545
+ def run(self) -> int:
546
+ """Run the complete training pipeline"""
547
+ self._log_config()
548
+ start_time = time.time()
549
+
550
+ try:
551
+ # Step 1: Start services
552
+ if not self.start_services():
553
+ logger.error("Failed to start services")
554
+ return 1
555
+
556
+ # Step 2: For online/hybrid modes, check bridge health
557
+ if self.mode in ("online", "hybrid"):
558
+ if not self.check_bridge_health():
559
+ logger.error("Simulation bridge not available")
560
+ logger.error("Start it with: make bridge-server")
561
+ return 1
562
+
563
+ # Step 3: Start appropriate environment based on mode
564
+ env_starter = {
565
+ "offline": self.start_environment,
566
+ "online": self.start_online_environment,
567
+ "hybrid": self.start_hybrid_environment,
568
+ }.get(self.mode, self.start_environment)
569
+
570
+ if not env_starter():
571
+ logger.error(f"Failed to start {self.mode} environment")
572
+ return 1
573
+
574
+ # Step 4: Start trainer
575
+ if not self.start_trainer():
576
+ logger.error("Failed to start trainer")
577
+ return 1
578
+
579
+ return_code = self._stream_trainer_output()
580
+ elapsed = time.time() - start_time
581
+
582
+ if return_code == 0:
583
+ logger.info("\n" + "=" * 70)
584
+ logger.info("TRAINING COMPLETED SUCCESSFULLY")
585
+ logger.info(f"Mode: {self.mode.upper()}")
586
+ logger.info(f"Total time: {elapsed:.1f}s ({elapsed/60:.1f} minutes)")
587
+ logger.info(f"Model saved to: {self.save_path}")
588
+ logger.info("=" * 70)
589
+ else:
590
+ logger.error(f"Training failed with return code: {return_code}")
591
+ logger.error(f"Check logs at: {self.log_dir}")
592
+
593
+ return return_code
594
+ finally:
595
+ self.cleanup()
596
+
597
+ def _log_config(self):
598
+ """Log training configuration"""
599
+ logger.info("=" * 70)
600
+ logger.info("ELIZAOS RL TRAINING PIPELINE")
601
+ logger.info("=" * 70)
602
+ logger.info(f"Mode: {self.mode.upper()}")
603
+ if self.mode in ("online", "hybrid"):
604
+ logger.info(f"Bridge URL: {self.bridge_url}")
605
+ if self.mode == "hybrid":
606
+ logger.info(f"Online ratio: {self.hybrid_online_ratio:.0%}")
607
+ logger.info(f"Model: {self.model_name}")
608
+ logger.info(f"Steps: {self.training_steps}")
609
+ logger.info(f"Batch size: {self.batch_size}")
610
+ logger.info(f"Learning rate: {self.learning_rate} (scheduler: {self.lr_scheduler})")
611
+ logger.info(f"Save path: {self.save_path}")
612
+ logger.info(f"W&B: {'enabled' if self.use_wandb else 'disabled'}")
613
+ if self.resume_from:
614
+ logger.info(f"Resuming from: {self.resume_from}")
615
+ logger.info("=" * 70)
616
+
617
+ def _stream_trainer_output(self) -> int:
618
+ """Stream trainer output to console and log file"""
619
+ logger.info("\n" + "-" * 70)
620
+ logger.info("TRAINING IN PROGRESS")
621
+ logger.info("-" * 70 + "\n")
622
+
623
+ log_file = self.log_dir / "trainer.log"
624
+
625
+ assert self.trainer_process is not None
626
+ assert self.trainer_process.stdout is not None
627
+
628
+ with open(log_file, "w") as log_handle:
629
+ for line in iter(self.trainer_process.stdout.readline, b''):
630
+ decoded = line.decode('utf-8', errors='replace')
631
+ print(decoded, end='')
632
+ log_handle.write(decoded)
633
+ log_handle.flush()
634
+
635
+ return self.trainer_process.wait()
636
+
637
+
638
+ def main():
639
+ parser = argparse.ArgumentParser(
640
+ description="ElizaOS RL Training Pipeline",
641
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
642
+ )
643
+
644
+ # Profile settings (applied first, can be overridden by explicit args)
645
+ parser.add_argument(
646
+ "--profile",
647
+ choices=get_available_profiles() or None,
648
+ help="GPU profile to use (e.g., 12gb, 24gb). See --list-profiles"
649
+ )
650
+ parser.add_argument(
651
+ "--list-profiles",
652
+ action="store_true",
653
+ help="List available GPU profiles and exit"
654
+ )
655
+
656
+ # Model settings
657
+ parser.add_argument(
658
+ "--model",
659
+ default=None, # Will use profile default or fallback
660
+ help="Model to train (default: from profile or Qwen2.5-3B-Instruct)"
661
+ )
662
+ parser.add_argument(
663
+ "--base-model",
664
+ default=None,
665
+ help="Optional base model alias passed to trainer"
666
+ )
667
+ parser.add_argument(
668
+ "--dataset-input",
669
+ default=None,
670
+ help="Optional dataset input path passed to trainer"
671
+ )
672
+ parser.add_argument(
673
+ "--scoring-mode",
674
+ choices=["deterministic", "llm_judge"],
675
+ default="deterministic",
676
+ help="Scoring mode used by environment/trainer pipeline"
677
+ )
678
+ parser.add_argument(
679
+ "--steps",
680
+ type=int,
681
+ default=100,
682
+ help="Number of training steps"
683
+ )
684
+ parser.add_argument(
685
+ "--batch-size",
686
+ type=int,
687
+ default=4,
688
+ help="Batch size"
689
+ )
690
+
691
+ # Learning rate settings
692
+ parser.add_argument(
693
+ "--lr",
694
+ type=float,
695
+ default=1e-5,
696
+ help="Initial learning rate"
697
+ )
698
+ parser.add_argument(
699
+ "--min-lr",
700
+ type=float,
701
+ default=1e-7,
702
+ help="Minimum learning rate"
703
+ )
704
+ parser.add_argument(
705
+ "--lr-scheduler",
706
+ choices=["constant", "linear", "cosine"],
707
+ default="cosine",
708
+ help="Learning rate scheduler"
709
+ )
710
+ parser.add_argument(
711
+ "--warmup-steps",
712
+ type=int,
713
+ default=10,
714
+ help="LR warmup steps"
715
+ )
716
+
717
+ # Service settings
718
+ parser.add_argument(
719
+ "--api-port",
720
+ type=int,
721
+ default=8000,
722
+ help="Atropos API server port"
723
+ )
724
+ parser.add_argument(
725
+ "--vllm-port",
726
+ type=int,
727
+ default=9001,
728
+ help="vLLM inference server port"
729
+ )
730
+ parser.add_argument(
731
+ "--vllm-host",
732
+ default="127.0.0.1",
733
+ help="vLLM inference host"
734
+ )
735
+ parser.add_argument(
736
+ "--vllm-gpu-memory",
737
+ type=float,
738
+ default=0.45,
739
+ help="GPU memory fraction for vLLM"
740
+ )
741
+ parser.add_argument(
742
+ "--skip-services",
743
+ action="store_true",
744
+ help="Skip starting services (assume already running)"
745
+ )
746
+
747
+ # Checkpoint settings
748
+ parser.add_argument(
749
+ "--save-path",
750
+ default="./trained_models",
751
+ help="Directory to save checkpoints"
752
+ )
753
+ parser.add_argument(
754
+ "--save-every",
755
+ type=int,
756
+ default=5,
757
+ help="Save checkpoint every N steps"
758
+ )
759
+ parser.add_argument(
760
+ "--keep-checkpoints",
761
+ type=int,
762
+ default=3,
763
+ help="Number of checkpoints to keep"
764
+ )
765
+ parser.add_argument(
766
+ "--resume",
767
+ help="Resume from checkpoint path"
768
+ )
769
+
770
+ # W&B settings
771
+ parser.add_argument(
772
+ "--wandb-project",
773
+ default="eliza-training",
774
+ help="W&B project name"
775
+ )
776
+ parser.add_argument(
777
+ "--wandb-entity",
778
+ help="W&B entity/team"
779
+ )
780
+ parser.add_argument(
781
+ "--wandb-run-name",
782
+ help="W&B run name"
783
+ )
784
+ parser.add_argument(
785
+ "--no-wandb",
786
+ action="store_true",
787
+ help="Disable W&B logging"
788
+ )
789
+
790
+ # Logging
791
+ parser.add_argument(
792
+ "--log-dir",
793
+ default="./logs",
794
+ help="Directory for log files"
795
+ )
796
+
797
+ # Validation
798
+ parser.add_argument(
799
+ "--skip-validation",
800
+ action="store_true",
801
+ help="Skip environment validation"
802
+ )
803
+
804
+ # Training Mode (Phase 3)
805
+ parser.add_argument(
806
+ "--mode",
807
+ choices=["offline", "online", "hybrid"],
808
+ default="offline",
809
+ help="Training mode: offline (DB trajectories), online (simulation bridge), hybrid (mix)"
810
+ )
811
+ parser.add_argument(
812
+ "--bridge-url",
813
+ default="http://localhost:3001",
814
+ help="Simulation bridge URL (for online/hybrid modes)"
815
+ )
816
+ parser.add_argument(
817
+ "--hybrid-online-ratio",
818
+ type=float,
819
+ default=0.2,
820
+ help="Ratio of online rollouts in hybrid mode (0.0-1.0)"
821
+ )
822
+ parser.add_argument(
823
+ "--online",
824
+ action="store_true",
825
+ help="Shorthand for --mode online"
826
+ )
827
+
828
+ args = parser.parse_args()
829
+
830
+ # Handle --online shorthand
831
+ if args.online:
832
+ args.mode = "online"
833
+
834
+ # Handle --list-profiles
835
+ if args.list_profiles:
836
+ list_profiles()
837
+ sys.exit(0)
838
+
839
+ # Apply profile defaults (can be overridden by explicit args)
840
+ profile = {}
841
+ if args.profile:
842
+ profile = load_profile(args.profile)
843
+
844
+ # Apply profile values as defaults for unset args
845
+ if args.model is None:
846
+ args.model = profile.get("model", "Qwen/Qwen2.5-3B-Instruct")
847
+ if args.batch_size == 4 and "batch_size" in profile: # 4 is the argparse default
848
+ args.batch_size = profile["batch_size"]
849
+ if args.vllm_gpu_memory == 0.45 and "vllm_gpu_memory" in profile: # 0.45 is the default
850
+ args.vllm_gpu_memory = profile["vllm_gpu_memory"]
851
+
852
+ # Phase 4: Read multi-GPU settings from profile
853
+ args.tensor_parallel_size = profile.get("tensor_parallel_size", 1)
854
+ args.use_flash_attention = profile.get("use_flash_attention", False)
855
+ args.vllm_gpu = profile.get("vllm_gpu") # Explicit GPU assignment for vLLM
856
+ args.training_gpu = profile.get("training_gpu") # Explicit GPU assignment for training
857
+
858
+ # Log effective settings
859
+ if args.profile:
860
+ tp_info = f", tp={args.tensor_parallel_size}" if args.tensor_parallel_size > 1 else ""
861
+ logger.info(f"Using profile '{args.profile}': model={args.model}, "
862
+ f"vllm_mem={args.vllm_gpu_memory:.0%}, batch={args.batch_size}{tp_info}")
863
+
864
+ # Validate environment
865
+ if not args.skip_validation:
866
+ errors = validate_environment()
867
+ if errors:
868
+ logger.error("Environment validation failed:")
869
+ for error in errors:
870
+ logger.error(f" • {error}")
871
+ logger.error("\nFix the above issues or use --skip-validation to bypass.")
872
+ sys.exit(1)
873
+
874
+ orchestrator = TrainingOrchestrator(
875
+ model_name=args.model,
876
+ base_model=args.base_model,
877
+ dataset_input=args.dataset_input,
878
+ scoring_mode=args.scoring_mode,
879
+ training_steps=args.steps,
880
+ batch_size=args.batch_size,
881
+ learning_rate=args.lr,
882
+ min_learning_rate=args.min_lr,
883
+ lr_scheduler=args.lr_scheduler,
884
+ warmup_steps=args.warmup_steps,
885
+ api_port=args.api_port,
886
+ vllm_host=args.vllm_host,
887
+ vllm_port=args.vllm_port,
888
+ vllm_gpu_memory=args.vllm_gpu_memory,
889
+ save_path=args.save_path,
890
+ save_every=args.save_every,
891
+ keep_checkpoints=args.keep_checkpoints,
892
+ resume_from=args.resume,
893
+ use_wandb=not args.no_wandb,
894
+ wandb_project=args.wandb_project,
895
+ wandb_entity=args.wandb_entity,
896
+ wandb_run_name=args.wandb_run_name,
897
+ skip_services=args.skip_services,
898
+ log_dir=args.log_dir,
899
+ # Phase 3: Online training
900
+ mode=args.mode,
901
+ bridge_url=args.bridge_url,
902
+ hybrid_online_ratio=args.hybrid_online_ratio,
903
+ # Phase 4: Cloud/Multi-GPU
904
+ tensor_parallel_size=args.tensor_parallel_size,
905
+ use_flash_attention=args.use_flash_attention,
906
+ vllm_gpu=args.vllm_gpu,
907
+ training_gpu=args.training_gpu,
908
+ )
909
+
910
+ sys.exit(orchestrator.run())
911
+
912
+
913
+ if __name__ == "__main__":
914
+ main()