mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,676 @@
1
+ """TrainingClient SDK for MLXSmith.
2
+
3
+ Async client for training operations with futures-based API.
4
+ Provides methods for forward/backward passes, optimizer steps,
5
+ checkpoint management, and weight manipulation.
6
+
7
+ Example:
8
+ >>> from mlxsmith.sdk import TrainingClient
9
+ >>> client = TrainingClient(backend, pool)
10
+ >>>
11
+ >>> # Run training step
12
+ >>> future = client.forward_backward(batch)
13
+ >>> loss, grads = future.result()
14
+ >>>
15
+ >>> # Optimizer step
16
+ >>> client.optim_step(grads).result()
17
+ >>>
18
+ >>> # Save checkpoint
19
+ >>> client.save_state("checkpoint.pt").result()
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from dataclasses import dataclass
25
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
26
+
27
+ from .future import APIFuture, SdkFuturePool
28
+
29
+
30
+ @dataclass
31
+ class ForwardBackwardResult:
32
+ """Result from a forward/backward pass."""
33
+ loss: float
34
+ grads: Any # Backend-specific gradient type
35
+ metrics: Dict[str, float]
36
+ batch_size: int = 1
37
+ has_grads: bool = False
38
+
39
+
40
+ @dataclass
41
+ class OptimizerStepResult:
42
+ """Result from an optimizer step."""
43
+ step: int
44
+ learning_rate: float
45
+ grad_norm: Optional[float]
46
+
47
+
48
+ @dataclass
49
+ class CheckpointResult:
50
+ """Result from a checkpoint operation."""
51
+ path: str
52
+ success: bool
53
+ message: str
54
+
55
+
56
+ @dataclass
57
+ class WeightsResult:
58
+ """Result for weight operations."""
59
+ weights: Dict[str, Any]
60
+ success: bool
61
+ message: str
62
+ num_tensors: int = 0
63
+
64
+
65
+ class TrainingBatch:
66
+ """A batch of training data.
67
+
68
+ Supports SFT, preference, and custom loss training.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ prompts: List[str],
74
+ responses: Optional[List[str]] = None,
75
+ rejected_responses: Optional[List[str]] = None,
76
+ advantages: Optional[List[float]] = None,
77
+ loss_type: str = "sft",
78
+ train_on_prompt: bool = False,
79
+ max_seq_len: Optional[int] = None,
80
+ extra: Optional[Dict[str, Any]] = None,
81
+ ):
82
+ """Initialize a training batch.
83
+
84
+ Args:
85
+ prompts: List of prompt strings
86
+ responses: List of response strings (for SFT/positive in preference)
87
+ rejected_responses: List of rejected responses (for preference training)
88
+ advantages: List of advantage values (for RL training)
89
+ loss_type: Type of loss - "sft", "dpo", "orpo", "ppo", "custom"
90
+ train_on_prompt: Whether to compute loss on prompt tokens
91
+ max_seq_len: Maximum sequence length
92
+ extra: Additional batch metadata
93
+ """
94
+ self.prompts = prompts
95
+ self.responses = responses
96
+ self.rejected_responses = rejected_responses
97
+ self.advantages = advantages
98
+ self.loss_type = loss_type
99
+ self.train_on_prompt = train_on_prompt
100
+ self.max_seq_len = max_seq_len
101
+ self.extra = extra or {}
102
+ self._size = len(prompts)
103
+
104
+ def __len__(self) -> int:
105
+ return self._size
106
+
107
+ @property
108
+ def is_preference(self) -> bool:
109
+ """Check if this is a preference batch."""
110
+ return self.loss_type in ("dpo", "orpo", "ipo", "preference")
111
+
112
+ @property
113
+ def is_rl(self) -> bool:
114
+ """Check if this is an RL batch."""
115
+ return self.loss_type in ("ppo", "grpo", "reinforce")
116
+
117
+
118
+ class TrainingClient:
119
+ """Async client for training operations.
120
+
121
+ Provides a futures-based API for all training operations, enabling
122
+ concurrent execution and flexible callback handling.
123
+
124
+ Example:
125
+ >>> client = TrainingClient(backend, pool)
126
+ >>>
127
+ >>> # Async training loop
128
+ >>> for batch in dataloader:
129
+ ... fb_future = client.forward_backward(batch)
130
+ ...
131
+ ... # Chain operations with callbacks
132
+ ... fb_future.then(lambda r: print(f"Loss: {r.loss}"))
133
+ ...
134
+ ... # Get result and continue
135
+ ... loss, grads = fb_future.result()
136
+ ... if grads is not None:
137
+ ... client.optim_step(grads).result()
138
+ >>>
139
+ >>> # Save checkpoint
140
+ >>> client.save_state("checkpoint.pt").result()
141
+ """
142
+
143
+ def __init__(
144
+ self,
145
+ backend: Any,
146
+ pool: Optional[SdkFuturePool] = None,
147
+ optimizer: Optional[Any] = None,
148
+ step: int = 0,
149
+ ):
150
+ """Initialize TrainingClient.
151
+
152
+ Args:
153
+ backend: The LLM backend instance
154
+ pool: Optional SdkFuturePool for async execution (creates default if None)
155
+ optimizer: Optional pre-created optimizer
156
+ step: Initial training step counter
157
+ """
158
+ self.backend = backend
159
+ self.pool = pool or SdkFuturePool(max_workers=1)
160
+ self.optimizer = optimizer
161
+ self._step = step
162
+ self._training_state: Dict[str, Any] = {}
163
+ self._checkpoint_handlers: Dict[str, Callable] = {}
164
+
165
+ # ========================================================================
166
+ # Core Training Operations
167
+ # ========================================================================
168
+
169
+ def forward_backward(self, batch: TrainingBatch) -> APIFuture[ForwardBackwardResult]:
170
+ """Run forward and backward pass on a batch.
171
+
172
+ Args:
173
+ batch: TrainingBatch with prompts and responses
174
+
175
+ Returns:
176
+ APIFuture resolving to ForwardBackwardResult
177
+
178
+ Example:
179
+ >>> batch = TrainingBatch(
180
+ ... prompts=["What is 2+2?"],
181
+ ... responses=["The answer is 4."],
182
+ ... loss_type="sft"
183
+ ... )
184
+ >>> future = client.forward_backward(batch)
185
+ >>> result = future.result()
186
+ >>> print(f"Loss: {result.loss}")
187
+ """
188
+ def _run_forward_backward() -> ForwardBackwardResult:
189
+ from . import sft_forward_backward, preference_forward_backward
190
+
191
+ losses = []
192
+ all_grads = []
193
+
194
+ if batch.is_preference:
195
+ # Preference training (DPO, ORPO, etc.)
196
+ if batch.rejected_responses is None:
197
+ raise ValueError(f"Preference batch requires rejected_responses")
198
+
199
+ for prompt, chosen, rejected in zip(
200
+ batch.prompts,
201
+ batch.responses or [],
202
+ batch.rejected_responses
203
+ ):
204
+ loss, grads = preference_forward_backward(
205
+ self.backend,
206
+ prompt,
207
+ chosen,
208
+ rejected,
209
+ algo=batch.loss_type,
210
+ beta=batch.extra.get("beta", 0.1),
211
+ reference_backend=batch.extra.get("reference_backend"),
212
+ kl_coeff=batch.extra.get("kl_coeff", 0.0),
213
+ train_on_prompt=batch.train_on_prompt,
214
+ max_seq_len=batch.max_seq_len,
215
+ )
216
+ losses.append(float(loss) if loss is not None else 0.0)
217
+ if grads is not None:
218
+ all_grads.append(grads)
219
+
220
+ elif batch.is_rl:
221
+ # RL training (PPO, etc.)
222
+ # For now, fall back to SFT-style with advantages
223
+ for prompt, response, advantage in zip(
224
+ batch.prompts,
225
+ batch.responses or [],
226
+ batch.advantages or [0.0] * len(batch.prompts)
227
+ ):
228
+ # Use SFT forward/backward with modified loss
229
+ loss, grads = sft_forward_backward(
230
+ self.backend,
231
+ prompt,
232
+ response,
233
+ train_on_prompt=batch.train_on_prompt,
234
+ max_seq_len=batch.max_seq_len,
235
+ )
236
+ # Scale by advantage
237
+ if loss is not None and advantage != 0.0:
238
+ loss = loss * advantage
239
+ losses.append(float(loss) if loss is not None else 0.0)
240
+ if grads is not None:
241
+ all_grads.append(grads)
242
+
243
+ else:
244
+ # Standard SFT
245
+ for prompt, response in zip(batch.prompts, batch.responses or []):
246
+ loss, grads = sft_forward_backward(
247
+ self.backend,
248
+ prompt,
249
+ response,
250
+ train_on_prompt=batch.train_on_prompt,
251
+ max_seq_len=batch.max_seq_len,
252
+ )
253
+ losses.append(float(loss) if loss is not None else 0.0)
254
+ if grads is not None:
255
+ all_grads.append(grads)
256
+
257
+ # Average gradients if multiple
258
+ grads = self._aggregate_gradients(all_grads) if all_grads else None
259
+ avg_loss = sum(losses) / len(losses) if losses else 0.0
260
+
261
+ return ForwardBackwardResult(
262
+ loss=avg_loss,
263
+ grads=grads,
264
+ batch_size=len(batch),
265
+ has_grads=grads is not None,
266
+ metrics={
267
+ "avg_loss": avg_loss,
268
+ "num_samples": len(losses),
269
+ }
270
+ )
271
+
272
+ return self.pool.submit(_run_forward_backward)
273
+
274
+ def optim_step(self, grads: Optional[Any] = None) -> APIFuture[OptimizerStepResult]:
275
+ """Execute optimizer step with gradients.
276
+
277
+ Args:
278
+ grads: Gradients from forward/backward (uses stored if None)
279
+
280
+ Returns:
281
+ APIFuture resolving to OptimizerStepResult
282
+
283
+ Example:
284
+ >>> # After forward_backward
285
+ >>> grads = fb_result.grads
286
+ >>> step_future = client.optim_step(grads)
287
+ >>> step_info = step_future.result()
288
+ >>> print(f"Step {step_info.step} completed")
289
+ """
290
+ def _run_optim_step() -> OptimizerStepResult:
291
+ from . import optim_step as _optim_step
292
+
293
+ if self.optimizer is None:
294
+ raise RuntimeError("Optimizer not initialized. Call create_optimizer() first.")
295
+
296
+ if grads is None:
297
+ raise ValueError("No gradients provided for optimizer step")
298
+
299
+ _optim_step(self.backend, self.optimizer, grads)
300
+ self._step += 1
301
+
302
+ # Compute gradient norm if possible
303
+ grad_norm = None
304
+ if hasattr(grads, '__iter__'):
305
+ try:
306
+ import math
307
+ grad_norm = math.sqrt(sum(float(g**2) for g in grads if g is not None))
308
+ except Exception:
309
+ pass
310
+
311
+ # Get current learning rate
312
+ lr = 0.0
313
+ if hasattr(self.optimizer, 'learning_rate'):
314
+ lr = self.optimizer.learning_rate
315
+ elif isinstance(self.optimizer, dict):
316
+ lr = self.optimizer.get('learning_rate', 0.0)
317
+
318
+ return OptimizerStepResult(
319
+ step=self._step,
320
+ learning_rate=lr,
321
+ grad_norm=grad_norm,
322
+ )
323
+
324
+ return self.pool.submit(_run_optim_step)
325
+
326
+ # ========================================================================
327
+ # Checkpoint Management
328
+ # ========================================================================
329
+
330
+ def save_state(self, path: str, metadata: Optional[Dict[str, Any]] = None) -> APIFuture[CheckpointResult]:
331
+ """Save training checkpoint.
332
+
333
+ Args:
334
+ path: Path to save checkpoint
335
+ metadata: Optional metadata to save with checkpoint
336
+
337
+ Returns:
338
+ APIFuture resolving to CheckpointResult
339
+
340
+ Example:
341
+ >>> client.save_state("checkpoints/step_1000.pt").result()
342
+ >>> # With metadata
343
+ >>> client.save_state("checkpoint.pt", {"epoch": 5, "score": 0.95}).result()
344
+ """
345
+ def _run_save() -> CheckpointResult:
346
+ try:
347
+ from pathlib import Path
348
+ import json
349
+
350
+ save_path = Path(path)
351
+ save_path.parent.mkdir(parents=True, exist_ok=True)
352
+
353
+ # Save adapter weights
354
+ full_metadata = {
355
+ "step": self._step,
356
+ "training_state": self._training_state,
357
+ **(metadata or {}),
358
+ }
359
+
360
+ # Use backend's save_adapter
361
+ self.backend.save_adapter(str(save_path), metadata=full_metadata)
362
+
363
+ return CheckpointResult(
364
+ path=str(save_path),
365
+ success=True,
366
+ message=f"Checkpoint saved to {save_path}",
367
+ )
368
+ except Exception as e:
369
+ return CheckpointResult(
370
+ path=path,
371
+ success=False,
372
+ message=f"Failed to save checkpoint: {e}",
373
+ )
374
+
375
+ return self.pool.submit(_run_save)
376
+
377
+ def load_state(self, path: str) -> APIFuture[CheckpointResult]:
378
+ """Load training checkpoint.
379
+
380
+ Args:
381
+ path: Path to checkpoint to load
382
+
383
+ Returns:
384
+ APIFuture resolving to CheckpointResult
385
+
386
+ Example:
387
+ >>> result = client.load_state("checkpoints/step_1000.pt").result()
388
+ >>> if result.success:
389
+ ... print(f"Loaded from {result.path}")
390
+ """
391
+ def _run_load() -> CheckpointResult:
392
+ try:
393
+ from pathlib import Path
394
+ import json
395
+
396
+ load_path = Path(path)
397
+ if not load_path.exists():
398
+ return CheckpointResult(
399
+ path=path,
400
+ success=False,
401
+ message=f"Checkpoint not found: {path}",
402
+ )
403
+
404
+ # Load adapter weights
405
+ self.backend.apply_adapter(str(load_path))
406
+
407
+ # Try to load metadata
408
+ metadata_path = load_path / "adapter_metadata.json"
409
+ if metadata_path.exists():
410
+ with open(metadata_path) as f:
411
+ metadata = json.load(f)
412
+ self._step = metadata.get("step", self._step)
413
+ self._training_state = metadata.get("training_state", {})
414
+
415
+ return CheckpointResult(
416
+ path=str(load_path),
417
+ success=True,
418
+ message=f"Checkpoint loaded from {load_path}",
419
+ )
420
+ except Exception as e:
421
+ return CheckpointResult(
422
+ path=path,
423
+ success=False,
424
+ message=f"Failed to load checkpoint: {e}",
425
+ )
426
+
427
+ return self.pool.submit(_run_load)
428
+
429
+ # ========================================================================
430
+ # Weight Management
431
+ # ========================================================================
432
+
433
+ def get_weights(self) -> APIFuture[WeightsResult]:
434
+ """Get current model weights.
435
+
436
+ Returns:
437
+ APIFuture resolving to WeightsResult with weights dictionary
438
+
439
+ Example:
440
+ >>> weights_future = client.get_weights()
441
+ >>> result = weights_future.result()
442
+ >>> print(f"Loaded {len(result.weights)} weight tensors")
443
+ """
444
+ def _run_get_weights() -> WeightsResult:
445
+ try:
446
+ weights = {}
447
+
448
+ # Try to get model parameters
449
+ if hasattr(self.backend, 'model'):
450
+ model = self.backend.model
451
+ if hasattr(model, 'parameters'):
452
+ params = model.parameters()
453
+ if isinstance(params, dict):
454
+ weights = params
455
+ else:
456
+ weights = {"params": params}
457
+ elif hasattr(model, 'trainable_parameters'):
458
+ weights = model.trainable_parameters()
459
+
460
+ return WeightsResult(
461
+ weights=weights,
462
+ success=True,
463
+ message=f"Retrieved {len(weights)} weight tensors",
464
+ num_tensors=len(weights),
465
+ )
466
+ except Exception as e:
467
+ return WeightsResult(
468
+ weights={},
469
+ success=False,
470
+ message=f"Failed to get weights: {e}",
471
+ num_tensors=0,
472
+ )
473
+
474
+ return self.pool.submit(_run_get_weights)
475
+
476
+ def set_weights(self, weights: Dict[str, Any]) -> APIFuture[WeightsResult]:
477
+ """Set model weights.
478
+
479
+ Args:
480
+ weights: Dictionary of weight tensors
481
+
482
+ Returns:
483
+ APIFuture resolving to WeightsResult
484
+
485
+ Example:
486
+ >>> client.set_weights(new_weights).result()
487
+ """
488
+ def _run_set_weights() -> WeightsResult:
489
+ try:
490
+ # This is backend-specific - for MLX we need to update arrays
491
+ if hasattr(self.backend, 'model'):
492
+ model = self.backend.model
493
+ if hasattr(model, 'update'):
494
+ model.update(weights)
495
+ elif hasattr(model, 'load_weights'):
496
+ model.load_weights(weights)
497
+
498
+ return WeightsResult(
499
+ weights=weights,
500
+ success=True,
501
+ message=f"Set {len(weights)} weight tensors",
502
+ num_tensors=len(weights),
503
+ )
504
+ except Exception as e:
505
+ return WeightsResult(
506
+ weights={},
507
+ success=False,
508
+ message=f"Failed to set weights: {e}",
509
+ num_tensors=0,
510
+ )
511
+
512
+ return self.pool.submit(_run_set_weights)
513
+
514
+ # ========================================================================
515
+ # Utility Methods
516
+ # ========================================================================
517
+
518
+ def create_optimizer(self, lr: float = 1e-4, weight_decay: float = 0.0) -> APIFuture[Any]:
519
+ """Create optimizer for training.
520
+
521
+ Args:
522
+ lr: Learning rate
523
+ weight_decay: Weight decay coefficient
524
+
525
+ Returns:
526
+ APIFuture resolving to optimizer instance
527
+
528
+ Example:
529
+ >>> opt_future = client.create_optimizer(lr=1e-4, weight_decay=0.01)
530
+ >>> client.optimizer = opt_future.result()
531
+ """
532
+ def _run_create_optimizer() -> Any:
533
+ from . import create_optimizer as _create_optimizer
534
+
535
+ self.optimizer, _ = _create_optimizer(
536
+ self.backend,
537
+ lr=lr,
538
+ weight_decay=weight_decay,
539
+ )
540
+ return self.optimizer
541
+
542
+ return self.pool.submit(_run_create_optimizer)
543
+
544
+ def zero_grad(self) -> APIFuture[None]:
545
+ """Zero out gradients.
546
+
547
+ Returns:
548
+ APIFuture that completes when gradients are zeroed
549
+ """
550
+ def _run_zero_grad() -> None:
551
+ if self.optimizer is not None and hasattr(self.optimizer, 'zero_grad'):
552
+ self.optimizer.zero_grad()
553
+
554
+ return self.pool.submit(_run_zero_grad)
555
+
556
+ # ========================================================================
557
+ # Properties
558
+ # ========================================================================
559
+
560
+ @property
561
+ def step(self) -> int:
562
+ """Current training step."""
563
+ return self._step
564
+
565
+ @property
566
+ def training_state(self) -> Dict[str, Any]:
567
+ """Get training state dictionary."""
568
+ return self._training_state.copy()
569
+
570
+ def update_training_state(self, updates: Dict[str, Any]) -> None:
571
+ """Update training state."""
572
+ self._training_state.update(updates)
573
+
574
+ # ========================================================================
575
+ # Internal Helpers
576
+ # ========================================================================
577
+
578
+ def _aggregate_gradients(self, grads_list: List[Any]) -> Any:
579
+ """Aggregate gradients from multiple samples.
580
+
581
+ Args:
582
+ grads_list: List of gradient dictionaries/arrays
583
+
584
+ Returns:
585
+ Aggregated gradients
586
+ """
587
+ if not grads_list:
588
+ return None
589
+ if len(grads_list) == 1:
590
+ return grads_list[0]
591
+
592
+ # Average gradients
593
+ # This is backend-specific; for now return first grad
594
+ # In practice, MLX would average the arrays
595
+ return grads_list[0]
596
+
597
+ def shutdown(self) -> None:
598
+ """Shutdown the client and its thread pool."""
599
+ self.pool.shutdown(wait=True)
600
+
601
+
602
+ class DistillationTrainingClient(TrainingClient):
603
+ """Extended TrainingClient for knowledge distillation.
604
+
605
+ Adds support for teacher model logprobs during training,
606
+ enabling distillation from a larger teacher model.
607
+
608
+ Example:
609
+ >>> student = TrainingClient(student_backend, pool)
610
+ >>> teacher = SamplingClient(teacher_backend_endpoint)
611
+ >>>
612
+ >>> # Get teacher logprobs for student samples
613
+ >>> samples = student.sample_batch(prompts)
614
+ >>> teacher_logprobs = teacher.get_logprobs_for_texts(prompts, samples.texts)
615
+ >>>
616
+ >>> # Distillation loss
617
+ >>> batch = TrainingBatch(
618
+ ... prompts=prompts,
619
+ ... responses=samples.texts,
620
+ ... loss_type="distillation",
621
+ ... extra={"teacher_logprobs": teacher_logprobs}
622
+ ... )
623
+ >>> client.forward_backward(batch)
624
+ """
625
+
626
+ def __init__(
627
+ self,
628
+ backend: Any,
629
+ pool: Optional[SdkFuturePool] = None,
630
+ optimizer: Optional[Any] = None,
631
+ step: int = 0,
632
+ teacher_sampling_client: Optional['SamplingClient'] = None, # type: ignore
633
+ ):
634
+ """Initialize DistillationTrainingClient.
635
+
636
+ Args:
637
+ backend: The LLM backend instance
638
+ pool: Optional SdkFuturePool
639
+ optimizer: Optional pre-created optimizer
640
+ step: Initial training step
641
+ teacher_sampling_client: Optional SamplingClient for teacher model
642
+ """
643
+ super().__init__(backend, pool, optimizer, step)
644
+ self.teacher_client = teacher_sampling_client
645
+
646
+ def compute_distillation_loss(
647
+ self,
648
+ student_batch: TrainingBatch,
649
+ teacher_logprobs: List[List[Dict[str, float]]],
650
+ temperature: float = 2.0,
651
+ ) -> APIFuture[ForwardBackwardResult]:
652
+ """Compute distillation loss with teacher logprobs.
653
+
654
+ Args:
655
+ student_batch: Training batch for student
656
+ teacher_logprobs: Top-k logprobs from teacher for each token
657
+ temperature: Distillation temperature
658
+
659
+ Returns:
660
+ APIFuture resolving to ForwardBackwardResult
661
+ """
662
+ def _run_distillation() -> ForwardBackwardResult:
663
+ # Store teacher logprobs in batch for use during loss computation
664
+ student_batch.extra["teacher_logprobs"] = teacher_logprobs
665
+ student_batch.extra["distillation_temperature"] = temperature
666
+
667
+ # Fall back to standard forward/backward
668
+ # In practice, the loss function would use teacher_logprobs
669
+ result_future = self.forward_backward(student_batch)
670
+ return result_future.result()
671
+
672
+ return self.pool.submit(_run_distillation)
673
+
674
+
675
+ # Import at end to avoid circular dependency
676
+ from .sampling_client import SamplingClient