@elizaos/training 2.0.0-alpha.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. package/Dockerfile +75 -0
  2. package/Makefile +374 -0
  3. package/README.md +346 -0
  4. package/config/rubrics.json +137 -0
  5. package/data/.gitkeep +0 -0
  6. package/data/degen/.gitkeep +2 -0
  7. package/data/trader/.gitkeep +2 -0
  8. package/docker-compose.test.yml +57 -0
  9. package/package.json +58 -0
  10. package/python/config/babylon_atropos.yaml +90 -0
  11. package/python/config/profiles/12gb.json +11 -0
  12. package/python/config/profiles/16gb.json +10 -0
  13. package/python/config/profiles/24gb.json +10 -0
  14. package/python/config/profiles/48gb.json +10 -0
  15. package/python/config/profiles/cpu.json +11 -0
  16. package/python/config/profiles/l40-2gpu-safe.json +20 -0
  17. package/python/config/profiles/l40-2gpu.json +22 -0
  18. package/python/config/profiles/l40-4gpu.json +21 -0
  19. package/python/config/profiles/l40.json +17 -0
  20. package/python/config/tinker_training.yaml +143 -0
  21. package/python/curriculum_state.json +165 -0
  22. package/python/env.template +86 -0
  23. package/python/env.training.template +46 -0
  24. package/python/pyproject.toml +41 -0
  25. package/python/requirements-ci.txt +31 -0
  26. package/python/requirements.txt +87 -0
  27. package/python/scripts/__init__.py +4 -0
  28. package/python/scripts/import_json_trajectories.py +412 -0
  29. package/python/scripts/local-finetune/README.md +63 -0
  30. package/python/scripts/local-finetune/ingest_and_score.py +139 -0
  31. package/python/scripts/local-finetune/merge_model.py +32 -0
  32. package/python/scripts/local-finetune/test_adapter.py +91 -0
  33. package/python/scripts/local-finetune/train_from_csv.py +132 -0
  34. package/python/scripts/merge_trajectories.py +318 -0
  35. package/python/scripts/run_ab_test.py +143 -0
  36. package/python/scripts/run_full_pipeline.py +544 -0
  37. package/python/scripts/run_tinker_training.py +192 -0
  38. package/python/scripts/run_training.py +914 -0
  39. package/python/scripts/test_judge.py +155 -0
  40. package/python/scripts/test_pipeline.py +356 -0
  41. package/python/scripts/test_trained_model.py +380 -0
  42. package/python/scripts/train_local.py +528 -0
  43. package/python/setup.py +20 -0
  44. package/python/src/__init__.py +190 -0
  45. package/python/src/data_bridge/__init__.py +24 -0
  46. package/python/src/data_bridge/converter.py +435 -0
  47. package/python/src/data_bridge/reader.py +393 -0
  48. package/python/src/models.py +283 -0
  49. package/python/src/training/__init__.py +605 -0
  50. package/python/src/training/ab_testing.py +404 -0
  51. package/python/src/training/action_executor.py +621 -0
  52. package/python/src/training/archetype_trainer.py +347 -0
  53. package/python/src/training/atropos_trainer.py +980 -0
  54. package/python/src/training/babylon_env.py +1254 -0
  55. package/python/src/training/error_recovery.py +647 -0
  56. package/python/src/training/evaluation.py +856 -0
  57. package/python/src/training/fast_simulator.py +880 -0
  58. package/python/src/training/format_validator.py +584 -0
  59. package/python/src/training/hybrid_env.py +522 -0
  60. package/python/src/training/kl_controller.py +628 -0
  61. package/python/src/training/multi_prompt_dataset.py +883 -0
  62. package/python/src/training/multi_turn.py +656 -0
  63. package/python/src/training/online_env.py +1084 -0
  64. package/python/src/training/quality_scorer.py +391 -0
  65. package/python/src/training/quality_utils.py +633 -0
  66. package/python/src/training/rewards.py +1344 -0
  67. package/python/src/training/rlaif_env.py +17 -0
  68. package/python/src/training/rollout_generator.py +502 -0
  69. package/python/src/training/rubric_loader.py +198 -0
  70. package/python/src/training/scenario_pool.py +1072 -0
  71. package/python/src/training/schemas.py +481 -0
  72. package/python/src/training/service_manager.py +552 -0
  73. package/python/src/training/simulation_bridge.py +535 -0
  74. package/python/src/training/tick_reward_attribution.py +399 -0
  75. package/python/src/training/tinker_client.py +575 -0
  76. package/python/src/training/tinker_trainer.py +646 -0
  77. package/python/src/training/tokenization_utils.py +402 -0
  78. package/python/tests/e2e/__init__.py +13 -0
  79. package/python/tests/e2e/conftest.py +258 -0
  80. package/python/tests/e2e/test_full_pipeline.py +643 -0
  81. package/python/tests/e2e/test_online_training_e2e.py +365 -0
  82. package/python/tests/integration/__init__.py +12 -0
  83. package/python/tests/integration/conftest.py +383 -0
  84. package/python/tests/integration/test_db_integration.py +649 -0
  85. package/python/tests/integration/test_json_mode_integration.py +554 -0
  86. package/python/tests/test_action_executor.py +594 -0
  87. package/python/tests/test_archetype_scoring.py +1027 -0
  88. package/python/tests/test_atropos_integration.py +360 -0
  89. package/python/tests/test_evaluation.py +727 -0
  90. package/python/tests/test_format_validator.py +486 -0
  91. package/python/tests/test_kl_controller.py +432 -0
  92. package/python/tests/test_lr_scheduler.py +579 -0
  93. package/python/tests/test_multi_turn.py +590 -0
  94. package/python/tests/test_online_env.py +519 -0
  95. package/python/tests/test_quality_scorer.py +474 -0
  96. package/python/tests/test_scenario_pool.py +735 -0
  97. package/python/tests/test_service_manager.py +585 -0
  98. package/python/tests/test_simulation_rollout.py +581 -0
  99. package/python/tests/test_tokenization_utils.py +501 -0
  100. package/python/tests/test_training_orchestrator.py +497 -0
  101. package/python/tests/test_training_output_structure.py +661 -0
  102. package/research-output/training-runs/training-run-1770772042899.json +26 -0
  103. package/research-output/training-runs/training-run-1770930079670.json +32 -0
  104. package/research-output/training-runs/training-run-1770930143700.json +44 -0
  105. package/research-output/training-runs/training-run-1770930183638.json +38 -0
  106. package/research-output/training-runs/training-run-1770930442049.json +38 -0
  107. package/research-output/training-runs/training-run-1770930793243.json +38 -0
  108. package/scripts/assess-training-data.ts +422 -0
  109. package/scripts/e2e-training-test.ts +550 -0
  110. package/scripts/export-rubrics.ts +64 -0
  111. package/scripts/generate-research-report.ts +1523 -0
  112. package/scripts/generate_dataset.sh +173 -0
  113. package/scripts/json-mode-benchmark.ts +399 -0
  114. package/scripts/real-archetype-benchmark.ts +210 -0
  115. package/scripts/run-baseline-comparison.ts +116 -0
  116. package/scripts/run-full-pipeline.ts +272 -0
  117. package/scripts/runpod_setup.sh +137 -0
  118. package/scripts/runpod_validate.sh +147 -0
  119. package/scripts/test-model-in-game.ts +955 -0
  120. package/scripts/test-scoring.ts +73 -0
  121. package/scripts/test-trained-model.ts +209 -0
  122. package/scripts/train-and-test.ts +824 -0
  123. package/scripts/verify-final.ts +118 -0
  124. package/src/adapter.ts +516 -0
  125. package/src/archetypes/ArchetypeConfigService.ts +626 -0
  126. package/src/archetypes/derive-archetype.ts +249 -0
  127. package/src/archetypes/index.ts +22 -0
  128. package/src/benchmark/ArchetypeMatchupBenchmark.ts +825 -0
  129. package/src/benchmark/BenchmarkChartGenerator.ts +748 -0
  130. package/src/benchmark/BenchmarkDataGenerator.ts +1288 -0
  131. package/src/benchmark/BenchmarkDataViewer.ts +324 -0
  132. package/src/benchmark/BenchmarkHistoryService.ts +221 -0
  133. package/src/benchmark/BenchmarkRunner.ts +685 -0
  134. package/src/benchmark/BenchmarkValidator.ts +206 -0
  135. package/src/benchmark/FastEvalRunner.ts +225 -0
  136. package/src/benchmark/MetricsValidator.ts +165 -0
  137. package/src/benchmark/MetricsVisualizer.ts +909 -0
  138. package/src/benchmark/ModelBenchmarkService.ts +611 -0
  139. package/src/benchmark/ModelRegistry.ts +158 -0
  140. package/src/benchmark/RulerBenchmarkIntegration.ts +235 -0
  141. package/src/benchmark/SimulationA2AInterface.ts +1169 -0
  142. package/src/benchmark/SimulationEngine.ts +832 -0
  143. package/src/benchmark/__tests__/BenchmarkRunner.test.ts +534 -0
  144. package/src/benchmark/__tests__/HeadToHead.test.ts +126 -0
  145. package/src/benchmark/index.ts +89 -0
  146. package/src/benchmark/parseSimulationMetrics.ts +124 -0
  147. package/src/benchmark/simulation-types.ts +78 -0
  148. package/src/dependencies.ts +439 -0
  149. package/src/generation/TrajectoryGenerator.ts +387 -0
  150. package/src/generation/index.ts +12 -0
  151. package/src/huggingface/HuggingFaceDatasetUploader.ts +636 -0
  152. package/src/huggingface/HuggingFaceIntegrationService.ts +426 -0
  153. package/src/huggingface/HuggingFaceModelUploader.ts +532 -0
  154. package/src/huggingface/index.ts +27 -0
  155. package/src/huggingface/shared/HuggingFaceUploadUtil.ts +206 -0
  156. package/src/index.ts +102 -0
  157. package/src/init-training.ts +53 -0
  158. package/src/metrics/TrajectoryMetricsExtractor.ts +653 -0
  159. package/src/metrics/__tests__/TrajectoryMetricsExtractor.test.ts +759 -0
  160. package/src/metrics/index.ts +8 -0
  161. package/src/metrics/types.ts +200 -0
  162. package/src/rubrics/__tests__/index.test.ts +184 -0
  163. package/src/rubrics/ass-kisser.ts +85 -0
  164. package/src/rubrics/degen.ts +80 -0
  165. package/src/rubrics/goody-twoshoes.ts +84 -0
  166. package/src/rubrics/index.ts +236 -0
  167. package/src/rubrics/information-trader.ts +84 -0
  168. package/src/rubrics/infosec.ts +101 -0
  169. package/src/rubrics/liar.ts +104 -0
  170. package/src/rubrics/perps-trader.ts +87 -0
  171. package/src/rubrics/researcher.ts +81 -0
  172. package/src/rubrics/scammer.ts +82 -0
  173. package/src/rubrics/social-butterfly.ts +73 -0
  174. package/src/rubrics/super-predictor.ts +97 -0
  175. package/src/rubrics/trader.ts +67 -0
  176. package/src/scoring/ArchetypeScoringService.ts +486 -0
  177. package/src/scoring/JudgePromptBuilder.ts +556 -0
  178. package/src/scoring/LLMJudgeCache.ts +401 -0
  179. package/src/scoring/index.ts +9 -0
  180. package/src/training/AutomationPipeline.ts +916 -0
  181. package/src/training/BenchmarkService.ts +518 -0
  182. package/src/training/ConfigValidator.ts +220 -0
  183. package/src/training/MarketOutcomesTracker.ts +187 -0
  184. package/src/training/ModelDeployer.ts +186 -0
  185. package/src/training/ModelFetcher.ts +76 -0
  186. package/src/training/ModelSelectionService.ts +341 -0
  187. package/src/training/ModelUsageVerifier.ts +160 -0
  188. package/src/training/MultiModelOrchestrator.ts +580 -0
  189. package/src/training/RLModelConfig.ts +407 -0
  190. package/src/training/RewardBackpropagationService.ts +149 -0
  191. package/src/training/RulerScoringService.ts +666 -0
  192. package/src/training/TrainingMonitor.ts +166 -0
  193. package/src/training/TrajectoryRecorder.ts +399 -0
  194. package/src/training/__tests__/TrajectoryRecorder.test.ts +472 -0
  195. package/src/training/index.ts +100 -0
  196. package/src/training/logRLConfig.ts +34 -0
  197. package/src/training/pipeline.ts +129 -0
  198. package/src/training/storage/ModelStorageService.ts +279 -0
  199. package/src/training/storage/TrainingDataArchiver.ts +197 -0
  200. package/src/training/storage/index.ts +17 -0
  201. package/src/training/types.ts +207 -0
  202. package/src/training/window-utils.ts +138 -0
  203. package/src/utils/index.ts +101 -0
  204. package/src/utils/logger.ts +59 -0
  205. package/src/utils/snowflake.ts +17 -0
  206. package/src/utils/synthetic-detector.ts +111 -0
  207. package/tsconfig.json +20 -0
@@ -0,0 +1,575 @@
1
+ """
2
+ Tinker client for RL training.
3
+
4
+ Replaces local vLLM + PyTorch training with Tinker's cloud API.
5
+ This provides a unified interface for both training and inference.
6
+
7
+ Based on: https://tinker-docs.thinkingmachines.ai/training-sampling
8
+ Integration pattern from: tinker-atropos (Nous Research)
9
+
10
+ Key features:
11
+ - TrainingClient for forward_backward + optim_step
12
+ - SamplingClient for inference during rollouts
13
+ - Weight synchronization between training and sampling
14
+ - Automatic tokenization and format conversion
15
+ """
16
+
17
+ import logging
18
+ import os
19
+ from dataclasses import dataclass, field
20
+ from typing import List, Literal, Sequence
21
+
22
+ import numpy as np
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Lazy import tinker to allow graceful degradation
27
+ try:
28
+ import tinker
29
+ from tinker import types as tinker_types
30
+
31
+ TINKER_AVAILABLE = True
32
+ except ImportError:
33
+ TINKER_AVAILABLE = False
34
+ tinker = None # type: ignore
35
+ tinker_types = None # type: ignore
36
+ logger.warning("Tinker not installed. Install with: pip install tinker")
37
+
38
+
39
+ @dataclass
40
+ class TinkerConfig:
41
+ """Configuration for Tinker client"""
42
+
43
+ # Model settings
44
+ base_model: str = "Qwen/Qwen3-30B-A3B-Instruct"
45
+ lora_rank: int = 32
46
+
47
+ # Training hyperparameters
48
+ learning_rate: float = 4e-5
49
+ beta1: float = 0.9
50
+ beta2: float = 0.95
51
+ epsilon: float = 1e-8
52
+
53
+ # Sampling settings
54
+ default_max_tokens: int = 512
55
+ default_temperature: float = 0.7
56
+ stop_sequences: List[str] = field(
57
+ default_factory=lambda: ["\n\n", "<|endoftext|>", "<|im_end|>"]
58
+ )
59
+
60
+ # Weight sync settings
61
+ checkpoint_name_prefix: str = "eliza"
62
+
63
+
64
+ class TinkerDatum:
65
+ """
66
+ Wrapper for Tinker Datum to avoid direct tinker_types dependency.
67
+
68
+ This allows code to work even when tinker is not installed.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ input_tokens: List[int],
74
+ target_tokens: List[int],
75
+ weights: List[float],
76
+ ):
77
+ self.input_tokens = input_tokens
78
+ self.target_tokens = target_tokens
79
+ self.weights = weights
80
+ self._tinker_datum: object = None
81
+
82
+ def to_tinker(self) -> object:
83
+ """Convert to actual Tinker Datum"""
84
+ if not TINKER_AVAILABLE:
85
+ raise RuntimeError("Tinker not installed")
86
+
87
+ if self._tinker_datum is None:
88
+ self._tinker_datum = tinker_types.Datum(
89
+ model_input=tinker_types.ModelInput.from_ints(tokens=self.input_tokens),
90
+ loss_fn_inputs=dict(
91
+ weights=self.weights,
92
+ target_tokens=self.target_tokens,
93
+ ),
94
+ )
95
+ return self._tinker_datum
96
+
97
+
98
+ @dataclass
99
+ class TrainStepResult:
100
+ """Result from a training step"""
101
+
102
+ loss: float
103
+ num_samples: int
104
+ logprobs_mean: float = 0.0
105
+ pos_advantage_mean: float = 0.0
106
+ neg_advantage_mean: float = 0.0
107
+
108
+
109
+ @dataclass
110
+ class SampleResult:
111
+ """Result from sampling"""
112
+
113
+ completions: List[str]
114
+ logprobs: List[List[float]] = field(default_factory=list)
115
+ finish_reasons: List[str] = field(default_factory=list)
116
+
117
+
118
+ class TinkerClient:
119
+ """
120
+ Unified Tinker client for training and inference.
121
+
122
+ This replaces local vLLM + PyTorch training with Tinker's cloud API:
123
+ - No local GPU required for training
124
+ - Training happens in Tinker cloud
125
+ - Fast weight sync between training and sampling
126
+ - Automatic format conversion
127
+
128
+ Usage:
129
+ client = TinkerClient(config)
130
+ client.setup()
131
+
132
+ # Training
133
+ data = [client.prepare_datum(messages, completion) for ...]
134
+ result = client.train_step(data, scores)
135
+
136
+ # Inference
137
+ completions = client.sample(messages)
138
+
139
+ # Sync weights after training
140
+ client.sync_weights("checkpoint-name")
141
+ """
142
+
143
+ def __init__(self, config: TinkerConfig | None = None):
144
+ if not TINKER_AVAILABLE:
145
+ raise RuntimeError(
146
+ "Tinker not installed. Install with: pip install tinker"
147
+ )
148
+
149
+ self.config = config or TinkerConfig()
150
+ self._service_client: object = None
151
+ self._training_client: object = None
152
+ self._sampling_client: object = None
153
+ self._tokenizer: object = None
154
+ self._initialized = False
155
+ self._current_step = 0
156
+
157
+ @property
158
+ def service_client(self) -> object:
159
+ """Lazily initialize service client"""
160
+ if self._service_client is None:
161
+ self._service_client = tinker.ServiceClient()
162
+ return self._service_client
163
+
164
+ @property
165
+ def training_client(self) -> object:
166
+ """Get training client (must call setup first)"""
167
+ if self._training_client is None:
168
+ raise RuntimeError("Client not initialized. Call setup() first.")
169
+ return self._training_client
170
+
171
+ @property
172
+ def sampling_client(self) -> object:
173
+ """Get sampling client (must call setup first)"""
174
+ if self._sampling_client is None:
175
+ raise RuntimeError("Client not initialized. Call setup() first.")
176
+ return self._sampling_client
177
+
178
+ @property
179
+ def tokenizer(self) -> object:
180
+ """Get tokenizer (must call setup first)"""
181
+ if self._tokenizer is None:
182
+ raise RuntimeError("Client not initialized. Call setup() first.")
183
+ return self._tokenizer
184
+
185
+ def setup(self) -> None:
186
+ """
187
+ Initialize training client, sampling client, and tokenizer.
188
+
189
+ Must be called before any training or sampling operations.
190
+ """
191
+ if self._initialized:
192
+ logger.info("Client already initialized")
193
+ return
194
+
195
+ logger.info(f"Initializing Tinker client with model: {self.config.base_model}")
196
+
197
+ # Verify API key is set
198
+ if not os.environ.get("TINKER_API_KEY"):
199
+ raise ValueError(
200
+ "TINKER_API_KEY environment variable not set. "
201
+ "Get your API key from Thinking Machines."
202
+ )
203
+
204
+ # Check model availability
205
+ capabilities = self.service_client.get_server_capabilities()
206
+ available_models = [m.model_name for m in capabilities.supported_models]
207
+
208
+ if self.config.base_model not in available_models:
209
+ logger.warning(
210
+ f"Model {self.config.base_model} not in available models. "
211
+ f"Available: {available_models[:5]}..."
212
+ )
213
+
214
+ # Create training client with LoRA
215
+ self._training_client = self.service_client.create_lora_training_client(
216
+ base_model=self.config.base_model,
217
+ lora_rank=self.config.lora_rank,
218
+ )
219
+
220
+ # Get tokenizer
221
+ self._tokenizer = self._training_client.get_tokenizer()
222
+
223
+ # Create initial sampling client
224
+ initial_name = f"{self.config.checkpoint_name_prefix}-initial"
225
+ self._sampling_client = self._training_client.save_weights_and_get_sampling_client(
226
+ name=initial_name
227
+ )
228
+
229
+ self._initialized = True
230
+ logger.info("Tinker client initialized successfully")
231
+
232
+ def prepare_datum(
233
+ self,
234
+ messages: List[dict],
235
+ completion: str,
236
+ ) -> TinkerDatum:
237
+ """
238
+ Convert chat messages + completion to Tinker Datum.
239
+
240
+ Args:
241
+ messages: List of chat messages (role/content dicts)
242
+ completion: The assistant completion to train on
243
+
244
+ Returns:
245
+ TinkerDatum ready for training
246
+ """
247
+ # Render messages to prompt using chat template
248
+ prompt = self.tokenizer.apply_chat_template(
249
+ messages,
250
+ tokenize=False,
251
+ add_generation_prompt=True,
252
+ )
253
+
254
+ # Tokenize prompt (no loss on prompt tokens)
255
+ prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
256
+ prompt_weights = [0.0] * len(prompt_tokens)
257
+
258
+ # Tokenize completion (loss on these tokens)
259
+ completion_tokens = self.tokenizer.encode(completion, add_special_tokens=False)
260
+ completion_weights = [1.0] * len(completion_tokens)
261
+
262
+ # Combine
263
+ all_tokens = prompt_tokens + completion_tokens
264
+ all_weights = prompt_weights + completion_weights
265
+
266
+ # Shift for next-token prediction
267
+ input_tokens = all_tokens[:-1]
268
+ target_tokens = all_tokens[1:]
269
+ weights = all_weights[1:]
270
+
271
+ return TinkerDatum(
272
+ input_tokens=input_tokens,
273
+ target_tokens=target_tokens,
274
+ weights=weights,
275
+ )
276
+
277
+ def prepare_datum_from_tokens(
278
+ self,
279
+ tokens: List[int],
280
+ masks: List[int],
281
+ ) -> TinkerDatum:
282
+ """
283
+ Create Datum from pre-tokenized data (e.g., from Atropos).
284
+
285
+ Args:
286
+ tokens: Token IDs
287
+ masks: Mask values (-100 for no loss, token_id for loss)
288
+
289
+ Returns:
290
+ TinkerDatum ready for training
291
+ """
292
+ # Convert masks to weights (0 for -100, 1 otherwise)
293
+ weights = [0.0 if m == -100 else 1.0 for m in masks]
294
+
295
+ # Shift for next-token prediction
296
+ input_tokens = tokens[:-1]
297
+ target_tokens = tokens[1:]
298
+ weights = weights[1:]
299
+
300
+ return TinkerDatum(
301
+ input_tokens=input_tokens,
302
+ target_tokens=target_tokens,
303
+ weights=weights,
304
+ )
305
+
306
+ def train_step(
307
+ self,
308
+ data: Sequence[TinkerDatum],
309
+ scores: List[float],
310
+ loss_fn: Literal["cross_entropy", "importance_sampling"] = "importance_sampling",
311
+ ) -> TrainStepResult:
312
+ """
313
+ Execute one training step with Tinker.
314
+
315
+ Args:
316
+ data: List of TinkerDatum objects
317
+ scores: Advantage scores for each datum (should be centered at 0)
318
+ loss_fn: Loss function to use
319
+
320
+ Returns:
321
+ TrainStepResult with loss and metrics
322
+ """
323
+ if not data:
324
+ return TrainStepResult(loss=0.0, num_samples=0)
325
+
326
+ # Convert to Tinker format and apply advantage weights
327
+ tinker_data = []
328
+ for datum, score in zip(data, scores):
329
+ tinker_datum = datum.to_tinker()
330
+
331
+ # Scale weights by advantage for GRPO/IS
332
+ # Positive advantage = learn this behavior
333
+ # Negative advantage = unlearn this behavior
334
+ scaled_weights = [w * score for w in datum.weights]
335
+ tinker_datum.loss_fn_inputs["weights"] = scaled_weights
336
+
337
+ tinker_data.append(tinker_datum)
338
+
339
+ # Forward-backward pass (async submission)
340
+ fwdbwd_future = self.training_client.forward_backward(tinker_data, loss_fn)
341
+
342
+ # Optimizer step (async submission)
343
+ optim_future = self.training_client.optim_step(
344
+ tinker_types.AdamParams(
345
+ learning_rate=self.config.learning_rate,
346
+ beta1=self.config.beta1,
347
+ beta2=self.config.beta2,
348
+ epsilon=self.config.epsilon,
349
+ )
350
+ )
351
+
352
+ # Wait for results
353
+ fwdbwd_result = fwdbwd_future.result()
354
+ _ = optim_future.result() # Just wait for completion
355
+
356
+ # Compute metrics
357
+ all_logprobs = []
358
+ all_weights = []
359
+ for output, datum in zip(fwdbwd_result.loss_fn_outputs, tinker_data):
360
+ logprobs = output["logprobs"].tolist()
361
+ weights = datum.loss_fn_inputs["weights"]
362
+ all_logprobs.extend(logprobs)
363
+ all_weights.extend(weights if isinstance(weights, list) else weights.tolist())
364
+
365
+ # Compute weighted loss
366
+ logprobs_arr = np.array(all_logprobs)
367
+ weights_arr = np.array(all_weights)
368
+
369
+ weight_sum = np.sum(np.abs(weights_arr))
370
+ if weight_sum > 1e-8:
371
+ loss = float(-np.dot(logprobs_arr, weights_arr) / weight_sum)
372
+ logprobs_mean = float(np.mean(logprobs_arr))
373
+ else:
374
+ loss = 0.0
375
+ logprobs_mean = 0.0
376
+
377
+ # Compute advantage statistics
378
+ scores_arr = np.array(scores)
379
+ pos_mask = scores_arr > 0
380
+ neg_mask = scores_arr <= 0
381
+
382
+ pos_advantage_mean = float(np.mean(scores_arr[pos_mask])) if np.any(pos_mask) else 0.0
383
+ neg_advantage_mean = float(np.mean(scores_arr[neg_mask])) if np.any(neg_mask) else 0.0
384
+
385
+ self._current_step += 1
386
+
387
+ return TrainStepResult(
388
+ loss=loss,
389
+ num_samples=len(data),
390
+ logprobs_mean=logprobs_mean,
391
+ pos_advantage_mean=pos_advantage_mean,
392
+ neg_advantage_mean=neg_advantage_mean,
393
+ )
394
+
395
+ def sync_weights(self, name: str | None = None) -> None:
396
+ """
397
+ Sync training weights to sampling client.
398
+
399
+ This updates the sampling client to use the latest trained weights.
400
+ Should be called periodically during training.
401
+
402
+ Args:
403
+ name: Checkpoint name (auto-generated if not provided)
404
+ """
405
+ if name is None:
406
+ name = f"{self.config.checkpoint_name_prefix}-step-{self._current_step}"
407
+
408
+ logger.info(f"Syncing weights to sampling client: {name}")
409
+
410
+ self._sampling_client = self.training_client.save_weights_and_get_sampling_client(
411
+ name=name
412
+ )
413
+
414
+ def sample(
415
+ self,
416
+ messages: List[dict],
417
+ max_tokens: int | None = None,
418
+ temperature: float | None = None,
419
+ n: int = 1,
420
+ stop: List[str] | None = None,
421
+ include_logprobs: bool = False,
422
+ ) -> SampleResult:
423
+ """
424
+ Sample completions from current model.
425
+
426
+ Args:
427
+ messages: Chat messages to complete
428
+ max_tokens: Maximum tokens to generate
429
+ temperature: Sampling temperature
430
+ n: Number of completions to generate
431
+ stop: Stop sequences
432
+ include_logprobs: Whether to include logprobs
433
+
434
+ Returns:
435
+ SampleResult with completions and optional logprobs
436
+ """
437
+ max_tokens = max_tokens or self.config.default_max_tokens
438
+ temperature = temperature if temperature is not None else self.config.default_temperature
439
+ stop = stop or self.config.stop_sequences
440
+
441
+ # Render prompt
442
+ prompt = self.tokenizer.apply_chat_template(
443
+ messages,
444
+ tokenize=False,
445
+ add_generation_prompt=True,
446
+ )
447
+
448
+ # Tokenize
449
+ prompt_tokens = tinker_types.ModelInput.from_ints(
450
+ self.tokenizer.encode(prompt)
451
+ )
452
+
453
+ # Sampling params
454
+ params = tinker_types.SamplingParams(
455
+ max_tokens=max_tokens,
456
+ temperature=temperature,
457
+ stop=stop,
458
+ )
459
+
460
+ # Sample
461
+ result = self.sampling_client.sample(
462
+ prompt=prompt_tokens,
463
+ sampling_params=params,
464
+ num_samples=n,
465
+ include_prompt_logprobs=include_logprobs,
466
+ ).result()
467
+
468
+ # Decode completions
469
+ completions = [
470
+ self.tokenizer.decode(seq.tokens)
471
+ for seq in result.sequences
472
+ ]
473
+
474
+ # Extract logprobs if requested
475
+ logprobs = []
476
+ if include_logprobs and hasattr(result, "prompt_logprobs"):
477
+ logprobs = [result.prompt_logprobs] * n
478
+
479
+ # Extract finish reasons
480
+ finish_reasons = [
481
+ getattr(seq, "finish_reason", "stop")
482
+ for seq in result.sequences
483
+ ]
484
+
485
+ return SampleResult(
486
+ completions=completions,
487
+ logprobs=logprobs,
488
+ finish_reasons=finish_reasons,
489
+ )
490
+
491
+ def compute_logprobs(
492
+ self,
493
+ messages: List[dict],
494
+ completion: str,
495
+ ) -> List[float]:
496
+ """
497
+ Compute logprobs for a specific completion.
498
+
499
+ Useful for importance sampling and evaluation.
500
+
501
+ Args:
502
+ messages: Chat messages
503
+ completion: Completion to compute logprobs for
504
+
505
+ Returns:
506
+ List of logprobs for each token
507
+ """
508
+ # Build full sequence
509
+ prompt = self.tokenizer.apply_chat_template(
510
+ messages,
511
+ tokenize=False,
512
+ add_generation_prompt=True,
513
+ )
514
+ full_text = prompt + completion
515
+
516
+ prompt_tokens = tinker_types.ModelInput.from_ints(
517
+ self.tokenizer.encode(full_text)
518
+ )
519
+
520
+ # Compute logprobs via prefill
521
+ result = self.sampling_client.sample(
522
+ prompt=prompt_tokens,
523
+ num_samples=1,
524
+ sampling_params=tinker_types.SamplingParams(max_tokens=1),
525
+ include_prompt_logprobs=True,
526
+ ).result()
527
+
528
+ # Return logprobs (first is None for first token)
529
+ logprobs = result.prompt_logprobs or []
530
+ return [lp if lp is not None else 0.0 for lp in logprobs]
531
+
532
+ def save_weights(self, name: str) -> str:
533
+ """
534
+ Save current weights to Tinker storage.
535
+
536
+ Args:
537
+ name: Name for the saved weights
538
+
539
+ Returns:
540
+ Weight identifier
541
+ """
542
+ logger.info(f"Saving weights: {name}")
543
+ return self.training_client.save_weights(name=name)
544
+
545
+ def load_weights(self, name: str) -> None:
546
+ """
547
+ Load weights from Tinker storage.
548
+
549
+ Args:
550
+ name: Name of weights to load
551
+ """
552
+ logger.info(f"Loading weights: {name}")
553
+ self.training_client.load_weights(name=name)
554
+
555
+ # Update sampling client with loaded weights
556
+ self.sync_weights(name=f"{name}-loaded")
557
+
558
+ def get_available_models(self) -> List[str]:
559
+ """Get list of available base models from Tinker"""
560
+ capabilities = self.service_client.get_server_capabilities()
561
+ return [m.model_name for m in capabilities.supported_models]
562
+
563
+ @property
564
+ def current_step(self) -> int:
565
+ """Get current training step"""
566
+ return self._current_step
567
+
568
+ @property
569
+ def is_initialized(self) -> bool:
570
+ """Check if client is initialized"""
571
+ return self._initialized
572
+
573
+
574
+ # Backward compatibility alias while imports migrate.
575
+ BabylonTinkerClient = TinkerClient