@elizaos/training 2.0.0-alpha.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Dockerfile +75 -0
- package/Makefile +374 -0
- package/README.md +346 -0
- package/config/rubrics.json +137 -0
- package/data/.gitkeep +0 -0
- package/data/degen/.gitkeep +2 -0
- package/data/trader/.gitkeep +2 -0
- package/docker-compose.test.yml +57 -0
- package/package.json +58 -0
- package/python/config/babylon_atropos.yaml +90 -0
- package/python/config/profiles/12gb.json +11 -0
- package/python/config/profiles/16gb.json +10 -0
- package/python/config/profiles/24gb.json +10 -0
- package/python/config/profiles/48gb.json +10 -0
- package/python/config/profiles/cpu.json +11 -0
- package/python/config/profiles/l40-2gpu-safe.json +20 -0
- package/python/config/profiles/l40-2gpu.json +22 -0
- package/python/config/profiles/l40-4gpu.json +21 -0
- package/python/config/profiles/l40.json +17 -0
- package/python/config/tinker_training.yaml +143 -0
- package/python/curriculum_state.json +165 -0
- package/python/env.template +86 -0
- package/python/env.training.template +46 -0
- package/python/pyproject.toml +41 -0
- package/python/requirements-ci.txt +31 -0
- package/python/requirements.txt +87 -0
- package/python/scripts/__init__.py +4 -0
- package/python/scripts/import_json_trajectories.py +412 -0
- package/python/scripts/local-finetune/README.md +63 -0
- package/python/scripts/local-finetune/ingest_and_score.py +139 -0
- package/python/scripts/local-finetune/merge_model.py +32 -0
- package/python/scripts/local-finetune/test_adapter.py +91 -0
- package/python/scripts/local-finetune/train_from_csv.py +132 -0
- package/python/scripts/merge_trajectories.py +318 -0
- package/python/scripts/run_ab_test.py +143 -0
- package/python/scripts/run_full_pipeline.py +544 -0
- package/python/scripts/run_tinker_training.py +192 -0
- package/python/scripts/run_training.py +914 -0
- package/python/scripts/test_judge.py +155 -0
- package/python/scripts/test_pipeline.py +356 -0
- package/python/scripts/test_trained_model.py +380 -0
- package/python/scripts/train_local.py +528 -0
- package/python/setup.py +20 -0
- package/python/src/__init__.py +190 -0
- package/python/src/data_bridge/__init__.py +24 -0
- package/python/src/data_bridge/converter.py +435 -0
- package/python/src/data_bridge/reader.py +393 -0
- package/python/src/models.py +283 -0
- package/python/src/training/__init__.py +605 -0
- package/python/src/training/ab_testing.py +404 -0
- package/python/src/training/action_executor.py +621 -0
- package/python/src/training/archetype_trainer.py +347 -0
- package/python/src/training/atropos_trainer.py +980 -0
- package/python/src/training/babylon_env.py +1254 -0
- package/python/src/training/error_recovery.py +647 -0
- package/python/src/training/evaluation.py +856 -0
- package/python/src/training/fast_simulator.py +880 -0
- package/python/src/training/format_validator.py +584 -0
- package/python/src/training/hybrid_env.py +522 -0
- package/python/src/training/kl_controller.py +628 -0
- package/python/src/training/multi_prompt_dataset.py +883 -0
- package/python/src/training/multi_turn.py +656 -0
- package/python/src/training/online_env.py +1084 -0
- package/python/src/training/quality_scorer.py +391 -0
- package/python/src/training/quality_utils.py +633 -0
- package/python/src/training/rewards.py +1344 -0
- package/python/src/training/rlaif_env.py +17 -0
- package/python/src/training/rollout_generator.py +502 -0
- package/python/src/training/rubric_loader.py +198 -0
- package/python/src/training/scenario_pool.py +1072 -0
- package/python/src/training/schemas.py +481 -0
- package/python/src/training/service_manager.py +552 -0
- package/python/src/training/simulation_bridge.py +535 -0
- package/python/src/training/tick_reward_attribution.py +399 -0
- package/python/src/training/tinker_client.py +575 -0
- package/python/src/training/tinker_trainer.py +646 -0
- package/python/src/training/tokenization_utils.py +402 -0
- package/python/tests/e2e/__init__.py +13 -0
- package/python/tests/e2e/conftest.py +258 -0
- package/python/tests/e2e/test_full_pipeline.py +643 -0
- package/python/tests/e2e/test_online_training_e2e.py +365 -0
- package/python/tests/integration/__init__.py +12 -0
- package/python/tests/integration/conftest.py +383 -0
- package/python/tests/integration/test_db_integration.py +649 -0
- package/python/tests/integration/test_json_mode_integration.py +554 -0
- package/python/tests/test_action_executor.py +594 -0
- package/python/tests/test_archetype_scoring.py +1027 -0
- package/python/tests/test_atropos_integration.py +360 -0
- package/python/tests/test_evaluation.py +727 -0
- package/python/tests/test_format_validator.py +486 -0
- package/python/tests/test_kl_controller.py +432 -0
- package/python/tests/test_lr_scheduler.py +579 -0
- package/python/tests/test_multi_turn.py +590 -0
- package/python/tests/test_online_env.py +519 -0
- package/python/tests/test_quality_scorer.py +474 -0
- package/python/tests/test_scenario_pool.py +735 -0
- package/python/tests/test_service_manager.py +585 -0
- package/python/tests/test_simulation_rollout.py +581 -0
- package/python/tests/test_tokenization_utils.py +501 -0
- package/python/tests/test_training_orchestrator.py +497 -0
- package/python/tests/test_training_output_structure.py +661 -0
- package/research-output/training-runs/training-run-1770772042899.json +26 -0
- package/research-output/training-runs/training-run-1770930079670.json +32 -0
- package/research-output/training-runs/training-run-1770930143700.json +44 -0
- package/research-output/training-runs/training-run-1770930183638.json +38 -0
- package/research-output/training-runs/training-run-1770930442049.json +38 -0
- package/research-output/training-runs/training-run-1770930793243.json +38 -0
- package/scripts/assess-training-data.ts +422 -0
- package/scripts/e2e-training-test.ts +550 -0
- package/scripts/export-rubrics.ts +64 -0
- package/scripts/generate-research-report.ts +1523 -0
- package/scripts/generate_dataset.sh +173 -0
- package/scripts/json-mode-benchmark.ts +399 -0
- package/scripts/real-archetype-benchmark.ts +210 -0
- package/scripts/run-baseline-comparison.ts +116 -0
- package/scripts/run-full-pipeline.ts +272 -0
- package/scripts/runpod_setup.sh +137 -0
- package/scripts/runpod_validate.sh +147 -0
- package/scripts/test-model-in-game.ts +955 -0
- package/scripts/test-scoring.ts +73 -0
- package/scripts/test-trained-model.ts +209 -0
- package/scripts/train-and-test.ts +824 -0
- package/scripts/verify-final.ts +118 -0
- package/src/adapter.ts +516 -0
- package/src/archetypes/ArchetypeConfigService.ts +626 -0
- package/src/archetypes/derive-archetype.ts +249 -0
- package/src/archetypes/index.ts +22 -0
- package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
- package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
- package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
- package/src/benchmark/BenchmarkDataViewer.ts +324 -0
- package/src/benchmark/BenchmarkHistoryService.ts +221 -0
- package/src/benchmark/BenchmarkRunner.ts +685 -0
- package/src/benchmark/BenchmarkValidator.ts +206 -0
- package/src/benchmark/FastEvalRunner.ts +225 -0
- package/src/benchmark/MetricsValidator.ts +165 -0
- package/src/benchmark/MetricsVisualizer.ts +909 -0
- package/src/benchmark/ModelBenchmarkService.ts +611 -0
- package/src/benchmark/ModelRegistry.ts +158 -0
- package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
- package/src/benchmark/SimulationA2AInterface.ts +1169 -0
- package/src/benchmark/SimulationEngine.ts +832 -0
- package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
- package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
- package/src/benchmark/index.ts +89 -0
- package/src/benchmark/parseSimulationMetrics.ts +124 -0
- package/src/benchmark/simulation-types.ts +78 -0
- package/src/dependencies.ts +439 -0
- package/src/generation/TrajectoryGenerator.ts +387 -0
- package/src/generation/index.ts +12 -0
- package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
- package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
- package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
- package/src/huggingface/index.ts +27 -0
- package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
- package/src/index.ts +102 -0
- package/src/init-training.ts +53 -0
- package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
- package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
- package/src/metrics/index.ts +8 -0
- package/src/metrics/types.ts +200 -0
- package/src/rubrics/__tests__/index.test.ts +184 -0
- package/src/rubrics/ass-kisser.ts +85 -0
- package/src/rubrics/degen.ts +80 -0
- package/src/rubrics/goody-twoshoes.ts +84 -0
- package/src/rubrics/index.ts +236 -0
- package/src/rubrics/information-trader.ts +84 -0
- package/src/rubrics/infosec.ts +101 -0
- package/src/rubrics/liar.ts +104 -0
- package/src/rubrics/perps-trader.ts +87 -0
- package/src/rubrics/researcher.ts +81 -0
- package/src/rubrics/scammer.ts +82 -0
- package/src/rubrics/social-butterfly.ts +73 -0
- package/src/rubrics/super-predictor.ts +97 -0
- package/src/rubrics/trader.ts +67 -0
- package/src/scoring/ArchetypeScoringService.ts +486 -0
- package/src/scoring/JudgePromptBuilder.ts +556 -0
- package/src/scoring/LLMJudgeCache.ts +401 -0
- package/src/scoring/index.ts +9 -0
- package/src/training/AutomationPipeline.ts +916 -0
- package/src/training/BenchmarkService.ts +518 -0
- package/src/training/ConfigValidator.ts +220 -0
- package/src/training/MarketOutcomesTracker.ts +187 -0
- package/src/training/ModelDeployer.ts +186 -0
- package/src/training/ModelFetcher.ts +76 -0
- package/src/training/ModelSelectionService.ts +341 -0
- package/src/training/ModelUsageVerifier.ts +160 -0
- package/src/training/MultiModelOrchestrator.ts +580 -0
- package/src/training/RLModelConfig.ts +407 -0
- package/src/training/RewardBackpropagationService.ts +149 -0
- package/src/training/RulerScoringService.ts +666 -0
- package/src/training/TrainingMonitor.ts +166 -0
- package/src/training/TrajectoryRecorder.ts +399 -0
- package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
- package/src/training/index.ts +100 -0
- package/src/training/logRLConfig.ts +34 -0
- package/src/training/pipeline.ts +129 -0
- package/src/training/storage/ModelStorageService.ts +279 -0
- package/src/training/storage/TrainingDataArchiver.ts +197 -0
- package/src/training/storage/index.ts +17 -0
- package/src/training/types.ts +207 -0
- package/src/training/window-utils.ts +138 -0
- package/src/utils/index.ts +101 -0
- package/src/utils/logger.ts +59 -0
- package/src/utils/snowflake.ts +17 -0
- package/src/utils/synthetic-detector.ts +111 -0
- package/tsconfig.json +20 -0
|
@@ -0,0 +1,575 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tinker client for RL training.
|
|
3
|
+
|
|
4
|
+
Replaces local vLLM + PyTorch training with Tinker's cloud API.
|
|
5
|
+
This provides a unified interface for both training and inference.
|
|
6
|
+
|
|
7
|
+
Based on: https://tinker-docs.thinkingmachines.ai/training-sampling
|
|
8
|
+
Integration pattern from: tinker-atropos (Nous Research)
|
|
9
|
+
|
|
10
|
+
Key features:
|
|
11
|
+
- TrainingClient for forward_backward + optim_step
|
|
12
|
+
- SamplingClient for inference during rollouts
|
|
13
|
+
- Weight synchronization between training and sampling
|
|
14
|
+
- Automatic tokenization and format conversion
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
from typing import List, Literal, Sequence
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
# Lazy import tinker to allow graceful degradation
|
|
27
|
+
try:
|
|
28
|
+
import tinker
|
|
29
|
+
from tinker import types as tinker_types
|
|
30
|
+
|
|
31
|
+
TINKER_AVAILABLE = True
|
|
32
|
+
except ImportError:
|
|
33
|
+
TINKER_AVAILABLE = False
|
|
34
|
+
tinker = None # type: ignore
|
|
35
|
+
tinker_types = None # type: ignore
|
|
36
|
+
logger.warning("Tinker not installed. Install with: pip install tinker")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class TinkerConfig:
|
|
41
|
+
"""Configuration for Tinker client"""
|
|
42
|
+
|
|
43
|
+
# Model settings
|
|
44
|
+
base_model: str = "Qwen/Qwen3-30B-A3B-Instruct"
|
|
45
|
+
lora_rank: int = 32
|
|
46
|
+
|
|
47
|
+
# Training hyperparameters
|
|
48
|
+
learning_rate: float = 4e-5
|
|
49
|
+
beta1: float = 0.9
|
|
50
|
+
beta2: float = 0.95
|
|
51
|
+
epsilon: float = 1e-8
|
|
52
|
+
|
|
53
|
+
# Sampling settings
|
|
54
|
+
default_max_tokens: int = 512
|
|
55
|
+
default_temperature: float = 0.7
|
|
56
|
+
stop_sequences: List[str] = field(
|
|
57
|
+
default_factory=lambda: ["\n\n", "<|endoftext|>", "<|im_end|>"]
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Weight sync settings
|
|
61
|
+
checkpoint_name_prefix: str = "eliza"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class TinkerDatum:
|
|
65
|
+
"""
|
|
66
|
+
Wrapper for Tinker Datum to avoid direct tinker_types dependency.
|
|
67
|
+
|
|
68
|
+
This allows code to work even when tinker is not installed.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
input_tokens: List[int],
|
|
74
|
+
target_tokens: List[int],
|
|
75
|
+
weights: List[float],
|
|
76
|
+
):
|
|
77
|
+
self.input_tokens = input_tokens
|
|
78
|
+
self.target_tokens = target_tokens
|
|
79
|
+
self.weights = weights
|
|
80
|
+
self._tinker_datum: object = None
|
|
81
|
+
|
|
82
|
+
def to_tinker(self) -> object:
|
|
83
|
+
"""Convert to actual Tinker Datum"""
|
|
84
|
+
if not TINKER_AVAILABLE:
|
|
85
|
+
raise RuntimeError("Tinker not installed")
|
|
86
|
+
|
|
87
|
+
if self._tinker_datum is None:
|
|
88
|
+
self._tinker_datum = tinker_types.Datum(
|
|
89
|
+
model_input=tinker_types.ModelInput.from_ints(tokens=self.input_tokens),
|
|
90
|
+
loss_fn_inputs=dict(
|
|
91
|
+
weights=self.weights,
|
|
92
|
+
target_tokens=self.target_tokens,
|
|
93
|
+
),
|
|
94
|
+
)
|
|
95
|
+
return self._tinker_datum
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass
|
|
99
|
+
class TrainStepResult:
|
|
100
|
+
"""Result from a training step"""
|
|
101
|
+
|
|
102
|
+
loss: float
|
|
103
|
+
num_samples: int
|
|
104
|
+
logprobs_mean: float = 0.0
|
|
105
|
+
pos_advantage_mean: float = 0.0
|
|
106
|
+
neg_advantage_mean: float = 0.0
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@dataclass
|
|
110
|
+
class SampleResult:
|
|
111
|
+
"""Result from sampling"""
|
|
112
|
+
|
|
113
|
+
completions: List[str]
|
|
114
|
+
logprobs: List[List[float]] = field(default_factory=list)
|
|
115
|
+
finish_reasons: List[str] = field(default_factory=list)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class TinkerClient:
|
|
119
|
+
"""
|
|
120
|
+
Unified Tinker client for training and inference.
|
|
121
|
+
|
|
122
|
+
This replaces local vLLM + PyTorch training with Tinker's cloud API:
|
|
123
|
+
- No local GPU required for training
|
|
124
|
+
- Training happens in Tinker cloud
|
|
125
|
+
- Fast weight sync between training and sampling
|
|
126
|
+
- Automatic format conversion
|
|
127
|
+
|
|
128
|
+
Usage:
|
|
129
|
+
client = TinkerClient(config)
|
|
130
|
+
client.setup()
|
|
131
|
+
|
|
132
|
+
# Training
|
|
133
|
+
data = [client.prepare_datum(messages, completion) for ...]
|
|
134
|
+
result = client.train_step(data, scores)
|
|
135
|
+
|
|
136
|
+
# Inference
|
|
137
|
+
completions = client.sample(messages)
|
|
138
|
+
|
|
139
|
+
# Sync weights after training
|
|
140
|
+
client.sync_weights("checkpoint-name")
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
def __init__(self, config: TinkerConfig | None = None):
|
|
144
|
+
if not TINKER_AVAILABLE:
|
|
145
|
+
raise RuntimeError(
|
|
146
|
+
"Tinker not installed. Install with: pip install tinker"
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
self.config = config or TinkerConfig()
|
|
150
|
+
self._service_client: object = None
|
|
151
|
+
self._training_client: object = None
|
|
152
|
+
self._sampling_client: object = None
|
|
153
|
+
self._tokenizer: object = None
|
|
154
|
+
self._initialized = False
|
|
155
|
+
self._current_step = 0
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def service_client(self) -> object:
|
|
159
|
+
"""Lazily initialize service client"""
|
|
160
|
+
if self._service_client is None:
|
|
161
|
+
self._service_client = tinker.ServiceClient()
|
|
162
|
+
return self._service_client
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def training_client(self) -> object:
|
|
166
|
+
"""Get training client (must call setup first)"""
|
|
167
|
+
if self._training_client is None:
|
|
168
|
+
raise RuntimeError("Client not initialized. Call setup() first.")
|
|
169
|
+
return self._training_client
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def sampling_client(self) -> object:
|
|
173
|
+
"""Get sampling client (must call setup first)"""
|
|
174
|
+
if self._sampling_client is None:
|
|
175
|
+
raise RuntimeError("Client not initialized. Call setup() first.")
|
|
176
|
+
return self._sampling_client
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def tokenizer(self) -> object:
|
|
180
|
+
"""Get tokenizer (must call setup first)"""
|
|
181
|
+
if self._tokenizer is None:
|
|
182
|
+
raise RuntimeError("Client not initialized. Call setup() first.")
|
|
183
|
+
return self._tokenizer
|
|
184
|
+
|
|
185
|
+
def setup(self) -> None:
|
|
186
|
+
"""
|
|
187
|
+
Initialize training client, sampling client, and tokenizer.
|
|
188
|
+
|
|
189
|
+
Must be called before any training or sampling operations.
|
|
190
|
+
"""
|
|
191
|
+
if self._initialized:
|
|
192
|
+
logger.info("Client already initialized")
|
|
193
|
+
return
|
|
194
|
+
|
|
195
|
+
logger.info(f"Initializing Tinker client with model: {self.config.base_model}")
|
|
196
|
+
|
|
197
|
+
# Verify API key is set
|
|
198
|
+
if not os.environ.get("TINKER_API_KEY"):
|
|
199
|
+
raise ValueError(
|
|
200
|
+
"TINKER_API_KEY environment variable not set. "
|
|
201
|
+
"Get your API key from Thinking Machines."
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Check model availability
|
|
205
|
+
capabilities = self.service_client.get_server_capabilities()
|
|
206
|
+
available_models = [m.model_name for m in capabilities.supported_models]
|
|
207
|
+
|
|
208
|
+
if self.config.base_model not in available_models:
|
|
209
|
+
logger.warning(
|
|
210
|
+
f"Model {self.config.base_model} not in available models. "
|
|
211
|
+
f"Available: {available_models[:5]}..."
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Create training client with LoRA
|
|
215
|
+
self._training_client = self.service_client.create_lora_training_client(
|
|
216
|
+
base_model=self.config.base_model,
|
|
217
|
+
lora_rank=self.config.lora_rank,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Get tokenizer
|
|
221
|
+
self._tokenizer = self._training_client.get_tokenizer()
|
|
222
|
+
|
|
223
|
+
# Create initial sampling client
|
|
224
|
+
initial_name = f"{self.config.checkpoint_name_prefix}-initial"
|
|
225
|
+
self._sampling_client = self._training_client.save_weights_and_get_sampling_client(
|
|
226
|
+
name=initial_name
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
self._initialized = True
|
|
230
|
+
logger.info("Tinker client initialized successfully")
|
|
231
|
+
|
|
232
|
+
def prepare_datum(
|
|
233
|
+
self,
|
|
234
|
+
messages: List[dict],
|
|
235
|
+
completion: str,
|
|
236
|
+
) -> TinkerDatum:
|
|
237
|
+
"""
|
|
238
|
+
Convert chat messages + completion to Tinker Datum.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
messages: List of chat messages (role/content dicts)
|
|
242
|
+
completion: The assistant completion to train on
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
TinkerDatum ready for training
|
|
246
|
+
"""
|
|
247
|
+
# Render messages to prompt using chat template
|
|
248
|
+
prompt = self.tokenizer.apply_chat_template(
|
|
249
|
+
messages,
|
|
250
|
+
tokenize=False,
|
|
251
|
+
add_generation_prompt=True,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Tokenize prompt (no loss on prompt tokens)
|
|
255
|
+
prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
|
|
256
|
+
prompt_weights = [0.0] * len(prompt_tokens)
|
|
257
|
+
|
|
258
|
+
# Tokenize completion (loss on these tokens)
|
|
259
|
+
completion_tokens = self.tokenizer.encode(completion, add_special_tokens=False)
|
|
260
|
+
completion_weights = [1.0] * len(completion_tokens)
|
|
261
|
+
|
|
262
|
+
# Combine
|
|
263
|
+
all_tokens = prompt_tokens + completion_tokens
|
|
264
|
+
all_weights = prompt_weights + completion_weights
|
|
265
|
+
|
|
266
|
+
# Shift for next-token prediction
|
|
267
|
+
input_tokens = all_tokens[:-1]
|
|
268
|
+
target_tokens = all_tokens[1:]
|
|
269
|
+
weights = all_weights[1:]
|
|
270
|
+
|
|
271
|
+
return TinkerDatum(
|
|
272
|
+
input_tokens=input_tokens,
|
|
273
|
+
target_tokens=target_tokens,
|
|
274
|
+
weights=weights,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def prepare_datum_from_tokens(
|
|
278
|
+
self,
|
|
279
|
+
tokens: List[int],
|
|
280
|
+
masks: List[int],
|
|
281
|
+
) -> TinkerDatum:
|
|
282
|
+
"""
|
|
283
|
+
Create Datum from pre-tokenized data (e.g., from Atropos).
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
tokens: Token IDs
|
|
287
|
+
masks: Mask values (-100 for no loss, token_id for loss)
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
TinkerDatum ready for training
|
|
291
|
+
"""
|
|
292
|
+
# Convert masks to weights (0 for -100, 1 otherwise)
|
|
293
|
+
weights = [0.0 if m == -100 else 1.0 for m in masks]
|
|
294
|
+
|
|
295
|
+
# Shift for next-token prediction
|
|
296
|
+
input_tokens = tokens[:-1]
|
|
297
|
+
target_tokens = tokens[1:]
|
|
298
|
+
weights = weights[1:]
|
|
299
|
+
|
|
300
|
+
return TinkerDatum(
|
|
301
|
+
input_tokens=input_tokens,
|
|
302
|
+
target_tokens=target_tokens,
|
|
303
|
+
weights=weights,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
def train_step(
|
|
307
|
+
self,
|
|
308
|
+
data: Sequence[TinkerDatum],
|
|
309
|
+
scores: List[float],
|
|
310
|
+
loss_fn: Literal["cross_entropy", "importance_sampling"] = "importance_sampling",
|
|
311
|
+
) -> TrainStepResult:
|
|
312
|
+
"""
|
|
313
|
+
Execute one training step with Tinker.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
data: List of TinkerDatum objects
|
|
317
|
+
scores: Advantage scores for each datum (should be centered at 0)
|
|
318
|
+
loss_fn: Loss function to use
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
TrainStepResult with loss and metrics
|
|
322
|
+
"""
|
|
323
|
+
if not data:
|
|
324
|
+
return TrainStepResult(loss=0.0, num_samples=0)
|
|
325
|
+
|
|
326
|
+
# Convert to Tinker format and apply advantage weights
|
|
327
|
+
tinker_data = []
|
|
328
|
+
for datum, score in zip(data, scores):
|
|
329
|
+
tinker_datum = datum.to_tinker()
|
|
330
|
+
|
|
331
|
+
# Scale weights by advantage for GRPO/IS
|
|
332
|
+
# Positive advantage = learn this behavior
|
|
333
|
+
# Negative advantage = unlearn this behavior
|
|
334
|
+
scaled_weights = [w * score for w in datum.weights]
|
|
335
|
+
tinker_datum.loss_fn_inputs["weights"] = scaled_weights
|
|
336
|
+
|
|
337
|
+
tinker_data.append(tinker_datum)
|
|
338
|
+
|
|
339
|
+
# Forward-backward pass (async submission)
|
|
340
|
+
fwdbwd_future = self.training_client.forward_backward(tinker_data, loss_fn)
|
|
341
|
+
|
|
342
|
+
# Optimizer step (async submission)
|
|
343
|
+
optim_future = self.training_client.optim_step(
|
|
344
|
+
tinker_types.AdamParams(
|
|
345
|
+
learning_rate=self.config.learning_rate,
|
|
346
|
+
beta1=self.config.beta1,
|
|
347
|
+
beta2=self.config.beta2,
|
|
348
|
+
epsilon=self.config.epsilon,
|
|
349
|
+
)
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Wait for results
|
|
353
|
+
fwdbwd_result = fwdbwd_future.result()
|
|
354
|
+
_ = optim_future.result() # Just wait for completion
|
|
355
|
+
|
|
356
|
+
# Compute metrics
|
|
357
|
+
all_logprobs = []
|
|
358
|
+
all_weights = []
|
|
359
|
+
for output, datum in zip(fwdbwd_result.loss_fn_outputs, tinker_data):
|
|
360
|
+
logprobs = output["logprobs"].tolist()
|
|
361
|
+
weights = datum.loss_fn_inputs["weights"]
|
|
362
|
+
all_logprobs.extend(logprobs)
|
|
363
|
+
all_weights.extend(weights if isinstance(weights, list) else weights.tolist())
|
|
364
|
+
|
|
365
|
+
# Compute weighted loss
|
|
366
|
+
logprobs_arr = np.array(all_logprobs)
|
|
367
|
+
weights_arr = np.array(all_weights)
|
|
368
|
+
|
|
369
|
+
weight_sum = np.sum(np.abs(weights_arr))
|
|
370
|
+
if weight_sum > 1e-8:
|
|
371
|
+
loss = float(-np.dot(logprobs_arr, weights_arr) / weight_sum)
|
|
372
|
+
logprobs_mean = float(np.mean(logprobs_arr))
|
|
373
|
+
else:
|
|
374
|
+
loss = 0.0
|
|
375
|
+
logprobs_mean = 0.0
|
|
376
|
+
|
|
377
|
+
# Compute advantage statistics
|
|
378
|
+
scores_arr = np.array(scores)
|
|
379
|
+
pos_mask = scores_arr > 0
|
|
380
|
+
neg_mask = scores_arr <= 0
|
|
381
|
+
|
|
382
|
+
pos_advantage_mean = float(np.mean(scores_arr[pos_mask])) if np.any(pos_mask) else 0.0
|
|
383
|
+
neg_advantage_mean = float(np.mean(scores_arr[neg_mask])) if np.any(neg_mask) else 0.0
|
|
384
|
+
|
|
385
|
+
self._current_step += 1
|
|
386
|
+
|
|
387
|
+
return TrainStepResult(
|
|
388
|
+
loss=loss,
|
|
389
|
+
num_samples=len(data),
|
|
390
|
+
logprobs_mean=logprobs_mean,
|
|
391
|
+
pos_advantage_mean=pos_advantage_mean,
|
|
392
|
+
neg_advantage_mean=neg_advantage_mean,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
def sync_weights(self, name: str | None = None) -> None:
|
|
396
|
+
"""
|
|
397
|
+
Sync training weights to sampling client.
|
|
398
|
+
|
|
399
|
+
This updates the sampling client to use the latest trained weights.
|
|
400
|
+
Should be called periodically during training.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
name: Checkpoint name (auto-generated if not provided)
|
|
404
|
+
"""
|
|
405
|
+
if name is None:
|
|
406
|
+
name = f"{self.config.checkpoint_name_prefix}-step-{self._current_step}"
|
|
407
|
+
|
|
408
|
+
logger.info(f"Syncing weights to sampling client: {name}")
|
|
409
|
+
|
|
410
|
+
self._sampling_client = self.training_client.save_weights_and_get_sampling_client(
|
|
411
|
+
name=name
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
def sample(
|
|
415
|
+
self,
|
|
416
|
+
messages: List[dict],
|
|
417
|
+
max_tokens: int | None = None,
|
|
418
|
+
temperature: float | None = None,
|
|
419
|
+
n: int = 1,
|
|
420
|
+
stop: List[str] | None = None,
|
|
421
|
+
include_logprobs: bool = False,
|
|
422
|
+
) -> SampleResult:
|
|
423
|
+
"""
|
|
424
|
+
Sample completions from current model.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
messages: Chat messages to complete
|
|
428
|
+
max_tokens: Maximum tokens to generate
|
|
429
|
+
temperature: Sampling temperature
|
|
430
|
+
n: Number of completions to generate
|
|
431
|
+
stop: Stop sequences
|
|
432
|
+
include_logprobs: Whether to include logprobs
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
SampleResult with completions and optional logprobs
|
|
436
|
+
"""
|
|
437
|
+
max_tokens = max_tokens or self.config.default_max_tokens
|
|
438
|
+
temperature = temperature if temperature is not None else self.config.default_temperature
|
|
439
|
+
stop = stop or self.config.stop_sequences
|
|
440
|
+
|
|
441
|
+
# Render prompt
|
|
442
|
+
prompt = self.tokenizer.apply_chat_template(
|
|
443
|
+
messages,
|
|
444
|
+
tokenize=False,
|
|
445
|
+
add_generation_prompt=True,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
# Tokenize
|
|
449
|
+
prompt_tokens = tinker_types.ModelInput.from_ints(
|
|
450
|
+
self.tokenizer.encode(prompt)
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# Sampling params
|
|
454
|
+
params = tinker_types.SamplingParams(
|
|
455
|
+
max_tokens=max_tokens,
|
|
456
|
+
temperature=temperature,
|
|
457
|
+
stop=stop,
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
# Sample
|
|
461
|
+
result = self.sampling_client.sample(
|
|
462
|
+
prompt=prompt_tokens,
|
|
463
|
+
sampling_params=params,
|
|
464
|
+
num_samples=n,
|
|
465
|
+
include_prompt_logprobs=include_logprobs,
|
|
466
|
+
).result()
|
|
467
|
+
|
|
468
|
+
# Decode completions
|
|
469
|
+
completions = [
|
|
470
|
+
self.tokenizer.decode(seq.tokens)
|
|
471
|
+
for seq in result.sequences
|
|
472
|
+
]
|
|
473
|
+
|
|
474
|
+
# Extract logprobs if requested
|
|
475
|
+
logprobs = []
|
|
476
|
+
if include_logprobs and hasattr(result, "prompt_logprobs"):
|
|
477
|
+
logprobs = [result.prompt_logprobs] * n
|
|
478
|
+
|
|
479
|
+
# Extract finish reasons
|
|
480
|
+
finish_reasons = [
|
|
481
|
+
getattr(seq, "finish_reason", "stop")
|
|
482
|
+
for seq in result.sequences
|
|
483
|
+
]
|
|
484
|
+
|
|
485
|
+
return SampleResult(
|
|
486
|
+
completions=completions,
|
|
487
|
+
logprobs=logprobs,
|
|
488
|
+
finish_reasons=finish_reasons,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
def compute_logprobs(
|
|
492
|
+
self,
|
|
493
|
+
messages: List[dict],
|
|
494
|
+
completion: str,
|
|
495
|
+
) -> List[float]:
|
|
496
|
+
"""
|
|
497
|
+
Compute logprobs for a specific completion.
|
|
498
|
+
|
|
499
|
+
Useful for importance sampling and evaluation.
|
|
500
|
+
|
|
501
|
+
Args:
|
|
502
|
+
messages: Chat messages
|
|
503
|
+
completion: Completion to compute logprobs for
|
|
504
|
+
|
|
505
|
+
Returns:
|
|
506
|
+
List of logprobs for each token
|
|
507
|
+
"""
|
|
508
|
+
# Build full sequence
|
|
509
|
+
prompt = self.tokenizer.apply_chat_template(
|
|
510
|
+
messages,
|
|
511
|
+
tokenize=False,
|
|
512
|
+
add_generation_prompt=True,
|
|
513
|
+
)
|
|
514
|
+
full_text = prompt + completion
|
|
515
|
+
|
|
516
|
+
prompt_tokens = tinker_types.ModelInput.from_ints(
|
|
517
|
+
self.tokenizer.encode(full_text)
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
# Compute logprobs via prefill
|
|
521
|
+
result = self.sampling_client.sample(
|
|
522
|
+
prompt=prompt_tokens,
|
|
523
|
+
num_samples=1,
|
|
524
|
+
sampling_params=tinker_types.SamplingParams(max_tokens=1),
|
|
525
|
+
include_prompt_logprobs=True,
|
|
526
|
+
).result()
|
|
527
|
+
|
|
528
|
+
# Return logprobs (first is None for first token)
|
|
529
|
+
logprobs = result.prompt_logprobs or []
|
|
530
|
+
return [lp if lp is not None else 0.0 for lp in logprobs]
|
|
531
|
+
|
|
532
|
+
def save_weights(self, name: str) -> str:
|
|
533
|
+
"""
|
|
534
|
+
Save current weights to Tinker storage.
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
name: Name for the saved weights
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
Weight identifier
|
|
541
|
+
"""
|
|
542
|
+
logger.info(f"Saving weights: {name}")
|
|
543
|
+
return self.training_client.save_weights(name=name)
|
|
544
|
+
|
|
545
|
+
def load_weights(self, name: str) -> None:
|
|
546
|
+
"""
|
|
547
|
+
Load weights from Tinker storage.
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
name: Name of weights to load
|
|
551
|
+
"""
|
|
552
|
+
logger.info(f"Loading weights: {name}")
|
|
553
|
+
self.training_client.load_weights(name=name)
|
|
554
|
+
|
|
555
|
+
# Update sampling client with loaded weights
|
|
556
|
+
self.sync_weights(name=f"{name}-loaded")
|
|
557
|
+
|
|
558
|
+
def get_available_models(self) -> List[str]:
|
|
559
|
+
"""Get list of available base models from Tinker"""
|
|
560
|
+
capabilities = self.service_client.get_server_capabilities()
|
|
561
|
+
return [m.model_name for m in capabilities.supported_models]
|
|
562
|
+
|
|
563
|
+
@property
|
|
564
|
+
def current_step(self) -> int:
|
|
565
|
+
"""Get current training step"""
|
|
566
|
+
return self._current_step
|
|
567
|
+
|
|
568
|
+
@property
|
|
569
|
+
def is_initialized(self) -> bool:
|
|
570
|
+
"""Check if client is initialized"""
|
|
571
|
+
return self._initialized
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
# Backward compatibility alias while imports migrate.
|
|
575
|
+
BabylonTinkerClient = TinkerClient
|