@elizaos/training 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/LICENSE +21 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +57 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/benchmark_should_respond.py +190 -0
- package/python/scripts/debug_inference.py +62 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/optimize_prompt_grpo.py +269 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_generation.py +29 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_grpo.py +360 -0
- package/python/scripts/train_jsonl.py +223 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/research-output/training-runs/training-run-1771276293257.json +38 -0
- package/research-output/training-runs/training-run-1771276389280.json +38 -0
- package/research-output/training-runs/training-run-1771276502776.json +38 -0
- package/research-output/training-runs/training-run-1771277340748.json +38 -0
- package/research-output/training-runs/training-run-1773013658993.json +38 -0
- package/research-output/training-runs/training-run-1773013861014.json +38 -0
- package/research-output/training-runs/training-run-1773014215983.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/generate_should_respond.ts +267 -0
- package/scripts/generate_should_respond_dataset.ts +162 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/rank_trajectories.ts +207 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/run_rlaif_loop.ts +78 -0
- package/scripts/run_task_benchmark.ts +247 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +204 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/TaskRunner.ts +94 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +91 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +475 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,579 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for Learning Rate Scheduling and AtroposTrainingConfig.
|
|
3
|
+
|
|
4
|
+
Tests cover:
|
|
5
|
+
- LR scheduler types (constant, linear, cosine)
|
|
6
|
+
- Warmup behavior
|
|
7
|
+
- Boundary conditions (step 0, step == total_steps)
|
|
8
|
+
- Minimum LR ratio enforcement
|
|
9
|
+
- Config validation
|
|
10
|
+
- Checkpoint resume logic
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import math
|
|
14
|
+
import sys
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from unittest.mock import MagicMock
|
|
17
|
+
|
|
18
|
+
import pytest
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import torch
|
|
22
|
+
from torch.optim import AdamW
|
|
23
|
+
except ImportError:
|
|
24
|
+
pytest.skip("torch not installed", allow_module_level=True)
|
|
25
|
+
|
|
26
|
+
# Add src to path
|
|
27
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
28
|
+
|
|
29
|
+
from src.training.atropos_trainer import (
|
|
30
|
+
AtroposTrainingConfig,
|
|
31
|
+
AtroposTrainer,
|
|
32
|
+
LRSchedulerType,
|
|
33
|
+
get_lr_scheduler,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TestLRSchedulerType:
|
|
38
|
+
"""Tests for LRSchedulerType enum"""
|
|
39
|
+
|
|
40
|
+
def test_all_types_defined(self):
|
|
41
|
+
"""Verify all scheduler types exist"""
|
|
42
|
+
assert LRSchedulerType.CONSTANT.value == "constant"
|
|
43
|
+
assert LRSchedulerType.LINEAR.value == "linear"
|
|
44
|
+
assert LRSchedulerType.COSINE.value == "cosine"
|
|
45
|
+
|
|
46
|
+
def test_type_count(self):
|
|
47
|
+
"""Verify expected number of types"""
|
|
48
|
+
assert len(LRSchedulerType) == 3
|
|
49
|
+
|
|
50
|
+
def test_from_string(self):
|
|
51
|
+
"""Test creating scheduler type from string"""
|
|
52
|
+
assert LRSchedulerType("constant") == LRSchedulerType.CONSTANT
|
|
53
|
+
assert LRSchedulerType("linear") == LRSchedulerType.LINEAR
|
|
54
|
+
assert LRSchedulerType("cosine") == LRSchedulerType.COSINE
|
|
55
|
+
|
|
56
|
+
def test_invalid_type_raises(self):
|
|
57
|
+
"""Test invalid type raises ValueError"""
|
|
58
|
+
with pytest.raises(ValueError):
|
|
59
|
+
LRSchedulerType("invalid")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class TestConstantScheduler:
|
|
63
|
+
"""Tests for constant LR scheduler"""
|
|
64
|
+
|
|
65
|
+
@pytest.fixture
|
|
66
|
+
def optimizer(self):
|
|
67
|
+
"""Create a simple optimizer for testing"""
|
|
68
|
+
model = torch.nn.Linear(10, 10)
|
|
69
|
+
return AdamW(model.parameters(), lr=1e-4)
|
|
70
|
+
|
|
71
|
+
def test_constant_no_decay(self, optimizer):
|
|
72
|
+
"""Test constant scheduler maintains LR throughout training"""
|
|
73
|
+
scheduler = get_lr_scheduler(
|
|
74
|
+
optimizer=optimizer,
|
|
75
|
+
scheduler_type=LRSchedulerType.CONSTANT,
|
|
76
|
+
num_training_steps=100,
|
|
77
|
+
warmup_steps=0,
|
|
78
|
+
min_lr_ratio=0.1,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
lrs = []
|
|
82
|
+
for _ in range(100):
|
|
83
|
+
lrs.append(scheduler.get_last_lr()[0])
|
|
84
|
+
scheduler.step()
|
|
85
|
+
|
|
86
|
+
# All LRs should be the same (within floating point tolerance)
|
|
87
|
+
assert all(abs(lr - 1e-4) < 1e-10 for lr in lrs)
|
|
88
|
+
|
|
89
|
+
def test_constant_with_warmup(self, optimizer):
|
|
90
|
+
"""Test constant scheduler with warmup phase"""
|
|
91
|
+
scheduler = get_lr_scheduler(
|
|
92
|
+
optimizer=optimizer,
|
|
93
|
+
scheduler_type=LRSchedulerType.CONSTANT,
|
|
94
|
+
num_training_steps=100,
|
|
95
|
+
warmup_steps=10,
|
|
96
|
+
min_lr_ratio=0.1,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
warmup_lrs = []
|
|
100
|
+
for _ in range(10):
|
|
101
|
+
warmup_lrs.append(scheduler.get_last_lr()[0])
|
|
102
|
+
scheduler.step()
|
|
103
|
+
|
|
104
|
+
post_warmup_lrs = []
|
|
105
|
+
for _ in range(90):
|
|
106
|
+
post_warmup_lrs.append(scheduler.get_last_lr()[0])
|
|
107
|
+
scheduler.step()
|
|
108
|
+
|
|
109
|
+
# Warmup should increase LR
|
|
110
|
+
assert warmup_lrs[0] < warmup_lrs[-1]
|
|
111
|
+
|
|
112
|
+
# Post warmup should be constant at full LR
|
|
113
|
+
assert all(abs(lr - 1e-4) < 1e-10 for lr in post_warmup_lrs)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class TestLinearScheduler:
|
|
117
|
+
"""Tests for linear LR scheduler"""
|
|
118
|
+
|
|
119
|
+
@pytest.fixture
|
|
120
|
+
def optimizer(self):
|
|
121
|
+
model = torch.nn.Linear(10, 10)
|
|
122
|
+
return AdamW(model.parameters(), lr=1e-4)
|
|
123
|
+
|
|
124
|
+
def test_linear_decay(self, optimizer):
|
|
125
|
+
"""Test linear scheduler decays LR linearly"""
|
|
126
|
+
min_lr_ratio = 0.1
|
|
127
|
+
scheduler = get_lr_scheduler(
|
|
128
|
+
optimizer=optimizer,
|
|
129
|
+
scheduler_type=LRSchedulerType.LINEAR,
|
|
130
|
+
num_training_steps=100,
|
|
131
|
+
warmup_steps=0,
|
|
132
|
+
min_lr_ratio=min_lr_ratio,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
lrs = []
|
|
136
|
+
for _ in range(100):
|
|
137
|
+
lrs.append(scheduler.get_last_lr()[0])
|
|
138
|
+
scheduler.step()
|
|
139
|
+
|
|
140
|
+
# Should start at initial LR
|
|
141
|
+
assert abs(lrs[0] - 1e-4) < 1e-10
|
|
142
|
+
|
|
143
|
+
# Should end near min LR (use relative tolerance for floating point)
|
|
144
|
+
expected_min = 1e-4 * min_lr_ratio
|
|
145
|
+
assert abs(lrs[-1] - expected_min) / expected_min < 0.1 # 10% relative tolerance
|
|
146
|
+
|
|
147
|
+
# Should be monotonically decreasing
|
|
148
|
+
for i in range(1, len(lrs)):
|
|
149
|
+
assert lrs[i] <= lrs[i-1] + 1e-12 # Small tolerance for floating point
|
|
150
|
+
|
|
151
|
+
def test_linear_with_warmup(self, optimizer):
|
|
152
|
+
"""Test linear scheduler with warmup"""
|
|
153
|
+
scheduler = get_lr_scheduler(
|
|
154
|
+
optimizer=optimizer,
|
|
155
|
+
scheduler_type=LRSchedulerType.LINEAR,
|
|
156
|
+
num_training_steps=100,
|
|
157
|
+
warmup_steps=20,
|
|
158
|
+
min_lr_ratio=0.1,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Warmup phase
|
|
162
|
+
for step in range(20):
|
|
163
|
+
lr = scheduler.get_last_lr()[0]
|
|
164
|
+
expected = 1e-4 * (step / 20)
|
|
165
|
+
assert abs(lr - expected) < 1e-10, f"Step {step}: expected {expected}, got {lr}"
|
|
166
|
+
scheduler.step()
|
|
167
|
+
|
|
168
|
+
# After warmup, should be at full LR
|
|
169
|
+
assert abs(scheduler.get_last_lr()[0] - 1e-4) < 1e-10
|
|
170
|
+
|
|
171
|
+
def test_linear_min_lr_respected(self, optimizer):
|
|
172
|
+
"""Test that LR never goes below min_lr_ratio"""
|
|
173
|
+
min_lr_ratio = 0.2
|
|
174
|
+
scheduler = get_lr_scheduler(
|
|
175
|
+
optimizer=optimizer,
|
|
176
|
+
scheduler_type=LRSchedulerType.LINEAR,
|
|
177
|
+
num_training_steps=100,
|
|
178
|
+
warmup_steps=0,
|
|
179
|
+
min_lr_ratio=min_lr_ratio,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
min_expected = 1e-4 * min_lr_ratio
|
|
183
|
+
|
|
184
|
+
for _ in range(150): # Go beyond training steps
|
|
185
|
+
lr = scheduler.get_last_lr()[0]
|
|
186
|
+
assert lr >= min_expected - 1e-12
|
|
187
|
+
scheduler.step()
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class TestCosineScheduler:
|
|
191
|
+
"""Tests for cosine annealing LR scheduler"""
|
|
192
|
+
|
|
193
|
+
@pytest.fixture
|
|
194
|
+
def optimizer(self):
|
|
195
|
+
model = torch.nn.Linear(10, 10)
|
|
196
|
+
return AdamW(model.parameters(), lr=1e-4)
|
|
197
|
+
|
|
198
|
+
def test_cosine_decay(self, optimizer):
|
|
199
|
+
"""Test cosine scheduler follows cosine curve"""
|
|
200
|
+
min_lr_ratio = 0.1
|
|
201
|
+
scheduler = get_lr_scheduler(
|
|
202
|
+
optimizer=optimizer,
|
|
203
|
+
scheduler_type=LRSchedulerType.COSINE,
|
|
204
|
+
num_training_steps=100,
|
|
205
|
+
warmup_steps=0,
|
|
206
|
+
min_lr_ratio=min_lr_ratio,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
lrs = []
|
|
210
|
+
for _ in range(100):
|
|
211
|
+
lrs.append(scheduler.get_last_lr()[0])
|
|
212
|
+
scheduler.step()
|
|
213
|
+
|
|
214
|
+
# Should start at initial LR
|
|
215
|
+
assert abs(lrs[0] - 1e-4) < 1e-10
|
|
216
|
+
|
|
217
|
+
# Should end near min LR (use relative tolerance)
|
|
218
|
+
expected_min = 1e-4 * min_lr_ratio
|
|
219
|
+
assert abs(lrs[-1] - expected_min) / expected_min < 0.1 # 10% relative tolerance
|
|
220
|
+
|
|
221
|
+
# Should follow cosine curve shape
|
|
222
|
+
# At step 50 (halfway), should be near midpoint
|
|
223
|
+
step_50_expected = 0.5 * (1e-4 * min_lr_ratio + 1e-4)
|
|
224
|
+
assert abs(lrs[50] - step_50_expected) < 1e-6
|
|
225
|
+
|
|
226
|
+
def test_cosine_with_warmup(self, optimizer):
|
|
227
|
+
"""Test cosine scheduler with warmup"""
|
|
228
|
+
scheduler = get_lr_scheduler(
|
|
229
|
+
optimizer=optimizer,
|
|
230
|
+
scheduler_type=LRSchedulerType.COSINE,
|
|
231
|
+
num_training_steps=100,
|
|
232
|
+
warmup_steps=10,
|
|
233
|
+
min_lr_ratio=0.1,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Warmup should increase linearly
|
|
237
|
+
for step in range(10):
|
|
238
|
+
lr = scheduler.get_last_lr()[0]
|
|
239
|
+
expected = 1e-4 * (step / 10)
|
|
240
|
+
assert abs(lr - expected) < 1e-10
|
|
241
|
+
scheduler.step()
|
|
242
|
+
|
|
243
|
+
# After warmup, should start cosine from full LR
|
|
244
|
+
assert abs(scheduler.get_last_lr()[0] - 1e-4) < 1e-10
|
|
245
|
+
|
|
246
|
+
def test_cosine_min_lr_respected(self, optimizer):
|
|
247
|
+
"""Test that LR never goes below min_lr_ratio"""
|
|
248
|
+
min_lr_ratio = 0.3
|
|
249
|
+
scheduler = get_lr_scheduler(
|
|
250
|
+
optimizer=optimizer,
|
|
251
|
+
scheduler_type=LRSchedulerType.COSINE,
|
|
252
|
+
num_training_steps=100,
|
|
253
|
+
warmup_steps=0,
|
|
254
|
+
min_lr_ratio=min_lr_ratio,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
min_expected = 1e-4 * min_lr_ratio
|
|
258
|
+
|
|
259
|
+
for _ in range(150): # Go beyond training steps
|
|
260
|
+
lr = scheduler.get_last_lr()[0]
|
|
261
|
+
assert lr >= min_expected - 1e-10
|
|
262
|
+
scheduler.step()
|
|
263
|
+
|
|
264
|
+
def test_cosine_smooth_transition(self, optimizer):
|
|
265
|
+
"""Test cosine has smooth transitions (no discontinuities)"""
|
|
266
|
+
scheduler = get_lr_scheduler(
|
|
267
|
+
optimizer=optimizer,
|
|
268
|
+
scheduler_type=LRSchedulerType.COSINE,
|
|
269
|
+
num_training_steps=100,
|
|
270
|
+
warmup_steps=0,
|
|
271
|
+
min_lr_ratio=0.1,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
lrs = []
|
|
275
|
+
for _ in range(100):
|
|
276
|
+
lrs.append(scheduler.get_last_lr()[0])
|
|
277
|
+
scheduler.step()
|
|
278
|
+
|
|
279
|
+
# Check that changes between steps are gradual
|
|
280
|
+
for i in range(1, len(lrs)):
|
|
281
|
+
delta = abs(lrs[i] - lrs[i-1])
|
|
282
|
+
# Max change should be reasonable (< 5% of initial LR per step)
|
|
283
|
+
assert delta < 1e-4 * 0.05
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class TestWarmupBehavior:
|
|
287
|
+
"""Tests specifically for warmup behavior"""
|
|
288
|
+
|
|
289
|
+
@pytest.fixture
|
|
290
|
+
def optimizer(self):
|
|
291
|
+
model = torch.nn.Linear(10, 10)
|
|
292
|
+
return AdamW(model.parameters(), lr=1e-4)
|
|
293
|
+
|
|
294
|
+
def test_zero_warmup_steps(self, optimizer):
|
|
295
|
+
"""Test scheduler works with zero warmup steps"""
|
|
296
|
+
scheduler = get_lr_scheduler(
|
|
297
|
+
optimizer=optimizer,
|
|
298
|
+
scheduler_type=LRSchedulerType.COSINE,
|
|
299
|
+
num_training_steps=100,
|
|
300
|
+
warmup_steps=0,
|
|
301
|
+
min_lr_ratio=0.1,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Should start at full LR immediately
|
|
305
|
+
assert abs(scheduler.get_last_lr()[0] - 1e-4) < 1e-10
|
|
306
|
+
|
|
307
|
+
def test_warmup_at_step_zero(self, optimizer):
|
|
308
|
+
"""Test warmup starts at zero LR"""
|
|
309
|
+
scheduler = get_lr_scheduler(
|
|
310
|
+
optimizer=optimizer,
|
|
311
|
+
scheduler_type=LRSchedulerType.COSINE,
|
|
312
|
+
num_training_steps=100,
|
|
313
|
+
warmup_steps=10,
|
|
314
|
+
min_lr_ratio=0.1,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# At step 0, LR should be 0
|
|
318
|
+
assert scheduler.get_last_lr()[0] == 0.0
|
|
319
|
+
|
|
320
|
+
def test_warmup_reaches_full_lr(self, optimizer):
|
|
321
|
+
"""Test warmup reaches full LR at end of warmup"""
|
|
322
|
+
scheduler = get_lr_scheduler(
|
|
323
|
+
optimizer=optimizer,
|
|
324
|
+
scheduler_type=LRSchedulerType.COSINE,
|
|
325
|
+
num_training_steps=100,
|
|
326
|
+
warmup_steps=10,
|
|
327
|
+
min_lr_ratio=0.1,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Step through warmup
|
|
331
|
+
for _ in range(10):
|
|
332
|
+
scheduler.step()
|
|
333
|
+
|
|
334
|
+
# Should be at full LR
|
|
335
|
+
assert abs(scheduler.get_last_lr()[0] - 1e-4) < 1e-10
|
|
336
|
+
|
|
337
|
+
def test_warmup_equal_to_total_steps(self, optimizer):
|
|
338
|
+
"""Test edge case where warmup == total steps"""
|
|
339
|
+
scheduler = get_lr_scheduler(
|
|
340
|
+
optimizer=optimizer,
|
|
341
|
+
scheduler_type=LRSchedulerType.LINEAR,
|
|
342
|
+
num_training_steps=10,
|
|
343
|
+
warmup_steps=10,
|
|
344
|
+
min_lr_ratio=0.1,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Should complete warmup and not crash
|
|
348
|
+
for _ in range(15):
|
|
349
|
+
scheduler.step()
|
|
350
|
+
|
|
351
|
+
def test_warmup_greater_than_total_steps(self, optimizer):
|
|
352
|
+
"""Test edge case where warmup > total steps"""
|
|
353
|
+
scheduler = get_lr_scheduler(
|
|
354
|
+
optimizer=optimizer,
|
|
355
|
+
scheduler_type=LRSchedulerType.LINEAR,
|
|
356
|
+
num_training_steps=10,
|
|
357
|
+
warmup_steps=20,
|
|
358
|
+
min_lr_ratio=0.1,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# Should not crash
|
|
362
|
+
for step in range(25):
|
|
363
|
+
lr = scheduler.get_last_lr()[0]
|
|
364
|
+
scheduler.step()
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
class TestAtroposTrainingConfig:
|
|
368
|
+
"""Tests for AtroposTrainingConfig"""
|
|
369
|
+
|
|
370
|
+
def test_default_values(self):
|
|
371
|
+
"""Test all default values are set correctly"""
|
|
372
|
+
config = AtroposTrainingConfig()
|
|
373
|
+
|
|
374
|
+
assert config.model_name == "Qwen/Qwen2.5-3B-Instruct"
|
|
375
|
+
assert config.learning_rate == 1e-5
|
|
376
|
+
assert config.min_learning_rate == 1e-7
|
|
377
|
+
assert config.training_steps == 100
|
|
378
|
+
assert config.batch_size == 4
|
|
379
|
+
assert config.gradient_accumulation_steps == 8
|
|
380
|
+
assert config.seq_len == 4096
|
|
381
|
+
assert config.max_grad_norm == 1.0
|
|
382
|
+
assert config.lr_scheduler == LRSchedulerType.COSINE
|
|
383
|
+
assert config.warmup_steps == 10
|
|
384
|
+
assert config.vllm_port == 9001
|
|
385
|
+
assert config.vllm_restart_interval == 5
|
|
386
|
+
assert config.vllm_gpu_utilization == 0.45
|
|
387
|
+
assert config.save_path == "./trained_models"
|
|
388
|
+
assert config.save_every_steps == 5
|
|
389
|
+
assert config.keep_checkpoints == 3
|
|
390
|
+
assert config.resume_from is None
|
|
391
|
+
assert config.api_url == "http://localhost:8000"
|
|
392
|
+
assert config.log_to_file is True
|
|
393
|
+
assert config.log_file == "./logs/training_metrics.jsonl"
|
|
394
|
+
assert config.use_wandb is True
|
|
395
|
+
assert config.wandb_project == "eliza-training"
|
|
396
|
+
assert config.wandb_entity is None
|
|
397
|
+
assert config.wandb_run_name is None
|
|
398
|
+
|
|
399
|
+
def test_custom_values(self):
|
|
400
|
+
"""Test custom values override defaults"""
|
|
401
|
+
config = AtroposTrainingConfig(
|
|
402
|
+
model_name="custom/model",
|
|
403
|
+
learning_rate=5e-5,
|
|
404
|
+
training_steps=50,
|
|
405
|
+
lr_scheduler=LRSchedulerType.LINEAR,
|
|
406
|
+
use_wandb=False,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
assert config.model_name == "custom/model"
|
|
410
|
+
assert config.learning_rate == 5e-5
|
|
411
|
+
assert config.training_steps == 50
|
|
412
|
+
assert config.lr_scheduler == LRSchedulerType.LINEAR
|
|
413
|
+
assert config.use_wandb is False
|
|
414
|
+
|
|
415
|
+
def test_min_lr_ratio_calculation(self):
|
|
416
|
+
"""Test min_lr_ratio is calculated correctly from config"""
|
|
417
|
+
config = AtroposTrainingConfig(
|
|
418
|
+
learning_rate=1e-4,
|
|
419
|
+
min_learning_rate=1e-6,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
expected_ratio = config.min_learning_rate / config.learning_rate
|
|
423
|
+
assert abs(expected_ratio - 0.01) < 1e-10
|
|
424
|
+
|
|
425
|
+
def test_device_auto_detection(self):
|
|
426
|
+
"""Test device is auto-detected"""
|
|
427
|
+
config = AtroposTrainingConfig()
|
|
428
|
+
|
|
429
|
+
if torch.cuda.is_available():
|
|
430
|
+
assert config.device == "cuda"
|
|
431
|
+
else:
|
|
432
|
+
assert config.device == "cpu"
|
|
433
|
+
|
|
434
|
+
def test_device_override(self):
|
|
435
|
+
"""Test device can be overridden"""
|
|
436
|
+
config = AtroposTrainingConfig(device="cpu")
|
|
437
|
+
assert config.device == "cpu"
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
class TestAtroposTrainer:
|
|
441
|
+
"""Tests for AtroposTrainer class"""
|
|
442
|
+
|
|
443
|
+
def test_initialization(self):
|
|
444
|
+
"""Test trainer initializes correctly"""
|
|
445
|
+
config = AtroposTrainingConfig()
|
|
446
|
+
trainer = AtroposTrainer(config)
|
|
447
|
+
|
|
448
|
+
assert trainer.config == config
|
|
449
|
+
assert trainer.model is None
|
|
450
|
+
assert trainer.tokenizer is None
|
|
451
|
+
assert trainer.optimizer is None
|
|
452
|
+
assert trainer.scheduler is None
|
|
453
|
+
assert trainer.current_step == 0
|
|
454
|
+
assert trainer.vllm_process is None
|
|
455
|
+
assert trainer._wandb_initialized is False
|
|
456
|
+
assert trainer._checkpoint_history == []
|
|
457
|
+
assert len(trainer.run_id) > 0
|
|
458
|
+
|
|
459
|
+
def test_extract_step_from_path_valid(self):
|
|
460
|
+
"""Test step extraction from checkpoint path"""
|
|
461
|
+
config = AtroposTrainingConfig()
|
|
462
|
+
trainer = AtroposTrainer(config)
|
|
463
|
+
|
|
464
|
+
assert trainer._extract_step_from_path("./models/step_50") == 50
|
|
465
|
+
assert trainer._extract_step_from_path("/path/to/step_100") == 100
|
|
466
|
+
assert trainer._extract_step_from_path("step_0") == 0
|
|
467
|
+
assert trainer._extract_step_from_path("step_999") == 999
|
|
468
|
+
|
|
469
|
+
def test_extract_step_from_path_invalid(self):
|
|
470
|
+
"""Test step extraction with invalid paths"""
|
|
471
|
+
config = AtroposTrainingConfig()
|
|
472
|
+
trainer = AtroposTrainer(config)
|
|
473
|
+
|
|
474
|
+
# Non-step paths should return 0
|
|
475
|
+
assert trainer._extract_step_from_path("./models/final_model") == 0
|
|
476
|
+
assert trainer._extract_step_from_path("./models/checkpoint") == 0
|
|
477
|
+
assert trainer._extract_step_from_path("./step_abc") == 0
|
|
478
|
+
assert trainer._extract_step_from_path("") == 0
|
|
479
|
+
|
|
480
|
+
def test_extract_step_from_path_edge_cases(self):
|
|
481
|
+
"""Test step extraction edge cases"""
|
|
482
|
+
config = AtroposTrainingConfig()
|
|
483
|
+
trainer = AtroposTrainer(config)
|
|
484
|
+
|
|
485
|
+
# Path with just "step_"
|
|
486
|
+
assert trainer._extract_step_from_path("step_") == 0
|
|
487
|
+
|
|
488
|
+
# Path with negative-looking number (should not match)
|
|
489
|
+
assert trainer._extract_step_from_path("step_-5") == 0
|
|
490
|
+
|
|
491
|
+
# Leading zeros
|
|
492
|
+
assert trainer._extract_step_from_path("step_007") == 7
|
|
493
|
+
|
|
494
|
+
def test_run_id_format(self):
|
|
495
|
+
"""Test run_id is in expected format"""
|
|
496
|
+
config = AtroposTrainingConfig()
|
|
497
|
+
trainer = AtroposTrainer(config)
|
|
498
|
+
|
|
499
|
+
# Should be YYYYMMDD-HHMMSS format
|
|
500
|
+
assert len(trainer.run_id) == 15
|
|
501
|
+
assert trainer.run_id[8] == '-'
|
|
502
|
+
assert trainer.run_id[:8].isdigit()
|
|
503
|
+
assert trainer.run_id[9:].isdigit()
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
class TestBoundaryConditions:
|
|
507
|
+
"""Tests for various boundary conditions"""
|
|
508
|
+
|
|
509
|
+
@pytest.fixture
|
|
510
|
+
def optimizer(self):
|
|
511
|
+
model = torch.nn.Linear(10, 10)
|
|
512
|
+
return AdamW(model.parameters(), lr=1e-4)
|
|
513
|
+
|
|
514
|
+
def test_single_training_step(self, optimizer):
|
|
515
|
+
"""Test scheduler with only 1 training step"""
|
|
516
|
+
scheduler = get_lr_scheduler(
|
|
517
|
+
optimizer=optimizer,
|
|
518
|
+
scheduler_type=LRSchedulerType.COSINE,
|
|
519
|
+
num_training_steps=1,
|
|
520
|
+
warmup_steps=0,
|
|
521
|
+
min_lr_ratio=0.1,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
# Should not crash
|
|
525
|
+
lr = scheduler.get_last_lr()[0]
|
|
526
|
+
scheduler.step()
|
|
527
|
+
|
|
528
|
+
assert lr >= 0
|
|
529
|
+
|
|
530
|
+
def test_min_lr_ratio_zero(self, optimizer):
|
|
531
|
+
"""Test with min_lr_ratio of 0 (decay to zero)"""
|
|
532
|
+
scheduler = get_lr_scheduler(
|
|
533
|
+
optimizer=optimizer,
|
|
534
|
+
scheduler_type=LRSchedulerType.LINEAR,
|
|
535
|
+
num_training_steps=100,
|
|
536
|
+
warmup_steps=0,
|
|
537
|
+
min_lr_ratio=0.0,
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
for _ in range(100):
|
|
541
|
+
scheduler.step()
|
|
542
|
+
|
|
543
|
+
# Should be at or very near 0
|
|
544
|
+
assert scheduler.get_last_lr()[0] < 1e-12
|
|
545
|
+
|
|
546
|
+
def test_min_lr_ratio_one(self, optimizer):
|
|
547
|
+
"""Test with min_lr_ratio of 1 (no decay)"""
|
|
548
|
+
scheduler = get_lr_scheduler(
|
|
549
|
+
optimizer=optimizer,
|
|
550
|
+
scheduler_type=LRSchedulerType.LINEAR,
|
|
551
|
+
num_training_steps=100,
|
|
552
|
+
warmup_steps=0,
|
|
553
|
+
min_lr_ratio=1.0,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
lrs = []
|
|
557
|
+
for _ in range(100):
|
|
558
|
+
lrs.append(scheduler.get_last_lr()[0])
|
|
559
|
+
scheduler.step()
|
|
560
|
+
|
|
561
|
+
# All should be at initial LR
|
|
562
|
+
assert all(abs(lr - 1e-4) < 1e-10 for lr in lrs)
|
|
563
|
+
|
|
564
|
+
def test_very_large_step_count(self, optimizer):
|
|
565
|
+
"""Test scheduler handles large step counts"""
|
|
566
|
+
scheduler = get_lr_scheduler(
|
|
567
|
+
optimizer=optimizer,
|
|
568
|
+
scheduler_type=LRSchedulerType.COSINE,
|
|
569
|
+
num_training_steps=1000000,
|
|
570
|
+
warmup_steps=1000,
|
|
571
|
+
min_lr_ratio=0.01,
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
# Just verify it doesn't crash or produce NaN
|
|
575
|
+
for _ in range(1000):
|
|
576
|
+
lr = scheduler.get_last_lr()[0]
|
|
577
|
+
assert not math.isnan(lr)
|
|
578
|
+
assert not math.isinf(lr)
|
|
579
|
+
scheduler.step()
|