@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,585 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for ServiceManager - Process lifecycle, health checks, port detection.
|
|
3
|
+
|
|
4
|
+
Tests cover:
|
|
5
|
+
- ServiceConfig validation and defaults
|
|
6
|
+
- ManagedProcess state management
|
|
7
|
+
- Port-in-use detection
|
|
8
|
+
- Process start/stop lifecycle
|
|
9
|
+
- Health check behavior
|
|
10
|
+
- Resource cleanup (file handles)
|
|
11
|
+
- Signal handling
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import socket
|
|
15
|
+
import subprocess
|
|
16
|
+
import sys
|
|
17
|
+
import tempfile
|
|
18
|
+
import threading
|
|
19
|
+
import time
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from unittest.mock import MagicMock, patch
|
|
22
|
+
|
|
23
|
+
import pytest
|
|
24
|
+
|
|
25
|
+
# Add src to path
|
|
26
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
27
|
+
|
|
28
|
+
from src.training.service_manager import (
|
|
29
|
+
ManagedProcess,
|
|
30
|
+
ServiceConfig,
|
|
31
|
+
ServiceManager,
|
|
32
|
+
ServiceStatus,
|
|
33
|
+
check_prerequisites,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TestServiceConfig:
|
|
38
|
+
"""Tests for ServiceConfig dataclass"""
|
|
39
|
+
|
|
40
|
+
def test_default_values(self):
|
|
41
|
+
"""Verify all default values are set correctly"""
|
|
42
|
+
config = ServiceConfig()
|
|
43
|
+
|
|
44
|
+
assert config.atropos_port == 8000
|
|
45
|
+
assert config.atropos_host == "localhost"
|
|
46
|
+
assert config.vllm_port == 9001
|
|
47
|
+
assert config.vllm_host == "localhost"
|
|
48
|
+
assert config.model_name == "Qwen/Qwen2.5-3B-Instruct"
|
|
49
|
+
assert config.vllm_gpu_memory_utilization == 0.85
|
|
50
|
+
assert config.vllm_dtype == "auto"
|
|
51
|
+
assert config.vllm_max_model_len == 4096
|
|
52
|
+
assert config.startup_timeout == 180
|
|
53
|
+
assert config.health_check_interval == 2.0
|
|
54
|
+
assert config.shutdown_timeout == 10
|
|
55
|
+
assert config.log_dir == "./logs/services"
|
|
56
|
+
assert config.skip_atropos is False
|
|
57
|
+
assert config.skip_vllm is False
|
|
58
|
+
|
|
59
|
+
def test_custom_values(self):
|
|
60
|
+
"""Verify custom values override defaults"""
|
|
61
|
+
config = ServiceConfig(
|
|
62
|
+
atropos_port=9000,
|
|
63
|
+
vllm_port=8080,
|
|
64
|
+
model_name="custom/model",
|
|
65
|
+
vllm_gpu_memory_utilization=0.5,
|
|
66
|
+
startup_timeout=60,
|
|
67
|
+
skip_atropos=True,
|
|
68
|
+
skip_vllm=True,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
assert config.atropos_port == 9000
|
|
72
|
+
assert config.vllm_port == 8080
|
|
73
|
+
assert config.model_name == "custom/model"
|
|
74
|
+
assert config.vllm_gpu_memory_utilization == 0.5
|
|
75
|
+
assert config.startup_timeout == 60
|
|
76
|
+
assert config.skip_atropos is True
|
|
77
|
+
assert config.skip_vllm is True
|
|
78
|
+
|
|
79
|
+
def test_gpu_memory_boundary_values(self):
|
|
80
|
+
"""Test GPU memory utilization boundary values"""
|
|
81
|
+
# Valid boundaries
|
|
82
|
+
config_min = ServiceConfig(vllm_gpu_memory_utilization=0.0)
|
|
83
|
+
assert config_min.vllm_gpu_memory_utilization == 0.0
|
|
84
|
+
|
|
85
|
+
config_max = ServiceConfig(vllm_gpu_memory_utilization=1.0)
|
|
86
|
+
assert config_max.vllm_gpu_memory_utilization == 1.0
|
|
87
|
+
|
|
88
|
+
# Edge case - slightly above 0
|
|
89
|
+
config_small = ServiceConfig(vllm_gpu_memory_utilization=0.01)
|
|
90
|
+
assert config_small.vllm_gpu_memory_utilization == 0.01
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class TestManagedProcess:
|
|
94
|
+
"""Tests for ManagedProcess dataclass"""
|
|
95
|
+
|
|
96
|
+
def test_default_state(self):
|
|
97
|
+
"""Verify default state of ManagedProcess"""
|
|
98
|
+
proc = ManagedProcess(name="test")
|
|
99
|
+
|
|
100
|
+
assert proc.name == "test"
|
|
101
|
+
assert proc.process is None
|
|
102
|
+
assert proc.status == ServiceStatus.STOPPED
|
|
103
|
+
assert proc.log_file is None
|
|
104
|
+
assert proc.log_handle is None
|
|
105
|
+
assert proc.health_url is None
|
|
106
|
+
assert proc.pid is None
|
|
107
|
+
|
|
108
|
+
def test_pid_property_with_process(self):
|
|
109
|
+
"""Test pid property returns process.pid when process exists"""
|
|
110
|
+
mock_process = MagicMock()
|
|
111
|
+
mock_process.pid = 12345
|
|
112
|
+
|
|
113
|
+
proc = ManagedProcess(name="test", process=mock_process)
|
|
114
|
+
assert proc.pid == 12345
|
|
115
|
+
|
|
116
|
+
def test_pid_property_without_process(self):
|
|
117
|
+
"""Test pid property returns None when no process"""
|
|
118
|
+
proc = ManagedProcess(name="test")
|
|
119
|
+
assert proc.pid is None
|
|
120
|
+
|
|
121
|
+
def test_close_log_with_handle(self):
|
|
122
|
+
"""Test close_log properly closes file handle"""
|
|
123
|
+
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
|
|
124
|
+
temp_path = f.name
|
|
125
|
+
|
|
126
|
+
handle = open(temp_path, 'w')
|
|
127
|
+
proc = ManagedProcess(name="test", log_handle=handle)
|
|
128
|
+
|
|
129
|
+
assert not handle.closed
|
|
130
|
+
proc.close_log()
|
|
131
|
+
assert handle.closed
|
|
132
|
+
assert proc.log_handle is None
|
|
133
|
+
|
|
134
|
+
# Cleanup
|
|
135
|
+
Path(temp_path).unlink()
|
|
136
|
+
|
|
137
|
+
def test_close_log_without_handle(self):
|
|
138
|
+
"""Test close_log is safe when no handle exists"""
|
|
139
|
+
proc = ManagedProcess(name="test")
|
|
140
|
+
proc.close_log() # Should not raise
|
|
141
|
+
assert proc.log_handle is None
|
|
142
|
+
|
|
143
|
+
def test_close_log_idempotent(self):
|
|
144
|
+
"""Test close_log can be called multiple times safely"""
|
|
145
|
+
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
|
|
146
|
+
temp_path = f.name
|
|
147
|
+
|
|
148
|
+
handle = open(temp_path, 'w')
|
|
149
|
+
proc = ManagedProcess(name="test", log_handle=handle)
|
|
150
|
+
|
|
151
|
+
proc.close_log()
|
|
152
|
+
proc.close_log() # Second call should not raise
|
|
153
|
+
|
|
154
|
+
Path(temp_path).unlink()
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class TestServiceStatus:
|
|
158
|
+
"""Tests for ServiceStatus enum"""
|
|
159
|
+
|
|
160
|
+
def test_all_statuses_defined(self):
|
|
161
|
+
"""Verify all expected status values exist"""
|
|
162
|
+
assert ServiceStatus.STOPPED.value == "stopped"
|
|
163
|
+
assert ServiceStatus.STARTING.value == "starting"
|
|
164
|
+
assert ServiceStatus.RUNNING.value == "running"
|
|
165
|
+
assert ServiceStatus.FAILED.value == "failed"
|
|
166
|
+
assert ServiceStatus.STOPPING.value == "stopping"
|
|
167
|
+
|
|
168
|
+
def test_status_count(self):
|
|
169
|
+
"""Verify expected number of statuses"""
|
|
170
|
+
assert len(ServiceStatus) == 5
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class TestServiceManagerPortDetection:
|
|
174
|
+
"""Tests for port-in-use detection"""
|
|
175
|
+
|
|
176
|
+
def test_port_not_in_use(self):
|
|
177
|
+
"""Test detecting a free port"""
|
|
178
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
179
|
+
manager = ServiceManager(config)
|
|
180
|
+
|
|
181
|
+
# Find a free port
|
|
182
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
183
|
+
s.bind(('localhost', 0))
|
|
184
|
+
free_port = s.getsockname()[1]
|
|
185
|
+
|
|
186
|
+
assert manager._port_in_use("localhost", free_port) is False
|
|
187
|
+
|
|
188
|
+
def test_port_in_use(self):
|
|
189
|
+
"""Test detecting a port that is in use"""
|
|
190
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
191
|
+
manager = ServiceManager(config)
|
|
192
|
+
|
|
193
|
+
# Bind to a port
|
|
194
|
+
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
195
|
+
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
196
|
+
server.bind(('localhost', 0))
|
|
197
|
+
server.listen(1)
|
|
198
|
+
port = server.getsockname()[1]
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
assert manager._port_in_use("localhost", port) is True
|
|
202
|
+
finally:
|
|
203
|
+
server.close()
|
|
204
|
+
|
|
205
|
+
def test_port_in_use_timeout(self):
|
|
206
|
+
"""Test port check doesn't hang on unresponsive ports"""
|
|
207
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
208
|
+
manager = ServiceManager(config)
|
|
209
|
+
|
|
210
|
+
# Use a port that's unlikely to be in use
|
|
211
|
+
start = time.time()
|
|
212
|
+
result = manager._port_in_use("localhost", 59999)
|
|
213
|
+
elapsed = time.time() - start
|
|
214
|
+
|
|
215
|
+
# Should complete within timeout (1 second) plus margin
|
|
216
|
+
assert elapsed < 2.0
|
|
217
|
+
assert result is False
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class TestServiceManagerUrls:
|
|
221
|
+
"""Tests for URL generation"""
|
|
222
|
+
|
|
223
|
+
def test_get_atropos_url_default(self):
|
|
224
|
+
"""Test default Atropos URL"""
|
|
225
|
+
config = ServiceConfig()
|
|
226
|
+
manager = ServiceManager(config)
|
|
227
|
+
|
|
228
|
+
assert manager.get_atropos_url() == "http://localhost:8000"
|
|
229
|
+
|
|
230
|
+
def test_get_atropos_url_custom(self):
|
|
231
|
+
"""Test custom Atropos URL"""
|
|
232
|
+
config = ServiceConfig(atropos_host="192.168.1.1", atropos_port=9000)
|
|
233
|
+
manager = ServiceManager(config)
|
|
234
|
+
|
|
235
|
+
assert manager.get_atropos_url() == "http://192.168.1.1:9000"
|
|
236
|
+
|
|
237
|
+
def test_get_vllm_url_default(self):
|
|
238
|
+
"""Test default vLLM URL"""
|
|
239
|
+
config = ServiceConfig()
|
|
240
|
+
manager = ServiceManager(config)
|
|
241
|
+
|
|
242
|
+
assert manager.get_vllm_url() == "http://localhost:9001"
|
|
243
|
+
|
|
244
|
+
def test_get_vllm_url_custom(self):
|
|
245
|
+
"""Test custom vLLM URL"""
|
|
246
|
+
config = ServiceConfig(vllm_host="10.0.0.1", vllm_port=8080)
|
|
247
|
+
manager = ServiceManager(config)
|
|
248
|
+
|
|
249
|
+
assert manager.get_vllm_url() == "http://10.0.0.1:8080"
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class TestServiceManagerStatus:
|
|
253
|
+
"""Tests for service status tracking"""
|
|
254
|
+
|
|
255
|
+
def test_get_status_unknown_service(self):
|
|
256
|
+
"""Test getting status of non-existent service"""
|
|
257
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
258
|
+
manager = ServiceManager(config)
|
|
259
|
+
|
|
260
|
+
assert manager.get_status("nonexistent") == ServiceStatus.STOPPED
|
|
261
|
+
|
|
262
|
+
def test_get_status_skipped_service(self):
|
|
263
|
+
"""Test status after skip - should be STOPPED (never started)"""
|
|
264
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
265
|
+
manager = ServiceManager(config)
|
|
266
|
+
manager.start_all()
|
|
267
|
+
|
|
268
|
+
assert manager.get_status("atropos") == ServiceStatus.STOPPED
|
|
269
|
+
assert manager.get_status("vllm") == ServiceStatus.STOPPED
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
class TestServiceManagerSkipBehavior:
|
|
273
|
+
"""Tests for skip service behavior"""
|
|
274
|
+
|
|
275
|
+
def test_skip_atropos_only(self):
|
|
276
|
+
"""Test skipping only Atropos"""
|
|
277
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
278
|
+
manager = ServiceManager(config)
|
|
279
|
+
|
|
280
|
+
result = manager.start_all()
|
|
281
|
+
|
|
282
|
+
assert result is True
|
|
283
|
+
assert "atropos" not in manager._processes
|
|
284
|
+
|
|
285
|
+
def test_skip_vllm_only(self):
|
|
286
|
+
"""Test skipping only vLLM"""
|
|
287
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
288
|
+
manager = ServiceManager(config)
|
|
289
|
+
|
|
290
|
+
result = manager.start_all()
|
|
291
|
+
|
|
292
|
+
assert result is True
|
|
293
|
+
assert "vllm" not in manager._processes
|
|
294
|
+
|
|
295
|
+
def test_skip_all_services(self):
|
|
296
|
+
"""Test skipping all services"""
|
|
297
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
298
|
+
manager = ServiceManager(config)
|
|
299
|
+
|
|
300
|
+
result = manager.start_all()
|
|
301
|
+
|
|
302
|
+
assert result is True
|
|
303
|
+
assert len(manager._processes) == 0
|
|
304
|
+
|
|
305
|
+
def test_wait_for_ready_no_services(self):
|
|
306
|
+
"""Test wait_for_ready returns True when all services skipped"""
|
|
307
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
308
|
+
manager = ServiceManager(config)
|
|
309
|
+
manager.start_all()
|
|
310
|
+
|
|
311
|
+
result = manager.wait_for_ready(timeout=1)
|
|
312
|
+
|
|
313
|
+
assert result is True
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class TestServiceManagerHealthCheck:
|
|
317
|
+
"""Tests for health check behavior"""
|
|
318
|
+
|
|
319
|
+
def test_check_health_no_process(self):
|
|
320
|
+
"""Test health check returns False for non-existent service"""
|
|
321
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
322
|
+
manager = ServiceManager(config)
|
|
323
|
+
|
|
324
|
+
result = manager._check_health("atropos")
|
|
325
|
+
|
|
326
|
+
assert result is False
|
|
327
|
+
|
|
328
|
+
def test_check_health_no_url(self):
|
|
329
|
+
"""Test health check returns False when no health_url set"""
|
|
330
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
331
|
+
manager = ServiceManager(config)
|
|
332
|
+
manager._processes["test"] = ManagedProcess(name="test", health_url=None)
|
|
333
|
+
|
|
334
|
+
result = manager._check_health("test")
|
|
335
|
+
|
|
336
|
+
assert result is False
|
|
337
|
+
|
|
338
|
+
def test_is_healthy_delegates_to_check_health(self):
|
|
339
|
+
"""Test is_healthy is a public interface to _check_health"""
|
|
340
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
341
|
+
manager = ServiceManager(config)
|
|
342
|
+
|
|
343
|
+
# Should return same result
|
|
344
|
+
assert manager.is_healthy("nonexistent") == manager._check_health("nonexistent")
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class TestServiceManagerContextManager:
|
|
348
|
+
"""Tests for context manager interface"""
|
|
349
|
+
|
|
350
|
+
def test_context_manager_start_and_stop(self):
|
|
351
|
+
"""Test context manager starts and stops services"""
|
|
352
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
353
|
+
|
|
354
|
+
with ServiceManager(config) as manager:
|
|
355
|
+
# Inside context, should have started
|
|
356
|
+
assert isinstance(manager, ServiceManager)
|
|
357
|
+
|
|
358
|
+
# After context, should be cleaned up
|
|
359
|
+
assert len(manager._processes) == 0
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
class TestServiceManagerLogDirectory:
|
|
363
|
+
"""Tests for log directory management"""
|
|
364
|
+
|
|
365
|
+
def test_creates_log_directory(self):
|
|
366
|
+
"""Test that log directory is created on init"""
|
|
367
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
368
|
+
log_dir = Path(tmpdir) / "nested" / "logs"
|
|
369
|
+
config = ServiceConfig(log_dir=str(log_dir), skip_atropos=True, skip_vllm=True)
|
|
370
|
+
|
|
371
|
+
manager = ServiceManager(config)
|
|
372
|
+
|
|
373
|
+
assert log_dir.exists()
|
|
374
|
+
assert log_dir.is_dir()
|
|
375
|
+
|
|
376
|
+
def test_existing_log_directory_ok(self):
|
|
377
|
+
"""Test that existing log directory is acceptable"""
|
|
378
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
379
|
+
config = ServiceConfig(log_dir=tmpdir, skip_atropos=True, skip_vllm=True)
|
|
380
|
+
|
|
381
|
+
# Should not raise
|
|
382
|
+
manager = ServiceManager(config)
|
|
383
|
+
assert Path(tmpdir).exists()
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
class TestCheckPrerequisites:
|
|
387
|
+
"""Tests for the check_prerequisites function"""
|
|
388
|
+
|
|
389
|
+
def test_returns_list(self):
|
|
390
|
+
"""Test that check_prerequisites returns a list"""
|
|
391
|
+
result = check_prerequisites()
|
|
392
|
+
assert isinstance(result, list)
|
|
393
|
+
|
|
394
|
+
def test_missing_database_url(self):
|
|
395
|
+
"""Test error when DATABASE_URL not set"""
|
|
396
|
+
import os
|
|
397
|
+
|
|
398
|
+
# Save and clear DATABASE_URL
|
|
399
|
+
original = os.environ.get("DATABASE_URL")
|
|
400
|
+
if "DATABASE_URL" in os.environ:
|
|
401
|
+
del os.environ["DATABASE_URL"]
|
|
402
|
+
|
|
403
|
+
try:
|
|
404
|
+
errors = check_prerequisites()
|
|
405
|
+
|
|
406
|
+
# Should have at least the DATABASE_URL error
|
|
407
|
+
db_errors = [e for e in errors if "DATABASE_URL" in e]
|
|
408
|
+
assert len(db_errors) >= 1
|
|
409
|
+
finally:
|
|
410
|
+
if original:
|
|
411
|
+
os.environ["DATABASE_URL"] = original
|
|
412
|
+
|
|
413
|
+
def test_with_database_url_set(self):
|
|
414
|
+
"""Test no DATABASE_URL error when it's set"""
|
|
415
|
+
import os
|
|
416
|
+
|
|
417
|
+
original = os.environ.get("DATABASE_URL")
|
|
418
|
+
os.environ["DATABASE_URL"] = "postgresql://test:test@localhost/test"
|
|
419
|
+
|
|
420
|
+
try:
|
|
421
|
+
errors = check_prerequisites()
|
|
422
|
+
db_errors = [e for e in errors if "DATABASE_URL" in e]
|
|
423
|
+
assert len(db_errors) == 0
|
|
424
|
+
finally:
|
|
425
|
+
if original:
|
|
426
|
+
os.environ["DATABASE_URL"] = original
|
|
427
|
+
else:
|
|
428
|
+
del os.environ["DATABASE_URL"]
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
class TestServiceManagerStopProcess:
|
|
432
|
+
"""Tests for process stopping behavior"""
|
|
433
|
+
|
|
434
|
+
def test_stop_process_nonexistent(self):
|
|
435
|
+
"""Test stopping non-existent process is safe"""
|
|
436
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
437
|
+
manager = ServiceManager(config)
|
|
438
|
+
|
|
439
|
+
# Should not raise
|
|
440
|
+
manager._stop_process("nonexistent")
|
|
441
|
+
|
|
442
|
+
def test_stop_process_no_subprocess(self):
|
|
443
|
+
"""Test stopping process with no subprocess object"""
|
|
444
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
445
|
+
manager = ServiceManager(config)
|
|
446
|
+
manager._processes["test"] = ManagedProcess(name="test", process=None)
|
|
447
|
+
|
|
448
|
+
# Should not raise
|
|
449
|
+
manager._stop_process("test")
|
|
450
|
+
assert manager._processes["test"].status == ServiceStatus.STOPPED
|
|
451
|
+
|
|
452
|
+
def test_stop_process_closes_log_handle(self):
|
|
453
|
+
"""Test that stopping closes log handle"""
|
|
454
|
+
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
|
|
455
|
+
temp_path = f.name
|
|
456
|
+
|
|
457
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
458
|
+
manager = ServiceManager(config)
|
|
459
|
+
|
|
460
|
+
handle = open(temp_path, 'w')
|
|
461
|
+
manager._processes["test"] = ManagedProcess(
|
|
462
|
+
name="test",
|
|
463
|
+
process=None,
|
|
464
|
+
log_handle=handle
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
manager._stop_process("test")
|
|
468
|
+
|
|
469
|
+
assert handle.closed
|
|
470
|
+
Path(temp_path).unlink()
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
class TestServiceManagerRealProcess:
|
|
474
|
+
"""Tests with real subprocess (integration tests)"""
|
|
475
|
+
|
|
476
|
+
def test_start_and_stop_real_process(self):
|
|
477
|
+
"""Test starting and stopping a real process"""
|
|
478
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
479
|
+
config = ServiceConfig(
|
|
480
|
+
log_dir=tmpdir,
|
|
481
|
+
skip_atropos=True,
|
|
482
|
+
skip_vllm=True,
|
|
483
|
+
)
|
|
484
|
+
manager = ServiceManager(config)
|
|
485
|
+
|
|
486
|
+
# Start a simple long-running process
|
|
487
|
+
log_file = Path(tmpdir) / "test.log"
|
|
488
|
+
log_handle = open(log_file, 'w')
|
|
489
|
+
|
|
490
|
+
# Use a simple sleep command
|
|
491
|
+
process = subprocess.Popen(
|
|
492
|
+
[sys.executable, "-c", "import time; time.sleep(60)"],
|
|
493
|
+
stdout=log_handle,
|
|
494
|
+
stderr=subprocess.STDOUT,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
manager._processes["test"] = ManagedProcess(
|
|
498
|
+
name="test",
|
|
499
|
+
process=process,
|
|
500
|
+
status=ServiceStatus.RUNNING,
|
|
501
|
+
log_file=log_file,
|
|
502
|
+
log_handle=log_handle,
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
# Verify process is running
|
|
506
|
+
assert process.poll() is None
|
|
507
|
+
|
|
508
|
+
# Stop it
|
|
509
|
+
manager._stop_process("test")
|
|
510
|
+
|
|
511
|
+
# Verify process stopped
|
|
512
|
+
assert process.poll() is not None
|
|
513
|
+
assert manager._processes["test"].status == ServiceStatus.STOPPED
|
|
514
|
+
assert log_handle.closed
|
|
515
|
+
|
|
516
|
+
def test_stop_all_with_real_processes(self):
|
|
517
|
+
"""Test stop_all terminates all real processes"""
|
|
518
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
519
|
+
config = ServiceConfig(
|
|
520
|
+
log_dir=tmpdir,
|
|
521
|
+
skip_atropos=True,
|
|
522
|
+
skip_vllm=True,
|
|
523
|
+
)
|
|
524
|
+
manager = ServiceManager(config)
|
|
525
|
+
|
|
526
|
+
processes = []
|
|
527
|
+
|
|
528
|
+
for name in ["proc1", "proc2"]:
|
|
529
|
+
log_file = Path(tmpdir) / f"{name}.log"
|
|
530
|
+
log_handle = open(log_file, 'w')
|
|
531
|
+
|
|
532
|
+
process = subprocess.Popen(
|
|
533
|
+
[sys.executable, "-c", "import time; time.sleep(60)"],
|
|
534
|
+
stdout=log_handle,
|
|
535
|
+
stderr=subprocess.STDOUT,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
manager._processes[name] = ManagedProcess(
|
|
539
|
+
name=name,
|
|
540
|
+
process=process,
|
|
541
|
+
status=ServiceStatus.RUNNING,
|
|
542
|
+
log_handle=log_handle,
|
|
543
|
+
)
|
|
544
|
+
processes.append(process)
|
|
545
|
+
|
|
546
|
+
# Verify both running
|
|
547
|
+
for p in processes:
|
|
548
|
+
assert p.poll() is None
|
|
549
|
+
|
|
550
|
+
# Stop all
|
|
551
|
+
manager.stop_all()
|
|
552
|
+
|
|
553
|
+
# Verify all stopped
|
|
554
|
+
for p in processes:
|
|
555
|
+
assert p.poll() is not None
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
class TestConcurrentAccess:
|
|
559
|
+
"""Tests for concurrent/threaded access"""
|
|
560
|
+
|
|
561
|
+
def test_concurrent_health_checks(self):
|
|
562
|
+
"""Test health checks can be called concurrently"""
|
|
563
|
+
config = ServiceConfig(skip_atropos=True, skip_vllm=True)
|
|
564
|
+
manager = ServiceManager(config)
|
|
565
|
+
|
|
566
|
+
results = []
|
|
567
|
+
errors = []
|
|
568
|
+
|
|
569
|
+
def check_health():
|
|
570
|
+
try:
|
|
571
|
+
result = manager._check_health("nonexistent")
|
|
572
|
+
results.append(result)
|
|
573
|
+
except Exception as e:
|
|
574
|
+
errors.append(e)
|
|
575
|
+
|
|
576
|
+
threads = [threading.Thread(target=check_health) for _ in range(10)]
|
|
577
|
+
for t in threads:
|
|
578
|
+
t.start()
|
|
579
|
+
for t in threads:
|
|
580
|
+
t.join()
|
|
581
|
+
|
|
582
|
+
assert len(errors) == 0
|
|
583
|
+
assert len(results) == 10
|
|
584
|
+
assert all(r is False for r in results)
|
|
585
|
+
|