@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,585 @@
1
+ """
2
+ Tests for ServiceManager - Process lifecycle, health checks, port detection.
3
+
4
+ Tests cover:
5
+ - ServiceConfig validation and defaults
6
+ - ManagedProcess state management
7
+ - Port-in-use detection
8
+ - Process start/stop lifecycle
9
+ - Health check behavior
10
+ - Resource cleanup (file handles)
11
+ - Signal handling
12
+ """
13
+
14
+ import socket
15
+ import subprocess
16
+ import sys
17
+ import tempfile
18
+ import threading
19
+ import time
20
+ from pathlib import Path
21
+ from unittest.mock import MagicMock, patch
22
+
23
+ import pytest
24
+
25
+ # Add src to path
26
+ sys.path.insert(0, str(Path(__file__).parent.parent))
27
+
28
+ from src.training.service_manager import (
29
+ ManagedProcess,
30
+ ServiceConfig,
31
+ ServiceManager,
32
+ ServiceStatus,
33
+ check_prerequisites,
34
+ )
35
+
36
+
37
+ class TestServiceConfig:
38
+ """Tests for ServiceConfig dataclass"""
39
+
40
+ def test_default_values(self):
41
+ """Verify all default values are set correctly"""
42
+ config = ServiceConfig()
43
+
44
+ assert config.atropos_port == 8000
45
+ assert config.atropos_host == "localhost"
46
+ assert config.vllm_port == 9001
47
+ assert config.vllm_host == "localhost"
48
+ assert config.model_name == "Qwen/Qwen2.5-3B-Instruct"
49
+ assert config.vllm_gpu_memory_utilization == 0.85
50
+ assert config.vllm_dtype == "auto"
51
+ assert config.vllm_max_model_len == 4096
52
+ assert config.startup_timeout == 180
53
+ assert config.health_check_interval == 2.0
54
+ assert config.shutdown_timeout == 10
55
+ assert config.log_dir == "./logs/services"
56
+ assert config.skip_atropos is False
57
+ assert config.skip_vllm is False
58
+
59
+ def test_custom_values(self):
60
+ """Verify custom values override defaults"""
61
+ config = ServiceConfig(
62
+ atropos_port=9000,
63
+ vllm_port=8080,
64
+ model_name="custom/model",
65
+ vllm_gpu_memory_utilization=0.5,
66
+ startup_timeout=60,
67
+ skip_atropos=True,
68
+ skip_vllm=True,
69
+ )
70
+
71
+ assert config.atropos_port == 9000
72
+ assert config.vllm_port == 8080
73
+ assert config.model_name == "custom/model"
74
+ assert config.vllm_gpu_memory_utilization == 0.5
75
+ assert config.startup_timeout == 60
76
+ assert config.skip_atropos is True
77
+ assert config.skip_vllm is True
78
+
79
+ def test_gpu_memory_boundary_values(self):
80
+ """Test GPU memory utilization boundary values"""
81
+ # Valid boundaries
82
+ config_min = ServiceConfig(vllm_gpu_memory_utilization=0.0)
83
+ assert config_min.vllm_gpu_memory_utilization == 0.0
84
+
85
+ config_max = ServiceConfig(vllm_gpu_memory_utilization=1.0)
86
+ assert config_max.vllm_gpu_memory_utilization == 1.0
87
+
88
+ # Edge case - slightly above 0
89
+ config_small = ServiceConfig(vllm_gpu_memory_utilization=0.01)
90
+ assert config_small.vllm_gpu_memory_utilization == 0.01
91
+
92
+
93
+ class TestManagedProcess:
94
+ """Tests for ManagedProcess dataclass"""
95
+
96
+ def test_default_state(self):
97
+ """Verify default state of ManagedProcess"""
98
+ proc = ManagedProcess(name="test")
99
+
100
+ assert proc.name == "test"
101
+ assert proc.process is None
102
+ assert proc.status == ServiceStatus.STOPPED
103
+ assert proc.log_file is None
104
+ assert proc.log_handle is None
105
+ assert proc.health_url is None
106
+ assert proc.pid is None
107
+
108
+ def test_pid_property_with_process(self):
109
+ """Test pid property returns process.pid when process exists"""
110
+ mock_process = MagicMock()
111
+ mock_process.pid = 12345
112
+
113
+ proc = ManagedProcess(name="test", process=mock_process)
114
+ assert proc.pid == 12345
115
+
116
+ def test_pid_property_without_process(self):
117
+ """Test pid property returns None when no process"""
118
+ proc = ManagedProcess(name="test")
119
+ assert proc.pid is None
120
+
121
+ def test_close_log_with_handle(self):
122
+ """Test close_log properly closes file handle"""
123
+ with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
124
+ temp_path = f.name
125
+
126
+ handle = open(temp_path, 'w')
127
+ proc = ManagedProcess(name="test", log_handle=handle)
128
+
129
+ assert not handle.closed
130
+ proc.close_log()
131
+ assert handle.closed
132
+ assert proc.log_handle is None
133
+
134
+ # Cleanup
135
+ Path(temp_path).unlink()
136
+
137
+ def test_close_log_without_handle(self):
138
+ """Test close_log is safe when no handle exists"""
139
+ proc = ManagedProcess(name="test")
140
+ proc.close_log() # Should not raise
141
+ assert proc.log_handle is None
142
+
143
+ def test_close_log_idempotent(self):
144
+ """Test close_log can be called multiple times safely"""
145
+ with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
146
+ temp_path = f.name
147
+
148
+ handle = open(temp_path, 'w')
149
+ proc = ManagedProcess(name="test", log_handle=handle)
150
+
151
+ proc.close_log()
152
+ proc.close_log() # Second call should not raise
153
+
154
+ Path(temp_path).unlink()
155
+
156
+
157
+ class TestServiceStatus:
158
+ """Tests for ServiceStatus enum"""
159
+
160
+ def test_all_statuses_defined(self):
161
+ """Verify all expected status values exist"""
162
+ assert ServiceStatus.STOPPED.value == "stopped"
163
+ assert ServiceStatus.STARTING.value == "starting"
164
+ assert ServiceStatus.RUNNING.value == "running"
165
+ assert ServiceStatus.FAILED.value == "failed"
166
+ assert ServiceStatus.STOPPING.value == "stopping"
167
+
168
+ def test_status_count(self):
169
+ """Verify expected number of statuses"""
170
+ assert len(ServiceStatus) == 5
171
+
172
+
173
+ class TestServiceManagerPortDetection:
174
+ """Tests for port-in-use detection"""
175
+
176
+ def test_port_not_in_use(self):
177
+ """Test detecting a free port"""
178
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
179
+ manager = ServiceManager(config)
180
+
181
+ # Find a free port
182
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
183
+ s.bind(('localhost', 0))
184
+ free_port = s.getsockname()[1]
185
+
186
+ assert manager._port_in_use("localhost", free_port) is False
187
+
188
+ def test_port_in_use(self):
189
+ """Test detecting a port that is in use"""
190
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
191
+ manager = ServiceManager(config)
192
+
193
+ # Bind to a port
194
+ server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
195
+ server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
196
+ server.bind(('localhost', 0))
197
+ server.listen(1)
198
+ port = server.getsockname()[1]
199
+
200
+ try:
201
+ assert manager._port_in_use("localhost", port) is True
202
+ finally:
203
+ server.close()
204
+
205
+ def test_port_in_use_timeout(self):
206
+ """Test port check doesn't hang on unresponsive ports"""
207
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
208
+ manager = ServiceManager(config)
209
+
210
+ # Use a port that's unlikely to be in use
211
+ start = time.time()
212
+ result = manager._port_in_use("localhost", 59999)
213
+ elapsed = time.time() - start
214
+
215
+ # Should complete within timeout (1 second) plus margin
216
+ assert elapsed < 2.0
217
+ assert result is False
218
+
219
+
220
+ class TestServiceManagerUrls:
221
+ """Tests for URL generation"""
222
+
223
+ def test_get_atropos_url_default(self):
224
+ """Test default Atropos URL"""
225
+ config = ServiceConfig()
226
+ manager = ServiceManager(config)
227
+
228
+ assert manager.get_atropos_url() == "http://localhost:8000"
229
+
230
+ def test_get_atropos_url_custom(self):
231
+ """Test custom Atropos URL"""
232
+ config = ServiceConfig(atropos_host="192.168.1.1", atropos_port=9000)
233
+ manager = ServiceManager(config)
234
+
235
+ assert manager.get_atropos_url() == "http://192.168.1.1:9000"
236
+
237
+ def test_get_vllm_url_default(self):
238
+ """Test default vLLM URL"""
239
+ config = ServiceConfig()
240
+ manager = ServiceManager(config)
241
+
242
+ assert manager.get_vllm_url() == "http://localhost:9001"
243
+
244
+ def test_get_vllm_url_custom(self):
245
+ """Test custom vLLM URL"""
246
+ config = ServiceConfig(vllm_host="10.0.0.1", vllm_port=8080)
247
+ manager = ServiceManager(config)
248
+
249
+ assert manager.get_vllm_url() == "http://10.0.0.1:8080"
250
+
251
+
252
+ class TestServiceManagerStatus:
253
+ """Tests for service status tracking"""
254
+
255
+ def test_get_status_unknown_service(self):
256
+ """Test getting status of non-existent service"""
257
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
258
+ manager = ServiceManager(config)
259
+
260
+ assert manager.get_status("nonexistent") == ServiceStatus.STOPPED
261
+
262
+ def test_get_status_skipped_service(self):
263
+ """Test status after skip - should be STOPPED (never started)"""
264
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
265
+ manager = ServiceManager(config)
266
+ manager.start_all()
267
+
268
+ assert manager.get_status("atropos") == ServiceStatus.STOPPED
269
+ assert manager.get_status("vllm") == ServiceStatus.STOPPED
270
+
271
+
272
+ class TestServiceManagerSkipBehavior:
273
+ """Tests for skip service behavior"""
274
+
275
+ def test_skip_atropos_only(self):
276
+ """Test skipping only Atropos"""
277
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
278
+ manager = ServiceManager(config)
279
+
280
+ result = manager.start_all()
281
+
282
+ assert result is True
283
+ assert "atropos" not in manager._processes
284
+
285
+ def test_skip_vllm_only(self):
286
+ """Test skipping only vLLM"""
287
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
288
+ manager = ServiceManager(config)
289
+
290
+ result = manager.start_all()
291
+
292
+ assert result is True
293
+ assert "vllm" not in manager._processes
294
+
295
+ def test_skip_all_services(self):
296
+ """Test skipping all services"""
297
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
298
+ manager = ServiceManager(config)
299
+
300
+ result = manager.start_all()
301
+
302
+ assert result is True
303
+ assert len(manager._processes) == 0
304
+
305
+ def test_wait_for_ready_no_services(self):
306
+ """Test wait_for_ready returns True when all services skipped"""
307
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
308
+ manager = ServiceManager(config)
309
+ manager.start_all()
310
+
311
+ result = manager.wait_for_ready(timeout=1)
312
+
313
+ assert result is True
314
+
315
+
316
+ class TestServiceManagerHealthCheck:
317
+ """Tests for health check behavior"""
318
+
319
+ def test_check_health_no_process(self):
320
+ """Test health check returns False for non-existent service"""
321
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
322
+ manager = ServiceManager(config)
323
+
324
+ result = manager._check_health("atropos")
325
+
326
+ assert result is False
327
+
328
+ def test_check_health_no_url(self):
329
+ """Test health check returns False when no health_url set"""
330
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
331
+ manager = ServiceManager(config)
332
+ manager._processes["test"] = ManagedProcess(name="test", health_url=None)
333
+
334
+ result = manager._check_health("test")
335
+
336
+ assert result is False
337
+
338
+ def test_is_healthy_delegates_to_check_health(self):
339
+ """Test is_healthy is a public interface to _check_health"""
340
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
341
+ manager = ServiceManager(config)
342
+
343
+ # Should return same result
344
+ assert manager.is_healthy("nonexistent") == manager._check_health("nonexistent")
345
+
346
+
347
+ class TestServiceManagerContextManager:
348
+ """Tests for context manager interface"""
349
+
350
+ def test_context_manager_start_and_stop(self):
351
+ """Test context manager starts and stops services"""
352
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
353
+
354
+ with ServiceManager(config) as manager:
355
+ # Inside context, should have started
356
+ assert isinstance(manager, ServiceManager)
357
+
358
+ # After context, should be cleaned up
359
+ assert len(manager._processes) == 0
360
+
361
+
362
+ class TestServiceManagerLogDirectory:
363
+ """Tests for log directory management"""
364
+
365
+ def test_creates_log_directory(self):
366
+ """Test that log directory is created on init"""
367
+ with tempfile.TemporaryDirectory() as tmpdir:
368
+ log_dir = Path(tmpdir) / "nested" / "logs"
369
+ config = ServiceConfig(log_dir=str(log_dir), skip_atropos=True, skip_vllm=True)
370
+
371
+ manager = ServiceManager(config)
372
+
373
+ assert log_dir.exists()
374
+ assert log_dir.is_dir()
375
+
376
+ def test_existing_log_directory_ok(self):
377
+ """Test that existing log directory is acceptable"""
378
+ with tempfile.TemporaryDirectory() as tmpdir:
379
+ config = ServiceConfig(log_dir=tmpdir, skip_atropos=True, skip_vllm=True)
380
+
381
+ # Should not raise
382
+ manager = ServiceManager(config)
383
+ assert Path(tmpdir).exists()
384
+
385
+
386
+ class TestCheckPrerequisites:
387
+ """Tests for the check_prerequisites function"""
388
+
389
+ def test_returns_list(self):
390
+ """Test that check_prerequisites returns a list"""
391
+ result = check_prerequisites()
392
+ assert isinstance(result, list)
393
+
394
+ def test_missing_database_url(self):
395
+ """Test error when DATABASE_URL not set"""
396
+ import os
397
+
398
+ # Save and clear DATABASE_URL
399
+ original = os.environ.get("DATABASE_URL")
400
+ if "DATABASE_URL" in os.environ:
401
+ del os.environ["DATABASE_URL"]
402
+
403
+ try:
404
+ errors = check_prerequisites()
405
+
406
+ # Should have at least the DATABASE_URL error
407
+ db_errors = [e for e in errors if "DATABASE_URL" in e]
408
+ assert len(db_errors) >= 1
409
+ finally:
410
+ if original:
411
+ os.environ["DATABASE_URL"] = original
412
+
413
+ def test_with_database_url_set(self):
414
+ """Test no DATABASE_URL error when it's set"""
415
+ import os
416
+
417
+ original = os.environ.get("DATABASE_URL")
418
+ os.environ["DATABASE_URL"] = "postgresql://test:test@localhost/test"
419
+
420
+ try:
421
+ errors = check_prerequisites()
422
+ db_errors = [e for e in errors if "DATABASE_URL" in e]
423
+ assert len(db_errors) == 0
424
+ finally:
425
+ if original:
426
+ os.environ["DATABASE_URL"] = original
427
+ else:
428
+ del os.environ["DATABASE_URL"]
429
+
430
+
431
+ class TestServiceManagerStopProcess:
432
+ """Tests for process stopping behavior"""
433
+
434
+ def test_stop_process_nonexistent(self):
435
+ """Test stopping non-existent process is safe"""
436
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
437
+ manager = ServiceManager(config)
438
+
439
+ # Should not raise
440
+ manager._stop_process("nonexistent")
441
+
442
+ def test_stop_process_no_subprocess(self):
443
+ """Test stopping process with no subprocess object"""
444
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
445
+ manager = ServiceManager(config)
446
+ manager._processes["test"] = ManagedProcess(name="test", process=None)
447
+
448
+ # Should not raise
449
+ manager._stop_process("test")
450
+ assert manager._processes["test"].status == ServiceStatus.STOPPED
451
+
452
+ def test_stop_process_closes_log_handle(self):
453
+ """Test that stopping closes log handle"""
454
+ with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
455
+ temp_path = f.name
456
+
457
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
458
+ manager = ServiceManager(config)
459
+
460
+ handle = open(temp_path, 'w')
461
+ manager._processes["test"] = ManagedProcess(
462
+ name="test",
463
+ process=None,
464
+ log_handle=handle
465
+ )
466
+
467
+ manager._stop_process("test")
468
+
469
+ assert handle.closed
470
+ Path(temp_path).unlink()
471
+
472
+
473
+ class TestServiceManagerRealProcess:
474
+ """Tests with real subprocess (integration tests)"""
475
+
476
+ def test_start_and_stop_real_process(self):
477
+ """Test starting and stopping a real process"""
478
+ with tempfile.TemporaryDirectory() as tmpdir:
479
+ config = ServiceConfig(
480
+ log_dir=tmpdir,
481
+ skip_atropos=True,
482
+ skip_vllm=True,
483
+ )
484
+ manager = ServiceManager(config)
485
+
486
+ # Start a simple long-running process
487
+ log_file = Path(tmpdir) / "test.log"
488
+ log_handle = open(log_file, 'w')
489
+
490
+ # Use a simple sleep command
491
+ process = subprocess.Popen(
492
+ [sys.executable, "-c", "import time; time.sleep(60)"],
493
+ stdout=log_handle,
494
+ stderr=subprocess.STDOUT,
495
+ )
496
+
497
+ manager._processes["test"] = ManagedProcess(
498
+ name="test",
499
+ process=process,
500
+ status=ServiceStatus.RUNNING,
501
+ log_file=log_file,
502
+ log_handle=log_handle,
503
+ )
504
+
505
+ # Verify process is running
506
+ assert process.poll() is None
507
+
508
+ # Stop it
509
+ manager._stop_process("test")
510
+
511
+ # Verify process stopped
512
+ assert process.poll() is not None
513
+ assert manager._processes["test"].status == ServiceStatus.STOPPED
514
+ assert log_handle.closed
515
+
516
+ def test_stop_all_with_real_processes(self):
517
+ """Test stop_all terminates all real processes"""
518
+ with tempfile.TemporaryDirectory() as tmpdir:
519
+ config = ServiceConfig(
520
+ log_dir=tmpdir,
521
+ skip_atropos=True,
522
+ skip_vllm=True,
523
+ )
524
+ manager = ServiceManager(config)
525
+
526
+ processes = []
527
+
528
+ for name in ["proc1", "proc2"]:
529
+ log_file = Path(tmpdir) / f"{name}.log"
530
+ log_handle = open(log_file, 'w')
531
+
532
+ process = subprocess.Popen(
533
+ [sys.executable, "-c", "import time; time.sleep(60)"],
534
+ stdout=log_handle,
535
+ stderr=subprocess.STDOUT,
536
+ )
537
+
538
+ manager._processes[name] = ManagedProcess(
539
+ name=name,
540
+ process=process,
541
+ status=ServiceStatus.RUNNING,
542
+ log_handle=log_handle,
543
+ )
544
+ processes.append(process)
545
+
546
+ # Verify both running
547
+ for p in processes:
548
+ assert p.poll() is None
549
+
550
+ # Stop all
551
+ manager.stop_all()
552
+
553
+ # Verify all stopped
554
+ for p in processes:
555
+ assert p.poll() is not None
556
+
557
+
558
+ class TestConcurrentAccess:
559
+ """Tests for concurrent/threaded access"""
560
+
561
+ def test_concurrent_health_checks(self):
562
+ """Test health checks can be called concurrently"""
563
+ config = ServiceConfig(skip_atropos=True, skip_vllm=True)
564
+ manager = ServiceManager(config)
565
+
566
+ results = []
567
+ errors = []
568
+
569
+ def check_health():
570
+ try:
571
+ result = manager._check_health("nonexistent")
572
+ results.append(result)
573
+ except Exception as e:
574
+ errors.append(e)
575
+
576
+ threads = [threading.Thread(target=check_health) for _ in range(10)]
577
+ for t in threads:
578
+ t.start()
579
+ for t in threads:
580
+ t.join()
581
+
582
+ assert len(errors) == 0
583
+ assert len(results) == 10
584
+ assert all(r is False for r in results)
585
+