@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,552 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Service Manager for Local Training Infrastructure
|
|
3
|
+
|
|
4
|
+
Manages the lifecycle of background services required for local GRPO training:
|
|
5
|
+
- Atropos API Server: Handles batch collection and distribution
|
|
6
|
+
- vLLM Server: Provides inference during rollouts
|
|
7
|
+
|
|
8
|
+
Features:
|
|
9
|
+
- Automatic startup with health checks
|
|
10
|
+
- Graceful shutdown with kill fallback
|
|
11
|
+
- Context manager interface for automatic cleanup
|
|
12
|
+
- Configurable ports and timeouts
|
|
13
|
+
- Process output logging to files
|
|
14
|
+
|
|
15
|
+
Usage:
|
|
16
|
+
config = ServiceConfig(
|
|
17
|
+
atropos_port=8000,
|
|
18
|
+
vllm_port=9001,
|
|
19
|
+
model_name="Qwen/Qwen2.5-3B-Instruct",
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
with ServiceManager(config) as services:
|
|
23
|
+
if not services.wait_for_ready():
|
|
24
|
+
raise RuntimeError("Services failed to start")
|
|
25
|
+
# Run training...
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
import logging
|
|
29
|
+
import os
|
|
30
|
+
import shutil
|
|
31
|
+
import signal
|
|
32
|
+
import socket
|
|
33
|
+
import subprocess
|
|
34
|
+
import sys
|
|
35
|
+
import time
|
|
36
|
+
from dataclasses import dataclass
|
|
37
|
+
from enum import Enum
|
|
38
|
+
from pathlib import Path
|
|
39
|
+
from typing import IO, Optional
|
|
40
|
+
|
|
41
|
+
import requests
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ServiceStatus(Enum):
|
|
47
|
+
"""Status of a managed service"""
|
|
48
|
+
STOPPED = "stopped"
|
|
49
|
+
STARTING = "starting"
|
|
50
|
+
RUNNING = "running"
|
|
51
|
+
FAILED = "failed"
|
|
52
|
+
STOPPING = "stopping"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass
|
|
56
|
+
class ServiceConfig:
|
|
57
|
+
"""Configuration for managed services"""
|
|
58
|
+
|
|
59
|
+
# Atropos API settings
|
|
60
|
+
atropos_port: int = 8000
|
|
61
|
+
atropos_host: str = "localhost"
|
|
62
|
+
|
|
63
|
+
# vLLM settings
|
|
64
|
+
vllm_port: int = 9001
|
|
65
|
+
vllm_host: str = "localhost"
|
|
66
|
+
model_name: str = "Qwen/Qwen2.5-3B-Instruct"
|
|
67
|
+
vllm_gpu_memory_utilization: float = 0.85
|
|
68
|
+
vllm_dtype: str = "auto"
|
|
69
|
+
vllm_max_model_len: int = 4096
|
|
70
|
+
|
|
71
|
+
# Multi-GPU settings (Phase 4)
|
|
72
|
+
tensor_parallel_size: int = 1 # Number of GPUs for tensor parallelism
|
|
73
|
+
use_flash_attention: bool = False # Enable flash attention for performance
|
|
74
|
+
|
|
75
|
+
# GPU assignment - separate vLLM and training to avoid OOM conflicts
|
|
76
|
+
# vllm_gpu: Comma-separated GPU IDs for vLLM (e.g., "0" or "0,1" for tensor parallel)
|
|
77
|
+
# training_gpu: GPU ID for training model (e.g., "1" for dedicated training GPU)
|
|
78
|
+
vllm_gpu: Optional[str] = None # If None, falls back to auto-assignment
|
|
79
|
+
training_gpu: Optional[str] = None # If None, falls back to auto-assignment
|
|
80
|
+
|
|
81
|
+
# Timeouts
|
|
82
|
+
startup_timeout: int = 180 # 3 minutes for vLLM to load model
|
|
83
|
+
health_check_interval: float = 2.0
|
|
84
|
+
shutdown_timeout: int = 10
|
|
85
|
+
|
|
86
|
+
# Logging
|
|
87
|
+
log_dir: str = "./logs/services"
|
|
88
|
+
|
|
89
|
+
# Skip services (for testing or when already running)
|
|
90
|
+
skip_atropos: bool = False
|
|
91
|
+
skip_vllm: bool = False
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class ManagedProcess:
|
|
96
|
+
"""A managed subprocess with metadata"""
|
|
97
|
+
name: str
|
|
98
|
+
process: Optional[subprocess.Popen] = None
|
|
99
|
+
status: ServiceStatus = ServiceStatus.STOPPED
|
|
100
|
+
log_file: Optional[Path] = None
|
|
101
|
+
log_handle: Optional[IO] = None
|
|
102
|
+
health_url: Optional[str] = None
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def pid(self) -> Optional[int]:
|
|
106
|
+
return self.process.pid if self.process else None
|
|
107
|
+
|
|
108
|
+
def close_log(self) -> None:
|
|
109
|
+
if self.log_handle:
|
|
110
|
+
self.log_handle.close()
|
|
111
|
+
self.log_handle = None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class ServiceManager:
|
|
115
|
+
"""
|
|
116
|
+
Manages background services for local training.
|
|
117
|
+
|
|
118
|
+
Provides automatic startup, health checking, and cleanup of:
|
|
119
|
+
- Atropos API server (for GRPO batch distribution)
|
|
120
|
+
- vLLM inference server (for model rollouts)
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
def __init__(self, config: ServiceConfig):
|
|
124
|
+
self.config = config
|
|
125
|
+
self._processes: dict[str, ManagedProcess] = {}
|
|
126
|
+
self._shutdown_requested = False
|
|
127
|
+
|
|
128
|
+
# Create log directory
|
|
129
|
+
self._log_dir = Path(config.log_dir)
|
|
130
|
+
self._log_dir.mkdir(parents=True, exist_ok=True)
|
|
131
|
+
|
|
132
|
+
# Register signal handlers for graceful shutdown
|
|
133
|
+
self._original_sigint = signal.getsignal(signal.SIGINT)
|
|
134
|
+
self._original_sigterm = signal.getsignal(signal.SIGTERM)
|
|
135
|
+
|
|
136
|
+
def __enter__(self) -> "ServiceManager":
|
|
137
|
+
"""Context manager entry - start all services"""
|
|
138
|
+
signal.signal(signal.SIGINT, self._signal_handler)
|
|
139
|
+
signal.signal(signal.SIGTERM, self._signal_handler)
|
|
140
|
+
self.start_all()
|
|
141
|
+
return self
|
|
142
|
+
|
|
143
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
144
|
+
"""Context manager exit - stop all services"""
|
|
145
|
+
try:
|
|
146
|
+
self.stop_all()
|
|
147
|
+
finally:
|
|
148
|
+
# Always restore original signal handlers
|
|
149
|
+
signal.signal(signal.SIGINT, self._original_sigint)
|
|
150
|
+
signal.signal(signal.SIGTERM, self._original_sigterm)
|
|
151
|
+
|
|
152
|
+
def _signal_handler(self, signum: int, frame) -> None:
|
|
153
|
+
"""Handle shutdown signals gracefully"""
|
|
154
|
+
if self._shutdown_requested:
|
|
155
|
+
# Force exit on second signal
|
|
156
|
+
logger.warning("Forced shutdown requested")
|
|
157
|
+
sys.exit(1)
|
|
158
|
+
|
|
159
|
+
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
|
|
160
|
+
self._shutdown_requested = True
|
|
161
|
+
self.stop_all()
|
|
162
|
+
sys.exit(0)
|
|
163
|
+
|
|
164
|
+
def start_all(self) -> bool:
|
|
165
|
+
"""Start all configured services"""
|
|
166
|
+
logger.info("=" * 60)
|
|
167
|
+
logger.info("STARTING TRAINING SERVICES")
|
|
168
|
+
logger.info("=" * 60)
|
|
169
|
+
|
|
170
|
+
success = True
|
|
171
|
+
|
|
172
|
+
# Start Atropos API
|
|
173
|
+
if not self.config.skip_atropos:
|
|
174
|
+
if not self._start_atropos():
|
|
175
|
+
logger.error("Failed to start Atropos API server")
|
|
176
|
+
success = False
|
|
177
|
+
else:
|
|
178
|
+
logger.info("Skipping Atropos API (configured to skip)")
|
|
179
|
+
|
|
180
|
+
# Start vLLM
|
|
181
|
+
if not self.config.skip_vllm and success:
|
|
182
|
+
if not self._start_vllm():
|
|
183
|
+
logger.error("Failed to start vLLM server")
|
|
184
|
+
success = False
|
|
185
|
+
else:
|
|
186
|
+
if self.config.skip_vllm:
|
|
187
|
+
logger.info("Skipping vLLM server (configured to skip)")
|
|
188
|
+
|
|
189
|
+
return success
|
|
190
|
+
|
|
191
|
+
def stop_all(self) -> None:
|
|
192
|
+
"""Stop all managed services gracefully"""
|
|
193
|
+
logger.info("Stopping all services...")
|
|
194
|
+
|
|
195
|
+
# Stop in reverse order (vLLM first, then Atropos)
|
|
196
|
+
for name in reversed(list(self._processes.keys())):
|
|
197
|
+
self._stop_process(name)
|
|
198
|
+
|
|
199
|
+
logger.info("All services stopped")
|
|
200
|
+
|
|
201
|
+
def wait_for_ready(self, timeout: Optional[int] = None) -> bool:
|
|
202
|
+
"""
|
|
203
|
+
Wait for all services to be healthy.
|
|
204
|
+
|
|
205
|
+
Returns True if all services are ready, False on timeout or failure.
|
|
206
|
+
"""
|
|
207
|
+
timeout = timeout or self.config.startup_timeout
|
|
208
|
+
start_time = time.time()
|
|
209
|
+
|
|
210
|
+
services_to_check = []
|
|
211
|
+
if not self.config.skip_atropos:
|
|
212
|
+
services_to_check.append("atropos")
|
|
213
|
+
if not self.config.skip_vllm:
|
|
214
|
+
services_to_check.append("vllm")
|
|
215
|
+
|
|
216
|
+
if not services_to_check:
|
|
217
|
+
logger.info("No services to wait for")
|
|
218
|
+
return True
|
|
219
|
+
|
|
220
|
+
logger.info(f"Waiting for services to be ready (timeout: {timeout}s)...")
|
|
221
|
+
|
|
222
|
+
ready = {name: False for name in services_to_check}
|
|
223
|
+
|
|
224
|
+
while time.time() - start_time < timeout:
|
|
225
|
+
if self._shutdown_requested:
|
|
226
|
+
return False
|
|
227
|
+
|
|
228
|
+
all_ready = True
|
|
229
|
+
for name in services_to_check:
|
|
230
|
+
if ready[name]:
|
|
231
|
+
continue
|
|
232
|
+
|
|
233
|
+
if self._check_health(name):
|
|
234
|
+
ready[name] = True
|
|
235
|
+
logger.info(f" ✓ {name} is ready")
|
|
236
|
+
else:
|
|
237
|
+
all_ready = False
|
|
238
|
+
|
|
239
|
+
# Check if process died
|
|
240
|
+
proc = self._processes.get(name)
|
|
241
|
+
if proc and proc.process and proc.process.poll() is not None:
|
|
242
|
+
logger.error(f" ✗ {name} process died (exit code: {proc.process.returncode})")
|
|
243
|
+
return False
|
|
244
|
+
|
|
245
|
+
if all_ready:
|
|
246
|
+
elapsed = time.time() - start_time
|
|
247
|
+
logger.info(f"All services ready in {elapsed:.1f}s")
|
|
248
|
+
return True
|
|
249
|
+
|
|
250
|
+
time.sleep(self.config.health_check_interval)
|
|
251
|
+
|
|
252
|
+
# Timeout - report which services failed
|
|
253
|
+
for name, is_ready in ready.items():
|
|
254
|
+
if not is_ready:
|
|
255
|
+
logger.error(f" ✗ {name} failed to become ready")
|
|
256
|
+
|
|
257
|
+
return False
|
|
258
|
+
|
|
259
|
+
def is_healthy(self, service: str) -> bool:
|
|
260
|
+
"""Check if a specific service is healthy"""
|
|
261
|
+
return self._check_health(service)
|
|
262
|
+
|
|
263
|
+
def get_status(self, service: str) -> ServiceStatus:
|
|
264
|
+
"""Get the status of a specific service"""
|
|
265
|
+
proc = self._processes.get(service)
|
|
266
|
+
if not proc:
|
|
267
|
+
return ServiceStatus.STOPPED
|
|
268
|
+
return proc.status
|
|
269
|
+
|
|
270
|
+
def get_atropos_url(self) -> str:
|
|
271
|
+
"""Get the Atropos API URL"""
|
|
272
|
+
return f"http://{self.config.atropos_host}:{self.config.atropos_port}"
|
|
273
|
+
|
|
274
|
+
def get_vllm_url(self) -> str:
|
|
275
|
+
"""Get the vLLM server URL"""
|
|
276
|
+
return f"http://{self.config.vllm_host}:{self.config.vllm_port}"
|
|
277
|
+
|
|
278
|
+
def _start_atropos(self) -> bool:
|
|
279
|
+
"""Start the Atropos API server"""
|
|
280
|
+
host, port = self.config.atropos_host, self.config.atropos_port
|
|
281
|
+
# Atropos doesn't have /health endpoint, use / which returns 200
|
|
282
|
+
health_url = f"http://{host}:{port}/"
|
|
283
|
+
|
|
284
|
+
logger.info(f"Starting Atropos API server on port {port}...")
|
|
285
|
+
|
|
286
|
+
if self._port_in_use(host, port):
|
|
287
|
+
logger.warning(f"Port {port} already in use, assuming Atropos is running")
|
|
288
|
+
self._processes["atropos"] = ManagedProcess(
|
|
289
|
+
name="atropos", status=ServiceStatus.RUNNING, health_url=health_url
|
|
290
|
+
)
|
|
291
|
+
return True
|
|
292
|
+
|
|
293
|
+
log_file = self._log_dir / "atropos.log"
|
|
294
|
+
log_handle = open(log_file, "w")
|
|
295
|
+
|
|
296
|
+
try:
|
|
297
|
+
process = subprocess.Popen(
|
|
298
|
+
["run-api", "--port", str(port)],
|
|
299
|
+
stdout=log_handle,
|
|
300
|
+
stderr=subprocess.STDOUT,
|
|
301
|
+
env=os.environ.copy(),
|
|
302
|
+
)
|
|
303
|
+
except Exception as e:
|
|
304
|
+
log_handle.close()
|
|
305
|
+
logger.error(f"Failed to start Atropos: {e}")
|
|
306
|
+
raise
|
|
307
|
+
|
|
308
|
+
self._processes["atropos"] = ManagedProcess(
|
|
309
|
+
name="atropos",
|
|
310
|
+
process=process,
|
|
311
|
+
status=ServiceStatus.STARTING,
|
|
312
|
+
log_file=log_file,
|
|
313
|
+
log_handle=log_handle,
|
|
314
|
+
health_url=health_url,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
logger.info(f" Atropos started with PID {process.pid}, logs: {log_file}")
|
|
318
|
+
return True
|
|
319
|
+
|
|
320
|
+
def _start_vllm(self) -> bool:
|
|
321
|
+
"""Start the vLLM inference server"""
|
|
322
|
+
host, port = self.config.vllm_host, self.config.vllm_port
|
|
323
|
+
health_url = f"http://{host}:{port}/health"
|
|
324
|
+
cfg = self.config
|
|
325
|
+
|
|
326
|
+
logger.info(f"Starting vLLM server on port {port}...")
|
|
327
|
+
logger.info(f" Model: {cfg.model_name}")
|
|
328
|
+
logger.info(f" GPU Memory: {cfg.vllm_gpu_memory_utilization * 100:.0f}%")
|
|
329
|
+
if cfg.tensor_parallel_size > 1:
|
|
330
|
+
logger.info(f" Tensor Parallel: {cfg.tensor_parallel_size} GPUs")
|
|
331
|
+
if cfg.use_flash_attention:
|
|
332
|
+
logger.info(" Flash Attention: enabled")
|
|
333
|
+
|
|
334
|
+
if self._port_in_use(host, port):
|
|
335
|
+
logger.warning(f"Port {port} already in use, assuming vLLM is running")
|
|
336
|
+
self._processes["vllm"] = ManagedProcess(
|
|
337
|
+
name="vllm", status=ServiceStatus.RUNNING, health_url=health_url
|
|
338
|
+
)
|
|
339
|
+
return True
|
|
340
|
+
|
|
341
|
+
log_file = self._log_dir / "vllm.log"
|
|
342
|
+
log_handle = open(log_file, "w")
|
|
343
|
+
|
|
344
|
+
cmd = [
|
|
345
|
+
sys.executable, "-m", "vllm.entrypoints.openai.api_server",
|
|
346
|
+
"--model", cfg.model_name,
|
|
347
|
+
"--port", str(port),
|
|
348
|
+
"--dtype", cfg.vllm_dtype,
|
|
349
|
+
"--gpu-memory-utilization", str(cfg.vllm_gpu_memory_utilization),
|
|
350
|
+
"--max-model-len", str(cfg.vllm_max_model_len),
|
|
351
|
+
"--disable-log-requests",
|
|
352
|
+
"--served-model-name", cfg.model_name,
|
|
353
|
+
]
|
|
354
|
+
|
|
355
|
+
# Multi-GPU tensor parallelism (Phase 4)
|
|
356
|
+
if cfg.tensor_parallel_size > 1:
|
|
357
|
+
cmd.extend(["--tensor-parallel-size", str(cfg.tensor_parallel_size)])
|
|
358
|
+
|
|
359
|
+
env = os.environ.copy()
|
|
360
|
+
|
|
361
|
+
# Set attention backend if flash attention is configured
|
|
362
|
+
if cfg.use_flash_attention:
|
|
363
|
+
env["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"
|
|
364
|
+
|
|
365
|
+
# Set CUDA devices for vLLM based on explicit configuration or tensor parallel size
|
|
366
|
+
if cfg.vllm_gpu:
|
|
367
|
+
# Explicit GPU assignment from profile
|
|
368
|
+
env["CUDA_VISIBLE_DEVICES"] = cfg.vllm_gpu
|
|
369
|
+
logger.info(f" vLLM GPUs (explicit): {cfg.vllm_gpu}")
|
|
370
|
+
elif cfg.tensor_parallel_size > 1:
|
|
371
|
+
# Auto-assign GPUs for tensor parallelism
|
|
372
|
+
gpu_ids = ",".join(str(i) for i in range(cfg.tensor_parallel_size))
|
|
373
|
+
env["CUDA_VISIBLE_DEVICES"] = gpu_ids
|
|
374
|
+
logger.info(f" vLLM GPUs (auto tensor parallel): {gpu_ids}")
|
|
375
|
+
else:
|
|
376
|
+
env.setdefault("CUDA_VISIBLE_DEVICES", "0")
|
|
377
|
+
logger.info(" vLLM GPU (default): 0")
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
process = subprocess.Popen(cmd, stdout=log_handle, stderr=subprocess.STDOUT, env=env)
|
|
381
|
+
except Exception as e:
|
|
382
|
+
log_handle.close()
|
|
383
|
+
logger.error(f"Failed to start vLLM: {e}")
|
|
384
|
+
raise
|
|
385
|
+
|
|
386
|
+
self._processes["vllm"] = ManagedProcess(
|
|
387
|
+
name="vllm",
|
|
388
|
+
process=process,
|
|
389
|
+
status=ServiceStatus.STARTING,
|
|
390
|
+
log_file=log_file,
|
|
391
|
+
log_handle=log_handle,
|
|
392
|
+
health_url=health_url,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
logger.info(f" vLLM started with PID {process.pid}, logs: {log_file}")
|
|
396
|
+
return True
|
|
397
|
+
|
|
398
|
+
def _stop_process(self, name: str) -> None:
|
|
399
|
+
"""Stop a specific process gracefully"""
|
|
400
|
+
proc = self._processes.get(name)
|
|
401
|
+
if not proc:
|
|
402
|
+
return
|
|
403
|
+
|
|
404
|
+
# Close log handle first
|
|
405
|
+
proc.close_log()
|
|
406
|
+
|
|
407
|
+
if not proc.process or proc.process.poll() is not None:
|
|
408
|
+
proc.status = ServiceStatus.STOPPED
|
|
409
|
+
return
|
|
410
|
+
|
|
411
|
+
logger.info(f"Stopping {name} (PID: {proc.pid})...")
|
|
412
|
+
proc.status = ServiceStatus.STOPPING
|
|
413
|
+
proc.process.terminate()
|
|
414
|
+
|
|
415
|
+
# Wait for graceful shutdown
|
|
416
|
+
deadline = time.time() + self.config.shutdown_timeout
|
|
417
|
+
while time.time() < deadline and proc.process.poll() is None:
|
|
418
|
+
time.sleep(0.5)
|
|
419
|
+
|
|
420
|
+
if proc.process.poll() is None:
|
|
421
|
+
logger.warning(f" {name} did not stop gracefully, sending SIGKILL")
|
|
422
|
+
proc.process.kill()
|
|
423
|
+
proc.process.wait()
|
|
424
|
+
else:
|
|
425
|
+
logger.info(f" {name} stopped gracefully")
|
|
426
|
+
|
|
427
|
+
proc.status = ServiceStatus.STOPPED
|
|
428
|
+
|
|
429
|
+
def _check_health(self, name: str) -> bool:
|
|
430
|
+
"""Check if a service is healthy via its health endpoint"""
|
|
431
|
+
proc = self._processes.get(name)
|
|
432
|
+
if not proc or not proc.health_url:
|
|
433
|
+
return False
|
|
434
|
+
|
|
435
|
+
try:
|
|
436
|
+
response = requests.get(proc.health_url, timeout=5)
|
|
437
|
+
if response.status_code == 200:
|
|
438
|
+
proc.status = ServiceStatus.RUNNING
|
|
439
|
+
return True
|
|
440
|
+
except requests.exceptions.ConnectionError:
|
|
441
|
+
pass
|
|
442
|
+
except requests.exceptions.Timeout:
|
|
443
|
+
pass
|
|
444
|
+
|
|
445
|
+
return False
|
|
446
|
+
|
|
447
|
+
def _port_in_use(self, host: str, port: int) -> bool:
|
|
448
|
+
"""
|
|
449
|
+
Check if a port is already in use.
|
|
450
|
+
|
|
451
|
+
Note: There is an inherent TOCTOU (time-of-check to time-of-use) race condition
|
|
452
|
+
between this check and actually starting the process. If another process grabs
|
|
453
|
+
the port between check and Popen, startup will fail. This is acceptable for our
|
|
454
|
+
use case since we primarily use this to detect already-running services.
|
|
455
|
+
"""
|
|
456
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
457
|
+
sock.settimeout(1)
|
|
458
|
+
return sock.connect_ex((host, port)) == 0
|
|
459
|
+
|
|
460
|
+
def restart_vllm(self, model_path: Optional[str] = None) -> bool:
|
|
461
|
+
"""
|
|
462
|
+
Restart vLLM with optionally updated model weights.
|
|
463
|
+
|
|
464
|
+
Used during training to sync new weights to the inference server.
|
|
465
|
+
"""
|
|
466
|
+
logger.info("Restarting vLLM server...")
|
|
467
|
+
|
|
468
|
+
# Stop existing vLLM
|
|
469
|
+
self._stop_process("vllm")
|
|
470
|
+
|
|
471
|
+
# Clear CUDA cache
|
|
472
|
+
self._clear_cuda_cache()
|
|
473
|
+
|
|
474
|
+
# Update model path if provided
|
|
475
|
+
if model_path:
|
|
476
|
+
self.config.model_name = model_path
|
|
477
|
+
|
|
478
|
+
# Start new vLLM
|
|
479
|
+
if not self._start_vllm():
|
|
480
|
+
return False
|
|
481
|
+
|
|
482
|
+
# Wait for it to be ready
|
|
483
|
+
start_time = time.time()
|
|
484
|
+
timeout = self.config.startup_timeout
|
|
485
|
+
|
|
486
|
+
while time.time() - start_time < timeout:
|
|
487
|
+
if self._check_health("vllm"):
|
|
488
|
+
elapsed = time.time() - start_time
|
|
489
|
+
logger.info(f"vLLM restarted successfully in {elapsed:.1f}s")
|
|
490
|
+
return True
|
|
491
|
+
|
|
492
|
+
# Check if process died
|
|
493
|
+
proc = self._processes.get("vllm")
|
|
494
|
+
if proc and proc.process and proc.process.poll() is not None:
|
|
495
|
+
logger.error(f"vLLM died during restart (exit code: {proc.process.returncode})")
|
|
496
|
+
return False
|
|
497
|
+
|
|
498
|
+
time.sleep(self.config.health_check_interval)
|
|
499
|
+
|
|
500
|
+
logger.error("vLLM restart timed out")
|
|
501
|
+
return False
|
|
502
|
+
|
|
503
|
+
def _clear_cuda_cache(self) -> None:
|
|
504
|
+
"""Clear CUDA memory cache if available"""
|
|
505
|
+
try:
|
|
506
|
+
import torch
|
|
507
|
+
if torch.cuda.is_available():
|
|
508
|
+
torch.cuda.empty_cache()
|
|
509
|
+
torch.cuda.synchronize()
|
|
510
|
+
logger.debug("CUDA cache cleared")
|
|
511
|
+
except ImportError:
|
|
512
|
+
pass
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def check_prerequisites() -> list[str]:
|
|
516
|
+
"""
|
|
517
|
+
Check that all prerequisites for local training are available.
|
|
518
|
+
|
|
519
|
+
Returns a list of error messages for missing requirements.
|
|
520
|
+
"""
|
|
521
|
+
errors = []
|
|
522
|
+
|
|
523
|
+
if not shutil.which("run-api"):
|
|
524
|
+
errors.append("Atropos API not found. Install with: pip install atroposlib")
|
|
525
|
+
|
|
526
|
+
try:
|
|
527
|
+
import vllm # noqa: F401
|
|
528
|
+
except ImportError:
|
|
529
|
+
errors.append("vLLM not installed. Install with: pip install vllm")
|
|
530
|
+
|
|
531
|
+
try:
|
|
532
|
+
import torch
|
|
533
|
+
if torch.cuda.is_available():
|
|
534
|
+
gpu_name = torch.cuda.get_device_name(0)
|
|
535
|
+
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
|
|
536
|
+
logger.info(f"GPU detected: {gpu_name} ({gpu_mem:.1f} GB)")
|
|
537
|
+
else:
|
|
538
|
+
errors.append(
|
|
539
|
+
"CUDA not available. GPU is required for vLLM inference. "
|
|
540
|
+
"For CPU-only training, use --skip-vllm and provide external inference."
|
|
541
|
+
)
|
|
542
|
+
except ImportError:
|
|
543
|
+
errors.append("PyTorch not installed. Install with: pip install torch")
|
|
544
|
+
|
|
545
|
+
if not os.getenv("DATABASE_URL"):
|
|
546
|
+
errors.append(
|
|
547
|
+
"DATABASE_URL not set. Required for loading training trajectories. "
|
|
548
|
+
"Set with: export DATABASE_URL=postgresql://user:pass@host:5432/dbname"
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
return errors
|
|
552
|
+
|