@elizaos/training 2.0.0-alpha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/Dockerfile +75 -0
  2. package/LICENSE +21 -0
  3. package/Makefile +374 -0
  4. package/README.md +346 -0
  5. package/config/rubrics.json +137 -0
  6. package/docker-compose.test.yml +57 -0
  7. package/package.json +57 -0
  8. package/python/config/babylon_atropos.yaml +90 -0
  9. package/python/config/profiles/12gb.json +11 -0
  10. package/python/config/profiles/16gb.json +10 -0
  11. package/python/config/profiles/24gb.json +10 -0
  12. package/python/config/profiles/48gb.json +10 -0
  13. package/python/config/profiles/cpu.json +11 -0
  14. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  15. package/python/config/profiles/l40-2gpu.json +22 -0
  16. package/python/config/profiles/l40-4gpu.json +21 -0
  17. package/python/config/profiles/l40.json +17 -0
  18. package/python/config/tinker_training.yaml +143 -0
  19. package/python/curriculum_state.json +165 -0
  20. package/python/env.template +86 -0
  21. package/python/env.training.template +46 -0
  22. package/python/pyproject.toml +41 -0
  23. package/python/requirements-ci.txt +31 -0
  24. package/python/requirements.txt +87 -0
  25. package/python/scripts/__init__.py +4 -0
  26. package/python/scripts/benchmark_should_respond.py +190 -0
  27. package/python/scripts/debug_inference.py +62 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/optimize_prompt_grpo.py +269 -0
  36. package/python/scripts/run_ab_test.py +143 -0
  37. package/python/scripts/run_full_pipeline.py +544 -0
  38. package/python/scripts/run_tinker_training.py +192 -0
  39. package/python/scripts/run_training.py +914 -0
  40. package/python/scripts/test_generation.py +29 -0
  41. package/python/scripts/test_judge.py +155 -0
  42. package/python/scripts/test_pipeline.py +356 -0
  43. package/python/scripts/test_trained_model.py +380 -0
  44. package/python/scripts/train_grpo.py +360 -0
  45. package/python/scripts/train_jsonl.py +223 -0
  46. package/python/scripts/train_local.py +528 -0
  47. package/python/setup.py +20 -0
  48. package/python/src/__init__.py +190 -0
  49. package/python/src/data_bridge/__init__.py +24 -0
  50. package/python/src/data_bridge/converter.py +435 -0
  51. package/python/src/data_bridge/reader.py +393 -0
  52. package/python/src/models.py +283 -0
  53. package/python/src/training/__init__.py +605 -0
  54. package/python/src/training/ab_testing.py +404 -0
  55. package/python/src/training/action_executor.py +621 -0
  56. package/python/src/training/archetype_trainer.py +347 -0
  57. package/python/src/training/atropos_trainer.py +980 -0
  58. package/python/src/training/babylon_env.py +1254 -0
  59. package/python/src/training/error_recovery.py +647 -0
  60. package/python/src/training/evaluation.py +856 -0
  61. package/python/src/training/fast_simulator.py +880 -0
  62. package/python/src/training/format_validator.py +584 -0
  63. package/python/src/training/hybrid_env.py +522 -0
  64. package/python/src/training/kl_controller.py +628 -0
  65. package/python/src/training/multi_prompt_dataset.py +883 -0
  66. package/python/src/training/multi_turn.py +656 -0
  67. package/python/src/training/online_env.py +1084 -0
  68. package/python/src/training/quality_scorer.py +391 -0
  69. package/python/src/training/quality_utils.py +633 -0
  70. package/python/src/training/rewards.py +1344 -0
  71. package/python/src/training/rlaif_env.py +17 -0
  72. package/python/src/training/rollout_generator.py +502 -0
  73. package/python/src/training/rubric_loader.py +198 -0
  74. package/python/src/training/scenario_pool.py +1072 -0
  75. package/python/src/training/schemas.py +481 -0
  76. package/python/src/training/service_manager.py +552 -0
  77. package/python/src/training/simulation_bridge.py +535 -0
  78. package/python/src/training/tick_reward_attribution.py +399 -0
  79. package/python/src/training/tinker_client.py +575 -0
  80. package/python/src/training/tinker_trainer.py +646 -0
  81. package/python/src/training/tokenization_utils.py +402 -0
  82. package/python/tests/e2e/__init__.py +13 -0
  83. package/python/tests/e2e/conftest.py +258 -0
  84. package/python/tests/e2e/test_full_pipeline.py +643 -0
  85. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  86. package/python/tests/integration/__init__.py +12 -0
  87. package/python/tests/integration/conftest.py +383 -0
  88. package/python/tests/integration/test_db_integration.py +649 -0
  89. package/python/tests/integration/test_json_mode_integration.py +554 -0
  90. package/python/tests/test_action_executor.py +594 -0
  91. package/python/tests/test_archetype_scoring.py +1027 -0
  92. package/python/tests/test_atropos_integration.py +360 -0
  93. package/python/tests/test_evaluation.py +727 -0
  94. package/python/tests/test_format_validator.py +486 -0
  95. package/python/tests/test_kl_controller.py +432 -0
  96. package/python/tests/test_lr_scheduler.py +579 -0
  97. package/python/tests/test_multi_turn.py +590 -0
  98. package/python/tests/test_online_env.py +519 -0
  99. package/python/tests/test_quality_scorer.py +474 -0
  100. package/python/tests/test_scenario_pool.py +735 -0
  101. package/python/tests/test_service_manager.py +585 -0
  102. package/python/tests/test_simulation_rollout.py +581 -0
  103. package/python/tests/test_tokenization_utils.py +501 -0
  104. package/python/tests/test_training_orchestrator.py +497 -0
  105. package/python/tests/test_training_output_structure.py +661 -0
  106. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  107. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  108. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  109. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  110. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  111. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  112. package/research-output/training-runs/training-run-1771276293257.json +38 -0
  113. package/research-output/training-runs/training-run-1771276389280.json +38 -0
  114. package/research-output/training-runs/training-run-1771276502776.json +38 -0
  115. package/research-output/training-runs/training-run-1771277340748.json +38 -0
  116. package/research-output/training-runs/training-run-1773013658993.json +38 -0
  117. package/research-output/training-runs/training-run-1773013861014.json +38 -0
  118. package/research-output/training-runs/training-run-1773014215983.json +38 -0
  119. package/scripts/assess-training-data.ts +422 -0
  120. package/scripts/e2e-training-test.ts +550 -0
  121. package/scripts/export-rubrics.ts +64 -0
  122. package/scripts/generate-research-report.ts +1523 -0
  123. package/scripts/generate_dataset.sh +173 -0
  124. package/scripts/generate_should_respond.ts +267 -0
  125. package/scripts/generate_should_respond_dataset.ts +162 -0
  126. package/scripts/json-mode-benchmark.ts +399 -0
  127. package/scripts/rank_trajectories.ts +207 -0
  128. package/scripts/real-archetype-benchmark.ts +210 -0
  129. package/scripts/run-baseline-comparison.ts +116 -0
  130. package/scripts/run-full-pipeline.ts +272 -0
  131. package/scripts/run_rlaif_loop.ts +78 -0
  132. package/scripts/run_task_benchmark.ts +247 -0
  133. package/scripts/runpod_setup.sh +137 -0
  134. package/scripts/runpod_validate.sh +147 -0
  135. package/scripts/test-model-in-game.ts +955 -0
  136. package/scripts/test-scoring.ts +73 -0
  137. package/scripts/test-trained-model.ts +209 -0
  138. package/scripts/train-and-test.ts +824 -0
  139. package/scripts/verify-final.ts +118 -0
  140. package/src/adapter.ts +516 -0
  141. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  142. package/src/archetypes/derive-archetype.ts +249 -0
  143. package/src/archetypes/index.ts +22 -0
  144. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  145. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  146. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  147. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  148. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  149. package/src/benchmark/BenchmarkRunner.ts +685 -0
  150. package/src/benchmark/BenchmarkValidator.ts +204 -0
  151. package/src/benchmark/FastEvalRunner.ts +225 -0
  152. package/src/benchmark/MetricsValidator.ts +165 -0
  153. package/src/benchmark/MetricsVisualizer.ts +909 -0
  154. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  155. package/src/benchmark/ModelRegistry.ts +158 -0
  156. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  157. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  158. package/src/benchmark/SimulationEngine.ts +832 -0
  159. package/src/benchmark/TaskRunner.ts +94 -0
  160. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  161. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  162. package/src/benchmark/index.ts +91 -0
  163. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  164. package/src/benchmark/simulation-types.ts +78 -0
  165. package/src/dependencies.ts +475 -0
  166. package/src/generation/TrajectoryGenerator.ts +387 -0
  167. package/src/generation/index.ts +12 -0
  168. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  169. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  170. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  171. package/src/huggingface/index.ts +27 -0
  172. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  173. package/src/index.ts +102 -0
  174. package/src/init-training.ts +53 -0
  175. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  176. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  177. package/src/metrics/index.ts +8 -0
  178. package/src/metrics/types.ts +200 -0
  179. package/src/rubrics/__tests__/index.test.ts +184 -0
  180. package/src/rubrics/ass-kisser.ts +85 -0
  181. package/src/rubrics/degen.ts +80 -0
  182. package/src/rubrics/goody-twoshoes.ts +84 -0
  183. package/src/rubrics/index.ts +236 -0
  184. package/src/rubrics/information-trader.ts +84 -0
  185. package/src/rubrics/infosec.ts +101 -0
  186. package/src/rubrics/liar.ts +104 -0
  187. package/src/rubrics/perps-trader.ts +87 -0
  188. package/src/rubrics/researcher.ts +81 -0
  189. package/src/rubrics/scammer.ts +82 -0
  190. package/src/rubrics/social-butterfly.ts +73 -0
  191. package/src/rubrics/super-predictor.ts +97 -0
  192. package/src/rubrics/trader.ts +67 -0
  193. package/src/scoring/ArchetypeScoringService.ts +486 -0
  194. package/src/scoring/JudgePromptBuilder.ts +556 -0
  195. package/src/scoring/LLMJudgeCache.ts +401 -0
  196. package/src/scoring/index.ts +9 -0
  197. package/src/training/AutomationPipeline.ts +916 -0
  198. package/src/training/BenchmarkService.ts +518 -0
  199. package/src/training/ConfigValidator.ts +220 -0
  200. package/src/training/MarketOutcomesTracker.ts +187 -0
  201. package/src/training/ModelDeployer.ts +186 -0
  202. package/src/training/ModelFetcher.ts +76 -0
  203. package/src/training/ModelSelectionService.ts +341 -0
  204. package/src/training/ModelUsageVerifier.ts +160 -0
  205. package/src/training/MultiModelOrchestrator.ts +580 -0
  206. package/src/training/RLModelConfig.ts +407 -0
  207. package/src/training/RewardBackpropagationService.ts +149 -0
  208. package/src/training/RulerScoringService.ts +666 -0
  209. package/src/training/TrainingMonitor.ts +166 -0
  210. package/src/training/TrajectoryRecorder.ts +399 -0
  211. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  212. package/src/training/index.ts +100 -0
  213. package/src/training/logRLConfig.ts +34 -0
  214. package/src/training/pipeline.ts +129 -0
  215. package/src/training/storage/ModelStorageService.ts +279 -0
  216. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  217. package/src/training/storage/index.ts +17 -0
  218. package/src/training/types.ts +207 -0
  219. package/src/training/window-utils.ts +138 -0
  220. package/src/utils/index.ts +101 -0
  221. package/src/utils/logger.ts +59 -0
  222. package/src/utils/snowflake.ts +17 -0
  223. package/src/utils/synthetic-detector.ts +111 -0
  224. package/tsconfig.json +20 -0
@@ -0,0 +1,579 @@
1
+ """
2
+ Tests for Learning Rate Scheduling and AtroposTrainingConfig.
3
+
4
+ Tests cover:
5
+ - LR scheduler types (constant, linear, cosine)
6
+ - Warmup behavior
7
+ - Boundary conditions (step 0, step == total_steps)
8
+ - Minimum LR ratio enforcement
9
+ - Config validation
10
+ - Checkpoint resume logic
11
+ """
12
+
13
+ import math
14
+ import sys
15
+ from pathlib import Path
16
+ from unittest.mock import MagicMock
17
+
18
+ import pytest
19
+
20
+ try:
21
+ import torch
22
+ from torch.optim import AdamW
23
+ except ImportError:
24
+ pytest.skip("torch not installed", allow_module_level=True)
25
+
26
+ # Add src to path
27
+ sys.path.insert(0, str(Path(__file__).parent.parent))
28
+
29
+ from src.training.atropos_trainer import (
30
+ AtroposTrainingConfig,
31
+ AtroposTrainer,
32
+ LRSchedulerType,
33
+ get_lr_scheduler,
34
+ )
35
+
36
+
37
+ class TestLRSchedulerType:
38
+ """Tests for LRSchedulerType enum"""
39
+
40
+ def test_all_types_defined(self):
41
+ """Verify all scheduler types exist"""
42
+ assert LRSchedulerType.CONSTANT.value == "constant"
43
+ assert LRSchedulerType.LINEAR.value == "linear"
44
+ assert LRSchedulerType.COSINE.value == "cosine"
45
+
46
+ def test_type_count(self):
47
+ """Verify expected number of types"""
48
+ assert len(LRSchedulerType) == 3
49
+
50
+ def test_from_string(self):
51
+ """Test creating scheduler type from string"""
52
+ assert LRSchedulerType("constant") == LRSchedulerType.CONSTANT
53
+ assert LRSchedulerType("linear") == LRSchedulerType.LINEAR
54
+ assert LRSchedulerType("cosine") == LRSchedulerType.COSINE
55
+
56
+ def test_invalid_type_raises(self):
57
+ """Test invalid type raises ValueError"""
58
+ with pytest.raises(ValueError):
59
+ LRSchedulerType("invalid")
60
+
61
+
62
+ class TestConstantScheduler:
63
+ """Tests for constant LR scheduler"""
64
+
65
+ @pytest.fixture
66
+ def optimizer(self):
67
+ """Create a simple optimizer for testing"""
68
+ model = torch.nn.Linear(10, 10)
69
+ return AdamW(model.parameters(), lr=1e-4)
70
+
71
+ def test_constant_no_decay(self, optimizer):
72
+ """Test constant scheduler maintains LR throughout training"""
73
+ scheduler = get_lr_scheduler(
74
+ optimizer=optimizer,
75
+ scheduler_type=LRSchedulerType.CONSTANT,
76
+ num_training_steps=100,
77
+ warmup_steps=0,
78
+ min_lr_ratio=0.1,
79
+ )
80
+
81
+ lrs = []
82
+ for _ in range(100):
83
+ lrs.append(scheduler.get_last_lr()[0])
84
+ scheduler.step()
85
+
86
+ # All LRs should be the same (within floating point tolerance)
87
+ assert all(abs(lr - 1e-4) < 1e-10 for lr in lrs)
88
+
89
+ def test_constant_with_warmup(self, optimizer):
90
+ """Test constant scheduler with warmup phase"""
91
+ scheduler = get_lr_scheduler(
92
+ optimizer=optimizer,
93
+ scheduler_type=LRSchedulerType.CONSTANT,
94
+ num_training_steps=100,
95
+ warmup_steps=10,
96
+ min_lr_ratio=0.1,
97
+ )
98
+
99
+ warmup_lrs = []
100
+ for _ in range(10):
101
+ warmup_lrs.append(scheduler.get_last_lr()[0])
102
+ scheduler.step()
103
+
104
+ post_warmup_lrs = []
105
+ for _ in range(90):
106
+ post_warmup_lrs.append(scheduler.get_last_lr()[0])
107
+ scheduler.step()
108
+
109
+ # Warmup should increase LR
110
+ assert warmup_lrs[0] < warmup_lrs[-1]
111
+
112
+ # Post warmup should be constant at full LR
113
+ assert all(abs(lr - 1e-4) < 1e-10 for lr in post_warmup_lrs)
114
+
115
+
116
+ class TestLinearScheduler:
117
+ """Tests for linear LR scheduler"""
118
+
119
+ @pytest.fixture
120
+ def optimizer(self):
121
+ model = torch.nn.Linear(10, 10)
122
+ return AdamW(model.parameters(), lr=1e-4)
123
+
124
+ def test_linear_decay(self, optimizer):
125
+ """Test linear scheduler decays LR linearly"""
126
+ min_lr_ratio = 0.1
127
+ scheduler = get_lr_scheduler(
128
+ optimizer=optimizer,
129
+ scheduler_type=LRSchedulerType.LINEAR,
130
+ num_training_steps=100,
131
+ warmup_steps=0,
132
+ min_lr_ratio=min_lr_ratio,
133
+ )
134
+
135
+ lrs = []
136
+ for _ in range(100):
137
+ lrs.append(scheduler.get_last_lr()[0])
138
+ scheduler.step()
139
+
140
+ # Should start at initial LR
141
+ assert abs(lrs[0] - 1e-4) < 1e-10
142
+
143
+ # Should end near min LR (use relative tolerance for floating point)
144
+ expected_min = 1e-4 * min_lr_ratio
145
+ assert abs(lrs[-1] - expected_min) / expected_min < 0.1 # 10% relative tolerance
146
+
147
+ # Should be monotonically decreasing
148
+ for i in range(1, len(lrs)):
149
+ assert lrs[i] <= lrs[i-1] + 1e-12 # Small tolerance for floating point
150
+
151
+ def test_linear_with_warmup(self, optimizer):
152
+ """Test linear scheduler with warmup"""
153
+ scheduler = get_lr_scheduler(
154
+ optimizer=optimizer,
155
+ scheduler_type=LRSchedulerType.LINEAR,
156
+ num_training_steps=100,
157
+ warmup_steps=20,
158
+ min_lr_ratio=0.1,
159
+ )
160
+
161
+ # Warmup phase
162
+ for step in range(20):
163
+ lr = scheduler.get_last_lr()[0]
164
+ expected = 1e-4 * (step / 20)
165
+ assert abs(lr - expected) < 1e-10, f"Step {step}: expected {expected}, got {lr}"
166
+ scheduler.step()
167
+
168
+ # After warmup, should be at full LR
169
+ assert abs(scheduler.get_last_lr()[0] - 1e-4) < 1e-10
170
+
171
+ def test_linear_min_lr_respected(self, optimizer):
172
+ """Test that LR never goes below min_lr_ratio"""
173
+ min_lr_ratio = 0.2
174
+ scheduler = get_lr_scheduler(
175
+ optimizer=optimizer,
176
+ scheduler_type=LRSchedulerType.LINEAR,
177
+ num_training_steps=100,
178
+ warmup_steps=0,
179
+ min_lr_ratio=min_lr_ratio,
180
+ )
181
+
182
+ min_expected = 1e-4 * min_lr_ratio
183
+
184
+ for _ in range(150): # Go beyond training steps
185
+ lr = scheduler.get_last_lr()[0]
186
+ assert lr >= min_expected - 1e-12
187
+ scheduler.step()
188
+
189
+
190
+ class TestCosineScheduler:
191
+ """Tests for cosine annealing LR scheduler"""
192
+
193
+ @pytest.fixture
194
+ def optimizer(self):
195
+ model = torch.nn.Linear(10, 10)
196
+ return AdamW(model.parameters(), lr=1e-4)
197
+
198
+ def test_cosine_decay(self, optimizer):
199
+ """Test cosine scheduler follows cosine curve"""
200
+ min_lr_ratio = 0.1
201
+ scheduler = get_lr_scheduler(
202
+ optimizer=optimizer,
203
+ scheduler_type=LRSchedulerType.COSINE,
204
+ num_training_steps=100,
205
+ warmup_steps=0,
206
+ min_lr_ratio=min_lr_ratio,
207
+ )
208
+
209
+ lrs = []
210
+ for _ in range(100):
211
+ lrs.append(scheduler.get_last_lr()[0])
212
+ scheduler.step()
213
+
214
+ # Should start at initial LR
215
+ assert abs(lrs[0] - 1e-4) < 1e-10
216
+
217
+ # Should end near min LR (use relative tolerance)
218
+ expected_min = 1e-4 * min_lr_ratio
219
+ assert abs(lrs[-1] - expected_min) / expected_min < 0.1 # 10% relative tolerance
220
+
221
+ # Should follow cosine curve shape
222
+ # At step 50 (halfway), should be near midpoint
223
+ step_50_expected = 0.5 * (1e-4 * min_lr_ratio + 1e-4)
224
+ assert abs(lrs[50] - step_50_expected) < 1e-6
225
+
226
+ def test_cosine_with_warmup(self, optimizer):
227
+ """Test cosine scheduler with warmup"""
228
+ scheduler = get_lr_scheduler(
229
+ optimizer=optimizer,
230
+ scheduler_type=LRSchedulerType.COSINE,
231
+ num_training_steps=100,
232
+ warmup_steps=10,
233
+ min_lr_ratio=0.1,
234
+ )
235
+
236
+ # Warmup should increase linearly
237
+ for step in range(10):
238
+ lr = scheduler.get_last_lr()[0]
239
+ expected = 1e-4 * (step / 10)
240
+ assert abs(lr - expected) < 1e-10
241
+ scheduler.step()
242
+
243
+ # After warmup, should start cosine from full LR
244
+ assert abs(scheduler.get_last_lr()[0] - 1e-4) < 1e-10
245
+
246
+ def test_cosine_min_lr_respected(self, optimizer):
247
+ """Test that LR never goes below min_lr_ratio"""
248
+ min_lr_ratio = 0.3
249
+ scheduler = get_lr_scheduler(
250
+ optimizer=optimizer,
251
+ scheduler_type=LRSchedulerType.COSINE,
252
+ num_training_steps=100,
253
+ warmup_steps=0,
254
+ min_lr_ratio=min_lr_ratio,
255
+ )
256
+
257
+ min_expected = 1e-4 * min_lr_ratio
258
+
259
+ for _ in range(150): # Go beyond training steps
260
+ lr = scheduler.get_last_lr()[0]
261
+ assert lr >= min_expected - 1e-10
262
+ scheduler.step()
263
+
264
+ def test_cosine_smooth_transition(self, optimizer):
265
+ """Test cosine has smooth transitions (no discontinuities)"""
266
+ scheduler = get_lr_scheduler(
267
+ optimizer=optimizer,
268
+ scheduler_type=LRSchedulerType.COSINE,
269
+ num_training_steps=100,
270
+ warmup_steps=0,
271
+ min_lr_ratio=0.1,
272
+ )
273
+
274
+ lrs = []
275
+ for _ in range(100):
276
+ lrs.append(scheduler.get_last_lr()[0])
277
+ scheduler.step()
278
+
279
+ # Check that changes between steps are gradual
280
+ for i in range(1, len(lrs)):
281
+ delta = abs(lrs[i] - lrs[i-1])
282
+ # Max change should be reasonable (< 5% of initial LR per step)
283
+ assert delta < 1e-4 * 0.05
284
+
285
+
286
+ class TestWarmupBehavior:
287
+ """Tests specifically for warmup behavior"""
288
+
289
+ @pytest.fixture
290
+ def optimizer(self):
291
+ model = torch.nn.Linear(10, 10)
292
+ return AdamW(model.parameters(), lr=1e-4)
293
+
294
+ def test_zero_warmup_steps(self, optimizer):
295
+ """Test scheduler works with zero warmup steps"""
296
+ scheduler = get_lr_scheduler(
297
+ optimizer=optimizer,
298
+ scheduler_type=LRSchedulerType.COSINE,
299
+ num_training_steps=100,
300
+ warmup_steps=0,
301
+ min_lr_ratio=0.1,
302
+ )
303
+
304
+ # Should start at full LR immediately
305
+ assert abs(scheduler.get_last_lr()[0] - 1e-4) < 1e-10
306
+
307
+ def test_warmup_at_step_zero(self, optimizer):
308
+ """Test warmup starts at zero LR"""
309
+ scheduler = get_lr_scheduler(
310
+ optimizer=optimizer,
311
+ scheduler_type=LRSchedulerType.COSINE,
312
+ num_training_steps=100,
313
+ warmup_steps=10,
314
+ min_lr_ratio=0.1,
315
+ )
316
+
317
+ # At step 0, LR should be 0
318
+ assert scheduler.get_last_lr()[0] == 0.0
319
+
320
+ def test_warmup_reaches_full_lr(self, optimizer):
321
+ """Test warmup reaches full LR at end of warmup"""
322
+ scheduler = get_lr_scheduler(
323
+ optimizer=optimizer,
324
+ scheduler_type=LRSchedulerType.COSINE,
325
+ num_training_steps=100,
326
+ warmup_steps=10,
327
+ min_lr_ratio=0.1,
328
+ )
329
+
330
+ # Step through warmup
331
+ for _ in range(10):
332
+ scheduler.step()
333
+
334
+ # Should be at full LR
335
+ assert abs(scheduler.get_last_lr()[0] - 1e-4) < 1e-10
336
+
337
+ def test_warmup_equal_to_total_steps(self, optimizer):
338
+ """Test edge case where warmup == total steps"""
339
+ scheduler = get_lr_scheduler(
340
+ optimizer=optimizer,
341
+ scheduler_type=LRSchedulerType.LINEAR,
342
+ num_training_steps=10,
343
+ warmup_steps=10,
344
+ min_lr_ratio=0.1,
345
+ )
346
+
347
+ # Should complete warmup and not crash
348
+ for _ in range(15):
349
+ scheduler.step()
350
+
351
+ def test_warmup_greater_than_total_steps(self, optimizer):
352
+ """Test edge case where warmup > total steps"""
353
+ scheduler = get_lr_scheduler(
354
+ optimizer=optimizer,
355
+ scheduler_type=LRSchedulerType.LINEAR,
356
+ num_training_steps=10,
357
+ warmup_steps=20,
358
+ min_lr_ratio=0.1,
359
+ )
360
+
361
+ # Should not crash
362
+ for step in range(25):
363
+ lr = scheduler.get_last_lr()[0]
364
+ scheduler.step()
365
+
366
+
367
+ class TestAtroposTrainingConfig:
368
+ """Tests for AtroposTrainingConfig"""
369
+
370
+ def test_default_values(self):
371
+ """Test all default values are set correctly"""
372
+ config = AtroposTrainingConfig()
373
+
374
+ assert config.model_name == "Qwen/Qwen2.5-3B-Instruct"
375
+ assert config.learning_rate == 1e-5
376
+ assert config.min_learning_rate == 1e-7
377
+ assert config.training_steps == 100
378
+ assert config.batch_size == 4
379
+ assert config.gradient_accumulation_steps == 8
380
+ assert config.seq_len == 4096
381
+ assert config.max_grad_norm == 1.0
382
+ assert config.lr_scheduler == LRSchedulerType.COSINE
383
+ assert config.warmup_steps == 10
384
+ assert config.vllm_port == 9001
385
+ assert config.vllm_restart_interval == 5
386
+ assert config.vllm_gpu_utilization == 0.45
387
+ assert config.save_path == "./trained_models"
388
+ assert config.save_every_steps == 5
389
+ assert config.keep_checkpoints == 3
390
+ assert config.resume_from is None
391
+ assert config.api_url == "http://localhost:8000"
392
+ assert config.log_to_file is True
393
+ assert config.log_file == "./logs/training_metrics.jsonl"
394
+ assert config.use_wandb is True
395
+ assert config.wandb_project == "eliza-training"
396
+ assert config.wandb_entity is None
397
+ assert config.wandb_run_name is None
398
+
399
+ def test_custom_values(self):
400
+ """Test custom values override defaults"""
401
+ config = AtroposTrainingConfig(
402
+ model_name="custom/model",
403
+ learning_rate=5e-5,
404
+ training_steps=50,
405
+ lr_scheduler=LRSchedulerType.LINEAR,
406
+ use_wandb=False,
407
+ )
408
+
409
+ assert config.model_name == "custom/model"
410
+ assert config.learning_rate == 5e-5
411
+ assert config.training_steps == 50
412
+ assert config.lr_scheduler == LRSchedulerType.LINEAR
413
+ assert config.use_wandb is False
414
+
415
+ def test_min_lr_ratio_calculation(self):
416
+ """Test min_lr_ratio is calculated correctly from config"""
417
+ config = AtroposTrainingConfig(
418
+ learning_rate=1e-4,
419
+ min_learning_rate=1e-6,
420
+ )
421
+
422
+ expected_ratio = config.min_learning_rate / config.learning_rate
423
+ assert abs(expected_ratio - 0.01) < 1e-10
424
+
425
+ def test_device_auto_detection(self):
426
+ """Test device is auto-detected"""
427
+ config = AtroposTrainingConfig()
428
+
429
+ if torch.cuda.is_available():
430
+ assert config.device == "cuda"
431
+ else:
432
+ assert config.device == "cpu"
433
+
434
+ def test_device_override(self):
435
+ """Test device can be overridden"""
436
+ config = AtroposTrainingConfig(device="cpu")
437
+ assert config.device == "cpu"
438
+
439
+
440
+ class TestAtroposTrainer:
441
+ """Tests for AtroposTrainer class"""
442
+
443
+ def test_initialization(self):
444
+ """Test trainer initializes correctly"""
445
+ config = AtroposTrainingConfig()
446
+ trainer = AtroposTrainer(config)
447
+
448
+ assert trainer.config == config
449
+ assert trainer.model is None
450
+ assert trainer.tokenizer is None
451
+ assert trainer.optimizer is None
452
+ assert trainer.scheduler is None
453
+ assert trainer.current_step == 0
454
+ assert trainer.vllm_process is None
455
+ assert trainer._wandb_initialized is False
456
+ assert trainer._checkpoint_history == []
457
+ assert len(trainer.run_id) > 0
458
+
459
+ def test_extract_step_from_path_valid(self):
460
+ """Test step extraction from checkpoint path"""
461
+ config = AtroposTrainingConfig()
462
+ trainer = AtroposTrainer(config)
463
+
464
+ assert trainer._extract_step_from_path("./models/step_50") == 50
465
+ assert trainer._extract_step_from_path("/path/to/step_100") == 100
466
+ assert trainer._extract_step_from_path("step_0") == 0
467
+ assert trainer._extract_step_from_path("step_999") == 999
468
+
469
+ def test_extract_step_from_path_invalid(self):
470
+ """Test step extraction with invalid paths"""
471
+ config = AtroposTrainingConfig()
472
+ trainer = AtroposTrainer(config)
473
+
474
+ # Non-step paths should return 0
475
+ assert trainer._extract_step_from_path("./models/final_model") == 0
476
+ assert trainer._extract_step_from_path("./models/checkpoint") == 0
477
+ assert trainer._extract_step_from_path("./step_abc") == 0
478
+ assert trainer._extract_step_from_path("") == 0
479
+
480
+ def test_extract_step_from_path_edge_cases(self):
481
+ """Test step extraction edge cases"""
482
+ config = AtroposTrainingConfig()
483
+ trainer = AtroposTrainer(config)
484
+
485
+ # Path with just "step_"
486
+ assert trainer._extract_step_from_path("step_") == 0
487
+
488
+ # Path with negative-looking number (should not match)
489
+ assert trainer._extract_step_from_path("step_-5") == 0
490
+
491
+ # Leading zeros
492
+ assert trainer._extract_step_from_path("step_007") == 7
493
+
494
+ def test_run_id_format(self):
495
+ """Test run_id is in expected format"""
496
+ config = AtroposTrainingConfig()
497
+ trainer = AtroposTrainer(config)
498
+
499
+ # Should be YYYYMMDD-HHMMSS format
500
+ assert len(trainer.run_id) == 15
501
+ assert trainer.run_id[8] == '-'
502
+ assert trainer.run_id[:8].isdigit()
503
+ assert trainer.run_id[9:].isdigit()
504
+
505
+
506
+ class TestBoundaryConditions:
507
+ """Tests for various boundary conditions"""
508
+
509
+ @pytest.fixture
510
+ def optimizer(self):
511
+ model = torch.nn.Linear(10, 10)
512
+ return AdamW(model.parameters(), lr=1e-4)
513
+
514
+ def test_single_training_step(self, optimizer):
515
+ """Test scheduler with only 1 training step"""
516
+ scheduler = get_lr_scheduler(
517
+ optimizer=optimizer,
518
+ scheduler_type=LRSchedulerType.COSINE,
519
+ num_training_steps=1,
520
+ warmup_steps=0,
521
+ min_lr_ratio=0.1,
522
+ )
523
+
524
+ # Should not crash
525
+ lr = scheduler.get_last_lr()[0]
526
+ scheduler.step()
527
+
528
+ assert lr >= 0
529
+
530
+ def test_min_lr_ratio_zero(self, optimizer):
531
+ """Test with min_lr_ratio of 0 (decay to zero)"""
532
+ scheduler = get_lr_scheduler(
533
+ optimizer=optimizer,
534
+ scheduler_type=LRSchedulerType.LINEAR,
535
+ num_training_steps=100,
536
+ warmup_steps=0,
537
+ min_lr_ratio=0.0,
538
+ )
539
+
540
+ for _ in range(100):
541
+ scheduler.step()
542
+
543
+ # Should be at or very near 0
544
+ assert scheduler.get_last_lr()[0] < 1e-12
545
+
546
+ def test_min_lr_ratio_one(self, optimizer):
547
+ """Test with min_lr_ratio of 1 (no decay)"""
548
+ scheduler = get_lr_scheduler(
549
+ optimizer=optimizer,
550
+ scheduler_type=LRSchedulerType.LINEAR,
551
+ num_training_steps=100,
552
+ warmup_steps=0,
553
+ min_lr_ratio=1.0,
554
+ )
555
+
556
+ lrs = []
557
+ for _ in range(100):
558
+ lrs.append(scheduler.get_last_lr()[0])
559
+ scheduler.step()
560
+
561
+ # All should be at initial LR
562
+ assert all(abs(lr - 1e-4) < 1e-10 for lr in lrs)
563
+
564
+ def test_very_large_step_count(self, optimizer):
565
+ """Test scheduler handles large step counts"""
566
+ scheduler = get_lr_scheduler(
567
+ optimizer=optimizer,
568
+ scheduler_type=LRSchedulerType.COSINE,
569
+ num_training_steps=1000000,
570
+ warmup_steps=1000,
571
+ min_lr_ratio=0.01,
572
+ )
573
+
574
+ # Just verify it doesn't crash or produce NaN
575
+ for _ in range(1000):
576
+ lr = scheduler.get_last_lr()[0]
577
+ assert not math.isnan(lr)
578
+ assert not math.isinf(lr)
579
+ scheduler.step()