@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,501 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for Tokenization Utilities
|
|
3
|
+
|
|
4
|
+
Tests cover:
|
|
5
|
+
- Proper prompt/completion masking with -100/token_id format
|
|
6
|
+
- Multi-turn conversation masking
|
|
7
|
+
- Mask validation
|
|
8
|
+
- Historical mask fixing
|
|
9
|
+
|
|
10
|
+
MASK FORMAT:
|
|
11
|
+
- mask = -100: Prompt token, ignored in loss calculation
|
|
12
|
+
- mask = token_id: Completion token, trained on
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import pytest
|
|
16
|
+
from unittest.mock import MagicMock, patch
|
|
17
|
+
|
|
18
|
+
from src.training.tokenization_utils import (
|
|
19
|
+
TokenizationResult,
|
|
20
|
+
tokenize_for_trainer,
|
|
21
|
+
tokenize_conversation_for_trainer,
|
|
22
|
+
validate_masks,
|
|
23
|
+
create_masks_from_response_start,
|
|
24
|
+
fix_historical_masks,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# =============================================================================
|
|
29
|
+
# Mock Tokenizer
|
|
30
|
+
# =============================================================================
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MockTokenizer:
|
|
34
|
+
"""Mock tokenizer for testing"""
|
|
35
|
+
|
|
36
|
+
def __init__(self):
|
|
37
|
+
self.vocab = {
|
|
38
|
+
"<|system|>": 1,
|
|
39
|
+
"<|user|>": 2,
|
|
40
|
+
"<|assistant|>": 3,
|
|
41
|
+
"<|end|>": 4,
|
|
42
|
+
"hello": 10,
|
|
43
|
+
"world": 11,
|
|
44
|
+
"how": 12,
|
|
45
|
+
"are": 13,
|
|
46
|
+
"you": 14,
|
|
47
|
+
"i": 15,
|
|
48
|
+
"am": 16,
|
|
49
|
+
"fine": 17,
|
|
50
|
+
"thanks": 18,
|
|
51
|
+
"for": 19,
|
|
52
|
+
"asking": 20,
|
|
53
|
+
}
|
|
54
|
+
self.reverse_vocab = {v: k for k, v in self.vocab.items()}
|
|
55
|
+
|
|
56
|
+
def encode(self, text: str, add_special_tokens: bool = True) -> list:
|
|
57
|
+
"""Simple word-level encoding"""
|
|
58
|
+
words = text.lower().replace("<|", " <|").replace("|>", "|> ").split()
|
|
59
|
+
tokens = []
|
|
60
|
+
for word in words:
|
|
61
|
+
word = word.strip()
|
|
62
|
+
if word in self.vocab:
|
|
63
|
+
tokens.append(self.vocab[word])
|
|
64
|
+
else:
|
|
65
|
+
tokens.append(100 + len(word)) # Unknown token
|
|
66
|
+
return tokens
|
|
67
|
+
|
|
68
|
+
def decode(self, tokens: list) -> str:
|
|
69
|
+
"""Simple decoding"""
|
|
70
|
+
words = []
|
|
71
|
+
for t in tokens:
|
|
72
|
+
if t in self.reverse_vocab:
|
|
73
|
+
words.append(self.reverse_vocab[t])
|
|
74
|
+
else:
|
|
75
|
+
words.append(f"[{t}]")
|
|
76
|
+
return " ".join(words)
|
|
77
|
+
|
|
78
|
+
def apply_chat_template(
|
|
79
|
+
self,
|
|
80
|
+
messages: list,
|
|
81
|
+
return_tensors=None,
|
|
82
|
+
add_generation_prompt: bool = False,
|
|
83
|
+
) -> list:
|
|
84
|
+
"""Mock chat template application"""
|
|
85
|
+
tokens = []
|
|
86
|
+
|
|
87
|
+
for msg in messages:
|
|
88
|
+
role = msg.get("role", "user")
|
|
89
|
+
content = msg.get("content", "")
|
|
90
|
+
|
|
91
|
+
# Add role token
|
|
92
|
+
if role == "system":
|
|
93
|
+
tokens.append(1)
|
|
94
|
+
elif role == "user":
|
|
95
|
+
tokens.append(2)
|
|
96
|
+
elif role == "assistant":
|
|
97
|
+
tokens.append(3)
|
|
98
|
+
|
|
99
|
+
# Add content tokens
|
|
100
|
+
tokens.extend(self.encode(content, add_special_tokens=False))
|
|
101
|
+
|
|
102
|
+
# Add end token
|
|
103
|
+
tokens.append(4)
|
|
104
|
+
|
|
105
|
+
if add_generation_prompt:
|
|
106
|
+
tokens.append(3) # Assistant start token
|
|
107
|
+
|
|
108
|
+
return tokens
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# =============================================================================
|
|
112
|
+
# TokenizationResult Tests
|
|
113
|
+
# =============================================================================
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class TestTokenizationResult:
|
|
117
|
+
"""Tests for TokenizationResult dataclass"""
|
|
118
|
+
|
|
119
|
+
def test_creation(self):
|
|
120
|
+
# New format: -100 for prompt, actual token IDs for completion
|
|
121
|
+
tokens = [1, 2, 3, 4, 5]
|
|
122
|
+
masks = [-100, -100, 3, 4, 5] # First 2 prompt, last 3 completion
|
|
123
|
+
result = TokenizationResult(
|
|
124
|
+
tokens=tokens,
|
|
125
|
+
masks=masks,
|
|
126
|
+
prompt_length=2,
|
|
127
|
+
completion_length=3,
|
|
128
|
+
total_length=5,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
assert len(result.tokens) == 5
|
|
132
|
+
assert result.prompt_length == 2
|
|
133
|
+
assert result.completion_length == 3
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# =============================================================================
|
|
137
|
+
# tokenize_for_trainer Tests
|
|
138
|
+
# =============================================================================
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class TestTokenizeForTrainer:
|
|
142
|
+
"""Tests for tokenize_for_trainer"""
|
|
143
|
+
|
|
144
|
+
def test_empty_messages(self):
|
|
145
|
+
tokenizer = MockTokenizer()
|
|
146
|
+
|
|
147
|
+
result = tokenize_for_trainer(tokenizer, [])
|
|
148
|
+
|
|
149
|
+
assert result.tokens == []
|
|
150
|
+
assert result.masks == []
|
|
151
|
+
assert result.total_length == 0
|
|
152
|
+
|
|
153
|
+
def test_prompt_only(self):
|
|
154
|
+
tokenizer = MockTokenizer()
|
|
155
|
+
messages = [
|
|
156
|
+
{"role": "user", "content": "hello world"},
|
|
157
|
+
]
|
|
158
|
+
|
|
159
|
+
result = tokenize_for_trainer(tokenizer, messages, add_generation_prompt=True)
|
|
160
|
+
|
|
161
|
+
# All should be masked (no assistant response) - all -100
|
|
162
|
+
assert all(m == -100 for m in result.masks)
|
|
163
|
+
assert result.completion_length == 0
|
|
164
|
+
|
|
165
|
+
def test_with_assistant_response(self):
|
|
166
|
+
tokenizer = MockTokenizer()
|
|
167
|
+
messages = [
|
|
168
|
+
{"role": "user", "content": "hello"},
|
|
169
|
+
{"role": "assistant", "content": "world"},
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
result = tokenize_for_trainer(tokenizer, messages)
|
|
173
|
+
|
|
174
|
+
# Should have both prompt and completion masks
|
|
175
|
+
# Prompt: -100, Completion: actual token IDs
|
|
176
|
+
assert any(m == -100 for m in result.masks) # Prompt masked with -100
|
|
177
|
+
assert any(m != -100 for m in result.masks) # Completion has token IDs
|
|
178
|
+
assert len(result.masks) == len(result.tokens)
|
|
179
|
+
|
|
180
|
+
# Verify completion tokens match actual tokens
|
|
181
|
+
for i, (token, mask) in enumerate(zip(result.tokens, result.masks)):
|
|
182
|
+
if mask != -100:
|
|
183
|
+
assert mask == token, f"Mask at pos {i} should equal token for completion"
|
|
184
|
+
|
|
185
|
+
def test_with_system_prompt(self):
|
|
186
|
+
tokenizer = MockTokenizer()
|
|
187
|
+
messages = [
|
|
188
|
+
{"role": "system", "content": "you are helpful"},
|
|
189
|
+
{"role": "user", "content": "hello"},
|
|
190
|
+
{"role": "assistant", "content": "hi"},
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
result = tokenize_for_trainer(tokenizer, messages)
|
|
194
|
+
|
|
195
|
+
# System and user should be masked (-100), assistant unmasked (token IDs)
|
|
196
|
+
assert result.prompt_length > 0
|
|
197
|
+
assert result.completion_length > 0
|
|
198
|
+
assert result.prompt_length + result.completion_length == result.total_length
|
|
199
|
+
|
|
200
|
+
def test_multiple_turns(self):
|
|
201
|
+
tokenizer = MockTokenizer()
|
|
202
|
+
messages = [
|
|
203
|
+
{"role": "user", "content": "hello"},
|
|
204
|
+
{"role": "assistant", "content": "hi"},
|
|
205
|
+
{"role": "user", "content": "how are you"},
|
|
206
|
+
{"role": "assistant", "content": "fine thanks"},
|
|
207
|
+
]
|
|
208
|
+
|
|
209
|
+
result = tokenize_for_trainer(tokenizer, messages)
|
|
210
|
+
|
|
211
|
+
# Only last assistant should be unmasked
|
|
212
|
+
assert result.completion_length > 0
|
|
213
|
+
assert len(result.tokens) > 0
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# =============================================================================
|
|
217
|
+
# tokenize_conversation_for_trainer Tests
|
|
218
|
+
# =============================================================================
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class TestTokenizeConversationForTrainer:
|
|
222
|
+
"""Tests for tokenize_conversation_for_trainer"""
|
|
223
|
+
|
|
224
|
+
def test_empty_messages(self):
|
|
225
|
+
tokenizer = MockTokenizer()
|
|
226
|
+
|
|
227
|
+
result = tokenize_conversation_for_trainer(tokenizer, [])
|
|
228
|
+
|
|
229
|
+
assert result.tokens == []
|
|
230
|
+
assert result.masks == []
|
|
231
|
+
|
|
232
|
+
def test_single_turn(self):
|
|
233
|
+
tokenizer = MockTokenizer()
|
|
234
|
+
messages = [
|
|
235
|
+
{"role": "user", "content": "hello"},
|
|
236
|
+
{"role": "assistant", "content": "hi"},
|
|
237
|
+
]
|
|
238
|
+
|
|
239
|
+
result = tokenize_conversation_for_trainer(tokenizer, messages)
|
|
240
|
+
|
|
241
|
+
# User masked (-100), assistant unmasked (token IDs)
|
|
242
|
+
assert result.prompt_length > 0
|
|
243
|
+
assert result.completion_length > 0
|
|
244
|
+
|
|
245
|
+
def test_multi_turn_all_assistants_unmasked(self):
|
|
246
|
+
tokenizer = MockTokenizer()
|
|
247
|
+
messages = [
|
|
248
|
+
{"role": "user", "content": "hello"},
|
|
249
|
+
{"role": "assistant", "content": "hi"},
|
|
250
|
+
{"role": "user", "content": "how"},
|
|
251
|
+
{"role": "assistant", "content": "fine"},
|
|
252
|
+
]
|
|
253
|
+
|
|
254
|
+
result = tokenize_conversation_for_trainer(tokenizer, messages)
|
|
255
|
+
|
|
256
|
+
# Should have unmasked tokens for both assistant turns
|
|
257
|
+
assert result.completion_length > 0
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
# =============================================================================
|
|
261
|
+
# validate_masks Tests
|
|
262
|
+
# =============================================================================
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class TestValidateMasks:
|
|
266
|
+
"""Tests for validate_masks with new -100/token_id format"""
|
|
267
|
+
|
|
268
|
+
def test_valid_masks(self):
|
|
269
|
+
tokenizer = MockTokenizer()
|
|
270
|
+
tokens = [1, 2, 3, 4, 5]
|
|
271
|
+
# New format: -100 for prompt, actual token IDs for completion
|
|
272
|
+
masks = [-100, -100, 3, 4, 5] # Prompt then completion
|
|
273
|
+
|
|
274
|
+
is_valid, issues = validate_masks(tokens, masks, tokenizer)
|
|
275
|
+
|
|
276
|
+
assert is_valid is True
|
|
277
|
+
assert issues == []
|
|
278
|
+
|
|
279
|
+
def test_length_mismatch(self):
|
|
280
|
+
tokenizer = MockTokenizer()
|
|
281
|
+
tokens = [1, 2, 3, 4, 5]
|
|
282
|
+
masks = [-100, -100, 3] # Too short
|
|
283
|
+
|
|
284
|
+
is_valid, issues = validate_masks(tokens, masks, tokenizer)
|
|
285
|
+
|
|
286
|
+
assert is_valid is False
|
|
287
|
+
assert any("Length mismatch" in issue for issue in issues)
|
|
288
|
+
|
|
289
|
+
def test_legacy_format_detected(self):
|
|
290
|
+
tokenizer = MockTokenizer()
|
|
291
|
+
tokens = [1, 2, 3, 4, 5]
|
|
292
|
+
masks = [0, 0, 1, 1, 1] # Legacy 0/1 format - should be flagged
|
|
293
|
+
|
|
294
|
+
is_valid, issues = validate_masks(tokens, masks, tokenizer)
|
|
295
|
+
|
|
296
|
+
assert is_valid is False
|
|
297
|
+
assert any("LEGACY MASK FORMAT" in issue for issue in issues)
|
|
298
|
+
|
|
299
|
+
def test_all_masked(self):
|
|
300
|
+
tokenizer = MockTokenizer()
|
|
301
|
+
tokens = [1, 2, 3, 4, 5]
|
|
302
|
+
masks = [-100, -100, -100, -100, -100] # All masked (no completion)
|
|
303
|
+
|
|
304
|
+
is_valid, issues = validate_masks(tokens, masks, tokenizer)
|
|
305
|
+
|
|
306
|
+
assert is_valid is False
|
|
307
|
+
assert any("No unmasked tokens" in issue for issue in issues)
|
|
308
|
+
|
|
309
|
+
def test_all_unmasked(self):
|
|
310
|
+
tokenizer = MockTokenizer()
|
|
311
|
+
tokens = [1, 2, 3, 4, 5]
|
|
312
|
+
# All tokens match their positions (all unmasked, no prompt)
|
|
313
|
+
masks = [1, 2, 3, 4, 5]
|
|
314
|
+
|
|
315
|
+
is_valid, issues = validate_masks(tokens, masks, tokenizer)
|
|
316
|
+
|
|
317
|
+
assert is_valid is False
|
|
318
|
+
assert any("No masked tokens" in issue for issue in issues)
|
|
319
|
+
|
|
320
|
+
def test_mask_token_mismatch(self):
|
|
321
|
+
tokenizer = MockTokenizer()
|
|
322
|
+
tokens = [1, 2, 3, 4, 5]
|
|
323
|
+
# Token at position 2 is 3, but mask says 99
|
|
324
|
+
masks = [-100, -100, 99, 4, 5]
|
|
325
|
+
|
|
326
|
+
is_valid, issues = validate_masks(tokens, masks, tokenizer)
|
|
327
|
+
|
|
328
|
+
assert is_valid is False
|
|
329
|
+
assert any("Mask mismatch" in issue for issue in issues)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
# =============================================================================
|
|
333
|
+
# create_masks_from_response_start Tests
|
|
334
|
+
# =============================================================================
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
class TestCreateMasksFromResponseStart:
|
|
338
|
+
"""Tests for create_masks_from_response_start with new format"""
|
|
339
|
+
|
|
340
|
+
def test_normal_case(self):
|
|
341
|
+
tokens = [1, 2, 3, 4, 5]
|
|
342
|
+
response_start = 3
|
|
343
|
+
|
|
344
|
+
masks = create_masks_from_response_start(tokens, response_start)
|
|
345
|
+
|
|
346
|
+
# -100 for prompt, actual token IDs for completion
|
|
347
|
+
assert masks == [-100, -100, -100, 4, 5]
|
|
348
|
+
|
|
349
|
+
def test_start_at_beginning(self):
|
|
350
|
+
tokens = [1, 2, 3, 4, 5]
|
|
351
|
+
response_start = 0
|
|
352
|
+
|
|
353
|
+
masks = create_masks_from_response_start(tokens, response_start)
|
|
354
|
+
|
|
355
|
+
# All completion (all token IDs)
|
|
356
|
+
assert masks == [1, 2, 3, 4, 5]
|
|
357
|
+
|
|
358
|
+
def test_start_at_end(self):
|
|
359
|
+
tokens = [1, 2, 3, 4, 5]
|
|
360
|
+
response_start = 5
|
|
361
|
+
|
|
362
|
+
masks = create_masks_from_response_start(tokens, response_start)
|
|
363
|
+
|
|
364
|
+
# All prompt (all -100)
|
|
365
|
+
assert masks == [-100, -100, -100, -100, -100]
|
|
366
|
+
|
|
367
|
+
def test_negative_start_clamps(self):
|
|
368
|
+
tokens = [1, 2, 3, 4, 5]
|
|
369
|
+
response_start = -10
|
|
370
|
+
|
|
371
|
+
masks = create_masks_from_response_start(tokens, response_start)
|
|
372
|
+
|
|
373
|
+
# Clamps to 0, so all completion
|
|
374
|
+
assert masks == [1, 2, 3, 4, 5]
|
|
375
|
+
|
|
376
|
+
def test_beyond_end_clamps(self):
|
|
377
|
+
tokens = [1, 2, 3, 4, 5]
|
|
378
|
+
response_start = 100
|
|
379
|
+
|
|
380
|
+
masks = create_masks_from_response_start(tokens, response_start)
|
|
381
|
+
|
|
382
|
+
# Clamps to end, so all prompt
|
|
383
|
+
assert masks == [-100, -100, -100, -100, -100]
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
# =============================================================================
|
|
387
|
+
# fix_historical_masks Tests
|
|
388
|
+
# =============================================================================
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class TestFixHistoricalMasks:
|
|
392
|
+
"""Tests for fix_historical_masks"""
|
|
393
|
+
|
|
394
|
+
def test_all_ones_detected_and_fixed(self):
|
|
395
|
+
tokenizer = MockTokenizer()
|
|
396
|
+
tokens = [1, 10, 4, 2, 11, 4, 3, 12, 4] # system hello, user world, assistant how
|
|
397
|
+
masks = [1, 1, 1, 1, 1, 1, 1, 1, 1] # All 1s (legacy incorrect format)
|
|
398
|
+
messages = [
|
|
399
|
+
{"role": "system", "content": "hello"},
|
|
400
|
+
{"role": "user", "content": "world"},
|
|
401
|
+
{"role": "assistant", "content": "how"},
|
|
402
|
+
]
|
|
403
|
+
|
|
404
|
+
fixed = fix_historical_masks(tokens, masks, tokenizer, messages)
|
|
405
|
+
|
|
406
|
+
# Should have -100 for prompt now
|
|
407
|
+
assert any(m == -100 for m in fixed)
|
|
408
|
+
|
|
409
|
+
def test_legacy_zeros_ones_fixed(self):
|
|
410
|
+
tokenizer = MockTokenizer()
|
|
411
|
+
tokens = [2, 10, 4, 3, 11, 4] # user hello, assistant world
|
|
412
|
+
masks = [0, 0, 0, 1, 1, 1] # Legacy 0/1 format
|
|
413
|
+
messages = [
|
|
414
|
+
{"role": "user", "content": "hello"},
|
|
415
|
+
{"role": "assistant", "content": "world"},
|
|
416
|
+
]
|
|
417
|
+
|
|
418
|
+
fixed = fix_historical_masks(tokens, masks, tokenizer, messages)
|
|
419
|
+
|
|
420
|
+
# Should be converted to -100/token_id format
|
|
421
|
+
assert any(m == -100 for m in fixed)
|
|
422
|
+
assert any(m != -100 and m > 0 for m in fixed)
|
|
423
|
+
|
|
424
|
+
def test_already_valid_unchanged(self):
|
|
425
|
+
tokenizer = MockTokenizer()
|
|
426
|
+
tokens = [2, 10, 4, 3, 11, 4] # user hello, assistant world
|
|
427
|
+
# Already correct format: -100 for prompt, token IDs for completion
|
|
428
|
+
masks = [-100, -100, -100, 3, 11, 4]
|
|
429
|
+
messages = [
|
|
430
|
+
{"role": "user", "content": "hello"},
|
|
431
|
+
{"role": "assistant", "content": "world"},
|
|
432
|
+
]
|
|
433
|
+
|
|
434
|
+
fixed = fix_historical_masks(tokens, masks, tokenizer, messages)
|
|
435
|
+
|
|
436
|
+
assert fixed == masks
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
# =============================================================================
|
|
440
|
+
# Integration Tests
|
|
441
|
+
# =============================================================================
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
class TestIntegration:
|
|
445
|
+
"""Integration tests combining multiple utilities"""
|
|
446
|
+
|
|
447
|
+
def test_tokenize_validate_flow(self):
|
|
448
|
+
tokenizer = MockTokenizer()
|
|
449
|
+
messages = [
|
|
450
|
+
{"role": "system", "content": "you are helpful"},
|
|
451
|
+
{"role": "user", "content": "hello world"},
|
|
452
|
+
{"role": "assistant", "content": "hi there"},
|
|
453
|
+
]
|
|
454
|
+
|
|
455
|
+
# Tokenize
|
|
456
|
+
result = tokenize_for_trainer(tokenizer, messages)
|
|
457
|
+
|
|
458
|
+
# Validate
|
|
459
|
+
is_valid, issues = validate_masks(result.tokens, result.masks, tokenizer)
|
|
460
|
+
|
|
461
|
+
# Should be valid
|
|
462
|
+
assert len(result.tokens) > 0
|
|
463
|
+
assert len(result.masks) == len(result.tokens)
|
|
464
|
+
assert is_valid is True, f"Validation failed: {issues}"
|
|
465
|
+
|
|
466
|
+
def test_fix_and_validate_flow(self):
|
|
467
|
+
tokenizer = MockTokenizer()
|
|
468
|
+
messages = [
|
|
469
|
+
{"role": "user", "content": "hello"},
|
|
470
|
+
{"role": "assistant", "content": "world"},
|
|
471
|
+
]
|
|
472
|
+
|
|
473
|
+
# Simulate broken historical masks (legacy all-1s format)
|
|
474
|
+
tokens = tokenizer.apply_chat_template(messages)
|
|
475
|
+
broken_masks = [1] * len(tokens)
|
|
476
|
+
|
|
477
|
+
# Fix
|
|
478
|
+
fixed_masks = fix_historical_masks(tokens, broken_masks, tokenizer, messages)
|
|
479
|
+
|
|
480
|
+
# Should have -100 for prompt tokens now
|
|
481
|
+
assert any(m == -100 for m in fixed_masks)
|
|
482
|
+
|
|
483
|
+
# Validate the fixed masks
|
|
484
|
+
is_valid, issues = validate_masks(tokens, fixed_masks, tokenizer)
|
|
485
|
+
assert is_valid is True, f"Fixed masks should be valid: {issues}"
|
|
486
|
+
|
|
487
|
+
def test_completion_tokens_match_in_masks(self):
|
|
488
|
+
"""Verify that completion masks contain actual token IDs"""
|
|
489
|
+
tokenizer = MockTokenizer()
|
|
490
|
+
messages = [
|
|
491
|
+
{"role": "user", "content": "hello"},
|
|
492
|
+
{"role": "assistant", "content": "world"},
|
|
493
|
+
]
|
|
494
|
+
|
|
495
|
+
result = tokenize_for_trainer(tokenizer, messages)
|
|
496
|
+
|
|
497
|
+
# For completion tokens, mask should equal token
|
|
498
|
+
for i, (token, mask) in enumerate(zip(result.tokens, result.masks)):
|
|
499
|
+
if mask != -100:
|
|
500
|
+
assert mask == token, \
|
|
501
|
+
f"Position {i}: mask {mask} should equal token {token}"
|