mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,729 @@
|
|
|
1
|
+
"""SamplingClient SDK for MLXSmith.
|
|
2
|
+
|
|
3
|
+
Client for text sampling with logprobs support, including top-k logprobs
|
|
4
|
+
per token. Designed for distillation workflows where teacher model logprobs
|
|
5
|
+
are needed for student training.
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
>>> from mlxsmith.sdk import SamplingClient
|
|
9
|
+
>>>
|
|
10
|
+
>>> # Local backend
|
|
11
|
+
>>> client = SamplingClient(backend=loaded_model.backend)
|
|
12
|
+
>>>
|
|
13
|
+
>>> # Single sample with logprobs
|
|
14
|
+
>>> result = client.sample("What is 2+2?", logprobs_k=5)
|
|
15
|
+
>>> print(result.text)
|
|
16
|
+
>>> for token_lp in result.top_k_logprobs:
|
|
17
|
+
... print(token_lp) # {"token": logprob, ...}
|
|
18
|
+
>>>
|
|
19
|
+
>>> # Batch sampling
|
|
20
|
+
>>> results = client.sample_batch(
|
|
21
|
+
... ["Q1", "Q2", "Q3"],
|
|
22
|
+
... max_tokens=100,
|
|
23
|
+
... logprobs_k=5
|
|
24
|
+
... )
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
from dataclasses import dataclass, field
|
|
30
|
+
from typing import Any, Dict, List, Optional, Sequence, Union, Callable
|
|
31
|
+
import concurrent.futures
|
|
32
|
+
|
|
33
|
+
from .future import APIFuture, SdkFuturePool
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class SampleResult:
|
|
38
|
+
"""Result from a sampling operation.
|
|
39
|
+
|
|
40
|
+
Attributes:
|
|
41
|
+
text: Generated text (completion only, prompt excluded)
|
|
42
|
+
token_ids: Full sequence of token IDs (prompt + completion)
|
|
43
|
+
prompt_len: Length of prompt in tokens
|
|
44
|
+
logprobs: Log probability of each generated token
|
|
45
|
+
top_k_logprobs: Top-k logprobs per token (list of {token: logprob} dicts)
|
|
46
|
+
finish_reason: Why generation stopped ("stop", "length", etc.)
|
|
47
|
+
metrics: Additional metrics (perplexity, etc.)
|
|
48
|
+
"""
|
|
49
|
+
text: str
|
|
50
|
+
token_ids: List[int]
|
|
51
|
+
prompt_len: int
|
|
52
|
+
logprobs: List[float] = field(default_factory=list)
|
|
53
|
+
top_k_logprobs: Optional[List[Dict[str, float]]] = None
|
|
54
|
+
prompt_logprobs: Optional[List[float]] = None
|
|
55
|
+
prompt_top_k_logprobs: Optional[List[Dict[str, float]]] = None
|
|
56
|
+
finish_reason: str = "stop"
|
|
57
|
+
metrics: Dict[str, float] = field(default_factory=dict)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def completion_token_ids(self) -> List[int]:
|
|
61
|
+
"""Get token IDs for the completion only."""
|
|
62
|
+
return self.token_ids[self.prompt_len:]
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def avg_logprob(self) -> float:
|
|
66
|
+
"""Average log probability of generated tokens."""
|
|
67
|
+
if not self.logprobs:
|
|
68
|
+
return 0.0
|
|
69
|
+
return sum(self.logprobs) / len(self.logprobs)
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def perplexity(self) -> float:
|
|
73
|
+
"""Perplexity of the completion."""
|
|
74
|
+
import math
|
|
75
|
+
avg_lp = self.avg_logprob
|
|
76
|
+
return math.exp(-avg_lp) if avg_lp != 0 else 1.0
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class SampleBatchResult:
|
|
81
|
+
"""Result from a batch sampling operation."""
|
|
82
|
+
results: List[SampleResult]
|
|
83
|
+
total_tokens: int = 0
|
|
84
|
+
|
|
85
|
+
def __len__(self) -> int:
|
|
86
|
+
return len(self.results)
|
|
87
|
+
|
|
88
|
+
def __getitem__(self, idx: int) -> SampleResult:
|
|
89
|
+
return self.results[idx]
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def texts(self) -> List[str]:
|
|
93
|
+
"""Get all completion texts."""
|
|
94
|
+
return [r.text for r in self.results]
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def all_token_ids(self) -> List[List[int]]:
|
|
98
|
+
"""Get all token ID sequences."""
|
|
99
|
+
return [r.token_ids for r in self.results]
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def all_top_k_logprobs(self) -> List[List[Dict[str, float]]]:
|
|
103
|
+
"""Get all top-k logprobs."""
|
|
104
|
+
return [r.top_k_logprobs for r in self.results if r.top_k_logprobs is not None]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class SamplingClient:
|
|
108
|
+
"""Client for sampling with logprobs support.
|
|
109
|
+
|
|
110
|
+
Can be used with a local backend or an API endpoint for distributed
|
|
111
|
+
sampling (e.g., teacher model running on a different machine).
|
|
112
|
+
|
|
113
|
+
Example:
|
|
114
|
+
>>> # Local sampling
|
|
115
|
+
>>> client = SamplingClient(backend=backend)
|
|
116
|
+
>>> result = client.sample("Hello", max_tokens=10, logprobs_k=5)
|
|
117
|
+
>>>
|
|
118
|
+
>>> # API-based sampling (for remote teacher model)
|
|
119
|
+
>>> api_client = SamplingClient(
|
|
120
|
+
... api_endpoint="http://teacher:8000",
|
|
121
|
+
... api_key="secret"
|
|
122
|
+
... )
|
|
123
|
+
>>> result = api_client.sample("Hello", max_tokens=10, logprobs_k=5)
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
backend: Any = None,
|
|
129
|
+
pool: Optional[SdkFuturePool] = None,
|
|
130
|
+
api_endpoint: Optional[str] = None,
|
|
131
|
+
api_key: Optional[str] = None,
|
|
132
|
+
timeout: float = 60.0,
|
|
133
|
+
):
|
|
134
|
+
"""Initialize SamplingClient.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
backend: Local LLM backend instance
|
|
138
|
+
pool: Optional SdkFuturePool for async execution
|
|
139
|
+
api_endpoint: Optional API endpoint for remote sampling
|
|
140
|
+
api_key: API key for authentication
|
|
141
|
+
timeout: Default timeout for operations
|
|
142
|
+
|
|
143
|
+
Raises:
|
|
144
|
+
ValueError: If neither backend nor api_endpoint is provided
|
|
145
|
+
"""
|
|
146
|
+
if backend is None and api_endpoint is None:
|
|
147
|
+
raise ValueError("Either backend or api_endpoint must be provided")
|
|
148
|
+
|
|
149
|
+
self.backend = backend
|
|
150
|
+
self.pool = pool or SdkFuturePool(max_workers=4)
|
|
151
|
+
self.api_endpoint = api_endpoint
|
|
152
|
+
self.api_key = api_key
|
|
153
|
+
self.timeout = timeout
|
|
154
|
+
self._session = None
|
|
155
|
+
|
|
156
|
+
def sample(
|
|
157
|
+
self,
|
|
158
|
+
prompt: str,
|
|
159
|
+
max_tokens: int = 256,
|
|
160
|
+
temperature: float = 0.8,
|
|
161
|
+
top_p: float = 1.0,
|
|
162
|
+
top_k: Optional[int] = None,
|
|
163
|
+
seed: Optional[int] = None,
|
|
164
|
+
stop: Optional[Sequence[str]] = None,
|
|
165
|
+
logprobs_k: int = 0,
|
|
166
|
+
include_prompt_logprobs: bool = False,
|
|
167
|
+
prompt_logprobs_k: int = 0,
|
|
168
|
+
) -> SampleResult:
|
|
169
|
+
"""Sample a single completion.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
prompt: Input prompt text
|
|
173
|
+
max_tokens: Maximum number of tokens to generate
|
|
174
|
+
temperature: Sampling temperature
|
|
175
|
+
top_p: Nucleus sampling parameter
|
|
176
|
+
top_k: Top-k sampling parameter
|
|
177
|
+
seed: Random seed for reproducibility
|
|
178
|
+
stop: Stop sequences
|
|
179
|
+
logprobs_k: Number of top logprobs to return per token (0 = none)
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
SampleResult with text, token IDs, and logprobs
|
|
183
|
+
|
|
184
|
+
Example:
|
|
185
|
+
>>> result = client.sample(
|
|
186
|
+
... "What is the capital of France?",
|
|
187
|
+
... max_tokens=50,
|
|
188
|
+
... temperature=0.7,
|
|
189
|
+
... logprobs_k=5
|
|
190
|
+
... )
|
|
191
|
+
>>> print(result.text)
|
|
192
|
+
>>> print(f"Perplexity: {result.perplexity}")
|
|
193
|
+
"""
|
|
194
|
+
if self.api_endpoint:
|
|
195
|
+
return self._sample_api(
|
|
196
|
+
prompt,
|
|
197
|
+
max_tokens,
|
|
198
|
+
temperature,
|
|
199
|
+
top_p,
|
|
200
|
+
top_k,
|
|
201
|
+
seed,
|
|
202
|
+
stop,
|
|
203
|
+
logprobs_k,
|
|
204
|
+
include_prompt_logprobs,
|
|
205
|
+
prompt_logprobs_k,
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
return self._sample_local(
|
|
209
|
+
prompt,
|
|
210
|
+
max_tokens,
|
|
211
|
+
temperature,
|
|
212
|
+
top_p,
|
|
213
|
+
top_k,
|
|
214
|
+
seed,
|
|
215
|
+
stop,
|
|
216
|
+
logprobs_k,
|
|
217
|
+
include_prompt_logprobs,
|
|
218
|
+
prompt_logprobs_k,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
def sample_async(
|
|
222
|
+
self,
|
|
223
|
+
prompt: str,
|
|
224
|
+
max_tokens: int = 256,
|
|
225
|
+
temperature: float = 0.8,
|
|
226
|
+
top_p: float = 1.0,
|
|
227
|
+
top_k: Optional[int] = None,
|
|
228
|
+
seed: Optional[int] = None,
|
|
229
|
+
stop: Optional[Sequence[str]] = None,
|
|
230
|
+
logprobs_k: int = 0,
|
|
231
|
+
include_prompt_logprobs: bool = False,
|
|
232
|
+
prompt_logprobs_k: int = 0,
|
|
233
|
+
) -> APIFuture[SampleResult]:
|
|
234
|
+
"""Async version of sample().
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
APIFuture that resolves to SampleResult
|
|
238
|
+
|
|
239
|
+
Example:
|
|
240
|
+
>>> future = client.sample_async("Hello", max_tokens=10)
|
|
241
|
+
>>> future.then(lambda r: print(r.text))
|
|
242
|
+
>>> result = future.result()
|
|
243
|
+
"""
|
|
244
|
+
def _do_sample():
|
|
245
|
+
return self.sample(
|
|
246
|
+
prompt,
|
|
247
|
+
max_tokens,
|
|
248
|
+
temperature,
|
|
249
|
+
top_p,
|
|
250
|
+
top_k,
|
|
251
|
+
seed,
|
|
252
|
+
stop,
|
|
253
|
+
logprobs_k,
|
|
254
|
+
include_prompt_logprobs,
|
|
255
|
+
prompt_logprobs_k,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return self.pool.submit(_do_sample)
|
|
259
|
+
|
|
260
|
+
def sample_batch(
|
|
261
|
+
self,
|
|
262
|
+
prompts: Sequence[str],
|
|
263
|
+
max_tokens: int = 256,
|
|
264
|
+
temperature: float = 0.8,
|
|
265
|
+
top_p: float = 1.0,
|
|
266
|
+
top_k: Optional[int] = None,
|
|
267
|
+
seed: Optional[int] = None,
|
|
268
|
+
stop: Optional[Sequence[str]] = None,
|
|
269
|
+
logprobs_k: int = 0,
|
|
270
|
+
include_prompt_logprobs: bool = False,
|
|
271
|
+
prompt_logprobs_k: int = 0,
|
|
272
|
+
max_workers: Optional[int] = None,
|
|
273
|
+
) -> SampleBatchResult:
|
|
274
|
+
"""Sample completions for multiple prompts.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
prompts: List of prompt strings
|
|
278
|
+
max_tokens: Maximum tokens per completion
|
|
279
|
+
temperature: Sampling temperature
|
|
280
|
+
top_p: Nucleus sampling parameter
|
|
281
|
+
top_k: Top-k sampling parameter
|
|
282
|
+
seed: Random seed (different for each prompt if provided)
|
|
283
|
+
stop: Stop sequences
|
|
284
|
+
logprobs_k: Number of top logprobs per token
|
|
285
|
+
max_workers: Number of parallel workers (None = use pool default)
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
SampleBatchResult with all results
|
|
289
|
+
|
|
290
|
+
Example:
|
|
291
|
+
>>> prompts = ["Q: 2+2=", "Q: 3*4=", "Q: 10/2="]
|
|
292
|
+
>>> results = client.sample_batch(prompts, max_tokens=10, logprobs_k=5)
|
|
293
|
+
>>> for prompt, result in zip(prompts, results):
|
|
294
|
+
... print(f"{prompt} {result.text}")
|
|
295
|
+
"""
|
|
296
|
+
# Use different seeds for each prompt if seed provided
|
|
297
|
+
seeds = [seed + i if seed is not None else None for i in range(len(prompts))]
|
|
298
|
+
|
|
299
|
+
# Submit all samples to thread pool
|
|
300
|
+
futures = [
|
|
301
|
+
self.sample_async(
|
|
302
|
+
prompt=p,
|
|
303
|
+
max_tokens=max_tokens,
|
|
304
|
+
temperature=temperature,
|
|
305
|
+
top_p=top_p,
|
|
306
|
+
top_k=top_k,
|
|
307
|
+
seed=s,
|
|
308
|
+
stop=stop,
|
|
309
|
+
logprobs_k=logprobs_k,
|
|
310
|
+
include_prompt_logprobs=include_prompt_logprobs,
|
|
311
|
+
prompt_logprobs_k=prompt_logprobs_k,
|
|
312
|
+
)
|
|
313
|
+
for p, s in zip(prompts, seeds)
|
|
314
|
+
]
|
|
315
|
+
|
|
316
|
+
# Collect results
|
|
317
|
+
results = [f.result(timeout=self.timeout) for f in futures]
|
|
318
|
+
total_tokens = sum(len(r.token_ids) for r in results)
|
|
319
|
+
|
|
320
|
+
return SampleBatchResult(results=results, total_tokens=total_tokens)
|
|
321
|
+
|
|
322
|
+
def sample_batch_async(
|
|
323
|
+
self,
|
|
324
|
+
prompts: Sequence[str],
|
|
325
|
+
max_tokens: int = 256,
|
|
326
|
+
temperature: float = 0.8,
|
|
327
|
+
top_p: float = 1.0,
|
|
328
|
+
top_k: Optional[int] = None,
|
|
329
|
+
seed: Optional[int] = None,
|
|
330
|
+
stop: Optional[Sequence[str]] = None,
|
|
331
|
+
logprobs_k: int = 0,
|
|
332
|
+
include_prompt_logprobs: bool = False,
|
|
333
|
+
prompt_logprobs_k: int = 0,
|
|
334
|
+
) -> APIFuture[SampleBatchResult]:
|
|
335
|
+
"""Async version of sample_batch().
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
APIFuture that resolves to SampleBatchResult
|
|
339
|
+
"""
|
|
340
|
+
def _do_batch():
|
|
341
|
+
return self.sample_batch(
|
|
342
|
+
prompts,
|
|
343
|
+
max_tokens,
|
|
344
|
+
temperature,
|
|
345
|
+
top_p,
|
|
346
|
+
top_k,
|
|
347
|
+
seed,
|
|
348
|
+
stop,
|
|
349
|
+
logprobs_k,
|
|
350
|
+
include_prompt_logprobs,
|
|
351
|
+
prompt_logprobs_k,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
return self.pool.submit(_do_batch)
|
|
355
|
+
|
|
356
|
+
def get_logprobs_for_texts(
|
|
357
|
+
self,
|
|
358
|
+
prompts: Sequence[str],
|
|
359
|
+
completions: Sequence[str],
|
|
360
|
+
top_k: int = 0,
|
|
361
|
+
) -> List[List[Dict[str, float]]]:
|
|
362
|
+
"""Get logprobs for given prompt-completion pairs.
|
|
363
|
+
|
|
364
|
+
This is useful for distillation where you have student-generated
|
|
365
|
+
texts and want teacher logprobs for them.
|
|
366
|
+
|
|
367
|
+
Args:
|
|
368
|
+
prompts: List of prompts
|
|
369
|
+
completions: List of completions (parallel to prompts)
|
|
370
|
+
top_k: Number of top logprobs per token
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
List of top-k logprobs for each completion
|
|
374
|
+
|
|
375
|
+
Example:
|
|
376
|
+
>>> prompts = ["Q: What is 2+2?"]
|
|
377
|
+
>>> completions = ["The answer is 4."]
|
|
378
|
+
>>> logprobs = client.get_logprobs_for_texts(prompts, completions, top_k=5)
|
|
379
|
+
>>> # logprobs[0] is list of {token: logprob} for each token
|
|
380
|
+
"""
|
|
381
|
+
results = []
|
|
382
|
+
for prompt, completion in zip(prompts, completions):
|
|
383
|
+
# Encode and get logprobs
|
|
384
|
+
prompt_ids = self.backend.encode(prompt)
|
|
385
|
+
full_ids = self.backend.encode(prompt + completion)
|
|
386
|
+
|
|
387
|
+
# Try to get logprobs from the backend
|
|
388
|
+
if hasattr(self.backend, 'generate_with_logprobs'):
|
|
389
|
+
gen = self.backend.generate_with_logprobs(
|
|
390
|
+
prompt,
|
|
391
|
+
max_new_tokens=len(full_ids) - len(prompt_ids),
|
|
392
|
+
logprobs=top_k,
|
|
393
|
+
)
|
|
394
|
+
results.append(gen.top_k_logprobs or [])
|
|
395
|
+
else:
|
|
396
|
+
results.append([])
|
|
397
|
+
|
|
398
|
+
return results
|
|
399
|
+
|
|
400
|
+
def get_prompt_token_logprobs(
|
|
401
|
+
self,
|
|
402
|
+
prompts: Sequence[str],
|
|
403
|
+
) -> List[List[float]]:
|
|
404
|
+
"""Get per-token logprobs for prompt tokens.
|
|
405
|
+
|
|
406
|
+
Returns a list of logprob lists (one per prompt). The first token is
|
|
407
|
+
omitted because it has no previous context.
|
|
408
|
+
"""
|
|
409
|
+
results: List[List[float]] = []
|
|
410
|
+
for prompt in prompts:
|
|
411
|
+
if not hasattr(self.backend, "token_logprobs"):
|
|
412
|
+
results.append([])
|
|
413
|
+
continue
|
|
414
|
+
prompt_ids = self.backend.encode(prompt)
|
|
415
|
+
logprobs, _ = self.backend.token_logprobs(
|
|
416
|
+
prompt_ids,
|
|
417
|
+
prompt_len=len(prompt_ids),
|
|
418
|
+
top_k=0,
|
|
419
|
+
include_prompt=True,
|
|
420
|
+
)
|
|
421
|
+
results.append(logprobs)
|
|
422
|
+
return results
|
|
423
|
+
|
|
424
|
+
def get_prompt_top_k_logprobs(
|
|
425
|
+
self,
|
|
426
|
+
prompts: Sequence[str],
|
|
427
|
+
top_k: int = 5,
|
|
428
|
+
) -> List[List[Dict[str, float]]]:
|
|
429
|
+
"""Get top-k logprobs for each prompt token.
|
|
430
|
+
|
|
431
|
+
Returns a list of per-token top-k dicts for each prompt.
|
|
432
|
+
"""
|
|
433
|
+
results: List[List[Dict[str, float]]] = []
|
|
434
|
+
for prompt in prompts:
|
|
435
|
+
if not hasattr(self.backend, "token_logprobs"):
|
|
436
|
+
results.append([])
|
|
437
|
+
continue
|
|
438
|
+
prompt_ids = self.backend.encode(prompt)
|
|
439
|
+
_, top_k_logprobs = self.backend.token_logprobs(
|
|
440
|
+
prompt_ids,
|
|
441
|
+
prompt_len=len(prompt_ids),
|
|
442
|
+
top_k=top_k,
|
|
443
|
+
include_prompt=True,
|
|
444
|
+
)
|
|
445
|
+
results.append(top_k_logprobs or [])
|
|
446
|
+
return results
|
|
447
|
+
|
|
448
|
+
def compute_sequence_logprobs(
|
|
449
|
+
self,
|
|
450
|
+
prompts: Sequence[str],
|
|
451
|
+
completions: Sequence[str],
|
|
452
|
+
) -> List[float]:
|
|
453
|
+
"""Compute total logprob for each prompt-completion pair.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
prompts: List of prompts
|
|
457
|
+
completions: List of completions
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
List of total logprobs
|
|
461
|
+
"""
|
|
462
|
+
results = []
|
|
463
|
+
for prompt, completion in zip(prompts, completions):
|
|
464
|
+
prompt_ids = self.backend.encode(prompt)
|
|
465
|
+
full_ids = self.backend.encode(prompt + completion)
|
|
466
|
+
|
|
467
|
+
if hasattr(self.backend, 'sequence_logprob'):
|
|
468
|
+
logprob = self.backend.sequence_logprob(full_ids, prompt_len=len(prompt_ids))
|
|
469
|
+
results.append(float(logprob))
|
|
470
|
+
else:
|
|
471
|
+
results.append(0.0)
|
|
472
|
+
|
|
473
|
+
return results
|
|
474
|
+
|
|
475
|
+
# ========================================================================
|
|
476
|
+
# Internal methods
|
|
477
|
+
# ========================================================================
|
|
478
|
+
|
|
479
|
+
def _sample_local(
|
|
480
|
+
self,
|
|
481
|
+
prompt: str,
|
|
482
|
+
max_tokens: int,
|
|
483
|
+
temperature: float,
|
|
484
|
+
top_p: float,
|
|
485
|
+
top_k: Optional[int],
|
|
486
|
+
seed: Optional[int],
|
|
487
|
+
stop: Optional[Sequence[str]],
|
|
488
|
+
logprobs_k: int,
|
|
489
|
+
include_prompt_logprobs: bool,
|
|
490
|
+
prompt_logprobs_k: int,
|
|
491
|
+
) -> SampleResult:
|
|
492
|
+
"""Sample using local backend."""
|
|
493
|
+
if logprobs_k > 0:
|
|
494
|
+
gen = self.backend.generate_with_logprobs(
|
|
495
|
+
prompt,
|
|
496
|
+
max_new_tokens=max_tokens,
|
|
497
|
+
temperature=temperature,
|
|
498
|
+
top_p=top_p,
|
|
499
|
+
top_k_sampling=top_k,
|
|
500
|
+
seed=seed,
|
|
501
|
+
logprobs=logprobs_k,
|
|
502
|
+
)
|
|
503
|
+
else:
|
|
504
|
+
gen = self.backend.generate(
|
|
505
|
+
prompt,
|
|
506
|
+
max_new_tokens=max_tokens,
|
|
507
|
+
temperature=temperature,
|
|
508
|
+
top_p=top_p,
|
|
509
|
+
top_k=top_k,
|
|
510
|
+
seed=seed,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# Extract completion text (remove prompt prefix if present)
|
|
514
|
+
completion_text = gen.text
|
|
515
|
+
if completion_text.startswith(prompt):
|
|
516
|
+
completion_text = completion_text[len(prompt):]
|
|
517
|
+
|
|
518
|
+
# Apply stop sequences
|
|
519
|
+
finish_reason = "stop"
|
|
520
|
+
if stop:
|
|
521
|
+
for stop_seq in stop:
|
|
522
|
+
if stop_seq in completion_text:
|
|
523
|
+
completion_text = completion_text[:completion_text.index(stop_seq)]
|
|
524
|
+
finish_reason = "stop"
|
|
525
|
+
break
|
|
526
|
+
|
|
527
|
+
# Check if we hit length limit
|
|
528
|
+
if len(gen.token_ids) - gen.prompt_len >= max_tokens:
|
|
529
|
+
finish_reason = "length"
|
|
530
|
+
|
|
531
|
+
prompt_logprobs = None
|
|
532
|
+
prompt_top_k_logprobs = None
|
|
533
|
+
if (include_prompt_logprobs or prompt_logprobs_k > 0) and hasattr(self.backend, "token_logprobs"):
|
|
534
|
+
prompt_ids = self.backend.encode(prompt)
|
|
535
|
+
try:
|
|
536
|
+
plogps, ptopk = self.backend.token_logprobs(
|
|
537
|
+
prompt_ids,
|
|
538
|
+
prompt_len=len(prompt_ids),
|
|
539
|
+
top_k=prompt_logprobs_k if prompt_logprobs_k > 0 else 0,
|
|
540
|
+
include_prompt=True,
|
|
541
|
+
)
|
|
542
|
+
if include_prompt_logprobs:
|
|
543
|
+
prompt_logprobs = plogps
|
|
544
|
+
if prompt_logprobs_k > 0:
|
|
545
|
+
prompt_top_k_logprobs = ptopk or []
|
|
546
|
+
except Exception:
|
|
547
|
+
prompt_logprobs = None
|
|
548
|
+
prompt_top_k_logprobs = None
|
|
549
|
+
|
|
550
|
+
return SampleResult(
|
|
551
|
+
text=completion_text,
|
|
552
|
+
token_ids=gen.token_ids,
|
|
553
|
+
prompt_len=gen.prompt_len,
|
|
554
|
+
logprobs=gen.logprobs or [],
|
|
555
|
+
top_k_logprobs=gen.top_k_logprobs,
|
|
556
|
+
prompt_logprobs=prompt_logprobs,
|
|
557
|
+
prompt_top_k_logprobs=prompt_top_k_logprobs,
|
|
558
|
+
finish_reason=finish_reason,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
def _sample_api(
|
|
562
|
+
self,
|
|
563
|
+
prompt: str,
|
|
564
|
+
max_tokens: int,
|
|
565
|
+
temperature: float,
|
|
566
|
+
top_p: float,
|
|
567
|
+
top_k: Optional[int],
|
|
568
|
+
seed: Optional[int],
|
|
569
|
+
stop: Optional[Sequence[str]],
|
|
570
|
+
logprobs_k: int,
|
|
571
|
+
include_prompt_logprobs: bool,
|
|
572
|
+
prompt_logprobs_k: int,
|
|
573
|
+
) -> SampleResult:
|
|
574
|
+
"""Sample using remote API endpoint."""
|
|
575
|
+
import urllib.request
|
|
576
|
+
import urllib.error
|
|
577
|
+
import json
|
|
578
|
+
|
|
579
|
+
url = f"{self.api_endpoint}/internal/rollout"
|
|
580
|
+
|
|
581
|
+
payload = {
|
|
582
|
+
"prompt": prompt,
|
|
583
|
+
"max_tokens": max_tokens,
|
|
584
|
+
"temperature": temperature,
|
|
585
|
+
"top_p": top_p,
|
|
586
|
+
"top_k": top_k,
|
|
587
|
+
"seed": seed,
|
|
588
|
+
"include_tokens": True,
|
|
589
|
+
"include_logprobs": True,
|
|
590
|
+
"include_top_k_logprobs": logprobs_k if logprobs_k > 0 else None,
|
|
591
|
+
"include_prompt_logprobs": bool(include_prompt_logprobs),
|
|
592
|
+
"include_prompt_top_k_logprobs": prompt_logprobs_k if prompt_logprobs_k > 0 else None,
|
|
593
|
+
"include_text": True,
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
headers = {"Content-Type": "application/json"}
|
|
597
|
+
if self.api_key:
|
|
598
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
599
|
+
|
|
600
|
+
req = urllib.request.Request(
|
|
601
|
+
url,
|
|
602
|
+
data=json.dumps(payload).encode(),
|
|
603
|
+
headers=headers,
|
|
604
|
+
method="POST",
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
try:
|
|
608
|
+
with urllib.request.urlopen(req, timeout=self.timeout) as response:
|
|
609
|
+
data = json.loads(response.read().decode())
|
|
610
|
+
|
|
611
|
+
return SampleResult(
|
|
612
|
+
text=data.get("completion", ""),
|
|
613
|
+
token_ids=data.get("token_ids", []),
|
|
614
|
+
prompt_len=data.get("prompt_len", 0),
|
|
615
|
+
logprobs=data.get("logprobs", []),
|
|
616
|
+
top_k_logprobs=data.get("top_k_logprobs"),
|
|
617
|
+
prompt_logprobs=data.get("prompt_logprobs"),
|
|
618
|
+
prompt_top_k_logprobs=data.get("prompt_top_k_logprobs"),
|
|
619
|
+
finish_reason="stop",
|
|
620
|
+
)
|
|
621
|
+
except urllib.error.HTTPError as e:
|
|
622
|
+
raise RuntimeError(f"API error: {e.code} - {e.read().decode()}")
|
|
623
|
+
except Exception as e:
|
|
624
|
+
raise RuntimeError(f"Failed to sample from API: {e}")
|
|
625
|
+
|
|
626
|
+
def shutdown(self) -> None:
|
|
627
|
+
"""Shutdown the client and its thread pool."""
|
|
628
|
+
self.pool.shutdown(wait=True)
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
class DistillationSampler:
|
|
632
|
+
"""Helper for knowledge distillation workflows.
|
|
633
|
+
|
|
634
|
+
Combines a student training client with a teacher sampling client
|
|
635
|
+
to provide teacher logprobs for student-generated samples.
|
|
636
|
+
|
|
637
|
+
Example:
|
|
638
|
+
>>> sampler = DistillationSampler(
|
|
639
|
+
... teacher_client=teacher_sampling_client,
|
|
640
|
+
... student_client=student_training_client,
|
|
641
|
+
... )
|
|
642
|
+
>>>
|
|
643
|
+
>>> # Generate samples from student and get teacher logprobs
|
|
644
|
+
>>> prompts = ["Q: What is 2+2?"]
|
|
645
|
+
>>> samples, teacher_logprobs = sampler.sample_with_teacher_logprobs(
|
|
646
|
+
... prompts,
|
|
647
|
+
... max_tokens=50,
|
|
648
|
+
... logprobs_k=5
|
|
649
|
+
... )
|
|
650
|
+
"""
|
|
651
|
+
|
|
652
|
+
def __init__(
|
|
653
|
+
self,
|
|
654
|
+
teacher_client: SamplingClient,
|
|
655
|
+
student_client: Optional['SamplingClient'] = None, # type: ignore
|
|
656
|
+
):
|
|
657
|
+
"""Initialize DistillationSampler.
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
teacher_client: SamplingClient for the teacher model
|
|
661
|
+
student_client: Optional SamplingClient for the student model
|
|
662
|
+
(if None, uses teacher for sampling too)
|
|
663
|
+
"""
|
|
664
|
+
self.teacher = teacher_client
|
|
665
|
+
self.student = student_client or teacher_client
|
|
666
|
+
|
|
667
|
+
def sample_with_teacher_logprobs(
|
|
668
|
+
self,
|
|
669
|
+
prompts: Sequence[str],
|
|
670
|
+
max_tokens: int = 256,
|
|
671
|
+
temperature: float = 0.8,
|
|
672
|
+
top_p: float = 1.0,
|
|
673
|
+
logprobs_k: int = 5,
|
|
674
|
+
) -> tuple[SampleBatchResult, List[List[Dict[str, float]]]]:
|
|
675
|
+
"""Sample from student and get teacher logprobs for those samples.
|
|
676
|
+
|
|
677
|
+
This is the key operation for distillation: the student generates
|
|
678
|
+
samples, then the teacher evaluates the log probability of each
|
|
679
|
+
token in those samples.
|
|
680
|
+
|
|
681
|
+
Args:
|
|
682
|
+
prompts: List of prompts
|
|
683
|
+
max_tokens: Maximum tokens per sample
|
|
684
|
+
temperature: Sampling temperature
|
|
685
|
+
top_p: Nucleus sampling parameter
|
|
686
|
+
logprobs_k: Number of top logprobs to get from teacher
|
|
687
|
+
|
|
688
|
+
Returns:
|
|
689
|
+
Tuple of (student samples, teacher top-k logprobs)
|
|
690
|
+
|
|
691
|
+
Example:
|
|
692
|
+
>>> prompts = ["Question: What is Python?"]
|
|
693
|
+
>>> samples, teacher_lps = sampler.sample_with_teacher_logprobs(
|
|
694
|
+
... prompts, max_tokens=100, logprobs_k=5
|
|
695
|
+
... )
|
|
696
|
+
>>> # Use samples and teacher_lps for distillation training
|
|
697
|
+
"""
|
|
698
|
+
# Sample from student
|
|
699
|
+
student_samples = self.student.sample_batch(
|
|
700
|
+
prompts,
|
|
701
|
+
max_tokens=max_tokens,
|
|
702
|
+
temperature=temperature,
|
|
703
|
+
top_p=top_p,
|
|
704
|
+
logprobs_k=0, # Don't need student logprobs
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
# Get teacher logprobs for the student-generated completions
|
|
708
|
+
completions = student_samples.texts
|
|
709
|
+
teacher_logprobs = self.teacher.get_logprobs_for_texts(
|
|
710
|
+
prompts, completions, top_k=logprobs_k
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
return student_samples, teacher_logprobs
|
|
714
|
+
|
|
715
|
+
def sample_with_teacher_logprobs_async(
|
|
716
|
+
self,
|
|
717
|
+
prompts: Sequence[str],
|
|
718
|
+
max_tokens: int = 256,
|
|
719
|
+
temperature: float = 0.8,
|
|
720
|
+
top_p: float = 1.0,
|
|
721
|
+
logprobs_k: int = 5,
|
|
722
|
+
) -> APIFuture[tuple[SampleBatchResult, List[List[Dict[str, float]]]]]:
|
|
723
|
+
"""Async version of sample_with_teacher_logprobs()."""
|
|
724
|
+
def _do_sample():
|
|
725
|
+
return self.sample_with_teacher_logprobs(
|
|
726
|
+
prompts, max_tokens, temperature, top_p, logprobs_k
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
return self.student.pool.submit(_do_sample)
|