@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,497 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for TrainingOrchestrator from run_training.py
|
|
3
|
+
|
|
4
|
+
Tests cover:
|
|
5
|
+
- Initialization and configuration
|
|
6
|
+
- Service lifecycle management
|
|
7
|
+
- Cleanup behavior
|
|
8
|
+
- Signal handling
|
|
9
|
+
- Environment validation
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
import signal
|
|
14
|
+
import subprocess
|
|
15
|
+
import sys
|
|
16
|
+
import tempfile
|
|
17
|
+
import time
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from unittest.mock import MagicMock, patch
|
|
20
|
+
|
|
21
|
+
import pytest
|
|
22
|
+
|
|
23
|
+
# Add src and scripts to path
|
|
24
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
25
|
+
sys.path.insert(0, str(Path(__file__).parent.parent / "scripts"))
|
|
26
|
+
|
|
27
|
+
from run_training import TrainingOrchestrator, validate_environment
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TestValidateEnvironment:
|
|
31
|
+
"""Tests for validate_environment function"""
|
|
32
|
+
|
|
33
|
+
def test_returns_list(self):
|
|
34
|
+
"""Test that validate_environment returns a list"""
|
|
35
|
+
result = validate_environment()
|
|
36
|
+
assert isinstance(result, list)
|
|
37
|
+
|
|
38
|
+
def test_missing_database_url(self):
|
|
39
|
+
"""Test error when DATABASE_URL not set"""
|
|
40
|
+
original = os.environ.get("DATABASE_URL")
|
|
41
|
+
if "DATABASE_URL" in os.environ:
|
|
42
|
+
del os.environ["DATABASE_URL"]
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
errors = validate_environment()
|
|
46
|
+
db_errors = [e for e in errors if "DATABASE_URL" in e]
|
|
47
|
+
assert len(db_errors) >= 1
|
|
48
|
+
finally:
|
|
49
|
+
if original:
|
|
50
|
+
os.environ["DATABASE_URL"] = original
|
|
51
|
+
|
|
52
|
+
def test_missing_openai_api_key(self):
|
|
53
|
+
"""Test error when OPENAI_API_KEY not set"""
|
|
54
|
+
original = os.environ.get("OPENAI_API_KEY")
|
|
55
|
+
if "OPENAI_API_KEY" in os.environ:
|
|
56
|
+
del os.environ["OPENAI_API_KEY"]
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
errors = validate_environment()
|
|
60
|
+
key_errors = [e for e in errors if "OPENAI_API_KEY" in e]
|
|
61
|
+
assert len(key_errors) >= 1
|
|
62
|
+
finally:
|
|
63
|
+
if original:
|
|
64
|
+
os.environ["OPENAI_API_KEY"] = original
|
|
65
|
+
|
|
66
|
+
def test_with_all_env_vars_set(self):
|
|
67
|
+
"""Test fewer errors when env vars are set"""
|
|
68
|
+
original_db = os.environ.get("DATABASE_URL")
|
|
69
|
+
original_key = os.environ.get("OPENAI_API_KEY")
|
|
70
|
+
|
|
71
|
+
os.environ["DATABASE_URL"] = "postgresql://test:test@localhost/test"
|
|
72
|
+
os.environ["OPENAI_API_KEY"] = "sk-test123"
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
errors = validate_environment()
|
|
76
|
+
# Should not have DB or OpenAI key errors
|
|
77
|
+
db_errors = [e for e in errors if "DATABASE_URL" in e]
|
|
78
|
+
key_errors = [e for e in errors if "OPENAI_API_KEY" in e]
|
|
79
|
+
assert len(db_errors) == 0
|
|
80
|
+
assert len(key_errors) == 0
|
|
81
|
+
finally:
|
|
82
|
+
if original_db:
|
|
83
|
+
os.environ["DATABASE_URL"] = original_db
|
|
84
|
+
else:
|
|
85
|
+
del os.environ["DATABASE_URL"]
|
|
86
|
+
if original_key:
|
|
87
|
+
os.environ["OPENAI_API_KEY"] = original_key
|
|
88
|
+
else:
|
|
89
|
+
del os.environ["OPENAI_API_KEY"]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class TestTrainingOrchestratorInit:
|
|
93
|
+
"""Tests for TrainingOrchestrator initialization"""
|
|
94
|
+
|
|
95
|
+
def test_default_values(self):
|
|
96
|
+
"""Test default initialization values"""
|
|
97
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
98
|
+
orch = TrainingOrchestrator(log_dir=tmpdir)
|
|
99
|
+
|
|
100
|
+
assert orch.model_name == "Qwen/Qwen2.5-3B-Instruct"
|
|
101
|
+
assert orch.training_steps == 100
|
|
102
|
+
assert orch.batch_size == 4
|
|
103
|
+
assert orch.learning_rate == 1e-5
|
|
104
|
+
assert orch.min_learning_rate == 1e-7
|
|
105
|
+
assert orch.lr_scheduler == "cosine"
|
|
106
|
+
assert orch.warmup_steps == 10
|
|
107
|
+
assert orch.api_port == 8000
|
|
108
|
+
assert orch.vllm_port == 9001
|
|
109
|
+
assert orch.vllm_gpu_memory == 0.45
|
|
110
|
+
assert orch.save_path == "./trained_models"
|
|
111
|
+
assert orch.save_every == 5
|
|
112
|
+
assert orch.keep_checkpoints == 3
|
|
113
|
+
assert orch.resume_from is None
|
|
114
|
+
assert orch.use_wandb is True
|
|
115
|
+
assert orch.wandb_project == "eliza-training"
|
|
116
|
+
assert orch.wandb_entity is None
|
|
117
|
+
assert orch.wandb_run_name is None
|
|
118
|
+
assert orch.skip_services is False
|
|
119
|
+
|
|
120
|
+
def test_custom_values(self):
|
|
121
|
+
"""Test custom initialization values"""
|
|
122
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
123
|
+
orch = TrainingOrchestrator(
|
|
124
|
+
model_name="custom/model",
|
|
125
|
+
training_steps=50,
|
|
126
|
+
batch_size=8,
|
|
127
|
+
learning_rate=5e-5,
|
|
128
|
+
lr_scheduler="linear",
|
|
129
|
+
use_wandb=False,
|
|
130
|
+
skip_services=True,
|
|
131
|
+
log_dir=tmpdir,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
assert orch.model_name == "custom/model"
|
|
135
|
+
assert orch.training_steps == 50
|
|
136
|
+
assert orch.batch_size == 8
|
|
137
|
+
assert orch.learning_rate == 5e-5
|
|
138
|
+
assert orch.lr_scheduler == "linear"
|
|
139
|
+
assert orch.use_wandb is False
|
|
140
|
+
assert orch.skip_services is True
|
|
141
|
+
|
|
142
|
+
def test_creates_log_directory(self):
|
|
143
|
+
"""Test that log directory is created"""
|
|
144
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
145
|
+
log_dir = Path(tmpdir) / "nested" / "logs"
|
|
146
|
+
|
|
147
|
+
orch = TrainingOrchestrator(log_dir=str(log_dir))
|
|
148
|
+
|
|
149
|
+
assert log_dir.exists()
|
|
150
|
+
assert log_dir.is_dir()
|
|
151
|
+
|
|
152
|
+
def test_initializes_process_tracking(self):
|
|
153
|
+
"""Test process tracking is initialized"""
|
|
154
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
155
|
+
orch = TrainingOrchestrator(log_dir=tmpdir)
|
|
156
|
+
|
|
157
|
+
assert orch.env_process is None
|
|
158
|
+
assert orch.trainer_process is None
|
|
159
|
+
assert orch._service_manager is None
|
|
160
|
+
assert orch._shutdown_requested is False
|
|
161
|
+
assert orch._log_handles == []
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class TestTrainingOrchestratorCleanup:
|
|
165
|
+
"""Tests for TrainingOrchestrator cleanup behavior"""
|
|
166
|
+
|
|
167
|
+
def test_cleanup_with_no_processes(self):
|
|
168
|
+
"""Test cleanup is safe when no processes running"""
|
|
169
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
170
|
+
orch = TrainingOrchestrator(log_dir=tmpdir)
|
|
171
|
+
|
|
172
|
+
# Should not raise
|
|
173
|
+
orch.cleanup()
|
|
174
|
+
|
|
175
|
+
def test_cleanup_stops_real_process(self):
|
|
176
|
+
"""Test cleanup terminates real processes"""
|
|
177
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
178
|
+
orch = TrainingOrchestrator(log_dir=tmpdir)
|
|
179
|
+
|
|
180
|
+
# Start a real process
|
|
181
|
+
orch.trainer_process = subprocess.Popen(
|
|
182
|
+
[sys.executable, "-c", "import time; time.sleep(60)"]
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
assert orch.trainer_process.poll() is None
|
|
186
|
+
|
|
187
|
+
orch.cleanup()
|
|
188
|
+
|
|
189
|
+
assert orch.trainer_process.poll() is not None
|
|
190
|
+
|
|
191
|
+
def test_cleanup_closes_log_handles(self):
|
|
192
|
+
"""Test cleanup closes log handles"""
|
|
193
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
194
|
+
orch = TrainingOrchestrator(log_dir=tmpdir)
|
|
195
|
+
|
|
196
|
+
# Create a log handle
|
|
197
|
+
log_file = Path(tmpdir) / "test.log"
|
|
198
|
+
handle = open(log_file, 'w')
|
|
199
|
+
orch._log_handles.append(handle)
|
|
200
|
+
|
|
201
|
+
assert not handle.closed
|
|
202
|
+
|
|
203
|
+
orch.cleanup()
|
|
204
|
+
|
|
205
|
+
assert handle.closed
|
|
206
|
+
assert len(orch._log_handles) == 0
|
|
207
|
+
|
|
208
|
+
def test_cleanup_multiple_processes(self):
|
|
209
|
+
"""Test cleanup handles multiple processes"""
|
|
210
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
211
|
+
orch = TrainingOrchestrator(log_dir=tmpdir)
|
|
212
|
+
|
|
213
|
+
orch.trainer_process = subprocess.Popen(
|
|
214
|
+
[sys.executable, "-c", "import time; time.sleep(60)"]
|
|
215
|
+
)
|
|
216
|
+
orch.env_process = subprocess.Popen(
|
|
217
|
+
[sys.executable, "-c", "import time; time.sleep(60)"]
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
assert orch.trainer_process.poll() is None
|
|
221
|
+
assert orch.env_process.poll() is None
|
|
222
|
+
|
|
223
|
+
orch.cleanup()
|
|
224
|
+
|
|
225
|
+
assert orch.trainer_process.poll() is not None
|
|
226
|
+
assert orch.env_process.poll() is not None
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class TestTrainingOrchestratorStopProcess:
|
|
230
|
+
"""Tests for _stop_process helper method"""
|
|
231
|
+
|
|
232
|
+
def test_stop_process_none(self):
|
|
233
|
+
"""Test _stop_process handles None"""
|
|
234
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
235
|
+
orch = TrainingOrchestrator(log_dir=tmpdir)
|
|
236
|
+
|
|
237
|
+
# Should not raise
|
|
238
|
+
orch._stop_process(None, "test")
|
|
239
|
+
|
|
240
|
+
def test_stop_process_terminates_gracefully(self):
|
|
241
|
+
"""Test _stop_process terminates process gracefully"""
|
|
242
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
243
|
+
orch = TrainingOrchestrator(log_dir=tmpdir)
|
|
244
|
+
|
|
245
|
+
proc = subprocess.Popen(
|
|
246
|
+
[sys.executable, "-c", "import time; time.sleep(60)"]
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
start = time.time()
|
|
250
|
+
orch._stop_process(proc, "test", timeout=5)
|
|
251
|
+
elapsed = time.time() - start
|
|
252
|
+
|
|
253
|
+
assert proc.poll() is not None
|
|
254
|
+
# Should be quick (< 1 second for graceful termination)
|
|
255
|
+
assert elapsed < 2
|
|
256
|
+
|
|
257
|
+
def test_stop_process_kills_stubborn_process(self):
|
|
258
|
+
"""Test _stop_process kills process that ignores SIGTERM"""
|
|
259
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
260
|
+
orch = TrainingOrchestrator(log_dir=tmpdir)
|
|
261
|
+
|
|
262
|
+
# Create a process that ignores SIGTERM
|
|
263
|
+
proc = subprocess.Popen([
|
|
264
|
+
sys.executable, "-c",
|
|
265
|
+
"import signal, time; signal.signal(signal.SIGTERM, signal.SIG_IGN); time.sleep(60)"
|
|
266
|
+
])
|
|
267
|
+
|
|
268
|
+
# Allow process to start and set up signal handler
|
|
269
|
+
time.sleep(0.2)
|
|
270
|
+
|
|
271
|
+
orch._stop_process(proc, "stubborn", timeout=2)
|
|
272
|
+
|
|
273
|
+
# Key assertion: process should be terminated
|
|
274
|
+
assert proc.poll() is not None
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class TestTrainingOrchestratorSkipServices:
|
|
278
|
+
"""Tests for skip_services behavior"""
|
|
279
|
+
|
|
280
|
+
def test_start_services_skipped(self):
|
|
281
|
+
"""Test services are skipped when skip_services=True"""
|
|
282
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
283
|
+
orch = TrainingOrchestrator(
|
|
284
|
+
skip_services=True,
|
|
285
|
+
log_dir=tmpdir,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
result = orch.start_services()
|
|
289
|
+
|
|
290
|
+
assert result is True
|
|
291
|
+
assert orch._service_manager is None
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class TestTrainingOrchestratorLogConfig:
|
|
295
|
+
"""Tests for _log_config method"""
|
|
296
|
+
|
|
297
|
+
def test_log_config_runs(self):
|
|
298
|
+
"""Test _log_config executes without error"""
|
|
299
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
300
|
+
orch = TrainingOrchestrator(
|
|
301
|
+
model_name="test/model",
|
|
302
|
+
training_steps=50,
|
|
303
|
+
resume_from="./checkpoint",
|
|
304
|
+
log_dir=tmpdir,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Should not raise
|
|
308
|
+
orch._log_config()
|
|
309
|
+
|
|
310
|
+
def test_log_config_without_resume(self):
|
|
311
|
+
"""Test _log_config works when resume_from is None"""
|
|
312
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
313
|
+
orch = TrainingOrchestrator(
|
|
314
|
+
resume_from=None,
|
|
315
|
+
log_dir=tmpdir,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Should not raise
|
|
319
|
+
orch._log_config()
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class TestIntegrationStartEnvironment:
|
|
323
|
+
"""Integration tests for starting environment (mocked subprocess)"""
|
|
324
|
+
|
|
325
|
+
@patch('subprocess.Popen')
|
|
326
|
+
def test_start_environment_builds_correct_command(self, mock_popen):
|
|
327
|
+
"""Test start_environment builds correct command"""
|
|
328
|
+
mock_process = MagicMock()
|
|
329
|
+
mock_process.poll.return_value = None
|
|
330
|
+
mock_process.pid = 12345
|
|
331
|
+
mock_popen.return_value = mock_process
|
|
332
|
+
|
|
333
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
334
|
+
orch = TrainingOrchestrator(
|
|
335
|
+
model_name="test/model",
|
|
336
|
+
api_port=9000,
|
|
337
|
+
vllm_port=8080,
|
|
338
|
+
use_wandb=False,
|
|
339
|
+
log_dir=tmpdir,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
result = orch.start_environment()
|
|
343
|
+
|
|
344
|
+
assert result is True
|
|
345
|
+
|
|
346
|
+
# Verify command was called
|
|
347
|
+
call_args = mock_popen.call_args
|
|
348
|
+
cmd = call_args[0][0]
|
|
349
|
+
|
|
350
|
+
assert "-m" in cmd
|
|
351
|
+
assert "src.training.rlaif_env" in cmd
|
|
352
|
+
assert "serve" in cmd
|
|
353
|
+
assert "--env.tokenizer_name" in cmd
|
|
354
|
+
assert "test/model" in cmd
|
|
355
|
+
assert "--env.use_wandb" in cmd
|
|
356
|
+
assert "false" in cmd
|
|
357
|
+
|
|
358
|
+
@patch('subprocess.Popen')
|
|
359
|
+
def test_start_environment_tracks_log_handle(self, mock_popen):
|
|
360
|
+
"""Test start_environment tracks log handle"""
|
|
361
|
+
mock_process = MagicMock()
|
|
362
|
+
mock_process.poll.return_value = None
|
|
363
|
+
mock_process.pid = 12345
|
|
364
|
+
mock_popen.return_value = mock_process
|
|
365
|
+
|
|
366
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
367
|
+
orch = TrainingOrchestrator(log_dir=tmpdir)
|
|
368
|
+
|
|
369
|
+
orch.start_environment()
|
|
370
|
+
|
|
371
|
+
assert len(orch._log_handles) == 1
|
|
372
|
+
|
|
373
|
+
@patch('subprocess.Popen')
|
|
374
|
+
def test_start_environment_failure(self, mock_popen):
|
|
375
|
+
"""Test start_environment handles immediate failure"""
|
|
376
|
+
mock_process = MagicMock()
|
|
377
|
+
mock_process.poll.return_value = 1 # Process exited
|
|
378
|
+
mock_process.returncode = 1
|
|
379
|
+
mock_popen.return_value = mock_process
|
|
380
|
+
|
|
381
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
382
|
+
orch = TrainingOrchestrator(log_dir=tmpdir)
|
|
383
|
+
|
|
384
|
+
result = orch.start_environment()
|
|
385
|
+
|
|
386
|
+
assert result is False
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
class TestIntegrationStartTrainer:
|
|
390
|
+
"""Integration tests for starting trainer (mocked subprocess)"""
|
|
391
|
+
|
|
392
|
+
@patch('subprocess.Popen')
|
|
393
|
+
def test_start_trainer_builds_correct_command(self, mock_popen):
|
|
394
|
+
"""Test start_trainer builds correct command"""
|
|
395
|
+
mock_process = MagicMock()
|
|
396
|
+
mock_process.pid = 12345
|
|
397
|
+
mock_popen.return_value = mock_process
|
|
398
|
+
|
|
399
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
400
|
+
orch = TrainingOrchestrator(
|
|
401
|
+
model_name="test/model",
|
|
402
|
+
training_steps=50,
|
|
403
|
+
batch_size=8,
|
|
404
|
+
learning_rate=5e-5,
|
|
405
|
+
min_learning_rate=1e-7,
|
|
406
|
+
lr_scheduler="linear",
|
|
407
|
+
warmup_steps=5,
|
|
408
|
+
save_path="./models",
|
|
409
|
+
save_every=10,
|
|
410
|
+
keep_checkpoints=2,
|
|
411
|
+
use_wandb=False,
|
|
412
|
+
log_dir=tmpdir,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
result = orch.start_trainer()
|
|
416
|
+
|
|
417
|
+
assert result is True
|
|
418
|
+
|
|
419
|
+
call_args = mock_popen.call_args
|
|
420
|
+
cmd = call_args[0][0]
|
|
421
|
+
|
|
422
|
+
assert "--model" in cmd
|
|
423
|
+
assert "test/model" in cmd
|
|
424
|
+
assert "--steps" in cmd
|
|
425
|
+
assert "50" in cmd
|
|
426
|
+
assert "--batch-size" in cmd
|
|
427
|
+
assert "8" in cmd
|
|
428
|
+
assert "--lr-scheduler" in cmd
|
|
429
|
+
assert "linear" in cmd
|
|
430
|
+
assert "--no-wandb" in cmd
|
|
431
|
+
|
|
432
|
+
@patch('subprocess.Popen')
|
|
433
|
+
def test_start_trainer_with_resume(self, mock_popen):
|
|
434
|
+
"""Test start_trainer includes resume flag"""
|
|
435
|
+
mock_process = MagicMock()
|
|
436
|
+
mock_process.pid = 12345
|
|
437
|
+
mock_popen.return_value = mock_process
|
|
438
|
+
|
|
439
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
440
|
+
orch = TrainingOrchestrator(
|
|
441
|
+
resume_from="./checkpoint/step_50",
|
|
442
|
+
log_dir=tmpdir,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
orch.start_trainer()
|
|
446
|
+
|
|
447
|
+
call_args = mock_popen.call_args
|
|
448
|
+
cmd = call_args[0][0]
|
|
449
|
+
|
|
450
|
+
assert "--resume" in cmd
|
|
451
|
+
assert "./checkpoint/step_50" in cmd
|
|
452
|
+
|
|
453
|
+
@patch('subprocess.Popen')
|
|
454
|
+
def test_start_trainer_with_wandb_options(self, mock_popen):
|
|
455
|
+
"""Test start_trainer includes W&B options"""
|
|
456
|
+
mock_process = MagicMock()
|
|
457
|
+
mock_process.pid = 12345
|
|
458
|
+
mock_popen.return_value = mock_process
|
|
459
|
+
|
|
460
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
461
|
+
orch = TrainingOrchestrator(
|
|
462
|
+
use_wandb=True,
|
|
463
|
+
wandb_project="my-project",
|
|
464
|
+
wandb_entity="my-team",
|
|
465
|
+
wandb_run_name="my-run",
|
|
466
|
+
log_dir=tmpdir,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
orch.start_trainer()
|
|
470
|
+
|
|
471
|
+
call_args = mock_popen.call_args
|
|
472
|
+
cmd = call_args[0][0]
|
|
473
|
+
|
|
474
|
+
assert "--wandb-project" in cmd
|
|
475
|
+
assert "my-project" in cmd
|
|
476
|
+
assert "--wandb-entity" in cmd
|
|
477
|
+
assert "my-team" in cmd
|
|
478
|
+
assert "--wandb-run-name" in cmd
|
|
479
|
+
assert "my-run" in cmd
|
|
480
|
+
assert "--no-wandb" not in cmd
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
class TestSignalHandling:
|
|
484
|
+
"""Tests for signal handling behavior"""
|
|
485
|
+
|
|
486
|
+
def test_shutdown_requested_flag(self):
|
|
487
|
+
"""Test shutdown_requested flag is set on signal"""
|
|
488
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
489
|
+
orch = TrainingOrchestrator(log_dir=tmpdir)
|
|
490
|
+
|
|
491
|
+
assert orch._shutdown_requested is False
|
|
492
|
+
|
|
493
|
+
# Simulate signal handler behavior (not the actual signal)
|
|
494
|
+
orch._shutdown_requested = True
|
|
495
|
+
|
|
496
|
+
assert orch._shutdown_requested is True
|
|
497
|
+
|