mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,676 @@
|
|
|
1
|
+
"""TrainingClient SDK for MLXSmith.
|
|
2
|
+
|
|
3
|
+
Async client for training operations with futures-based API.
|
|
4
|
+
Provides methods for forward/backward passes, optimizer steps,
|
|
5
|
+
checkpoint management, and weight manipulation.
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
>>> from mlxsmith.sdk import TrainingClient
|
|
9
|
+
>>> client = TrainingClient(backend, pool)
|
|
10
|
+
>>>
|
|
11
|
+
>>> # Run training step
|
|
12
|
+
>>> future = client.forward_backward(batch)
|
|
13
|
+
>>> loss, grads = future.result()
|
|
14
|
+
>>>
|
|
15
|
+
>>> # Optimizer step
|
|
16
|
+
>>> client.optim_step(grads).result()
|
|
17
|
+
>>>
|
|
18
|
+
>>> # Save checkpoint
|
|
19
|
+
>>> client.save_state("checkpoint.pt").result()
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
from dataclasses import dataclass
|
|
25
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|
26
|
+
|
|
27
|
+
from .future import APIFuture, SdkFuturePool
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class ForwardBackwardResult:
|
|
32
|
+
"""Result from a forward/backward pass."""
|
|
33
|
+
loss: float
|
|
34
|
+
grads: Any # Backend-specific gradient type
|
|
35
|
+
metrics: Dict[str, float]
|
|
36
|
+
batch_size: int = 1
|
|
37
|
+
has_grads: bool = False
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class OptimizerStepResult:
|
|
42
|
+
"""Result from an optimizer step."""
|
|
43
|
+
step: int
|
|
44
|
+
learning_rate: float
|
|
45
|
+
grad_norm: Optional[float]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class CheckpointResult:
|
|
50
|
+
"""Result from a checkpoint operation."""
|
|
51
|
+
path: str
|
|
52
|
+
success: bool
|
|
53
|
+
message: str
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class WeightsResult:
|
|
58
|
+
"""Result for weight operations."""
|
|
59
|
+
weights: Dict[str, Any]
|
|
60
|
+
success: bool
|
|
61
|
+
message: str
|
|
62
|
+
num_tensors: int = 0
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class TrainingBatch:
|
|
66
|
+
"""A batch of training data.
|
|
67
|
+
|
|
68
|
+
Supports SFT, preference, and custom loss training.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
prompts: List[str],
|
|
74
|
+
responses: Optional[List[str]] = None,
|
|
75
|
+
rejected_responses: Optional[List[str]] = None,
|
|
76
|
+
advantages: Optional[List[float]] = None,
|
|
77
|
+
loss_type: str = "sft",
|
|
78
|
+
train_on_prompt: bool = False,
|
|
79
|
+
max_seq_len: Optional[int] = None,
|
|
80
|
+
extra: Optional[Dict[str, Any]] = None,
|
|
81
|
+
):
|
|
82
|
+
"""Initialize a training batch.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
prompts: List of prompt strings
|
|
86
|
+
responses: List of response strings (for SFT/positive in preference)
|
|
87
|
+
rejected_responses: List of rejected responses (for preference training)
|
|
88
|
+
advantages: List of advantage values (for RL training)
|
|
89
|
+
loss_type: Type of loss - "sft", "dpo", "orpo", "ppo", "custom"
|
|
90
|
+
train_on_prompt: Whether to compute loss on prompt tokens
|
|
91
|
+
max_seq_len: Maximum sequence length
|
|
92
|
+
extra: Additional batch metadata
|
|
93
|
+
"""
|
|
94
|
+
self.prompts = prompts
|
|
95
|
+
self.responses = responses
|
|
96
|
+
self.rejected_responses = rejected_responses
|
|
97
|
+
self.advantages = advantages
|
|
98
|
+
self.loss_type = loss_type
|
|
99
|
+
self.train_on_prompt = train_on_prompt
|
|
100
|
+
self.max_seq_len = max_seq_len
|
|
101
|
+
self.extra = extra or {}
|
|
102
|
+
self._size = len(prompts)
|
|
103
|
+
|
|
104
|
+
def __len__(self) -> int:
|
|
105
|
+
return self._size
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def is_preference(self) -> bool:
|
|
109
|
+
"""Check if this is a preference batch."""
|
|
110
|
+
return self.loss_type in ("dpo", "orpo", "ipo", "preference")
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def is_rl(self) -> bool:
|
|
114
|
+
"""Check if this is an RL batch."""
|
|
115
|
+
return self.loss_type in ("ppo", "grpo", "reinforce")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class TrainingClient:
|
|
119
|
+
"""Async client for training operations.
|
|
120
|
+
|
|
121
|
+
Provides a futures-based API for all training operations, enabling
|
|
122
|
+
concurrent execution and flexible callback handling.
|
|
123
|
+
|
|
124
|
+
Example:
|
|
125
|
+
>>> client = TrainingClient(backend, pool)
|
|
126
|
+
>>>
|
|
127
|
+
>>> # Async training loop
|
|
128
|
+
>>> for batch in dataloader:
|
|
129
|
+
... fb_future = client.forward_backward(batch)
|
|
130
|
+
...
|
|
131
|
+
... # Chain operations with callbacks
|
|
132
|
+
... fb_future.then(lambda r: print(f"Loss: {r.loss}"))
|
|
133
|
+
...
|
|
134
|
+
... # Get result and continue
|
|
135
|
+
... loss, grads = fb_future.result()
|
|
136
|
+
... if grads is not None:
|
|
137
|
+
... client.optim_step(grads).result()
|
|
138
|
+
>>>
|
|
139
|
+
>>> # Save checkpoint
|
|
140
|
+
>>> client.save_state("checkpoint.pt").result()
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
backend: Any,
|
|
146
|
+
pool: Optional[SdkFuturePool] = None,
|
|
147
|
+
optimizer: Optional[Any] = None,
|
|
148
|
+
step: int = 0,
|
|
149
|
+
):
|
|
150
|
+
"""Initialize TrainingClient.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
backend: The LLM backend instance
|
|
154
|
+
pool: Optional SdkFuturePool for async execution (creates default if None)
|
|
155
|
+
optimizer: Optional pre-created optimizer
|
|
156
|
+
step: Initial training step counter
|
|
157
|
+
"""
|
|
158
|
+
self.backend = backend
|
|
159
|
+
self.pool = pool or SdkFuturePool(max_workers=1)
|
|
160
|
+
self.optimizer = optimizer
|
|
161
|
+
self._step = step
|
|
162
|
+
self._training_state: Dict[str, Any] = {}
|
|
163
|
+
self._checkpoint_handlers: Dict[str, Callable] = {}
|
|
164
|
+
|
|
165
|
+
# ========================================================================
|
|
166
|
+
# Core Training Operations
|
|
167
|
+
# ========================================================================
|
|
168
|
+
|
|
169
|
+
def forward_backward(self, batch: TrainingBatch) -> APIFuture[ForwardBackwardResult]:
|
|
170
|
+
"""Run forward and backward pass on a batch.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
batch: TrainingBatch with prompts and responses
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
APIFuture resolving to ForwardBackwardResult
|
|
177
|
+
|
|
178
|
+
Example:
|
|
179
|
+
>>> batch = TrainingBatch(
|
|
180
|
+
... prompts=["What is 2+2?"],
|
|
181
|
+
... responses=["The answer is 4."],
|
|
182
|
+
... loss_type="sft"
|
|
183
|
+
... )
|
|
184
|
+
>>> future = client.forward_backward(batch)
|
|
185
|
+
>>> result = future.result()
|
|
186
|
+
>>> print(f"Loss: {result.loss}")
|
|
187
|
+
"""
|
|
188
|
+
def _run_forward_backward() -> ForwardBackwardResult:
|
|
189
|
+
from . import sft_forward_backward, preference_forward_backward
|
|
190
|
+
|
|
191
|
+
losses = []
|
|
192
|
+
all_grads = []
|
|
193
|
+
|
|
194
|
+
if batch.is_preference:
|
|
195
|
+
# Preference training (DPO, ORPO, etc.)
|
|
196
|
+
if batch.rejected_responses is None:
|
|
197
|
+
raise ValueError(f"Preference batch requires rejected_responses")
|
|
198
|
+
|
|
199
|
+
for prompt, chosen, rejected in zip(
|
|
200
|
+
batch.prompts,
|
|
201
|
+
batch.responses or [],
|
|
202
|
+
batch.rejected_responses
|
|
203
|
+
):
|
|
204
|
+
loss, grads = preference_forward_backward(
|
|
205
|
+
self.backend,
|
|
206
|
+
prompt,
|
|
207
|
+
chosen,
|
|
208
|
+
rejected,
|
|
209
|
+
algo=batch.loss_type,
|
|
210
|
+
beta=batch.extra.get("beta", 0.1),
|
|
211
|
+
reference_backend=batch.extra.get("reference_backend"),
|
|
212
|
+
kl_coeff=batch.extra.get("kl_coeff", 0.0),
|
|
213
|
+
train_on_prompt=batch.train_on_prompt,
|
|
214
|
+
max_seq_len=batch.max_seq_len,
|
|
215
|
+
)
|
|
216
|
+
losses.append(float(loss) if loss is not None else 0.0)
|
|
217
|
+
if grads is not None:
|
|
218
|
+
all_grads.append(grads)
|
|
219
|
+
|
|
220
|
+
elif batch.is_rl:
|
|
221
|
+
# RL training (PPO, etc.)
|
|
222
|
+
# For now, fall back to SFT-style with advantages
|
|
223
|
+
for prompt, response, advantage in zip(
|
|
224
|
+
batch.prompts,
|
|
225
|
+
batch.responses or [],
|
|
226
|
+
batch.advantages or [0.0] * len(batch.prompts)
|
|
227
|
+
):
|
|
228
|
+
# Use SFT forward/backward with modified loss
|
|
229
|
+
loss, grads = sft_forward_backward(
|
|
230
|
+
self.backend,
|
|
231
|
+
prompt,
|
|
232
|
+
response,
|
|
233
|
+
train_on_prompt=batch.train_on_prompt,
|
|
234
|
+
max_seq_len=batch.max_seq_len,
|
|
235
|
+
)
|
|
236
|
+
# Scale by advantage
|
|
237
|
+
if loss is not None and advantage != 0.0:
|
|
238
|
+
loss = loss * advantage
|
|
239
|
+
losses.append(float(loss) if loss is not None else 0.0)
|
|
240
|
+
if grads is not None:
|
|
241
|
+
all_grads.append(grads)
|
|
242
|
+
|
|
243
|
+
else:
|
|
244
|
+
# Standard SFT
|
|
245
|
+
for prompt, response in zip(batch.prompts, batch.responses or []):
|
|
246
|
+
loss, grads = sft_forward_backward(
|
|
247
|
+
self.backend,
|
|
248
|
+
prompt,
|
|
249
|
+
response,
|
|
250
|
+
train_on_prompt=batch.train_on_prompt,
|
|
251
|
+
max_seq_len=batch.max_seq_len,
|
|
252
|
+
)
|
|
253
|
+
losses.append(float(loss) if loss is not None else 0.0)
|
|
254
|
+
if grads is not None:
|
|
255
|
+
all_grads.append(grads)
|
|
256
|
+
|
|
257
|
+
# Average gradients if multiple
|
|
258
|
+
grads = self._aggregate_gradients(all_grads) if all_grads else None
|
|
259
|
+
avg_loss = sum(losses) / len(losses) if losses else 0.0
|
|
260
|
+
|
|
261
|
+
return ForwardBackwardResult(
|
|
262
|
+
loss=avg_loss,
|
|
263
|
+
grads=grads,
|
|
264
|
+
batch_size=len(batch),
|
|
265
|
+
has_grads=grads is not None,
|
|
266
|
+
metrics={
|
|
267
|
+
"avg_loss": avg_loss,
|
|
268
|
+
"num_samples": len(losses),
|
|
269
|
+
}
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
return self.pool.submit(_run_forward_backward)
|
|
273
|
+
|
|
274
|
+
def optim_step(self, grads: Optional[Any] = None) -> APIFuture[OptimizerStepResult]:
|
|
275
|
+
"""Execute optimizer step with gradients.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
grads: Gradients from forward/backward (uses stored if None)
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
APIFuture resolving to OptimizerStepResult
|
|
282
|
+
|
|
283
|
+
Example:
|
|
284
|
+
>>> # After forward_backward
|
|
285
|
+
>>> grads = fb_result.grads
|
|
286
|
+
>>> step_future = client.optim_step(grads)
|
|
287
|
+
>>> step_info = step_future.result()
|
|
288
|
+
>>> print(f"Step {step_info.step} completed")
|
|
289
|
+
"""
|
|
290
|
+
def _run_optim_step() -> OptimizerStepResult:
|
|
291
|
+
from . import optim_step as _optim_step
|
|
292
|
+
|
|
293
|
+
if self.optimizer is None:
|
|
294
|
+
raise RuntimeError("Optimizer not initialized. Call create_optimizer() first.")
|
|
295
|
+
|
|
296
|
+
if grads is None:
|
|
297
|
+
raise ValueError("No gradients provided for optimizer step")
|
|
298
|
+
|
|
299
|
+
_optim_step(self.backend, self.optimizer, grads)
|
|
300
|
+
self._step += 1
|
|
301
|
+
|
|
302
|
+
# Compute gradient norm if possible
|
|
303
|
+
grad_norm = None
|
|
304
|
+
if hasattr(grads, '__iter__'):
|
|
305
|
+
try:
|
|
306
|
+
import math
|
|
307
|
+
grad_norm = math.sqrt(sum(float(g**2) for g in grads if g is not None))
|
|
308
|
+
except Exception:
|
|
309
|
+
pass
|
|
310
|
+
|
|
311
|
+
# Get current learning rate
|
|
312
|
+
lr = 0.0
|
|
313
|
+
if hasattr(self.optimizer, 'learning_rate'):
|
|
314
|
+
lr = self.optimizer.learning_rate
|
|
315
|
+
elif isinstance(self.optimizer, dict):
|
|
316
|
+
lr = self.optimizer.get('learning_rate', 0.0)
|
|
317
|
+
|
|
318
|
+
return OptimizerStepResult(
|
|
319
|
+
step=self._step,
|
|
320
|
+
learning_rate=lr,
|
|
321
|
+
grad_norm=grad_norm,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
return self.pool.submit(_run_optim_step)
|
|
325
|
+
|
|
326
|
+
# ========================================================================
|
|
327
|
+
# Checkpoint Management
|
|
328
|
+
# ========================================================================
|
|
329
|
+
|
|
330
|
+
def save_state(self, path: str, metadata: Optional[Dict[str, Any]] = None) -> APIFuture[CheckpointResult]:
|
|
331
|
+
"""Save training checkpoint.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
path: Path to save checkpoint
|
|
335
|
+
metadata: Optional metadata to save with checkpoint
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
APIFuture resolving to CheckpointResult
|
|
339
|
+
|
|
340
|
+
Example:
|
|
341
|
+
>>> client.save_state("checkpoints/step_1000.pt").result()
|
|
342
|
+
>>> # With metadata
|
|
343
|
+
>>> client.save_state("checkpoint.pt", {"epoch": 5, "score": 0.95}).result()
|
|
344
|
+
"""
|
|
345
|
+
def _run_save() -> CheckpointResult:
|
|
346
|
+
try:
|
|
347
|
+
from pathlib import Path
|
|
348
|
+
import json
|
|
349
|
+
|
|
350
|
+
save_path = Path(path)
|
|
351
|
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
352
|
+
|
|
353
|
+
# Save adapter weights
|
|
354
|
+
full_metadata = {
|
|
355
|
+
"step": self._step,
|
|
356
|
+
"training_state": self._training_state,
|
|
357
|
+
**(metadata or {}),
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
# Use backend's save_adapter
|
|
361
|
+
self.backend.save_adapter(str(save_path), metadata=full_metadata)
|
|
362
|
+
|
|
363
|
+
return CheckpointResult(
|
|
364
|
+
path=str(save_path),
|
|
365
|
+
success=True,
|
|
366
|
+
message=f"Checkpoint saved to {save_path}",
|
|
367
|
+
)
|
|
368
|
+
except Exception as e:
|
|
369
|
+
return CheckpointResult(
|
|
370
|
+
path=path,
|
|
371
|
+
success=False,
|
|
372
|
+
message=f"Failed to save checkpoint: {e}",
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
return self.pool.submit(_run_save)
|
|
376
|
+
|
|
377
|
+
def load_state(self, path: str) -> APIFuture[CheckpointResult]:
|
|
378
|
+
"""Load training checkpoint.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
path: Path to checkpoint to load
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
APIFuture resolving to CheckpointResult
|
|
385
|
+
|
|
386
|
+
Example:
|
|
387
|
+
>>> result = client.load_state("checkpoints/step_1000.pt").result()
|
|
388
|
+
>>> if result.success:
|
|
389
|
+
... print(f"Loaded from {result.path}")
|
|
390
|
+
"""
|
|
391
|
+
def _run_load() -> CheckpointResult:
|
|
392
|
+
try:
|
|
393
|
+
from pathlib import Path
|
|
394
|
+
import json
|
|
395
|
+
|
|
396
|
+
load_path = Path(path)
|
|
397
|
+
if not load_path.exists():
|
|
398
|
+
return CheckpointResult(
|
|
399
|
+
path=path,
|
|
400
|
+
success=False,
|
|
401
|
+
message=f"Checkpoint not found: {path}",
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
# Load adapter weights
|
|
405
|
+
self.backend.apply_adapter(str(load_path))
|
|
406
|
+
|
|
407
|
+
# Try to load metadata
|
|
408
|
+
metadata_path = load_path / "adapter_metadata.json"
|
|
409
|
+
if metadata_path.exists():
|
|
410
|
+
with open(metadata_path) as f:
|
|
411
|
+
metadata = json.load(f)
|
|
412
|
+
self._step = metadata.get("step", self._step)
|
|
413
|
+
self._training_state = metadata.get("training_state", {})
|
|
414
|
+
|
|
415
|
+
return CheckpointResult(
|
|
416
|
+
path=str(load_path),
|
|
417
|
+
success=True,
|
|
418
|
+
message=f"Checkpoint loaded from {load_path}",
|
|
419
|
+
)
|
|
420
|
+
except Exception as e:
|
|
421
|
+
return CheckpointResult(
|
|
422
|
+
path=path,
|
|
423
|
+
success=False,
|
|
424
|
+
message=f"Failed to load checkpoint: {e}",
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
return self.pool.submit(_run_load)
|
|
428
|
+
|
|
429
|
+
# ========================================================================
|
|
430
|
+
# Weight Management
|
|
431
|
+
# ========================================================================
|
|
432
|
+
|
|
433
|
+
def get_weights(self) -> APIFuture[WeightsResult]:
|
|
434
|
+
"""Get current model weights.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
APIFuture resolving to WeightsResult with weights dictionary
|
|
438
|
+
|
|
439
|
+
Example:
|
|
440
|
+
>>> weights_future = client.get_weights()
|
|
441
|
+
>>> result = weights_future.result()
|
|
442
|
+
>>> print(f"Loaded {len(result.weights)} weight tensors")
|
|
443
|
+
"""
|
|
444
|
+
def _run_get_weights() -> WeightsResult:
|
|
445
|
+
try:
|
|
446
|
+
weights = {}
|
|
447
|
+
|
|
448
|
+
# Try to get model parameters
|
|
449
|
+
if hasattr(self.backend, 'model'):
|
|
450
|
+
model = self.backend.model
|
|
451
|
+
if hasattr(model, 'parameters'):
|
|
452
|
+
params = model.parameters()
|
|
453
|
+
if isinstance(params, dict):
|
|
454
|
+
weights = params
|
|
455
|
+
else:
|
|
456
|
+
weights = {"params": params}
|
|
457
|
+
elif hasattr(model, 'trainable_parameters'):
|
|
458
|
+
weights = model.trainable_parameters()
|
|
459
|
+
|
|
460
|
+
return WeightsResult(
|
|
461
|
+
weights=weights,
|
|
462
|
+
success=True,
|
|
463
|
+
message=f"Retrieved {len(weights)} weight tensors",
|
|
464
|
+
num_tensors=len(weights),
|
|
465
|
+
)
|
|
466
|
+
except Exception as e:
|
|
467
|
+
return WeightsResult(
|
|
468
|
+
weights={},
|
|
469
|
+
success=False,
|
|
470
|
+
message=f"Failed to get weights: {e}",
|
|
471
|
+
num_tensors=0,
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
return self.pool.submit(_run_get_weights)
|
|
475
|
+
|
|
476
|
+
def set_weights(self, weights: Dict[str, Any]) -> APIFuture[WeightsResult]:
|
|
477
|
+
"""Set model weights.
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
weights: Dictionary of weight tensors
|
|
481
|
+
|
|
482
|
+
Returns:
|
|
483
|
+
APIFuture resolving to WeightsResult
|
|
484
|
+
|
|
485
|
+
Example:
|
|
486
|
+
>>> client.set_weights(new_weights).result()
|
|
487
|
+
"""
|
|
488
|
+
def _run_set_weights() -> WeightsResult:
|
|
489
|
+
try:
|
|
490
|
+
# This is backend-specific - for MLX we need to update arrays
|
|
491
|
+
if hasattr(self.backend, 'model'):
|
|
492
|
+
model = self.backend.model
|
|
493
|
+
if hasattr(model, 'update'):
|
|
494
|
+
model.update(weights)
|
|
495
|
+
elif hasattr(model, 'load_weights'):
|
|
496
|
+
model.load_weights(weights)
|
|
497
|
+
|
|
498
|
+
return WeightsResult(
|
|
499
|
+
weights=weights,
|
|
500
|
+
success=True,
|
|
501
|
+
message=f"Set {len(weights)} weight tensors",
|
|
502
|
+
num_tensors=len(weights),
|
|
503
|
+
)
|
|
504
|
+
except Exception as e:
|
|
505
|
+
return WeightsResult(
|
|
506
|
+
weights={},
|
|
507
|
+
success=False,
|
|
508
|
+
message=f"Failed to set weights: {e}",
|
|
509
|
+
num_tensors=0,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
return self.pool.submit(_run_set_weights)
|
|
513
|
+
|
|
514
|
+
# ========================================================================
|
|
515
|
+
# Utility Methods
|
|
516
|
+
# ========================================================================
|
|
517
|
+
|
|
518
|
+
def create_optimizer(self, lr: float = 1e-4, weight_decay: float = 0.0) -> APIFuture[Any]:
|
|
519
|
+
"""Create optimizer for training.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
lr: Learning rate
|
|
523
|
+
weight_decay: Weight decay coefficient
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
APIFuture resolving to optimizer instance
|
|
527
|
+
|
|
528
|
+
Example:
|
|
529
|
+
>>> opt_future = client.create_optimizer(lr=1e-4, weight_decay=0.01)
|
|
530
|
+
>>> client.optimizer = opt_future.result()
|
|
531
|
+
"""
|
|
532
|
+
def _run_create_optimizer() -> Any:
|
|
533
|
+
from . import create_optimizer as _create_optimizer
|
|
534
|
+
|
|
535
|
+
self.optimizer, _ = _create_optimizer(
|
|
536
|
+
self.backend,
|
|
537
|
+
lr=lr,
|
|
538
|
+
weight_decay=weight_decay,
|
|
539
|
+
)
|
|
540
|
+
return self.optimizer
|
|
541
|
+
|
|
542
|
+
return self.pool.submit(_run_create_optimizer)
|
|
543
|
+
|
|
544
|
+
def zero_grad(self) -> APIFuture[None]:
|
|
545
|
+
"""Zero out gradients.
|
|
546
|
+
|
|
547
|
+
Returns:
|
|
548
|
+
APIFuture that completes when gradients are zeroed
|
|
549
|
+
"""
|
|
550
|
+
def _run_zero_grad() -> None:
|
|
551
|
+
if self.optimizer is not None and hasattr(self.optimizer, 'zero_grad'):
|
|
552
|
+
self.optimizer.zero_grad()
|
|
553
|
+
|
|
554
|
+
return self.pool.submit(_run_zero_grad)
|
|
555
|
+
|
|
556
|
+
# ========================================================================
|
|
557
|
+
# Properties
|
|
558
|
+
# ========================================================================
|
|
559
|
+
|
|
560
|
+
@property
|
|
561
|
+
def step(self) -> int:
|
|
562
|
+
"""Current training step."""
|
|
563
|
+
return self._step
|
|
564
|
+
|
|
565
|
+
@property
|
|
566
|
+
def training_state(self) -> Dict[str, Any]:
|
|
567
|
+
"""Get training state dictionary."""
|
|
568
|
+
return self._training_state.copy()
|
|
569
|
+
|
|
570
|
+
def update_training_state(self, updates: Dict[str, Any]) -> None:
|
|
571
|
+
"""Update training state."""
|
|
572
|
+
self._training_state.update(updates)
|
|
573
|
+
|
|
574
|
+
# ========================================================================
|
|
575
|
+
# Internal Helpers
|
|
576
|
+
# ========================================================================
|
|
577
|
+
|
|
578
|
+
def _aggregate_gradients(self, grads_list: List[Any]) -> Any:
|
|
579
|
+
"""Aggregate gradients from multiple samples.
|
|
580
|
+
|
|
581
|
+
Args:
|
|
582
|
+
grads_list: List of gradient dictionaries/arrays
|
|
583
|
+
|
|
584
|
+
Returns:
|
|
585
|
+
Aggregated gradients
|
|
586
|
+
"""
|
|
587
|
+
if not grads_list:
|
|
588
|
+
return None
|
|
589
|
+
if len(grads_list) == 1:
|
|
590
|
+
return grads_list[0]
|
|
591
|
+
|
|
592
|
+
# Average gradients
|
|
593
|
+
# This is backend-specific; for now return first grad
|
|
594
|
+
# In practice, MLX would average the arrays
|
|
595
|
+
return grads_list[0]
|
|
596
|
+
|
|
597
|
+
def shutdown(self) -> None:
|
|
598
|
+
"""Shutdown the client and its thread pool."""
|
|
599
|
+
self.pool.shutdown(wait=True)
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
class DistillationTrainingClient(TrainingClient):
|
|
603
|
+
"""Extended TrainingClient for knowledge distillation.
|
|
604
|
+
|
|
605
|
+
Adds support for teacher model logprobs during training,
|
|
606
|
+
enabling distillation from a larger teacher model.
|
|
607
|
+
|
|
608
|
+
Example:
|
|
609
|
+
>>> student = TrainingClient(student_backend, pool)
|
|
610
|
+
>>> teacher = SamplingClient(teacher_backend_endpoint)
|
|
611
|
+
>>>
|
|
612
|
+
>>> # Get teacher logprobs for student samples
|
|
613
|
+
>>> samples = student.sample_batch(prompts)
|
|
614
|
+
>>> teacher_logprobs = teacher.get_logprobs_for_texts(prompts, samples.texts)
|
|
615
|
+
>>>
|
|
616
|
+
>>> # Distillation loss
|
|
617
|
+
>>> batch = TrainingBatch(
|
|
618
|
+
... prompts=prompts,
|
|
619
|
+
... responses=samples.texts,
|
|
620
|
+
... loss_type="distillation",
|
|
621
|
+
... extra={"teacher_logprobs": teacher_logprobs}
|
|
622
|
+
... )
|
|
623
|
+
>>> client.forward_backward(batch)
|
|
624
|
+
"""
|
|
625
|
+
|
|
626
|
+
def __init__(
|
|
627
|
+
self,
|
|
628
|
+
backend: Any,
|
|
629
|
+
pool: Optional[SdkFuturePool] = None,
|
|
630
|
+
optimizer: Optional[Any] = None,
|
|
631
|
+
step: int = 0,
|
|
632
|
+
teacher_sampling_client: Optional['SamplingClient'] = None, # type: ignore
|
|
633
|
+
):
|
|
634
|
+
"""Initialize DistillationTrainingClient.
|
|
635
|
+
|
|
636
|
+
Args:
|
|
637
|
+
backend: The LLM backend instance
|
|
638
|
+
pool: Optional SdkFuturePool
|
|
639
|
+
optimizer: Optional pre-created optimizer
|
|
640
|
+
step: Initial training step
|
|
641
|
+
teacher_sampling_client: Optional SamplingClient for teacher model
|
|
642
|
+
"""
|
|
643
|
+
super().__init__(backend, pool, optimizer, step)
|
|
644
|
+
self.teacher_client = teacher_sampling_client
|
|
645
|
+
|
|
646
|
+
def compute_distillation_loss(
|
|
647
|
+
self,
|
|
648
|
+
student_batch: TrainingBatch,
|
|
649
|
+
teacher_logprobs: List[List[Dict[str, float]]],
|
|
650
|
+
temperature: float = 2.0,
|
|
651
|
+
) -> APIFuture[ForwardBackwardResult]:
|
|
652
|
+
"""Compute distillation loss with teacher logprobs.
|
|
653
|
+
|
|
654
|
+
Args:
|
|
655
|
+
student_batch: Training batch for student
|
|
656
|
+
teacher_logprobs: Top-k logprobs from teacher for each token
|
|
657
|
+
temperature: Distillation temperature
|
|
658
|
+
|
|
659
|
+
Returns:
|
|
660
|
+
APIFuture resolving to ForwardBackwardResult
|
|
661
|
+
"""
|
|
662
|
+
def _run_distillation() -> ForwardBackwardResult:
|
|
663
|
+
# Store teacher logprobs in batch for use during loss computation
|
|
664
|
+
student_batch.extra["teacher_logprobs"] = teacher_logprobs
|
|
665
|
+
student_batch.extra["distillation_temperature"] = temperature
|
|
666
|
+
|
|
667
|
+
# Fall back to standard forward/backward
|
|
668
|
+
# In practice, the loss function would use teacher_logprobs
|
|
669
|
+
result_future = self.forward_backward(student_batch)
|
|
670
|
+
return result_future.result()
|
|
671
|
+
|
|
672
|
+
return self.pool.submit(_run_distillation)
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
# Import at end to avoid circular dependency
|
|
676
|
+
from .sampling_client import SamplingClient
|