mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,729 @@
1
+ """SamplingClient SDK for MLXSmith.
2
+
3
+ Client for text sampling with logprobs support, including top-k logprobs
4
+ per token. Designed for distillation workflows where teacher model logprobs
5
+ are needed for student training.
6
+
7
+ Example:
8
+ >>> from mlxsmith.sdk import SamplingClient
9
+ >>>
10
+ >>> # Local backend
11
+ >>> client = SamplingClient(backend=loaded_model.backend)
12
+ >>>
13
+ >>> # Single sample with logprobs
14
+ >>> result = client.sample("What is 2+2?", logprobs_k=5)
15
+ >>> print(result.text)
16
+ >>> for token_lp in result.top_k_logprobs:
17
+ ... print(token_lp) # {"token": logprob, ...}
18
+ >>>
19
+ >>> # Batch sampling
20
+ >>> results = client.sample_batch(
21
+ ... ["Q1", "Q2", "Q3"],
22
+ ... max_tokens=100,
23
+ ... logprobs_k=5
24
+ ... )
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ from dataclasses import dataclass, field
30
+ from typing import Any, Dict, List, Optional, Sequence, Union, Callable
31
+ import concurrent.futures
32
+
33
+ from .future import APIFuture, SdkFuturePool
34
+
35
+
36
+ @dataclass
37
+ class SampleResult:
38
+ """Result from a sampling operation.
39
+
40
+ Attributes:
41
+ text: Generated text (completion only, prompt excluded)
42
+ token_ids: Full sequence of token IDs (prompt + completion)
43
+ prompt_len: Length of prompt in tokens
44
+ logprobs: Log probability of each generated token
45
+ top_k_logprobs: Top-k logprobs per token (list of {token: logprob} dicts)
46
+ finish_reason: Why generation stopped ("stop", "length", etc.)
47
+ metrics: Additional metrics (perplexity, etc.)
48
+ """
49
+ text: str
50
+ token_ids: List[int]
51
+ prompt_len: int
52
+ logprobs: List[float] = field(default_factory=list)
53
+ top_k_logprobs: Optional[List[Dict[str, float]]] = None
54
+ prompt_logprobs: Optional[List[float]] = None
55
+ prompt_top_k_logprobs: Optional[List[Dict[str, float]]] = None
56
+ finish_reason: str = "stop"
57
+ metrics: Dict[str, float] = field(default_factory=dict)
58
+
59
+ @property
60
+ def completion_token_ids(self) -> List[int]:
61
+ """Get token IDs for the completion only."""
62
+ return self.token_ids[self.prompt_len:]
63
+
64
+ @property
65
+ def avg_logprob(self) -> float:
66
+ """Average log probability of generated tokens."""
67
+ if not self.logprobs:
68
+ return 0.0
69
+ return sum(self.logprobs) / len(self.logprobs)
70
+
71
+ @property
72
+ def perplexity(self) -> float:
73
+ """Perplexity of the completion."""
74
+ import math
75
+ avg_lp = self.avg_logprob
76
+ return math.exp(-avg_lp) if avg_lp != 0 else 1.0
77
+
78
+
79
+ @dataclass
80
+ class SampleBatchResult:
81
+ """Result from a batch sampling operation."""
82
+ results: List[SampleResult]
83
+ total_tokens: int = 0
84
+
85
+ def __len__(self) -> int:
86
+ return len(self.results)
87
+
88
+ def __getitem__(self, idx: int) -> SampleResult:
89
+ return self.results[idx]
90
+
91
+ @property
92
+ def texts(self) -> List[str]:
93
+ """Get all completion texts."""
94
+ return [r.text for r in self.results]
95
+
96
+ @property
97
+ def all_token_ids(self) -> List[List[int]]:
98
+ """Get all token ID sequences."""
99
+ return [r.token_ids for r in self.results]
100
+
101
+ @property
102
+ def all_top_k_logprobs(self) -> List[List[Dict[str, float]]]:
103
+ """Get all top-k logprobs."""
104
+ return [r.top_k_logprobs for r in self.results if r.top_k_logprobs is not None]
105
+
106
+
107
+ class SamplingClient:
108
+ """Client for sampling with logprobs support.
109
+
110
+ Can be used with a local backend or an API endpoint for distributed
111
+ sampling (e.g., teacher model running on a different machine).
112
+
113
+ Example:
114
+ >>> # Local sampling
115
+ >>> client = SamplingClient(backend=backend)
116
+ >>> result = client.sample("Hello", max_tokens=10, logprobs_k=5)
117
+ >>>
118
+ >>> # API-based sampling (for remote teacher model)
119
+ >>> api_client = SamplingClient(
120
+ ... api_endpoint="http://teacher:8000",
121
+ ... api_key="secret"
122
+ ... )
123
+ >>> result = api_client.sample("Hello", max_tokens=10, logprobs_k=5)
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ backend: Any = None,
129
+ pool: Optional[SdkFuturePool] = None,
130
+ api_endpoint: Optional[str] = None,
131
+ api_key: Optional[str] = None,
132
+ timeout: float = 60.0,
133
+ ):
134
+ """Initialize SamplingClient.
135
+
136
+ Args:
137
+ backend: Local LLM backend instance
138
+ pool: Optional SdkFuturePool for async execution
139
+ api_endpoint: Optional API endpoint for remote sampling
140
+ api_key: API key for authentication
141
+ timeout: Default timeout for operations
142
+
143
+ Raises:
144
+ ValueError: If neither backend nor api_endpoint is provided
145
+ """
146
+ if backend is None and api_endpoint is None:
147
+ raise ValueError("Either backend or api_endpoint must be provided")
148
+
149
+ self.backend = backend
150
+ self.pool = pool or SdkFuturePool(max_workers=4)
151
+ self.api_endpoint = api_endpoint
152
+ self.api_key = api_key
153
+ self.timeout = timeout
154
+ self._session = None
155
+
156
+ def sample(
157
+ self,
158
+ prompt: str,
159
+ max_tokens: int = 256,
160
+ temperature: float = 0.8,
161
+ top_p: float = 1.0,
162
+ top_k: Optional[int] = None,
163
+ seed: Optional[int] = None,
164
+ stop: Optional[Sequence[str]] = None,
165
+ logprobs_k: int = 0,
166
+ include_prompt_logprobs: bool = False,
167
+ prompt_logprobs_k: int = 0,
168
+ ) -> SampleResult:
169
+ """Sample a single completion.
170
+
171
+ Args:
172
+ prompt: Input prompt text
173
+ max_tokens: Maximum number of tokens to generate
174
+ temperature: Sampling temperature
175
+ top_p: Nucleus sampling parameter
176
+ top_k: Top-k sampling parameter
177
+ seed: Random seed for reproducibility
178
+ stop: Stop sequences
179
+ logprobs_k: Number of top logprobs to return per token (0 = none)
180
+
181
+ Returns:
182
+ SampleResult with text, token IDs, and logprobs
183
+
184
+ Example:
185
+ >>> result = client.sample(
186
+ ... "What is the capital of France?",
187
+ ... max_tokens=50,
188
+ ... temperature=0.7,
189
+ ... logprobs_k=5
190
+ ... )
191
+ >>> print(result.text)
192
+ >>> print(f"Perplexity: {result.perplexity}")
193
+ """
194
+ if self.api_endpoint:
195
+ return self._sample_api(
196
+ prompt,
197
+ max_tokens,
198
+ temperature,
199
+ top_p,
200
+ top_k,
201
+ seed,
202
+ stop,
203
+ logprobs_k,
204
+ include_prompt_logprobs,
205
+ prompt_logprobs_k,
206
+ )
207
+ else:
208
+ return self._sample_local(
209
+ prompt,
210
+ max_tokens,
211
+ temperature,
212
+ top_p,
213
+ top_k,
214
+ seed,
215
+ stop,
216
+ logprobs_k,
217
+ include_prompt_logprobs,
218
+ prompt_logprobs_k,
219
+ )
220
+
221
+ def sample_async(
222
+ self,
223
+ prompt: str,
224
+ max_tokens: int = 256,
225
+ temperature: float = 0.8,
226
+ top_p: float = 1.0,
227
+ top_k: Optional[int] = None,
228
+ seed: Optional[int] = None,
229
+ stop: Optional[Sequence[str]] = None,
230
+ logprobs_k: int = 0,
231
+ include_prompt_logprobs: bool = False,
232
+ prompt_logprobs_k: int = 0,
233
+ ) -> APIFuture[SampleResult]:
234
+ """Async version of sample().
235
+
236
+ Returns:
237
+ APIFuture that resolves to SampleResult
238
+
239
+ Example:
240
+ >>> future = client.sample_async("Hello", max_tokens=10)
241
+ >>> future.then(lambda r: print(r.text))
242
+ >>> result = future.result()
243
+ """
244
+ def _do_sample():
245
+ return self.sample(
246
+ prompt,
247
+ max_tokens,
248
+ temperature,
249
+ top_p,
250
+ top_k,
251
+ seed,
252
+ stop,
253
+ logprobs_k,
254
+ include_prompt_logprobs,
255
+ prompt_logprobs_k,
256
+ )
257
+
258
+ return self.pool.submit(_do_sample)
259
+
260
+ def sample_batch(
261
+ self,
262
+ prompts: Sequence[str],
263
+ max_tokens: int = 256,
264
+ temperature: float = 0.8,
265
+ top_p: float = 1.0,
266
+ top_k: Optional[int] = None,
267
+ seed: Optional[int] = None,
268
+ stop: Optional[Sequence[str]] = None,
269
+ logprobs_k: int = 0,
270
+ include_prompt_logprobs: bool = False,
271
+ prompt_logprobs_k: int = 0,
272
+ max_workers: Optional[int] = None,
273
+ ) -> SampleBatchResult:
274
+ """Sample completions for multiple prompts.
275
+
276
+ Args:
277
+ prompts: List of prompt strings
278
+ max_tokens: Maximum tokens per completion
279
+ temperature: Sampling temperature
280
+ top_p: Nucleus sampling parameter
281
+ top_k: Top-k sampling parameter
282
+ seed: Random seed (different for each prompt if provided)
283
+ stop: Stop sequences
284
+ logprobs_k: Number of top logprobs per token
285
+ max_workers: Number of parallel workers (None = use pool default)
286
+
287
+ Returns:
288
+ SampleBatchResult with all results
289
+
290
+ Example:
291
+ >>> prompts = ["Q: 2+2=", "Q: 3*4=", "Q: 10/2="]
292
+ >>> results = client.sample_batch(prompts, max_tokens=10, logprobs_k=5)
293
+ >>> for prompt, result in zip(prompts, results):
294
+ ... print(f"{prompt} {result.text}")
295
+ """
296
+ # Use different seeds for each prompt if seed provided
297
+ seeds = [seed + i if seed is not None else None for i in range(len(prompts))]
298
+
299
+ # Submit all samples to thread pool
300
+ futures = [
301
+ self.sample_async(
302
+ prompt=p,
303
+ max_tokens=max_tokens,
304
+ temperature=temperature,
305
+ top_p=top_p,
306
+ top_k=top_k,
307
+ seed=s,
308
+ stop=stop,
309
+ logprobs_k=logprobs_k,
310
+ include_prompt_logprobs=include_prompt_logprobs,
311
+ prompt_logprobs_k=prompt_logprobs_k,
312
+ )
313
+ for p, s in zip(prompts, seeds)
314
+ ]
315
+
316
+ # Collect results
317
+ results = [f.result(timeout=self.timeout) for f in futures]
318
+ total_tokens = sum(len(r.token_ids) for r in results)
319
+
320
+ return SampleBatchResult(results=results, total_tokens=total_tokens)
321
+
322
+ def sample_batch_async(
323
+ self,
324
+ prompts: Sequence[str],
325
+ max_tokens: int = 256,
326
+ temperature: float = 0.8,
327
+ top_p: float = 1.0,
328
+ top_k: Optional[int] = None,
329
+ seed: Optional[int] = None,
330
+ stop: Optional[Sequence[str]] = None,
331
+ logprobs_k: int = 0,
332
+ include_prompt_logprobs: bool = False,
333
+ prompt_logprobs_k: int = 0,
334
+ ) -> APIFuture[SampleBatchResult]:
335
+ """Async version of sample_batch().
336
+
337
+ Returns:
338
+ APIFuture that resolves to SampleBatchResult
339
+ """
340
+ def _do_batch():
341
+ return self.sample_batch(
342
+ prompts,
343
+ max_tokens,
344
+ temperature,
345
+ top_p,
346
+ top_k,
347
+ seed,
348
+ stop,
349
+ logprobs_k,
350
+ include_prompt_logprobs,
351
+ prompt_logprobs_k,
352
+ )
353
+
354
+ return self.pool.submit(_do_batch)
355
+
356
+ def get_logprobs_for_texts(
357
+ self,
358
+ prompts: Sequence[str],
359
+ completions: Sequence[str],
360
+ top_k: int = 0,
361
+ ) -> List[List[Dict[str, float]]]:
362
+ """Get logprobs for given prompt-completion pairs.
363
+
364
+ This is useful for distillation where you have student-generated
365
+ texts and want teacher logprobs for them.
366
+
367
+ Args:
368
+ prompts: List of prompts
369
+ completions: List of completions (parallel to prompts)
370
+ top_k: Number of top logprobs per token
371
+
372
+ Returns:
373
+ List of top-k logprobs for each completion
374
+
375
+ Example:
376
+ >>> prompts = ["Q: What is 2+2?"]
377
+ >>> completions = ["The answer is 4."]
378
+ >>> logprobs = client.get_logprobs_for_texts(prompts, completions, top_k=5)
379
+ >>> # logprobs[0] is list of {token: logprob} for each token
380
+ """
381
+ results = []
382
+ for prompt, completion in zip(prompts, completions):
383
+ # Encode and get logprobs
384
+ prompt_ids = self.backend.encode(prompt)
385
+ full_ids = self.backend.encode(prompt + completion)
386
+
387
+ # Try to get logprobs from the backend
388
+ if hasattr(self.backend, 'generate_with_logprobs'):
389
+ gen = self.backend.generate_with_logprobs(
390
+ prompt,
391
+ max_new_tokens=len(full_ids) - len(prompt_ids),
392
+ logprobs=top_k,
393
+ )
394
+ results.append(gen.top_k_logprobs or [])
395
+ else:
396
+ results.append([])
397
+
398
+ return results
399
+
400
+ def get_prompt_token_logprobs(
401
+ self,
402
+ prompts: Sequence[str],
403
+ ) -> List[List[float]]:
404
+ """Get per-token logprobs for prompt tokens.
405
+
406
+ Returns a list of logprob lists (one per prompt). The first token is
407
+ omitted because it has no previous context.
408
+ """
409
+ results: List[List[float]] = []
410
+ for prompt in prompts:
411
+ if not hasattr(self.backend, "token_logprobs"):
412
+ results.append([])
413
+ continue
414
+ prompt_ids = self.backend.encode(prompt)
415
+ logprobs, _ = self.backend.token_logprobs(
416
+ prompt_ids,
417
+ prompt_len=len(prompt_ids),
418
+ top_k=0,
419
+ include_prompt=True,
420
+ )
421
+ results.append(logprobs)
422
+ return results
423
+
424
+ def get_prompt_top_k_logprobs(
425
+ self,
426
+ prompts: Sequence[str],
427
+ top_k: int = 5,
428
+ ) -> List[List[Dict[str, float]]]:
429
+ """Get top-k logprobs for each prompt token.
430
+
431
+ Returns a list of per-token top-k dicts for each prompt.
432
+ """
433
+ results: List[List[Dict[str, float]]] = []
434
+ for prompt in prompts:
435
+ if not hasattr(self.backend, "token_logprobs"):
436
+ results.append([])
437
+ continue
438
+ prompt_ids = self.backend.encode(prompt)
439
+ _, top_k_logprobs = self.backend.token_logprobs(
440
+ prompt_ids,
441
+ prompt_len=len(prompt_ids),
442
+ top_k=top_k,
443
+ include_prompt=True,
444
+ )
445
+ results.append(top_k_logprobs or [])
446
+ return results
447
+
448
+ def compute_sequence_logprobs(
449
+ self,
450
+ prompts: Sequence[str],
451
+ completions: Sequence[str],
452
+ ) -> List[float]:
453
+ """Compute total logprob for each prompt-completion pair.
454
+
455
+ Args:
456
+ prompts: List of prompts
457
+ completions: List of completions
458
+
459
+ Returns:
460
+ List of total logprobs
461
+ """
462
+ results = []
463
+ for prompt, completion in zip(prompts, completions):
464
+ prompt_ids = self.backend.encode(prompt)
465
+ full_ids = self.backend.encode(prompt + completion)
466
+
467
+ if hasattr(self.backend, 'sequence_logprob'):
468
+ logprob = self.backend.sequence_logprob(full_ids, prompt_len=len(prompt_ids))
469
+ results.append(float(logprob))
470
+ else:
471
+ results.append(0.0)
472
+
473
+ return results
474
+
475
+ # ========================================================================
476
+ # Internal methods
477
+ # ========================================================================
478
+
479
+ def _sample_local(
480
+ self,
481
+ prompt: str,
482
+ max_tokens: int,
483
+ temperature: float,
484
+ top_p: float,
485
+ top_k: Optional[int],
486
+ seed: Optional[int],
487
+ stop: Optional[Sequence[str]],
488
+ logprobs_k: int,
489
+ include_prompt_logprobs: bool,
490
+ prompt_logprobs_k: int,
491
+ ) -> SampleResult:
492
+ """Sample using local backend."""
493
+ if logprobs_k > 0:
494
+ gen = self.backend.generate_with_logprobs(
495
+ prompt,
496
+ max_new_tokens=max_tokens,
497
+ temperature=temperature,
498
+ top_p=top_p,
499
+ top_k_sampling=top_k,
500
+ seed=seed,
501
+ logprobs=logprobs_k,
502
+ )
503
+ else:
504
+ gen = self.backend.generate(
505
+ prompt,
506
+ max_new_tokens=max_tokens,
507
+ temperature=temperature,
508
+ top_p=top_p,
509
+ top_k=top_k,
510
+ seed=seed,
511
+ )
512
+
513
+ # Extract completion text (remove prompt prefix if present)
514
+ completion_text = gen.text
515
+ if completion_text.startswith(prompt):
516
+ completion_text = completion_text[len(prompt):]
517
+
518
+ # Apply stop sequences
519
+ finish_reason = "stop"
520
+ if stop:
521
+ for stop_seq in stop:
522
+ if stop_seq in completion_text:
523
+ completion_text = completion_text[:completion_text.index(stop_seq)]
524
+ finish_reason = "stop"
525
+ break
526
+
527
+ # Check if we hit length limit
528
+ if len(gen.token_ids) - gen.prompt_len >= max_tokens:
529
+ finish_reason = "length"
530
+
531
+ prompt_logprobs = None
532
+ prompt_top_k_logprobs = None
533
+ if (include_prompt_logprobs or prompt_logprobs_k > 0) and hasattr(self.backend, "token_logprobs"):
534
+ prompt_ids = self.backend.encode(prompt)
535
+ try:
536
+ plogps, ptopk = self.backend.token_logprobs(
537
+ prompt_ids,
538
+ prompt_len=len(prompt_ids),
539
+ top_k=prompt_logprobs_k if prompt_logprobs_k > 0 else 0,
540
+ include_prompt=True,
541
+ )
542
+ if include_prompt_logprobs:
543
+ prompt_logprobs = plogps
544
+ if prompt_logprobs_k > 0:
545
+ prompt_top_k_logprobs = ptopk or []
546
+ except Exception:
547
+ prompt_logprobs = None
548
+ prompt_top_k_logprobs = None
549
+
550
+ return SampleResult(
551
+ text=completion_text,
552
+ token_ids=gen.token_ids,
553
+ prompt_len=gen.prompt_len,
554
+ logprobs=gen.logprobs or [],
555
+ top_k_logprobs=gen.top_k_logprobs,
556
+ prompt_logprobs=prompt_logprobs,
557
+ prompt_top_k_logprobs=prompt_top_k_logprobs,
558
+ finish_reason=finish_reason,
559
+ )
560
+
561
+ def _sample_api(
562
+ self,
563
+ prompt: str,
564
+ max_tokens: int,
565
+ temperature: float,
566
+ top_p: float,
567
+ top_k: Optional[int],
568
+ seed: Optional[int],
569
+ stop: Optional[Sequence[str]],
570
+ logprobs_k: int,
571
+ include_prompt_logprobs: bool,
572
+ prompt_logprobs_k: int,
573
+ ) -> SampleResult:
574
+ """Sample using remote API endpoint."""
575
+ import urllib.request
576
+ import urllib.error
577
+ import json
578
+
579
+ url = f"{self.api_endpoint}/internal/rollout"
580
+
581
+ payload = {
582
+ "prompt": prompt,
583
+ "max_tokens": max_tokens,
584
+ "temperature": temperature,
585
+ "top_p": top_p,
586
+ "top_k": top_k,
587
+ "seed": seed,
588
+ "include_tokens": True,
589
+ "include_logprobs": True,
590
+ "include_top_k_logprobs": logprobs_k if logprobs_k > 0 else None,
591
+ "include_prompt_logprobs": bool(include_prompt_logprobs),
592
+ "include_prompt_top_k_logprobs": prompt_logprobs_k if prompt_logprobs_k > 0 else None,
593
+ "include_text": True,
594
+ }
595
+
596
+ headers = {"Content-Type": "application/json"}
597
+ if self.api_key:
598
+ headers["Authorization"] = f"Bearer {self.api_key}"
599
+
600
+ req = urllib.request.Request(
601
+ url,
602
+ data=json.dumps(payload).encode(),
603
+ headers=headers,
604
+ method="POST",
605
+ )
606
+
607
+ try:
608
+ with urllib.request.urlopen(req, timeout=self.timeout) as response:
609
+ data = json.loads(response.read().decode())
610
+
611
+ return SampleResult(
612
+ text=data.get("completion", ""),
613
+ token_ids=data.get("token_ids", []),
614
+ prompt_len=data.get("prompt_len", 0),
615
+ logprobs=data.get("logprobs", []),
616
+ top_k_logprobs=data.get("top_k_logprobs"),
617
+ prompt_logprobs=data.get("prompt_logprobs"),
618
+ prompt_top_k_logprobs=data.get("prompt_top_k_logprobs"),
619
+ finish_reason="stop",
620
+ )
621
+ except urllib.error.HTTPError as e:
622
+ raise RuntimeError(f"API error: {e.code} - {e.read().decode()}")
623
+ except Exception as e:
624
+ raise RuntimeError(f"Failed to sample from API: {e}")
625
+
626
+ def shutdown(self) -> None:
627
+ """Shutdown the client and its thread pool."""
628
+ self.pool.shutdown(wait=True)
629
+
630
+
631
+ class DistillationSampler:
632
+ """Helper for knowledge distillation workflows.
633
+
634
+ Combines a student training client with a teacher sampling client
635
+ to provide teacher logprobs for student-generated samples.
636
+
637
+ Example:
638
+ >>> sampler = DistillationSampler(
639
+ ... teacher_client=teacher_sampling_client,
640
+ ... student_client=student_training_client,
641
+ ... )
642
+ >>>
643
+ >>> # Generate samples from student and get teacher logprobs
644
+ >>> prompts = ["Q: What is 2+2?"]
645
+ >>> samples, teacher_logprobs = sampler.sample_with_teacher_logprobs(
646
+ ... prompts,
647
+ ... max_tokens=50,
648
+ ... logprobs_k=5
649
+ ... )
650
+ """
651
+
652
+ def __init__(
653
+ self,
654
+ teacher_client: SamplingClient,
655
+ student_client: Optional['SamplingClient'] = None, # type: ignore
656
+ ):
657
+ """Initialize DistillationSampler.
658
+
659
+ Args:
660
+ teacher_client: SamplingClient for the teacher model
661
+ student_client: Optional SamplingClient for the student model
662
+ (if None, uses teacher for sampling too)
663
+ """
664
+ self.teacher = teacher_client
665
+ self.student = student_client or teacher_client
666
+
667
+ def sample_with_teacher_logprobs(
668
+ self,
669
+ prompts: Sequence[str],
670
+ max_tokens: int = 256,
671
+ temperature: float = 0.8,
672
+ top_p: float = 1.0,
673
+ logprobs_k: int = 5,
674
+ ) -> tuple[SampleBatchResult, List[List[Dict[str, float]]]]:
675
+ """Sample from student and get teacher logprobs for those samples.
676
+
677
+ This is the key operation for distillation: the student generates
678
+ samples, then the teacher evaluates the log probability of each
679
+ token in those samples.
680
+
681
+ Args:
682
+ prompts: List of prompts
683
+ max_tokens: Maximum tokens per sample
684
+ temperature: Sampling temperature
685
+ top_p: Nucleus sampling parameter
686
+ logprobs_k: Number of top logprobs to get from teacher
687
+
688
+ Returns:
689
+ Tuple of (student samples, teacher top-k logprobs)
690
+
691
+ Example:
692
+ >>> prompts = ["Question: What is Python?"]
693
+ >>> samples, teacher_lps = sampler.sample_with_teacher_logprobs(
694
+ ... prompts, max_tokens=100, logprobs_k=5
695
+ ... )
696
+ >>> # Use samples and teacher_lps for distillation training
697
+ """
698
+ # Sample from student
699
+ student_samples = self.student.sample_batch(
700
+ prompts,
701
+ max_tokens=max_tokens,
702
+ temperature=temperature,
703
+ top_p=top_p,
704
+ logprobs_k=0, # Don't need student logprobs
705
+ )
706
+
707
+ # Get teacher logprobs for the student-generated completions
708
+ completions = student_samples.texts
709
+ teacher_logprobs = self.teacher.get_logprobs_for_texts(
710
+ prompts, completions, top_k=logprobs_k
711
+ )
712
+
713
+ return student_samples, teacher_logprobs
714
+
715
+ def sample_with_teacher_logprobs_async(
716
+ self,
717
+ prompts: Sequence[str],
718
+ max_tokens: int = 256,
719
+ temperature: float = 0.8,
720
+ top_p: float = 1.0,
721
+ logprobs_k: int = 5,
722
+ ) -> APIFuture[tuple[SampleBatchResult, List[List[Dict[str, float]]]]]:
723
+ """Async version of sample_with_teacher_logprobs()."""
724
+ def _do_sample():
725
+ return self.sample_with_teacher_logprobs(
726
+ prompts, max_tokens, temperature, top_p, logprobs_k
727
+ )
728
+
729
+ return self.student.pool.submit(_do_sample)