@elizaos/training 2.0.0-alpha.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. package/Dockerfile +75 -0
  2. package/Makefile +374 -0
  3. package/README.md +346 -0
  4. package/config/rubrics.json +137 -0
  5. package/data/.gitkeep +0 -0
  6. package/data/degen/.gitkeep +2 -0
  7. package/data/trader/.gitkeep +2 -0
  8. package/docker-compose.test.yml +57 -0
  9. package/package.json +58 -0
  10. package/python/config/babylon_atropos.yaml +90 -0
  11. package/python/config/profiles/12gb.json +11 -0
  12. package/python/config/profiles/16gb.json +10 -0
  13. package/python/config/profiles/24gb.json +10 -0
  14. package/python/config/profiles/48gb.json +10 -0
  15. package/python/config/profiles/cpu.json +11 -0
  16. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  17. package/python/config/profiles/l40-2gpu.json +22 -0
  18. package/python/config/profiles/l40-4gpu.json +21 -0
  19. package/python/config/profiles/l40.json +17 -0
  20. package/python/config/tinker_training.yaml +143 -0
  21. package/python/curriculum_state.json +165 -0
  22. package/python/env.template +86 -0
  23. package/python/env.training.template +46 -0
  24. package/python/pyproject.toml +41 -0
  25. package/python/requirements-ci.txt +31 -0
  26. package/python/requirements.txt +87 -0
  27. package/python/scripts/__init__.py +4 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/run_ab_test.py +143 -0
  36. package/python/scripts/run_full_pipeline.py +544 -0
  37. package/python/scripts/run_tinker_training.py +192 -0
  38. package/python/scripts/run_training.py +914 -0
  39. package/python/scripts/test_judge.py +155 -0
  40. package/python/scripts/test_pipeline.py +356 -0
  41. package/python/scripts/test_trained_model.py +380 -0
  42. package/python/scripts/train_local.py +528 -0
  43. package/python/setup.py +20 -0
  44. package/python/src/__init__.py +190 -0
  45. package/python/src/data_bridge/__init__.py +24 -0
  46. package/python/src/data_bridge/converter.py +435 -0
  47. package/python/src/data_bridge/reader.py +393 -0
  48. package/python/src/models.py +283 -0
  49. package/python/src/training/__init__.py +605 -0
  50. package/python/src/training/ab_testing.py +404 -0
  51. package/python/src/training/action_executor.py +621 -0
  52. package/python/src/training/archetype_trainer.py +347 -0
  53. package/python/src/training/atropos_trainer.py +980 -0
  54. package/python/src/training/babylon_env.py +1254 -0
  55. package/python/src/training/error_recovery.py +647 -0
  56. package/python/src/training/evaluation.py +856 -0
  57. package/python/src/training/fast_simulator.py +880 -0
  58. package/python/src/training/format_validator.py +584 -0
  59. package/python/src/training/hybrid_env.py +522 -0
  60. package/python/src/training/kl_controller.py +628 -0
  61. package/python/src/training/multi_prompt_dataset.py +883 -0
  62. package/python/src/training/multi_turn.py +656 -0
  63. package/python/src/training/online_env.py +1084 -0
  64. package/python/src/training/quality_scorer.py +391 -0
  65. package/python/src/training/quality_utils.py +633 -0
  66. package/python/src/training/rewards.py +1344 -0
  67. package/python/src/training/rlaif_env.py +17 -0
  68. package/python/src/training/rollout_generator.py +502 -0
  69. package/python/src/training/rubric_loader.py +198 -0
  70. package/python/src/training/scenario_pool.py +1072 -0
  71. package/python/src/training/schemas.py +481 -0
  72. package/python/src/training/service_manager.py +552 -0
  73. package/python/src/training/simulation_bridge.py +535 -0
  74. package/python/src/training/tick_reward_attribution.py +399 -0
  75. package/python/src/training/tinker_client.py +575 -0
  76. package/python/src/training/tinker_trainer.py +646 -0
  77. package/python/src/training/tokenization_utils.py +402 -0
  78. package/python/tests/e2e/__init__.py +13 -0
  79. package/python/tests/e2e/conftest.py +258 -0
  80. package/python/tests/e2e/test_full_pipeline.py +643 -0
  81. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  82. package/python/tests/integration/__init__.py +12 -0
  83. package/python/tests/integration/conftest.py +383 -0
  84. package/python/tests/integration/test_db_integration.py +649 -0
  85. package/python/tests/integration/test_json_mode_integration.py +554 -0
  86. package/python/tests/test_action_executor.py +594 -0
  87. package/python/tests/test_archetype_scoring.py +1027 -0
  88. package/python/tests/test_atropos_integration.py +360 -0
  89. package/python/tests/test_evaluation.py +727 -0
  90. package/python/tests/test_format_validator.py +486 -0
  91. package/python/tests/test_kl_controller.py +432 -0
  92. package/python/tests/test_lr_scheduler.py +579 -0
  93. package/python/tests/test_multi_turn.py +590 -0
  94. package/python/tests/test_online_env.py +519 -0
  95. package/python/tests/test_quality_scorer.py +474 -0
  96. package/python/tests/test_scenario_pool.py +735 -0
  97. package/python/tests/test_service_manager.py +585 -0
  98. package/python/tests/test_simulation_rollout.py +581 -0
  99. package/python/tests/test_tokenization_utils.py +501 -0
  100. package/python/tests/test_training_orchestrator.py +497 -0
  101. package/python/tests/test_training_output_structure.py +661 -0
  102. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  103. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  104. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  105. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  106. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  107. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  108. package/scripts/assess-training-data.ts +422 -0
  109. package/scripts/e2e-training-test.ts +550 -0
  110. package/scripts/export-rubrics.ts +64 -0
  111. package/scripts/generate-research-report.ts +1523 -0
  112. package/scripts/generate_dataset.sh +173 -0
  113. package/scripts/json-mode-benchmark.ts +399 -0
  114. package/scripts/real-archetype-benchmark.ts +210 -0
  115. package/scripts/run-baseline-comparison.ts +116 -0
  116. package/scripts/run-full-pipeline.ts +272 -0
  117. package/scripts/runpod_setup.sh +137 -0
  118. package/scripts/runpod_validate.sh +147 -0
  119. package/scripts/test-model-in-game.ts +955 -0
  120. package/scripts/test-scoring.ts +73 -0
  121. package/scripts/test-trained-model.ts +209 -0
  122. package/scripts/train-and-test.ts +824 -0
  123. package/scripts/verify-final.ts +118 -0
  124. package/src/adapter.ts +516 -0
  125. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  126. package/src/archetypes/derive-archetype.ts +249 -0
  127. package/src/archetypes/index.ts +22 -0
  128. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  129. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  130. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  131. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  132. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  133. package/src/benchmark/BenchmarkRunner.ts +685 -0
  134. package/src/benchmark/BenchmarkValidator.ts +206 -0
  135. package/src/benchmark/FastEvalRunner.ts +225 -0
  136. package/src/benchmark/MetricsValidator.ts +165 -0
  137. package/src/benchmark/MetricsVisualizer.ts +909 -0
  138. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  139. package/src/benchmark/ModelRegistry.ts +158 -0
  140. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  141. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  142. package/src/benchmark/SimulationEngine.ts +832 -0
  143. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  144. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  145. package/src/benchmark/index.ts +89 -0
  146. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  147. package/src/benchmark/simulation-types.ts +78 -0
  148. package/src/dependencies.ts +439 -0
  149. package/src/generation/TrajectoryGenerator.ts +387 -0
  150. package/src/generation/index.ts +12 -0
  151. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  152. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  153. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  154. package/src/huggingface/index.ts +27 -0
  155. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  156. package/src/index.ts +102 -0
  157. package/src/init-training.ts +53 -0
  158. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  159. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  160. package/src/metrics/index.ts +8 -0
  161. package/src/metrics/types.ts +200 -0
  162. package/src/rubrics/__tests__/index.test.ts +184 -0
  163. package/src/rubrics/ass-kisser.ts +85 -0
  164. package/src/rubrics/degen.ts +80 -0
  165. package/src/rubrics/goody-twoshoes.ts +84 -0
  166. package/src/rubrics/index.ts +236 -0
  167. package/src/rubrics/information-trader.ts +84 -0
  168. package/src/rubrics/infosec.ts +101 -0
  169. package/src/rubrics/liar.ts +104 -0
  170. package/src/rubrics/perps-trader.ts +87 -0
  171. package/src/rubrics/researcher.ts +81 -0
  172. package/src/rubrics/scammer.ts +82 -0
  173. package/src/rubrics/social-butterfly.ts +73 -0
  174. package/src/rubrics/super-predictor.ts +97 -0
  175. package/src/rubrics/trader.ts +67 -0
  176. package/src/scoring/ArchetypeScoringService.ts +486 -0
  177. package/src/scoring/JudgePromptBuilder.ts +556 -0
  178. package/src/scoring/LLMJudgeCache.ts +401 -0
  179. package/src/scoring/index.ts +9 -0
  180. package/src/training/AutomationPipeline.ts +916 -0
  181. package/src/training/BenchmarkService.ts +518 -0
  182. package/src/training/ConfigValidator.ts +220 -0
  183. package/src/training/MarketOutcomesTracker.ts +187 -0
  184. package/src/training/ModelDeployer.ts +186 -0
  185. package/src/training/ModelFetcher.ts +76 -0
  186. package/src/training/ModelSelectionService.ts +341 -0
  187. package/src/training/ModelUsageVerifier.ts +160 -0
  188. package/src/training/MultiModelOrchestrator.ts +580 -0
  189. package/src/training/RLModelConfig.ts +407 -0
  190. package/src/training/RewardBackpropagationService.ts +149 -0
  191. package/src/training/RulerScoringService.ts +666 -0
  192. package/src/training/TrainingMonitor.ts +166 -0
  193. package/src/training/TrajectoryRecorder.ts +399 -0
  194. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  195. package/src/training/index.ts +100 -0
  196. package/src/training/logRLConfig.ts +34 -0
  197. package/src/training/pipeline.ts +129 -0
  198. package/src/training/storage/ModelStorageService.ts +279 -0
  199. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  200. package/src/training/storage/index.ts +17 -0
  201. package/src/training/types.ts +207 -0
  202. package/src/training/window-utils.ts +138 -0
  203. package/src/utils/index.ts +101 -0
  204. package/src/utils/logger.ts +59 -0
  205. package/src/utils/snowflake.ts +17 -0
  206. package/src/utils/synthetic-detector.ts +111 -0
  207. package/tsconfig.json +20 -0
@@ -0,0 +1,497 @@
1
+ """
2
+ Tests for TrainingOrchestrator from run_training.py
3
+
4
+ Tests cover:
5
+ - Initialization and configuration
6
+ - Service lifecycle management
7
+ - Cleanup behavior
8
+ - Signal handling
9
+ - Environment validation
10
+ """
11
+
12
+ import os
13
+ import signal
14
+ import subprocess
15
+ import sys
16
+ import tempfile
17
+ import time
18
+ from pathlib import Path
19
+ from unittest.mock import MagicMock, patch
20
+
21
+ import pytest
22
+
23
+ # Add src and scripts to path
24
+ sys.path.insert(0, str(Path(__file__).parent.parent))
25
+ sys.path.insert(0, str(Path(__file__).parent.parent / "scripts"))
26
+
27
+ from run_training import TrainingOrchestrator, validate_environment
28
+
29
+
30
+ class TestValidateEnvironment:
31
+ """Tests for validate_environment function"""
32
+
33
+ def test_returns_list(self):
34
+ """Test that validate_environment returns a list"""
35
+ result = validate_environment()
36
+ assert isinstance(result, list)
37
+
38
+ def test_missing_database_url(self):
39
+ """Test error when DATABASE_URL not set"""
40
+ original = os.environ.get("DATABASE_URL")
41
+ if "DATABASE_URL" in os.environ:
42
+ del os.environ["DATABASE_URL"]
43
+
44
+ try:
45
+ errors = validate_environment()
46
+ db_errors = [e for e in errors if "DATABASE_URL" in e]
47
+ assert len(db_errors) >= 1
48
+ finally:
49
+ if original:
50
+ os.environ["DATABASE_URL"] = original
51
+
52
+ def test_missing_openai_api_key(self):
53
+ """Test error when OPENAI_API_KEY not set"""
54
+ original = os.environ.get("OPENAI_API_KEY")
55
+ if "OPENAI_API_KEY" in os.environ:
56
+ del os.environ["OPENAI_API_KEY"]
57
+
58
+ try:
59
+ errors = validate_environment()
60
+ key_errors = [e for e in errors if "OPENAI_API_KEY" in e]
61
+ assert len(key_errors) >= 1
62
+ finally:
63
+ if original:
64
+ os.environ["OPENAI_API_KEY"] = original
65
+
66
+ def test_with_all_env_vars_set(self):
67
+ """Test fewer errors when env vars are set"""
68
+ original_db = os.environ.get("DATABASE_URL")
69
+ original_key = os.environ.get("OPENAI_API_KEY")
70
+
71
+ os.environ["DATABASE_URL"] = "postgresql://test:test@localhost/test"
72
+ os.environ["OPENAI_API_KEY"] = "sk-test123"
73
+
74
+ try:
75
+ errors = validate_environment()
76
+ # Should not have DB or OpenAI key errors
77
+ db_errors = [e for e in errors if "DATABASE_URL" in e]
78
+ key_errors = [e for e in errors if "OPENAI_API_KEY" in e]
79
+ assert len(db_errors) == 0
80
+ assert len(key_errors) == 0
81
+ finally:
82
+ if original_db:
83
+ os.environ["DATABASE_URL"] = original_db
84
+ else:
85
+ del os.environ["DATABASE_URL"]
86
+ if original_key:
87
+ os.environ["OPENAI_API_KEY"] = original_key
88
+ else:
89
+ del os.environ["OPENAI_API_KEY"]
90
+
91
+
92
+ class TestTrainingOrchestratorInit:
93
+ """Tests for TrainingOrchestrator initialization"""
94
+
95
+ def test_default_values(self):
96
+ """Test default initialization values"""
97
+ with tempfile.TemporaryDirectory() as tmpdir:
98
+ orch = TrainingOrchestrator(log_dir=tmpdir)
99
+
100
+ assert orch.model_name == "Qwen/Qwen2.5-3B-Instruct"
101
+ assert orch.training_steps == 100
102
+ assert orch.batch_size == 4
103
+ assert orch.learning_rate == 1e-5
104
+ assert orch.min_learning_rate == 1e-7
105
+ assert orch.lr_scheduler == "cosine"
106
+ assert orch.warmup_steps == 10
107
+ assert orch.api_port == 8000
108
+ assert orch.vllm_port == 9001
109
+ assert orch.vllm_gpu_memory == 0.45
110
+ assert orch.save_path == "./trained_models"
111
+ assert orch.save_every == 5
112
+ assert orch.keep_checkpoints == 3
113
+ assert orch.resume_from is None
114
+ assert orch.use_wandb is True
115
+ assert orch.wandb_project == "eliza-training"
116
+ assert orch.wandb_entity is None
117
+ assert orch.wandb_run_name is None
118
+ assert orch.skip_services is False
119
+
120
+ def test_custom_values(self):
121
+ """Test custom initialization values"""
122
+ with tempfile.TemporaryDirectory() as tmpdir:
123
+ orch = TrainingOrchestrator(
124
+ model_name="custom/model",
125
+ training_steps=50,
126
+ batch_size=8,
127
+ learning_rate=5e-5,
128
+ lr_scheduler="linear",
129
+ use_wandb=False,
130
+ skip_services=True,
131
+ log_dir=tmpdir,
132
+ )
133
+
134
+ assert orch.model_name == "custom/model"
135
+ assert orch.training_steps == 50
136
+ assert orch.batch_size == 8
137
+ assert orch.learning_rate == 5e-5
138
+ assert orch.lr_scheduler == "linear"
139
+ assert orch.use_wandb is False
140
+ assert orch.skip_services is True
141
+
142
+ def test_creates_log_directory(self):
143
+ """Test that log directory is created"""
144
+ with tempfile.TemporaryDirectory() as tmpdir:
145
+ log_dir = Path(tmpdir) / "nested" / "logs"
146
+
147
+ orch = TrainingOrchestrator(log_dir=str(log_dir))
148
+
149
+ assert log_dir.exists()
150
+ assert log_dir.is_dir()
151
+
152
+ def test_initializes_process_tracking(self):
153
+ """Test process tracking is initialized"""
154
+ with tempfile.TemporaryDirectory() as tmpdir:
155
+ orch = TrainingOrchestrator(log_dir=tmpdir)
156
+
157
+ assert orch.env_process is None
158
+ assert orch.trainer_process is None
159
+ assert orch._service_manager is None
160
+ assert orch._shutdown_requested is False
161
+ assert orch._log_handles == []
162
+
163
+
164
+ class TestTrainingOrchestratorCleanup:
165
+ """Tests for TrainingOrchestrator cleanup behavior"""
166
+
167
+ def test_cleanup_with_no_processes(self):
168
+ """Test cleanup is safe when no processes running"""
169
+ with tempfile.TemporaryDirectory() as tmpdir:
170
+ orch = TrainingOrchestrator(log_dir=tmpdir)
171
+
172
+ # Should not raise
173
+ orch.cleanup()
174
+
175
+ def test_cleanup_stops_real_process(self):
176
+ """Test cleanup terminates real processes"""
177
+ with tempfile.TemporaryDirectory() as tmpdir:
178
+ orch = TrainingOrchestrator(log_dir=tmpdir)
179
+
180
+ # Start a real process
181
+ orch.trainer_process = subprocess.Popen(
182
+ [sys.executable, "-c", "import time; time.sleep(60)"]
183
+ )
184
+
185
+ assert orch.trainer_process.poll() is None
186
+
187
+ orch.cleanup()
188
+
189
+ assert orch.trainer_process.poll() is not None
190
+
191
+ def test_cleanup_closes_log_handles(self):
192
+ """Test cleanup closes log handles"""
193
+ with tempfile.TemporaryDirectory() as tmpdir:
194
+ orch = TrainingOrchestrator(log_dir=tmpdir)
195
+
196
+ # Create a log handle
197
+ log_file = Path(tmpdir) / "test.log"
198
+ handle = open(log_file, 'w')
199
+ orch._log_handles.append(handle)
200
+
201
+ assert not handle.closed
202
+
203
+ orch.cleanup()
204
+
205
+ assert handle.closed
206
+ assert len(orch._log_handles) == 0
207
+
208
+ def test_cleanup_multiple_processes(self):
209
+ """Test cleanup handles multiple processes"""
210
+ with tempfile.TemporaryDirectory() as tmpdir:
211
+ orch = TrainingOrchestrator(log_dir=tmpdir)
212
+
213
+ orch.trainer_process = subprocess.Popen(
214
+ [sys.executable, "-c", "import time; time.sleep(60)"]
215
+ )
216
+ orch.env_process = subprocess.Popen(
217
+ [sys.executable, "-c", "import time; time.sleep(60)"]
218
+ )
219
+
220
+ assert orch.trainer_process.poll() is None
221
+ assert orch.env_process.poll() is None
222
+
223
+ orch.cleanup()
224
+
225
+ assert orch.trainer_process.poll() is not None
226
+ assert orch.env_process.poll() is not None
227
+
228
+
229
+ class TestTrainingOrchestratorStopProcess:
230
+ """Tests for _stop_process helper method"""
231
+
232
+ def test_stop_process_none(self):
233
+ """Test _stop_process handles None"""
234
+ with tempfile.TemporaryDirectory() as tmpdir:
235
+ orch = TrainingOrchestrator(log_dir=tmpdir)
236
+
237
+ # Should not raise
238
+ orch._stop_process(None, "test")
239
+
240
+ def test_stop_process_terminates_gracefully(self):
241
+ """Test _stop_process terminates process gracefully"""
242
+ with tempfile.TemporaryDirectory() as tmpdir:
243
+ orch = TrainingOrchestrator(log_dir=tmpdir)
244
+
245
+ proc = subprocess.Popen(
246
+ [sys.executable, "-c", "import time; time.sleep(60)"]
247
+ )
248
+
249
+ start = time.time()
250
+ orch._stop_process(proc, "test", timeout=5)
251
+ elapsed = time.time() - start
252
+
253
+ assert proc.poll() is not None
254
+ # Should be quick (< 1 second for graceful termination)
255
+ assert elapsed < 2
256
+
257
+ def test_stop_process_kills_stubborn_process(self):
258
+ """Test _stop_process kills process that ignores SIGTERM"""
259
+ with tempfile.TemporaryDirectory() as tmpdir:
260
+ orch = TrainingOrchestrator(log_dir=tmpdir)
261
+
262
+ # Create a process that ignores SIGTERM
263
+ proc = subprocess.Popen([
264
+ sys.executable, "-c",
265
+ "import signal, time; signal.signal(signal.SIGTERM, signal.SIG_IGN); time.sleep(60)"
266
+ ])
267
+
268
+ # Allow process to start and set up signal handler
269
+ time.sleep(0.2)
270
+
271
+ orch._stop_process(proc, "stubborn", timeout=2)
272
+
273
+ # Key assertion: process should be terminated
274
+ assert proc.poll() is not None
275
+
276
+
277
+ class TestTrainingOrchestratorSkipServices:
278
+ """Tests for skip_services behavior"""
279
+
280
+ def test_start_services_skipped(self):
281
+ """Test services are skipped when skip_services=True"""
282
+ with tempfile.TemporaryDirectory() as tmpdir:
283
+ orch = TrainingOrchestrator(
284
+ skip_services=True,
285
+ log_dir=tmpdir,
286
+ )
287
+
288
+ result = orch.start_services()
289
+
290
+ assert result is True
291
+ assert orch._service_manager is None
292
+
293
+
294
+ class TestTrainingOrchestratorLogConfig:
295
+ """Tests for _log_config method"""
296
+
297
+ def test_log_config_runs(self):
298
+ """Test _log_config executes without error"""
299
+ with tempfile.TemporaryDirectory() as tmpdir:
300
+ orch = TrainingOrchestrator(
301
+ model_name="test/model",
302
+ training_steps=50,
303
+ resume_from="./checkpoint",
304
+ log_dir=tmpdir,
305
+ )
306
+
307
+ # Should not raise
308
+ orch._log_config()
309
+
310
+ def test_log_config_without_resume(self):
311
+ """Test _log_config works when resume_from is None"""
312
+ with tempfile.TemporaryDirectory() as tmpdir:
313
+ orch = TrainingOrchestrator(
314
+ resume_from=None,
315
+ log_dir=tmpdir,
316
+ )
317
+
318
+ # Should not raise
319
+ orch._log_config()
320
+
321
+
322
+ class TestIntegrationStartEnvironment:
323
+ """Integration tests for starting environment (mocked subprocess)"""
324
+
325
+ @patch('subprocess.Popen')
326
+ def test_start_environment_builds_correct_command(self, mock_popen):
327
+ """Test start_environment builds correct command"""
328
+ mock_process = MagicMock()
329
+ mock_process.poll.return_value = None
330
+ mock_process.pid = 12345
331
+ mock_popen.return_value = mock_process
332
+
333
+ with tempfile.TemporaryDirectory() as tmpdir:
334
+ orch = TrainingOrchestrator(
335
+ model_name="test/model",
336
+ api_port=9000,
337
+ vllm_port=8080,
338
+ use_wandb=False,
339
+ log_dir=tmpdir,
340
+ )
341
+
342
+ result = orch.start_environment()
343
+
344
+ assert result is True
345
+
346
+ # Verify command was called
347
+ call_args = mock_popen.call_args
348
+ cmd = call_args[0][0]
349
+
350
+ assert "-m" in cmd
351
+ assert "src.training.rlaif_env" in cmd
352
+ assert "serve" in cmd
353
+ assert "--env.tokenizer_name" in cmd
354
+ assert "test/model" in cmd
355
+ assert "--env.use_wandb" in cmd
356
+ assert "false" in cmd
357
+
358
+ @patch('subprocess.Popen')
359
+ def test_start_environment_tracks_log_handle(self, mock_popen):
360
+ """Test start_environment tracks log handle"""
361
+ mock_process = MagicMock()
362
+ mock_process.poll.return_value = None
363
+ mock_process.pid = 12345
364
+ mock_popen.return_value = mock_process
365
+
366
+ with tempfile.TemporaryDirectory() as tmpdir:
367
+ orch = TrainingOrchestrator(log_dir=tmpdir)
368
+
369
+ orch.start_environment()
370
+
371
+ assert len(orch._log_handles) == 1
372
+
373
+ @patch('subprocess.Popen')
374
+ def test_start_environment_failure(self, mock_popen):
375
+ """Test start_environment handles immediate failure"""
376
+ mock_process = MagicMock()
377
+ mock_process.poll.return_value = 1 # Process exited
378
+ mock_process.returncode = 1
379
+ mock_popen.return_value = mock_process
380
+
381
+ with tempfile.TemporaryDirectory() as tmpdir:
382
+ orch = TrainingOrchestrator(log_dir=tmpdir)
383
+
384
+ result = orch.start_environment()
385
+
386
+ assert result is False
387
+
388
+
389
+ class TestIntegrationStartTrainer:
390
+ """Integration tests for starting trainer (mocked subprocess)"""
391
+
392
+ @patch('subprocess.Popen')
393
+ def test_start_trainer_builds_correct_command(self, mock_popen):
394
+ """Test start_trainer builds correct command"""
395
+ mock_process = MagicMock()
396
+ mock_process.pid = 12345
397
+ mock_popen.return_value = mock_process
398
+
399
+ with tempfile.TemporaryDirectory() as tmpdir:
400
+ orch = TrainingOrchestrator(
401
+ model_name="test/model",
402
+ training_steps=50,
403
+ batch_size=8,
404
+ learning_rate=5e-5,
405
+ min_learning_rate=1e-7,
406
+ lr_scheduler="linear",
407
+ warmup_steps=5,
408
+ save_path="./models",
409
+ save_every=10,
410
+ keep_checkpoints=2,
411
+ use_wandb=False,
412
+ log_dir=tmpdir,
413
+ )
414
+
415
+ result = orch.start_trainer()
416
+
417
+ assert result is True
418
+
419
+ call_args = mock_popen.call_args
420
+ cmd = call_args[0][0]
421
+
422
+ assert "--model" in cmd
423
+ assert "test/model" in cmd
424
+ assert "--steps" in cmd
425
+ assert "50" in cmd
426
+ assert "--batch-size" in cmd
427
+ assert "8" in cmd
428
+ assert "--lr-scheduler" in cmd
429
+ assert "linear" in cmd
430
+ assert "--no-wandb" in cmd
431
+
432
+ @patch('subprocess.Popen')
433
+ def test_start_trainer_with_resume(self, mock_popen):
434
+ """Test start_trainer includes resume flag"""
435
+ mock_process = MagicMock()
436
+ mock_process.pid = 12345
437
+ mock_popen.return_value = mock_process
438
+
439
+ with tempfile.TemporaryDirectory() as tmpdir:
440
+ orch = TrainingOrchestrator(
441
+ resume_from="./checkpoint/step_50",
442
+ log_dir=tmpdir,
443
+ )
444
+
445
+ orch.start_trainer()
446
+
447
+ call_args = mock_popen.call_args
448
+ cmd = call_args[0][0]
449
+
450
+ assert "--resume" in cmd
451
+ assert "./checkpoint/step_50" in cmd
452
+
453
+ @patch('subprocess.Popen')
454
+ def test_start_trainer_with_wandb_options(self, mock_popen):
455
+ """Test start_trainer includes W&B options"""
456
+ mock_process = MagicMock()
457
+ mock_process.pid = 12345
458
+ mock_popen.return_value = mock_process
459
+
460
+ with tempfile.TemporaryDirectory() as tmpdir:
461
+ orch = TrainingOrchestrator(
462
+ use_wandb=True,
463
+ wandb_project="my-project",
464
+ wandb_entity="my-team",
465
+ wandb_run_name="my-run",
466
+ log_dir=tmpdir,
467
+ )
468
+
469
+ orch.start_trainer()
470
+
471
+ call_args = mock_popen.call_args
472
+ cmd = call_args[0][0]
473
+
474
+ assert "--wandb-project" in cmd
475
+ assert "my-project" in cmd
476
+ assert "--wandb-entity" in cmd
477
+ assert "my-team" in cmd
478
+ assert "--wandb-run-name" in cmd
479
+ assert "my-run" in cmd
480
+ assert "--no-wandb" not in cmd
481
+
482
+
483
+ class TestSignalHandling:
484
+ """Tests for signal handling behavior"""
485
+
486
+ def test_shutdown_requested_flag(self):
487
+ """Test shutdown_requested flag is set on signal"""
488
+ with tempfile.TemporaryDirectory() as tmpdir:
489
+ orch = TrainingOrchestrator(log_dir=tmpdir)
490
+
491
+ assert orch._shutdown_requested is False
492
+
493
+ # Simulate signal handler behavior (not the actual signal)
494
+ orch._shutdown_requested = True
495
+
496
+ assert orch._shutdown_requested is True
497
+