@elizaos/training 2.0.0-alpha.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. package/Dockerfile +75 -0
  2. package/Makefile +374 -0
  3. package/README.md +346 -0
  4. package/config/rubrics.json +137 -0
  5. package/data/.gitkeep +0 -0
  6. package/data/degen/.gitkeep +2 -0
  7. package/data/trader/.gitkeep +2 -0
  8. package/docker-compose.test.yml +57 -0
  9. package/package.json +58 -0
  10. package/python/config/babylon_atropos.yaml +90 -0
  11. package/python/config/profiles/12gb.json +11 -0
  12. package/python/config/profiles/16gb.json +10 -0
  13. package/python/config/profiles/24gb.json +10 -0
  14. package/python/config/profiles/48gb.json +10 -0
  15. package/python/config/profiles/cpu.json +11 -0
  16. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  17. package/python/config/profiles/l40-2gpu.json +22 -0
  18. package/python/config/profiles/l40-4gpu.json +21 -0
  19. package/python/config/profiles/l40.json +17 -0
  20. package/python/config/tinker_training.yaml +143 -0
  21. package/python/curriculum_state.json +165 -0
  22. package/python/env.template +86 -0
  23. package/python/env.training.template +46 -0
  24. package/python/pyproject.toml +41 -0
  25. package/python/requirements-ci.txt +31 -0
  26. package/python/requirements.txt +87 -0
  27. package/python/scripts/__init__.py +4 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/run_ab_test.py +143 -0
  36. package/python/scripts/run_full_pipeline.py +544 -0
  37. package/python/scripts/run_tinker_training.py +192 -0
  38. package/python/scripts/run_training.py +914 -0
  39. package/python/scripts/test_judge.py +155 -0
  40. package/python/scripts/test_pipeline.py +356 -0
  41. package/python/scripts/test_trained_model.py +380 -0
  42. package/python/scripts/train_local.py +528 -0
  43. package/python/setup.py +20 -0
  44. package/python/src/__init__.py +190 -0
  45. package/python/src/data_bridge/__init__.py +24 -0
  46. package/python/src/data_bridge/converter.py +435 -0
  47. package/python/src/data_bridge/reader.py +393 -0
  48. package/python/src/models.py +283 -0
  49. package/python/src/training/__init__.py +605 -0
  50. package/python/src/training/ab_testing.py +404 -0
  51. package/python/src/training/action_executor.py +621 -0
  52. package/python/src/training/archetype_trainer.py +347 -0
  53. package/python/src/training/atropos_trainer.py +980 -0
  54. package/python/src/training/babylon_env.py +1254 -0
  55. package/python/src/training/error_recovery.py +647 -0
  56. package/python/src/training/evaluation.py +856 -0
  57. package/python/src/training/fast_simulator.py +880 -0
  58. package/python/src/training/format_validator.py +584 -0
  59. package/python/src/training/hybrid_env.py +522 -0
  60. package/python/src/training/kl_controller.py +628 -0
  61. package/python/src/training/multi_prompt_dataset.py +883 -0
  62. package/python/src/training/multi_turn.py +656 -0
  63. package/python/src/training/online_env.py +1084 -0
  64. package/python/src/training/quality_scorer.py +391 -0
  65. package/python/src/training/quality_utils.py +633 -0
  66. package/python/src/training/rewards.py +1344 -0
  67. package/python/src/training/rlaif_env.py +17 -0
  68. package/python/src/training/rollout_generator.py +502 -0
  69. package/python/src/training/rubric_loader.py +198 -0
  70. package/python/src/training/scenario_pool.py +1072 -0
  71. package/python/src/training/schemas.py +481 -0
  72. package/python/src/training/service_manager.py +552 -0
  73. package/python/src/training/simulation_bridge.py +535 -0
  74. package/python/src/training/tick_reward_attribution.py +399 -0
  75. package/python/src/training/tinker_client.py +575 -0
  76. package/python/src/training/tinker_trainer.py +646 -0
  77. package/python/src/training/tokenization_utils.py +402 -0
  78. package/python/tests/e2e/__init__.py +13 -0
  79. package/python/tests/e2e/conftest.py +258 -0
  80. package/python/tests/e2e/test_full_pipeline.py +643 -0
  81. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  82. package/python/tests/integration/__init__.py +12 -0
  83. package/python/tests/integration/conftest.py +383 -0
  84. package/python/tests/integration/test_db_integration.py +649 -0
  85. package/python/tests/integration/test_json_mode_integration.py +554 -0
  86. package/python/tests/test_action_executor.py +594 -0
  87. package/python/tests/test_archetype_scoring.py +1027 -0
  88. package/python/tests/test_atropos_integration.py +360 -0
  89. package/python/tests/test_evaluation.py +727 -0
  90. package/python/tests/test_format_validator.py +486 -0
  91. package/python/tests/test_kl_controller.py +432 -0
  92. package/python/tests/test_lr_scheduler.py +579 -0
  93. package/python/tests/test_multi_turn.py +590 -0
  94. package/python/tests/test_online_env.py +519 -0
  95. package/python/tests/test_quality_scorer.py +474 -0
  96. package/python/tests/test_scenario_pool.py +735 -0
  97. package/python/tests/test_service_manager.py +585 -0
  98. package/python/tests/test_simulation_rollout.py +581 -0
  99. package/python/tests/test_tokenization_utils.py +501 -0
  100. package/python/tests/test_training_orchestrator.py +497 -0
  101. package/python/tests/test_training_output_structure.py +661 -0
  102. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  103. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  104. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  105. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  106. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  107. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  108. package/scripts/assess-training-data.ts +422 -0
  109. package/scripts/e2e-training-test.ts +550 -0
  110. package/scripts/export-rubrics.ts +64 -0
  111. package/scripts/generate-research-report.ts +1523 -0
  112. package/scripts/generate_dataset.sh +173 -0
  113. package/scripts/json-mode-benchmark.ts +399 -0
  114. package/scripts/real-archetype-benchmark.ts +210 -0
  115. package/scripts/run-baseline-comparison.ts +116 -0
  116. package/scripts/run-full-pipeline.ts +272 -0
  117. package/scripts/runpod_setup.sh +137 -0
  118. package/scripts/runpod_validate.sh +147 -0
  119. package/scripts/test-model-in-game.ts +955 -0
  120. package/scripts/test-scoring.ts +73 -0
  121. package/scripts/test-trained-model.ts +209 -0
  122. package/scripts/train-and-test.ts +824 -0
  123. package/scripts/verify-final.ts +118 -0
  124. package/src/adapter.ts +516 -0
  125. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  126. package/src/archetypes/derive-archetype.ts +249 -0
  127. package/src/archetypes/index.ts +22 -0
  128. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  129. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  130. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  131. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  132. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  133. package/src/benchmark/BenchmarkRunner.ts +685 -0
  134. package/src/benchmark/BenchmarkValidator.ts +206 -0
  135. package/src/benchmark/FastEvalRunner.ts +225 -0
  136. package/src/benchmark/MetricsValidator.ts +165 -0
  137. package/src/benchmark/MetricsVisualizer.ts +909 -0
  138. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  139. package/src/benchmark/ModelRegistry.ts +158 -0
  140. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  141. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  142. package/src/benchmark/SimulationEngine.ts +832 -0
  143. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  144. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  145. package/src/benchmark/index.ts +89 -0
  146. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  147. package/src/benchmark/simulation-types.ts +78 -0
  148. package/src/dependencies.ts +439 -0
  149. package/src/generation/TrajectoryGenerator.ts +387 -0
  150. package/src/generation/index.ts +12 -0
  151. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  152. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  153. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  154. package/src/huggingface/index.ts +27 -0
  155. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  156. package/src/index.ts +102 -0
  157. package/src/init-training.ts +53 -0
  158. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  159. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  160. package/src/metrics/index.ts +8 -0
  161. package/src/metrics/types.ts +200 -0
  162. package/src/rubrics/__tests__/index.test.ts +184 -0
  163. package/src/rubrics/ass-kisser.ts +85 -0
  164. package/src/rubrics/degen.ts +80 -0
  165. package/src/rubrics/goody-twoshoes.ts +84 -0
  166. package/src/rubrics/index.ts +236 -0
  167. package/src/rubrics/information-trader.ts +84 -0
  168. package/src/rubrics/infosec.ts +101 -0
  169. package/src/rubrics/liar.ts +104 -0
  170. package/src/rubrics/perps-trader.ts +87 -0
  171. package/src/rubrics/researcher.ts +81 -0
  172. package/src/rubrics/scammer.ts +82 -0
  173. package/src/rubrics/social-butterfly.ts +73 -0
  174. package/src/rubrics/super-predictor.ts +97 -0
  175. package/src/rubrics/trader.ts +67 -0
  176. package/src/scoring/ArchetypeScoringService.ts +486 -0
  177. package/src/scoring/JudgePromptBuilder.ts +556 -0
  178. package/src/scoring/LLMJudgeCache.ts +401 -0
  179. package/src/scoring/index.ts +9 -0
  180. package/src/training/AutomationPipeline.ts +916 -0
  181. package/src/training/BenchmarkService.ts +518 -0
  182. package/src/training/ConfigValidator.ts +220 -0
  183. package/src/training/MarketOutcomesTracker.ts +187 -0
  184. package/src/training/ModelDeployer.ts +186 -0
  185. package/src/training/ModelFetcher.ts +76 -0
  186. package/src/training/ModelSelectionService.ts +341 -0
  187. package/src/training/ModelUsageVerifier.ts +160 -0
  188. package/src/training/MultiModelOrchestrator.ts +580 -0
  189. package/src/training/RLModelConfig.ts +407 -0
  190. package/src/training/RewardBackpropagationService.ts +149 -0
  191. package/src/training/RulerScoringService.ts +666 -0
  192. package/src/training/TrainingMonitor.ts +166 -0
  193. package/src/training/TrajectoryRecorder.ts +399 -0
  194. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  195. package/src/training/index.ts +100 -0
  196. package/src/training/logRLConfig.ts +34 -0
  197. package/src/training/pipeline.ts +129 -0
  198. package/src/training/storage/ModelStorageService.ts +279 -0
  199. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  200. package/src/training/storage/index.ts +17 -0
  201. package/src/training/types.ts +207 -0
  202. package/src/training/window-utils.ts +138 -0
  203. package/src/utils/index.ts +101 -0
  204. package/src/utils/logger.ts +59 -0
  205. package/src/utils/snowflake.ts +17 -0
  206. package/src/utils/synthetic-detector.ts +111 -0
  207. package/tsconfig.json +20 -0
@@ -0,0 +1,501 @@
1
+ """
2
+ Tests for Tokenization Utilities
3
+
4
+ Tests cover:
5
+ - Proper prompt/completion masking with -100/token_id format
6
+ - Multi-turn conversation masking
7
+ - Mask validation
8
+ - Historical mask fixing
9
+
10
+ MASK FORMAT:
11
+ - mask = -100: Prompt token, ignored in loss calculation
12
+ - mask = token_id: Completion token, trained on
13
+ """
14
+
15
+ import pytest
16
+ from unittest.mock import MagicMock, patch
17
+
18
+ from src.training.tokenization_utils import (
19
+ TokenizationResult,
20
+ tokenize_for_trainer,
21
+ tokenize_conversation_for_trainer,
22
+ validate_masks,
23
+ create_masks_from_response_start,
24
+ fix_historical_masks,
25
+ )
26
+
27
+
28
+ # =============================================================================
29
+ # Mock Tokenizer
30
+ # =============================================================================
31
+
32
+
33
+ class MockTokenizer:
34
+ """Mock tokenizer for testing"""
35
+
36
+ def __init__(self):
37
+ self.vocab = {
38
+ "<|system|>": 1,
39
+ "<|user|>": 2,
40
+ "<|assistant|>": 3,
41
+ "<|end|>": 4,
42
+ "hello": 10,
43
+ "world": 11,
44
+ "how": 12,
45
+ "are": 13,
46
+ "you": 14,
47
+ "i": 15,
48
+ "am": 16,
49
+ "fine": 17,
50
+ "thanks": 18,
51
+ "for": 19,
52
+ "asking": 20,
53
+ }
54
+ self.reverse_vocab = {v: k for k, v in self.vocab.items()}
55
+
56
+ def encode(self, text: str, add_special_tokens: bool = True) -> list:
57
+ """Simple word-level encoding"""
58
+ words = text.lower().replace("<|", " <|").replace("|>", "|> ").split()
59
+ tokens = []
60
+ for word in words:
61
+ word = word.strip()
62
+ if word in self.vocab:
63
+ tokens.append(self.vocab[word])
64
+ else:
65
+ tokens.append(100 + len(word)) # Unknown token
66
+ return tokens
67
+
68
+ def decode(self, tokens: list) -> str:
69
+ """Simple decoding"""
70
+ words = []
71
+ for t in tokens:
72
+ if t in self.reverse_vocab:
73
+ words.append(self.reverse_vocab[t])
74
+ else:
75
+ words.append(f"[{t}]")
76
+ return " ".join(words)
77
+
78
+ def apply_chat_template(
79
+ self,
80
+ messages: list,
81
+ return_tensors=None,
82
+ add_generation_prompt: bool = False,
83
+ ) -> list:
84
+ """Mock chat template application"""
85
+ tokens = []
86
+
87
+ for msg in messages:
88
+ role = msg.get("role", "user")
89
+ content = msg.get("content", "")
90
+
91
+ # Add role token
92
+ if role == "system":
93
+ tokens.append(1)
94
+ elif role == "user":
95
+ tokens.append(2)
96
+ elif role == "assistant":
97
+ tokens.append(3)
98
+
99
+ # Add content tokens
100
+ tokens.extend(self.encode(content, add_special_tokens=False))
101
+
102
+ # Add end token
103
+ tokens.append(4)
104
+
105
+ if add_generation_prompt:
106
+ tokens.append(3) # Assistant start token
107
+
108
+ return tokens
109
+
110
+
111
+ # =============================================================================
112
+ # TokenizationResult Tests
113
+ # =============================================================================
114
+
115
+
116
+ class TestTokenizationResult:
117
+ """Tests for TokenizationResult dataclass"""
118
+
119
+ def test_creation(self):
120
+ # New format: -100 for prompt, actual token IDs for completion
121
+ tokens = [1, 2, 3, 4, 5]
122
+ masks = [-100, -100, 3, 4, 5] # First 2 prompt, last 3 completion
123
+ result = TokenizationResult(
124
+ tokens=tokens,
125
+ masks=masks,
126
+ prompt_length=2,
127
+ completion_length=3,
128
+ total_length=5,
129
+ )
130
+
131
+ assert len(result.tokens) == 5
132
+ assert result.prompt_length == 2
133
+ assert result.completion_length == 3
134
+
135
+
136
+ # =============================================================================
137
+ # tokenize_for_trainer Tests
138
+ # =============================================================================
139
+
140
+
141
+ class TestTokenizeForTrainer:
142
+ """Tests for tokenize_for_trainer"""
143
+
144
+ def test_empty_messages(self):
145
+ tokenizer = MockTokenizer()
146
+
147
+ result = tokenize_for_trainer(tokenizer, [])
148
+
149
+ assert result.tokens == []
150
+ assert result.masks == []
151
+ assert result.total_length == 0
152
+
153
+ def test_prompt_only(self):
154
+ tokenizer = MockTokenizer()
155
+ messages = [
156
+ {"role": "user", "content": "hello world"},
157
+ ]
158
+
159
+ result = tokenize_for_trainer(tokenizer, messages, add_generation_prompt=True)
160
+
161
+ # All should be masked (no assistant response) - all -100
162
+ assert all(m == -100 for m in result.masks)
163
+ assert result.completion_length == 0
164
+
165
+ def test_with_assistant_response(self):
166
+ tokenizer = MockTokenizer()
167
+ messages = [
168
+ {"role": "user", "content": "hello"},
169
+ {"role": "assistant", "content": "world"},
170
+ ]
171
+
172
+ result = tokenize_for_trainer(tokenizer, messages)
173
+
174
+ # Should have both prompt and completion masks
175
+ # Prompt: -100, Completion: actual token IDs
176
+ assert any(m == -100 for m in result.masks) # Prompt masked with -100
177
+ assert any(m != -100 for m in result.masks) # Completion has token IDs
178
+ assert len(result.masks) == len(result.tokens)
179
+
180
+ # Verify completion tokens match actual tokens
181
+ for i, (token, mask) in enumerate(zip(result.tokens, result.masks)):
182
+ if mask != -100:
183
+ assert mask == token, f"Mask at pos {i} should equal token for completion"
184
+
185
+ def test_with_system_prompt(self):
186
+ tokenizer = MockTokenizer()
187
+ messages = [
188
+ {"role": "system", "content": "you are helpful"},
189
+ {"role": "user", "content": "hello"},
190
+ {"role": "assistant", "content": "hi"},
191
+ ]
192
+
193
+ result = tokenize_for_trainer(tokenizer, messages)
194
+
195
+ # System and user should be masked (-100), assistant unmasked (token IDs)
196
+ assert result.prompt_length > 0
197
+ assert result.completion_length > 0
198
+ assert result.prompt_length + result.completion_length == result.total_length
199
+
200
+ def test_multiple_turns(self):
201
+ tokenizer = MockTokenizer()
202
+ messages = [
203
+ {"role": "user", "content": "hello"},
204
+ {"role": "assistant", "content": "hi"},
205
+ {"role": "user", "content": "how are you"},
206
+ {"role": "assistant", "content": "fine thanks"},
207
+ ]
208
+
209
+ result = tokenize_for_trainer(tokenizer, messages)
210
+
211
+ # Only last assistant should be unmasked
212
+ assert result.completion_length > 0
213
+ assert len(result.tokens) > 0
214
+
215
+
216
+ # =============================================================================
217
+ # tokenize_conversation_for_trainer Tests
218
+ # =============================================================================
219
+
220
+
221
+ class TestTokenizeConversationForTrainer:
222
+ """Tests for tokenize_conversation_for_trainer"""
223
+
224
+ def test_empty_messages(self):
225
+ tokenizer = MockTokenizer()
226
+
227
+ result = tokenize_conversation_for_trainer(tokenizer, [])
228
+
229
+ assert result.tokens == []
230
+ assert result.masks == []
231
+
232
+ def test_single_turn(self):
233
+ tokenizer = MockTokenizer()
234
+ messages = [
235
+ {"role": "user", "content": "hello"},
236
+ {"role": "assistant", "content": "hi"},
237
+ ]
238
+
239
+ result = tokenize_conversation_for_trainer(tokenizer, messages)
240
+
241
+ # User masked (-100), assistant unmasked (token IDs)
242
+ assert result.prompt_length > 0
243
+ assert result.completion_length > 0
244
+
245
+ def test_multi_turn_all_assistants_unmasked(self):
246
+ tokenizer = MockTokenizer()
247
+ messages = [
248
+ {"role": "user", "content": "hello"},
249
+ {"role": "assistant", "content": "hi"},
250
+ {"role": "user", "content": "how"},
251
+ {"role": "assistant", "content": "fine"},
252
+ ]
253
+
254
+ result = tokenize_conversation_for_trainer(tokenizer, messages)
255
+
256
+ # Should have unmasked tokens for both assistant turns
257
+ assert result.completion_length > 0
258
+
259
+
260
+ # =============================================================================
261
+ # validate_masks Tests
262
+ # =============================================================================
263
+
264
+
265
+ class TestValidateMasks:
266
+ """Tests for validate_masks with new -100/token_id format"""
267
+
268
+ def test_valid_masks(self):
269
+ tokenizer = MockTokenizer()
270
+ tokens = [1, 2, 3, 4, 5]
271
+ # New format: -100 for prompt, actual token IDs for completion
272
+ masks = [-100, -100, 3, 4, 5] # Prompt then completion
273
+
274
+ is_valid, issues = validate_masks(tokens, masks, tokenizer)
275
+
276
+ assert is_valid is True
277
+ assert issues == []
278
+
279
+ def test_length_mismatch(self):
280
+ tokenizer = MockTokenizer()
281
+ tokens = [1, 2, 3, 4, 5]
282
+ masks = [-100, -100, 3] # Too short
283
+
284
+ is_valid, issues = validate_masks(tokens, masks, tokenizer)
285
+
286
+ assert is_valid is False
287
+ assert any("Length mismatch" in issue for issue in issues)
288
+
289
+ def test_legacy_format_detected(self):
290
+ tokenizer = MockTokenizer()
291
+ tokens = [1, 2, 3, 4, 5]
292
+ masks = [0, 0, 1, 1, 1] # Legacy 0/1 format - should be flagged
293
+
294
+ is_valid, issues = validate_masks(tokens, masks, tokenizer)
295
+
296
+ assert is_valid is False
297
+ assert any("LEGACY MASK FORMAT" in issue for issue in issues)
298
+
299
+ def test_all_masked(self):
300
+ tokenizer = MockTokenizer()
301
+ tokens = [1, 2, 3, 4, 5]
302
+ masks = [-100, -100, -100, -100, -100] # All masked (no completion)
303
+
304
+ is_valid, issues = validate_masks(tokens, masks, tokenizer)
305
+
306
+ assert is_valid is False
307
+ assert any("No unmasked tokens" in issue for issue in issues)
308
+
309
+ def test_all_unmasked(self):
310
+ tokenizer = MockTokenizer()
311
+ tokens = [1, 2, 3, 4, 5]
312
+ # All tokens match their positions (all unmasked, no prompt)
313
+ masks = [1, 2, 3, 4, 5]
314
+
315
+ is_valid, issues = validate_masks(tokens, masks, tokenizer)
316
+
317
+ assert is_valid is False
318
+ assert any("No masked tokens" in issue for issue in issues)
319
+
320
+ def test_mask_token_mismatch(self):
321
+ tokenizer = MockTokenizer()
322
+ tokens = [1, 2, 3, 4, 5]
323
+ # Token at position 2 is 3, but mask says 99
324
+ masks = [-100, -100, 99, 4, 5]
325
+
326
+ is_valid, issues = validate_masks(tokens, masks, tokenizer)
327
+
328
+ assert is_valid is False
329
+ assert any("Mask mismatch" in issue for issue in issues)
330
+
331
+
332
+ # =============================================================================
333
+ # create_masks_from_response_start Tests
334
+ # =============================================================================
335
+
336
+
337
+ class TestCreateMasksFromResponseStart:
338
+ """Tests for create_masks_from_response_start with new format"""
339
+
340
+ def test_normal_case(self):
341
+ tokens = [1, 2, 3, 4, 5]
342
+ response_start = 3
343
+
344
+ masks = create_masks_from_response_start(tokens, response_start)
345
+
346
+ # -100 for prompt, actual token IDs for completion
347
+ assert masks == [-100, -100, -100, 4, 5]
348
+
349
+ def test_start_at_beginning(self):
350
+ tokens = [1, 2, 3, 4, 5]
351
+ response_start = 0
352
+
353
+ masks = create_masks_from_response_start(tokens, response_start)
354
+
355
+ # All completion (all token IDs)
356
+ assert masks == [1, 2, 3, 4, 5]
357
+
358
+ def test_start_at_end(self):
359
+ tokens = [1, 2, 3, 4, 5]
360
+ response_start = 5
361
+
362
+ masks = create_masks_from_response_start(tokens, response_start)
363
+
364
+ # All prompt (all -100)
365
+ assert masks == [-100, -100, -100, -100, -100]
366
+
367
+ def test_negative_start_clamps(self):
368
+ tokens = [1, 2, 3, 4, 5]
369
+ response_start = -10
370
+
371
+ masks = create_masks_from_response_start(tokens, response_start)
372
+
373
+ # Clamps to 0, so all completion
374
+ assert masks == [1, 2, 3, 4, 5]
375
+
376
+ def test_beyond_end_clamps(self):
377
+ tokens = [1, 2, 3, 4, 5]
378
+ response_start = 100
379
+
380
+ masks = create_masks_from_response_start(tokens, response_start)
381
+
382
+ # Clamps to end, so all prompt
383
+ assert masks == [-100, -100, -100, -100, -100]
384
+
385
+
386
+ # =============================================================================
387
+ # fix_historical_masks Tests
388
+ # =============================================================================
389
+
390
+
391
+ class TestFixHistoricalMasks:
392
+ """Tests for fix_historical_masks"""
393
+
394
+ def test_all_ones_detected_and_fixed(self):
395
+ tokenizer = MockTokenizer()
396
+ tokens = [1, 10, 4, 2, 11, 4, 3, 12, 4] # system hello, user world, assistant how
397
+ masks = [1, 1, 1, 1, 1, 1, 1, 1, 1] # All 1s (legacy incorrect format)
398
+ messages = [
399
+ {"role": "system", "content": "hello"},
400
+ {"role": "user", "content": "world"},
401
+ {"role": "assistant", "content": "how"},
402
+ ]
403
+
404
+ fixed = fix_historical_masks(tokens, masks, tokenizer, messages)
405
+
406
+ # Should have -100 for prompt now
407
+ assert any(m == -100 for m in fixed)
408
+
409
+ def test_legacy_zeros_ones_fixed(self):
410
+ tokenizer = MockTokenizer()
411
+ tokens = [2, 10, 4, 3, 11, 4] # user hello, assistant world
412
+ masks = [0, 0, 0, 1, 1, 1] # Legacy 0/1 format
413
+ messages = [
414
+ {"role": "user", "content": "hello"},
415
+ {"role": "assistant", "content": "world"},
416
+ ]
417
+
418
+ fixed = fix_historical_masks(tokens, masks, tokenizer, messages)
419
+
420
+ # Should be converted to -100/token_id format
421
+ assert any(m == -100 for m in fixed)
422
+ assert any(m != -100 and m > 0 for m in fixed)
423
+
424
+ def test_already_valid_unchanged(self):
425
+ tokenizer = MockTokenizer()
426
+ tokens = [2, 10, 4, 3, 11, 4] # user hello, assistant world
427
+ # Already correct format: -100 for prompt, token IDs for completion
428
+ masks = [-100, -100, -100, 3, 11, 4]
429
+ messages = [
430
+ {"role": "user", "content": "hello"},
431
+ {"role": "assistant", "content": "world"},
432
+ ]
433
+
434
+ fixed = fix_historical_masks(tokens, masks, tokenizer, messages)
435
+
436
+ assert fixed == masks
437
+
438
+
439
+ # =============================================================================
440
+ # Integration Tests
441
+ # =============================================================================
442
+
443
+
444
+ class TestIntegration:
445
+ """Integration tests combining multiple utilities"""
446
+
447
+ def test_tokenize_validate_flow(self):
448
+ tokenizer = MockTokenizer()
449
+ messages = [
450
+ {"role": "system", "content": "you are helpful"},
451
+ {"role": "user", "content": "hello world"},
452
+ {"role": "assistant", "content": "hi there"},
453
+ ]
454
+
455
+ # Tokenize
456
+ result = tokenize_for_trainer(tokenizer, messages)
457
+
458
+ # Validate
459
+ is_valid, issues = validate_masks(result.tokens, result.masks, tokenizer)
460
+
461
+ # Should be valid
462
+ assert len(result.tokens) > 0
463
+ assert len(result.masks) == len(result.tokens)
464
+ assert is_valid is True, f"Validation failed: {issues}"
465
+
466
+ def test_fix_and_validate_flow(self):
467
+ tokenizer = MockTokenizer()
468
+ messages = [
469
+ {"role": "user", "content": "hello"},
470
+ {"role": "assistant", "content": "world"},
471
+ ]
472
+
473
+ # Simulate broken historical masks (legacy all-1s format)
474
+ tokens = tokenizer.apply_chat_template(messages)
475
+ broken_masks = [1] * len(tokens)
476
+
477
+ # Fix
478
+ fixed_masks = fix_historical_masks(tokens, broken_masks, tokenizer, messages)
479
+
480
+ # Should have -100 for prompt tokens now
481
+ assert any(m == -100 for m in fixed_masks)
482
+
483
+ # Validate the fixed masks
484
+ is_valid, issues = validate_masks(tokens, fixed_masks, tokenizer)
485
+ assert is_valid is True, f"Fixed masks should be valid: {issues}"
486
+
487
+ def test_completion_tokens_match_in_masks(self):
488
+ """Verify that completion masks contain actual token IDs"""
489
+ tokenizer = MockTokenizer()
490
+ messages = [
491
+ {"role": "user", "content": "hello"},
492
+ {"role": "assistant", "content": "world"},
493
+ ]
494
+
495
+ result = tokenize_for_trainer(tokenizer, messages)
496
+
497
+ # For completion tokens, mask should equal token
498
+ for i, (token, mask) in enumerate(zip(result.tokens, result.masks)):
499
+ if mask != -100:
500
+ assert mask == token, \
501
+ f"Position {i}: mask {mask} should equal token {token}"