@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,552 @@
1
+ """
2
+ Service Manager for Local Training Infrastructure
3
+
4
+ Manages the lifecycle of background services required for local GRPO training:
5
+ - Atropos API Server: Handles batch collection and distribution
6
+ - vLLM Server: Provides inference during rollouts
7
+
8
+ Features:
9
+ - Automatic startup with health checks
10
+ - Graceful shutdown with kill fallback
11
+ - Context manager interface for automatic cleanup
12
+ - Configurable ports and timeouts
13
+ - Process output logging to files
14
+
15
+ Usage:
16
+ config = ServiceConfig(
17
+ atropos_port=8000,
18
+ vllm_port=9001,
19
+ model_name="Qwen/Qwen2.5-3B-Instruct",
20
+ )
21
+
22
+ with ServiceManager(config) as services:
23
+ if not services.wait_for_ready():
24
+ raise RuntimeError("Services failed to start")
25
+ # Run training...
26
+ """
27
+
28
+ import logging
29
+ import os
30
+ import shutil
31
+ import signal
32
+ import socket
33
+ import subprocess
34
+ import sys
35
+ import time
36
+ from dataclasses import dataclass
37
+ from enum import Enum
38
+ from pathlib import Path
39
+ from typing import IO, Optional
40
+
41
+ import requests
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ class ServiceStatus(Enum):
47
+ """Status of a managed service"""
48
+ STOPPED = "stopped"
49
+ STARTING = "starting"
50
+ RUNNING = "running"
51
+ FAILED = "failed"
52
+ STOPPING = "stopping"
53
+
54
+
55
+ @dataclass
56
+ class ServiceConfig:
57
+ """Configuration for managed services"""
58
+
59
+ # Atropos API settings
60
+ atropos_port: int = 8000
61
+ atropos_host: str = "localhost"
62
+
63
+ # vLLM settings
64
+ vllm_port: int = 9001
65
+ vllm_host: str = "localhost"
66
+ model_name: str = "Qwen/Qwen2.5-3B-Instruct"
67
+ vllm_gpu_memory_utilization: float = 0.85
68
+ vllm_dtype: str = "auto"
69
+ vllm_max_model_len: int = 4096
70
+
71
+ # Multi-GPU settings (Phase 4)
72
+ tensor_parallel_size: int = 1 # Number of GPUs for tensor parallelism
73
+ use_flash_attention: bool = False # Enable flash attention for performance
74
+
75
+ # GPU assignment - separate vLLM and training to avoid OOM conflicts
76
+ # vllm_gpu: Comma-separated GPU IDs for vLLM (e.g., "0" or "0,1" for tensor parallel)
77
+ # training_gpu: GPU ID for training model (e.g., "1" for dedicated training GPU)
78
+ vllm_gpu: Optional[str] = None # If None, falls back to auto-assignment
79
+ training_gpu: Optional[str] = None # If None, falls back to auto-assignment
80
+
81
+ # Timeouts
82
+ startup_timeout: int = 180 # 3 minutes for vLLM to load model
83
+ health_check_interval: float = 2.0
84
+ shutdown_timeout: int = 10
85
+
86
+ # Logging
87
+ log_dir: str = "./logs/services"
88
+
89
+ # Skip services (for testing or when already running)
90
+ skip_atropos: bool = False
91
+ skip_vllm: bool = False
92
+
93
+
94
+ @dataclass
95
+ class ManagedProcess:
96
+ """A managed subprocess with metadata"""
97
+ name: str
98
+ process: Optional[subprocess.Popen] = None
99
+ status: ServiceStatus = ServiceStatus.STOPPED
100
+ log_file: Optional[Path] = None
101
+ log_handle: Optional[IO] = None
102
+ health_url: Optional[str] = None
103
+
104
+ @property
105
+ def pid(self) -> Optional[int]:
106
+ return self.process.pid if self.process else None
107
+
108
+ def close_log(self) -> None:
109
+ if self.log_handle:
110
+ self.log_handle.close()
111
+ self.log_handle = None
112
+
113
+
114
+ class ServiceManager:
115
+ """
116
+ Manages background services for local training.
117
+
118
+ Provides automatic startup, health checking, and cleanup of:
119
+ - Atropos API server (for GRPO batch distribution)
120
+ - vLLM inference server (for model rollouts)
121
+ """
122
+
123
+ def __init__(self, config: ServiceConfig):
124
+ self.config = config
125
+ self._processes: dict[str, ManagedProcess] = {}
126
+ self._shutdown_requested = False
127
+
128
+ # Create log directory
129
+ self._log_dir = Path(config.log_dir)
130
+ self._log_dir.mkdir(parents=True, exist_ok=True)
131
+
132
+ # Register signal handlers for graceful shutdown
133
+ self._original_sigint = signal.getsignal(signal.SIGINT)
134
+ self._original_sigterm = signal.getsignal(signal.SIGTERM)
135
+
136
+ def __enter__(self) -> "ServiceManager":
137
+ """Context manager entry - start all services"""
138
+ signal.signal(signal.SIGINT, self._signal_handler)
139
+ signal.signal(signal.SIGTERM, self._signal_handler)
140
+ self.start_all()
141
+ return self
142
+
143
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
144
+ """Context manager exit - stop all services"""
145
+ try:
146
+ self.stop_all()
147
+ finally:
148
+ # Always restore original signal handlers
149
+ signal.signal(signal.SIGINT, self._original_sigint)
150
+ signal.signal(signal.SIGTERM, self._original_sigterm)
151
+
152
+ def _signal_handler(self, signum: int, frame) -> None:
153
+ """Handle shutdown signals gracefully"""
154
+ if self._shutdown_requested:
155
+ # Force exit on second signal
156
+ logger.warning("Forced shutdown requested")
157
+ sys.exit(1)
158
+
159
+ logger.info(f"Received signal {signum}, initiating graceful shutdown...")
160
+ self._shutdown_requested = True
161
+ self.stop_all()
162
+ sys.exit(0)
163
+
164
+ def start_all(self) -> bool:
165
+ """Start all configured services"""
166
+ logger.info("=" * 60)
167
+ logger.info("STARTING TRAINING SERVICES")
168
+ logger.info("=" * 60)
169
+
170
+ success = True
171
+
172
+ # Start Atropos API
173
+ if not self.config.skip_atropos:
174
+ if not self._start_atropos():
175
+ logger.error("Failed to start Atropos API server")
176
+ success = False
177
+ else:
178
+ logger.info("Skipping Atropos API (configured to skip)")
179
+
180
+ # Start vLLM
181
+ if not self.config.skip_vllm and success:
182
+ if not self._start_vllm():
183
+ logger.error("Failed to start vLLM server")
184
+ success = False
185
+ else:
186
+ if self.config.skip_vllm:
187
+ logger.info("Skipping vLLM server (configured to skip)")
188
+
189
+ return success
190
+
191
+ def stop_all(self) -> None:
192
+ """Stop all managed services gracefully"""
193
+ logger.info("Stopping all services...")
194
+
195
+ # Stop in reverse order (vLLM first, then Atropos)
196
+ for name in reversed(list(self._processes.keys())):
197
+ self._stop_process(name)
198
+
199
+ logger.info("All services stopped")
200
+
201
+ def wait_for_ready(self, timeout: Optional[int] = None) -> bool:
202
+ """
203
+ Wait for all services to be healthy.
204
+
205
+ Returns True if all services are ready, False on timeout or failure.
206
+ """
207
+ timeout = timeout or self.config.startup_timeout
208
+ start_time = time.time()
209
+
210
+ services_to_check = []
211
+ if not self.config.skip_atropos:
212
+ services_to_check.append("atropos")
213
+ if not self.config.skip_vllm:
214
+ services_to_check.append("vllm")
215
+
216
+ if not services_to_check:
217
+ logger.info("No services to wait for")
218
+ return True
219
+
220
+ logger.info(f"Waiting for services to be ready (timeout: {timeout}s)...")
221
+
222
+ ready = {name: False for name in services_to_check}
223
+
224
+ while time.time() - start_time < timeout:
225
+ if self._shutdown_requested:
226
+ return False
227
+
228
+ all_ready = True
229
+ for name in services_to_check:
230
+ if ready[name]:
231
+ continue
232
+
233
+ if self._check_health(name):
234
+ ready[name] = True
235
+ logger.info(f" ✓ {name} is ready")
236
+ else:
237
+ all_ready = False
238
+
239
+ # Check if process died
240
+ proc = self._processes.get(name)
241
+ if proc and proc.process and proc.process.poll() is not None:
242
+ logger.error(f" ✗ {name} process died (exit code: {proc.process.returncode})")
243
+ return False
244
+
245
+ if all_ready:
246
+ elapsed = time.time() - start_time
247
+ logger.info(f"All services ready in {elapsed:.1f}s")
248
+ return True
249
+
250
+ time.sleep(self.config.health_check_interval)
251
+
252
+ # Timeout - report which services failed
253
+ for name, is_ready in ready.items():
254
+ if not is_ready:
255
+ logger.error(f" ✗ {name} failed to become ready")
256
+
257
+ return False
258
+
259
+ def is_healthy(self, service: str) -> bool:
260
+ """Check if a specific service is healthy"""
261
+ return self._check_health(service)
262
+
263
+ def get_status(self, service: str) -> ServiceStatus:
264
+ """Get the status of a specific service"""
265
+ proc = self._processes.get(service)
266
+ if not proc:
267
+ return ServiceStatus.STOPPED
268
+ return proc.status
269
+
270
+ def get_atropos_url(self) -> str:
271
+ """Get the Atropos API URL"""
272
+ return f"http://{self.config.atropos_host}:{self.config.atropos_port}"
273
+
274
+ def get_vllm_url(self) -> str:
275
+ """Get the vLLM server URL"""
276
+ return f"http://{self.config.vllm_host}:{self.config.vllm_port}"
277
+
278
+ def _start_atropos(self) -> bool:
279
+ """Start the Atropos API server"""
280
+ host, port = self.config.atropos_host, self.config.atropos_port
281
+ # Atropos doesn't have /health endpoint, use / which returns 200
282
+ health_url = f"http://{host}:{port}/"
283
+
284
+ logger.info(f"Starting Atropos API server on port {port}...")
285
+
286
+ if self._port_in_use(host, port):
287
+ logger.warning(f"Port {port} already in use, assuming Atropos is running")
288
+ self._processes["atropos"] = ManagedProcess(
289
+ name="atropos", status=ServiceStatus.RUNNING, health_url=health_url
290
+ )
291
+ return True
292
+
293
+ log_file = self._log_dir / "atropos.log"
294
+ log_handle = open(log_file, "w")
295
+
296
+ try:
297
+ process = subprocess.Popen(
298
+ ["run-api", "--port", str(port)],
299
+ stdout=log_handle,
300
+ stderr=subprocess.STDOUT,
301
+ env=os.environ.copy(),
302
+ )
303
+ except Exception as e:
304
+ log_handle.close()
305
+ logger.error(f"Failed to start Atropos: {e}")
306
+ raise
307
+
308
+ self._processes["atropos"] = ManagedProcess(
309
+ name="atropos",
310
+ process=process,
311
+ status=ServiceStatus.STARTING,
312
+ log_file=log_file,
313
+ log_handle=log_handle,
314
+ health_url=health_url,
315
+ )
316
+
317
+ logger.info(f" Atropos started with PID {process.pid}, logs: {log_file}")
318
+ return True
319
+
320
+ def _start_vllm(self) -> bool:
321
+ """Start the vLLM inference server"""
322
+ host, port = self.config.vllm_host, self.config.vllm_port
323
+ health_url = f"http://{host}:{port}/health"
324
+ cfg = self.config
325
+
326
+ logger.info(f"Starting vLLM server on port {port}...")
327
+ logger.info(f" Model: {cfg.model_name}")
328
+ logger.info(f" GPU Memory: {cfg.vllm_gpu_memory_utilization * 100:.0f}%")
329
+ if cfg.tensor_parallel_size > 1:
330
+ logger.info(f" Tensor Parallel: {cfg.tensor_parallel_size} GPUs")
331
+ if cfg.use_flash_attention:
332
+ logger.info(" Flash Attention: enabled")
333
+
334
+ if self._port_in_use(host, port):
335
+ logger.warning(f"Port {port} already in use, assuming vLLM is running")
336
+ self._processes["vllm"] = ManagedProcess(
337
+ name="vllm", status=ServiceStatus.RUNNING, health_url=health_url
338
+ )
339
+ return True
340
+
341
+ log_file = self._log_dir / "vllm.log"
342
+ log_handle = open(log_file, "w")
343
+
344
+ cmd = [
345
+ sys.executable, "-m", "vllm.entrypoints.openai.api_server",
346
+ "--model", cfg.model_name,
347
+ "--port", str(port),
348
+ "--dtype", cfg.vllm_dtype,
349
+ "--gpu-memory-utilization", str(cfg.vllm_gpu_memory_utilization),
350
+ "--max-model-len", str(cfg.vllm_max_model_len),
351
+ "--disable-log-requests",
352
+ "--served-model-name", cfg.model_name,
353
+ ]
354
+
355
+ # Multi-GPU tensor parallelism (Phase 4)
356
+ if cfg.tensor_parallel_size > 1:
357
+ cmd.extend(["--tensor-parallel-size", str(cfg.tensor_parallel_size)])
358
+
359
+ env = os.environ.copy()
360
+
361
+ # Set attention backend if flash attention is configured
362
+ if cfg.use_flash_attention:
363
+ env["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"
364
+
365
+ # Set CUDA devices for vLLM based on explicit configuration or tensor parallel size
366
+ if cfg.vllm_gpu:
367
+ # Explicit GPU assignment from profile
368
+ env["CUDA_VISIBLE_DEVICES"] = cfg.vllm_gpu
369
+ logger.info(f" vLLM GPUs (explicit): {cfg.vllm_gpu}")
370
+ elif cfg.tensor_parallel_size > 1:
371
+ # Auto-assign GPUs for tensor parallelism
372
+ gpu_ids = ",".join(str(i) for i in range(cfg.tensor_parallel_size))
373
+ env["CUDA_VISIBLE_DEVICES"] = gpu_ids
374
+ logger.info(f" vLLM GPUs (auto tensor parallel): {gpu_ids}")
375
+ else:
376
+ env.setdefault("CUDA_VISIBLE_DEVICES", "0")
377
+ logger.info(" vLLM GPU (default): 0")
378
+
379
+ try:
380
+ process = subprocess.Popen(cmd, stdout=log_handle, stderr=subprocess.STDOUT, env=env)
381
+ except Exception as e:
382
+ log_handle.close()
383
+ logger.error(f"Failed to start vLLM: {e}")
384
+ raise
385
+
386
+ self._processes["vllm"] = ManagedProcess(
387
+ name="vllm",
388
+ process=process,
389
+ status=ServiceStatus.STARTING,
390
+ log_file=log_file,
391
+ log_handle=log_handle,
392
+ health_url=health_url,
393
+ )
394
+
395
+ logger.info(f" vLLM started with PID {process.pid}, logs: {log_file}")
396
+ return True
397
+
398
+ def _stop_process(self, name: str) -> None:
399
+ """Stop a specific process gracefully"""
400
+ proc = self._processes.get(name)
401
+ if not proc:
402
+ return
403
+
404
+ # Close log handle first
405
+ proc.close_log()
406
+
407
+ if not proc.process or proc.process.poll() is not None:
408
+ proc.status = ServiceStatus.STOPPED
409
+ return
410
+
411
+ logger.info(f"Stopping {name} (PID: {proc.pid})...")
412
+ proc.status = ServiceStatus.STOPPING
413
+ proc.process.terminate()
414
+
415
+ # Wait for graceful shutdown
416
+ deadline = time.time() + self.config.shutdown_timeout
417
+ while time.time() < deadline and proc.process.poll() is None:
418
+ time.sleep(0.5)
419
+
420
+ if proc.process.poll() is None:
421
+ logger.warning(f" {name} did not stop gracefully, sending SIGKILL")
422
+ proc.process.kill()
423
+ proc.process.wait()
424
+ else:
425
+ logger.info(f" {name} stopped gracefully")
426
+
427
+ proc.status = ServiceStatus.STOPPED
428
+
429
+ def _check_health(self, name: str) -> bool:
430
+ """Check if a service is healthy via its health endpoint"""
431
+ proc = self._processes.get(name)
432
+ if not proc or not proc.health_url:
433
+ return False
434
+
435
+ try:
436
+ response = requests.get(proc.health_url, timeout=5)
437
+ if response.status_code == 200:
438
+ proc.status = ServiceStatus.RUNNING
439
+ return True
440
+ except requests.exceptions.ConnectionError:
441
+ pass
442
+ except requests.exceptions.Timeout:
443
+ pass
444
+
445
+ return False
446
+
447
+ def _port_in_use(self, host: str, port: int) -> bool:
448
+ """
449
+ Check if a port is already in use.
450
+
451
+ Note: There is an inherent TOCTOU (time-of-check to time-of-use) race condition
452
+ between this check and actually starting the process. If another process grabs
453
+ the port between check and Popen, startup will fail. This is acceptable for our
454
+ use case since we primarily use this to detect already-running services.
455
+ """
456
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
457
+ sock.settimeout(1)
458
+ return sock.connect_ex((host, port)) == 0
459
+
460
+ def restart_vllm(self, model_path: Optional[str] = None) -> bool:
461
+ """
462
+ Restart vLLM with optionally updated model weights.
463
+
464
+ Used during training to sync new weights to the inference server.
465
+ """
466
+ logger.info("Restarting vLLM server...")
467
+
468
+ # Stop existing vLLM
469
+ self._stop_process("vllm")
470
+
471
+ # Clear CUDA cache
472
+ self._clear_cuda_cache()
473
+
474
+ # Update model path if provided
475
+ if model_path:
476
+ self.config.model_name = model_path
477
+
478
+ # Start new vLLM
479
+ if not self._start_vllm():
480
+ return False
481
+
482
+ # Wait for it to be ready
483
+ start_time = time.time()
484
+ timeout = self.config.startup_timeout
485
+
486
+ while time.time() - start_time < timeout:
487
+ if self._check_health("vllm"):
488
+ elapsed = time.time() - start_time
489
+ logger.info(f"vLLM restarted successfully in {elapsed:.1f}s")
490
+ return True
491
+
492
+ # Check if process died
493
+ proc = self._processes.get("vllm")
494
+ if proc and proc.process and proc.process.poll() is not None:
495
+ logger.error(f"vLLM died during restart (exit code: {proc.process.returncode})")
496
+ return False
497
+
498
+ time.sleep(self.config.health_check_interval)
499
+
500
+ logger.error("vLLM restart timed out")
501
+ return False
502
+
503
+ def _clear_cuda_cache(self) -> None:
504
+ """Clear CUDA memory cache if available"""
505
+ try:
506
+ import torch
507
+ if torch.cuda.is_available():
508
+ torch.cuda.empty_cache()
509
+ torch.cuda.synchronize()
510
+ logger.debug("CUDA cache cleared")
511
+ except ImportError:
512
+ pass
513
+
514
+
515
+ def check_prerequisites() -> list[str]:
516
+ """
517
+ Check that all prerequisites for local training are available.
518
+
519
+ Returns a list of error messages for missing requirements.
520
+ """
521
+ errors = []
522
+
523
+ if not shutil.which("run-api"):
524
+ errors.append("Atropos API not found. Install with: pip install atroposlib")
525
+
526
+ try:
527
+ import vllm # noqa: F401
528
+ except ImportError:
529
+ errors.append("vLLM not installed. Install with: pip install vllm")
530
+
531
+ try:
532
+ import torch
533
+ if torch.cuda.is_available():
534
+ gpu_name = torch.cuda.get_device_name(0)
535
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
536
+ logger.info(f"GPU detected: {gpu_name} ({gpu_mem:.1f} GB)")
537
+ else:
538
+ errors.append(
539
+ "CUDA not available. GPU is required for vLLM inference. "
540
+ "For CPU-only training, use --skip-vllm and provide external inference."
541
+ )
542
+ except ImportError:
543
+ errors.append("PyTorch not installed. Install with: pip install torch")
544
+
545
+ if not os.getenv("DATABASE_URL"):
546
+ errors.append(
547
+ "DATABASE_URL not set. Required for loading training trajectories. "
548
+ "Set with: export DATABASE_URL=postgresql://user:pass@host:5432/dbname"
549
+ )
550
+
551
+ return errors
552
+