mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/server.py ADDED
@@ -0,0 +1,376 @@
1
+ """MLXSmith FastAPI server with OpenAI-compatible chat completions.
2
+
3
+ This module provides:
4
+ - OpenAI-compatible chat completions with streaming support
5
+ - Internal rollout endpoint (tokens + logprobs)
6
+ - Adapter hot-reload
7
+ - RLM state and history endpoints
8
+ - Web UI and RLM monitor (when enabled)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import time
15
+ import uuid
16
+ from pathlib import Path
17
+ from typing import Any, Dict, List, Optional
18
+
19
+ from fastapi import FastAPI
20
+ from fastapi.responses import StreamingResponse, HTMLResponse
21
+ from pydantic import BaseModel
22
+
23
+ from .config import ProjectConfig
24
+ from .models import resolve_model_spec
25
+ from .llm.registry import get_llm_backend
26
+
27
+ # Import new API handlers
28
+ from .api.handlers import create_router, InternalAuthMiddleware
29
+ from .api.schemas import (
30
+ ChatMessage,
31
+ ChatRequest,
32
+ ChatResponse,
33
+ RolloutRequest,
34
+ RolloutResponse,
35
+ AdapterReloadRequest,
36
+ AdapterReloadResponse,
37
+ )
38
+
39
+ # Re-export schemas for backward compatibility
40
+ __all__ = [
41
+ "create_app",
42
+ "ChatMessage",
43
+ "ChatRequest",
44
+ "ChatResponse",
45
+ "RolloutRequest",
46
+ "RolloutResponse",
47
+ "AdapterReloadRequest",
48
+ "AdapterReloadResponse",
49
+ ]
50
+
51
+
52
+ def _ui_html() -> str:
53
+ return """<!doctype html>
54
+ <html lang="en">
55
+ <head>
56
+ <meta charset="utf-8" />
57
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
58
+ <title>mlxsmith serve</title>
59
+ <style>
60
+ :root {
61
+ --bg: #f7f2ea;
62
+ --ink: #231f20;
63
+ --accent: #0b6e4f;
64
+ --accent-2: #b44a1d;
65
+ --card: #fff8ee;
66
+ }
67
+ * { box-sizing: border-box; }
68
+ body {
69
+ margin: 0;
70
+ font-family: "Space Grotesk", "Avenir Next", "Segoe UI", sans-serif;
71
+ color: var(--ink);
72
+ background:
73
+ radial-gradient(circle at 10% 10%, rgba(11,110,79,0.12), transparent 35%),
74
+ radial-gradient(circle at 90% 20%, rgba(180,74,29,0.12), transparent 40%),
75
+ var(--bg);
76
+ }
77
+ header {
78
+ padding: 28px 32px;
79
+ border-bottom: 1px solid rgba(35,31,32,0.1);
80
+ display: flex;
81
+ align-items: center;
82
+ justify-content: space-between;
83
+ gap: 16px;
84
+ }
85
+ .title {
86
+ font-size: 20px;
87
+ letter-spacing: 0.02em;
88
+ text-transform: uppercase;
89
+ }
90
+ .wrap {
91
+ display: grid;
92
+ grid-template-columns: 1.2fr 0.8fr;
93
+ gap: 24px;
94
+ padding: 24px 32px 40px;
95
+ }
96
+ .card {
97
+ background: var(--card);
98
+ border: 1px solid rgba(35,31,32,0.08);
99
+ border-radius: 16px;
100
+ padding: 18px;
101
+ box-shadow: 0 6px 20px rgba(35,31,32,0.06);
102
+ }
103
+ h2 { margin: 0 0 12px; font-size: 18px; }
104
+ textarea {
105
+ width: 100%;
106
+ min-height: 120px;
107
+ border: 1px solid rgba(35,31,32,0.2);
108
+ border-radius: 12px;
109
+ padding: 12px;
110
+ font: inherit;
111
+ background: #fffdf9;
112
+ }
113
+ .row { display: flex; gap: 12px; margin-top: 10px; }
114
+ button {
115
+ border: 0;
116
+ border-radius: 10px;
117
+ padding: 10px 14px;
118
+ font: inherit;
119
+ cursor: pointer;
120
+ }
121
+ .btn-primary { background: var(--accent); color: #fff; }
122
+ .btn-secondary { background: var(--accent-2); color: #fff; }
123
+ .log {
124
+ background: #14110f;
125
+ color: #fdf7f0;
126
+ border-radius: 12px;
127
+ padding: 12px;
128
+ height: 300px;
129
+ overflow: auto;
130
+ font-family: "Iosevka", "Menlo", monospace;
131
+ font-size: 13px;
132
+ }
133
+ .muted { color: rgba(35,31,32,0.6); font-size: 12px; }
134
+ @media (max-width: 900px) {
135
+ .wrap { grid-template-columns: 1fr; }
136
+ }
137
+ </style>
138
+ </head>
139
+ <body>
140
+ <header>
141
+ <div class="title">mlxsmith serve</div>
142
+ <div class="muted">OpenAI-compatible API + RLM monitor</div>
143
+ </header>
144
+ <div class="wrap">
145
+ <section class="card">
146
+ <h2>Chat</h2>
147
+ <textarea id="prompt" placeholder="Ask the model anything..."></textarea>
148
+ <div class="row">
149
+ <button class="btn-primary" id="send">Send</button>
150
+ <button class="btn-secondary" id="stream">Stream</button>
151
+ <a class="btn-secondary" href="/rlm/monitor" style="text-decoration:none;">RLM Monitor</a>
152
+ </div>
153
+ <div class="log" id="output"></div>
154
+ </section>
155
+ <section class="card">
156
+ <h2>Status</h2>
157
+ <div id="status" class="muted">Loading…</div>
158
+ <div class="muted" style="margin-top:12px;">Tip: enable <code>serve.ui</code> in mlxsmith.yaml to keep this page on by default.</div>
159
+ </section>
160
+ </div>
161
+ <script>
162
+ const output = document.getElementById('output');
163
+ const statusEl = document.getElementById('status');
164
+ const promptEl = document.getElementById('prompt');
165
+ const baseBody = () => ({
166
+ messages: [{role: 'user', content: promptEl.value}],
167
+ max_tokens: 256
168
+ });
169
+
170
+ async function refreshStatus() {
171
+ try {
172
+ const res = await fetch('/internal/rlm/state');
173
+ if (!res.ok) return;
174
+ const data = await res.json();
175
+ statusEl.textContent = JSON.stringify(data, null, 2);
176
+ } catch (err) {
177
+ statusEl.textContent = 'Status unavailable';
178
+ }
179
+ }
180
+ refreshStatus();
181
+ setInterval(refreshStatus, 5000);
182
+
183
+ document.getElementById('send').onclick = async () => {
184
+ output.textContent = '';
185
+ const res = await fetch('/v1/chat/completions', {
186
+ method: 'POST',
187
+ headers: {'Content-Type': 'application/json'},
188
+ body: JSON.stringify(baseBody())
189
+ });
190
+ const data = await res.json();
191
+ output.textContent = data.choices?.[0]?.message?.content || '';
192
+ };
193
+
194
+ document.getElementById('stream').onclick = async () => {
195
+ output.textContent = '';
196
+ const body = baseBody();
197
+ body.stream = true;
198
+ const res = await fetch('/v1/chat/completions', {
199
+ method: 'POST',
200
+ headers: {'Content-Type': 'application/json'},
201
+ body: JSON.stringify(body)
202
+ });
203
+ const reader = res.body.getReader();
204
+ const decoder = new TextDecoder('utf-8');
205
+ let buf = '';
206
+ while (true) {
207
+ const { value, done } = await reader.read();
208
+ if (done) break;
209
+ buf += decoder.decode(value, {stream: true});
210
+ const parts = buf.split('\\n\\n');
211
+ buf = parts.pop() || '';
212
+ for (const part of parts) {
213
+ const line = part.trim();
214
+ if (!line.startsWith('data:')) continue;
215
+ const payload = line.replace('data:', '').trim();
216
+ if (payload === '[DONE]') return;
217
+ try {
218
+ const obj = JSON.parse(payload);
219
+ const delta = obj.choices?.[0]?.delta?.content || '';
220
+ output.textContent += delta;
221
+ } catch (e) {}
222
+ }
223
+ }
224
+ };
225
+ </script>
226
+ </body>
227
+ </html>
228
+ """
229
+
230
+
231
+ def _monitor_html() -> str:
232
+ return """<!doctype html>
233
+ <html lang="en">
234
+ <head>
235
+ <meta charset="utf-8" />
236
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
237
+ <title>RLM Monitor</title>
238
+ <style>
239
+ body {
240
+ margin: 0;
241
+ font-family: "Space Grotesk", "Avenir Next", "Segoe UI", sans-serif;
242
+ background: #f4f6f1;
243
+ color: #222;
244
+ }
245
+ header {
246
+ padding: 24px 32px;
247
+ background: #1f3b2c;
248
+ color: #f4f6f1;
249
+ display: flex;
250
+ justify-content: space-between;
251
+ align-items: center;
252
+ }
253
+ .wrap { padding: 24px 32px; display: grid; gap: 16px; }
254
+ .card {
255
+ background: #fff;
256
+ border-radius: 14px;
257
+ padding: 16px;
258
+ box-shadow: 0 8px 20px rgba(0,0,0,0.08);
259
+ }
260
+ canvas { width: 100%; height: 220px; }
261
+ pre { margin: 0; font-size: 12px; }
262
+ </style>
263
+ </head>
264
+ <body>
265
+ <header>
266
+ <div>RLM Monitor</div>
267
+ <a href="/" style="color:#f4f6f1;text-decoration:none;">Back to Serve</a>
268
+ </header>
269
+ <div class="wrap">
270
+ <section class="card">
271
+ <canvas id="chart" width="900" height="220"></canvas>
272
+ </section>
273
+ <section class="card">
274
+ <pre id="state">Loading…</pre>
275
+ </section>
276
+ </div>
277
+ <script>
278
+ const canvas = document.getElementById('chart');
279
+ const ctx = canvas.getContext('2d');
280
+ const stateEl = document.getElementById('state');
281
+
282
+ function draw(history) {
283
+ ctx.clearRect(0,0,canvas.width,canvas.height);
284
+ if (!history.length) return;
285
+ const scores = history.map(h => h.adapter_score || 0);
286
+ const max = Math.max(...scores, 1);
287
+ const min = Math.min(...scores, 0);
288
+ const pad = 20;
289
+ const w = canvas.width - pad * 2;
290
+ const h = canvas.height - pad * 2;
291
+ ctx.strokeStyle = '#1f3b2c';
292
+ ctx.lineWidth = 2;
293
+ ctx.beginPath();
294
+ scores.forEach((s, i) => {
295
+ const x = pad + (i / Math.max(1, scores.length - 1)) * w;
296
+ const y = pad + (1 - (s - min) / (max - min || 1)) * h;
297
+ if (i === 0) ctx.moveTo(x, y); else ctx.lineTo(x, y);
298
+ });
299
+ ctx.stroke();
300
+ }
301
+
302
+ async function refresh() {
303
+ try {
304
+ const h = await fetch('/internal/rlm/history').then(r => r.json());
305
+ draw(h);
306
+ const s = await fetch('/internal/rlm/state').then(r => r.json());
307
+ stateEl.textContent = JSON.stringify(s, null, 2);
308
+ } catch (e) {
309
+ stateEl.textContent = 'History unavailable';
310
+ }
311
+ }
312
+ refresh();
313
+ setInterval(refresh, 5000);
314
+ </script>
315
+ </body>
316
+ </html>
317
+ """
318
+
319
+
320
+ def create_app(model_spec: str, cfg: ProjectConfig) -> FastAPI:
321
+ """Create and configure the FastAPI application.
322
+
323
+ Args:
324
+ model_spec: Model specification (path or HF repo ID)
325
+ cfg: Project configuration
326
+
327
+ Returns:
328
+ Configured FastAPI application
329
+ """
330
+ app = FastAPI(
331
+ title="mlxsmith",
332
+ description="MLXSmith API server for local LLM inference and RLM training",
333
+ version="0.1.0",
334
+ )
335
+
336
+ # Load LLM backend
337
+ llm = get_llm_backend(cfg.model.backend)
338
+ base_model, adapter_path, _meta = resolve_model_spec(Path.cwd(), model_spec, cfg)
339
+ llm.load(
340
+ base_model,
341
+ max_seq_len=cfg.model.max_seq_len,
342
+ dtype=cfg.model.dtype,
343
+ trust_remote_code=cfg.model.trust_remote_code,
344
+ )
345
+ if adapter_path:
346
+ llm.apply_adapter(str(adapter_path))
347
+ current_adapter = str(adapter_path) if adapter_path else None
348
+
349
+ # Add authentication middleware for internal endpoints
350
+ app.add_middleware(
351
+ InternalAuthMiddleware,
352
+ api_token=None, # Set via MLXSMITH_API_TOKEN env var
353
+ internal_prefix="/internal",
354
+ public_paths=["/health", "/v1/chat/completions", "/docs", "/openapi.json"],
355
+ )
356
+
357
+ # Create and include the API router
358
+ router = create_router(
359
+ llm_backend=llm,
360
+ base_model=base_model,
361
+ current_adapter=current_adapter,
362
+ cfg=cfg,
363
+ )
364
+ app.include_router(router)
365
+
366
+ # Add UI routes if enabled
367
+ if cfg.serve.ui:
368
+ @app.get("/")
369
+ def ui_root():
370
+ return HTMLResponse(_ui_html())
371
+
372
+ @app.get("/rlm/monitor")
373
+ def ui_monitor():
374
+ return HTMLResponse(_monitor_html())
375
+
376
+ return app
File without changes
@@ -0,0 +1,279 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import random
5
+ from pathlib import Path
6
+ from typing import Iterable
7
+
8
+ from rich.console import Console
9
+
10
+ from ..accel import get_backend
11
+ from ..config import ProjectConfig
12
+ from ..llm.interface import compute_logprobs
13
+ from ..models import resolve_model_spec
14
+ from ..runs import RunPaths, new_run, snapshot_config
15
+ from ..llm.registry import get_llm_backend
16
+ from ..llm.backend import BackendNotAvailable
17
+ from ..sdk.losses import importance_sampling_loss
18
+ from ..train.lora import LoRAConfig
19
+ from ..util import ensure_dir, write_jsonl, now_ts, tree_add, tree_scale
20
+ from .sft import run_sft
21
+
22
+ console = Console()
23
+
24
+
25
+ def _iter_prompts(path: Path) -> Iterable[str]:
26
+ for line in path.read_text(encoding="utf-8").splitlines():
27
+ if not line.strip():
28
+ continue
29
+ row = json.loads(line)
30
+ prompt = row.get("prompt") or row.get("instruction") or row.get("input") or ""
31
+ if not prompt and "messages" in row:
32
+ msgs = row.get("messages") or []
33
+ if msgs:
34
+ prompt = "\n".join([m.get("content", "") for m in msgs])
35
+ if prompt:
36
+ yield str(prompt)
37
+
38
+
39
+ def run_distill(
40
+ project_root: Path,
41
+ cfg: ProjectConfig,
42
+ data_path: Path,
43
+ *,
44
+ teacher_model: str,
45
+ student_model: str,
46
+ accel: str,
47
+ mode: str = "offline",
48
+ max_new_tokens: int = 256,
49
+ temperature: float = 0.7,
50
+ ) -> RunPaths:
51
+ run = new_run(project_root, "distill")
52
+ snapshot_config(cfg.model_dump(), run.config_snapshot_path)
53
+
54
+ mode = mode.lower()
55
+ prompts = list(_iter_prompts(data_path))
56
+ if not prompts:
57
+ raise RuntimeError("No prompts found in distillation dataset")
58
+
59
+ backend = get_backend(accel)
60
+ backend.patch()
61
+ console.print(f"[bold]DISTILL[/bold] run: {run.run_dir.name} mode={mode} accel={backend.name}")
62
+
63
+ teacher = get_llm_backend(cfg.model.backend)
64
+ student = get_llm_backend(cfg.model.backend)
65
+
66
+ base_teacher, _, _meta_t = resolve_model_spec(project_root, teacher_model, cfg)
67
+ base_student, adapter_path, _meta_s = resolve_model_spec(project_root, student_model, cfg)
68
+
69
+ try:
70
+ teacher.load(
71
+ base_teacher,
72
+ max_seq_len=cfg.model.max_seq_len,
73
+ dtype=cfg.model.dtype,
74
+ trust_remote_code=cfg.model.trust_remote_code,
75
+ )
76
+ student.load(
77
+ base_student,
78
+ max_seq_len=cfg.model.max_seq_len,
79
+ dtype=cfg.model.dtype,
80
+ trust_remote_code=cfg.model.trust_remote_code,
81
+ )
82
+ except BackendNotAvailable as e:
83
+ console.print(f"[yellow]MLX backend unavailable[/yellow]: {e}")
84
+ (run.adapter_dir / "ADAPTER.txt").write_text(
85
+ f"Backend unavailable in this environment.\nteacher={teacher_model}\nstudent={student_model}\n",
86
+ encoding="utf-8",
87
+ )
88
+ return run
89
+
90
+ distill_dir = ensure_dir(run.artifacts_dir / "distill_data")
91
+ train_path = distill_dir / "train.jsonl"
92
+
93
+ if mode == "opd":
94
+ rows = [{"prompt": prompt} for prompt in prompts]
95
+ write_jsonl(train_path, rows)
96
+ console.print(f"[green]Distill dataset (OPD)[/green] {train_path}")
97
+
98
+ if adapter_path:
99
+ student.apply_adapter(str(adapter_path))
100
+ else:
101
+ lora_cfg = LoRAConfig(
102
+ r=cfg.lora.r,
103
+ alpha=cfg.lora.alpha,
104
+ dropout=cfg.lora.dropout,
105
+ target_modules=list(cfg.lora.target_modules or []),
106
+ num_layers=cfg.lora.num_layers,
107
+ scale=cfg.lora.scale,
108
+ fine_tune_type=cfg.lora.fine_tune_type,
109
+ )
110
+ student.apply_lora_from_config(lora_cfg)
111
+
112
+ opt, _params = student.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
113
+
114
+ rng = random.Random(cfg.train.seed)
115
+ total = int(cfg.train.iters)
116
+ grad_accum = max(1, int(cfg.train.grad_accum))
117
+ max_len = int(cfg.model.max_seq_len)
118
+ accum_grads = None
119
+
120
+ def _to_float(val):
121
+ if hasattr(val, "item"):
122
+ try:
123
+ return float(val.item())
124
+ except Exception:
125
+ pass
126
+ return float(val)
127
+
128
+ for step in range(1, total + 1):
129
+ prompt = rng.choice(prompts)
130
+ gen = student.generate_with_logprobs(
131
+ prompt,
132
+ max_new_tokens=max_new_tokens,
133
+ temperature=temperature,
134
+ logprobs=0,
135
+ )
136
+ completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
137
+ behavior_logprobs = list(gen.logprobs) if gen.logprobs else []
138
+ if not behavior_logprobs:
139
+ continue
140
+
141
+ teacher_logprobs: list[float] = []
142
+ teacher_avg = None
143
+ teacher_res = compute_logprobs(
144
+ teacher,
145
+ prompt,
146
+ completion,
147
+ top_k=0,
148
+ max_seq_len=max_len,
149
+ )
150
+ if teacher_res.token_logprobs:
151
+ teacher_logprobs = [float(lp) for lp in teacher_res.token_logprobs]
152
+ else:
153
+ prompt_ids = teacher.encode(prompt)
154
+ ids = teacher.encode(prompt + completion)
155
+ if max_len and len(ids) > max_len:
156
+ overflow = len(ids) - max_len
157
+ ids = ids[overflow:]
158
+ prompt_len = max(0, len(prompt_ids) - overflow)
159
+ else:
160
+ prompt_len = len(prompt_ids)
161
+ teacher_len = max(0, len(ids) - prompt_len)
162
+ if teacher_len > 0:
163
+ teacher_total = teacher.sequence_logprob(ids, prompt_len=prompt_len)
164
+ teacher_avg = _to_float(teacher_total) / float(teacher_len)
165
+
166
+ ids = list(gen.token_ids)
167
+ prompt_len = gen.prompt_len
168
+ if max_len and len(ids) > max_len:
169
+ overflow = len(ids) - max_len
170
+ ids = ids[overflow:]
171
+ if overflow >= prompt_len:
172
+ removed_completion = overflow - prompt_len
173
+ prompt_len = 0
174
+ if removed_completion > 0:
175
+ behavior_logprobs = behavior_logprobs[removed_completion:]
176
+ if teacher_logprobs:
177
+ teacher_logprobs = teacher_logprobs[removed_completion:]
178
+ else:
179
+ prompt_len = max(0, prompt_len - overflow)
180
+
181
+ completion_len = max(0, len(ids) - prompt_len)
182
+ if completion_len == 0 or not behavior_logprobs:
183
+ continue
184
+
185
+ if teacher_logprobs:
186
+ n = min(len(teacher_logprobs), len(behavior_logprobs), completion_len)
187
+ if n <= 0:
188
+ continue
189
+ if completion_len != n:
190
+ ids = ids[: prompt_len + n]
191
+ completion_len = n
192
+ behavior_logprobs = behavior_logprobs[:n]
193
+ teacher_logprobs = teacher_logprobs[:n]
194
+ behavior_logprob = sum(behavior_logprobs)
195
+ advantage = sum(t - b for t, b in zip(teacher_logprobs, behavior_logprobs)) / float(n)
196
+ else:
197
+ if teacher_avg is None:
198
+ continue
199
+ if len(behavior_logprobs) > completion_len:
200
+ behavior_logprobs = behavior_logprobs[:completion_len]
201
+ behavior_logprob = sum(behavior_logprobs)
202
+ if not behavior_logprobs:
203
+ continue
204
+ behavior_avg = behavior_logprob / float(len(behavior_logprobs))
205
+ advantage = teacher_avg - behavior_avg
206
+
207
+ def loss_fn(_model):
208
+ return importance_sampling_loss(
209
+ student,
210
+ ids,
211
+ prompt_len=prompt_len,
212
+ advantage=advantage,
213
+ behavior_logprob=behavior_logprob,
214
+ )
215
+
216
+ lval, grads = student.value_and_grad(loss_fn)
217
+ if grads is not None:
218
+ accum_grads = tree_add(accum_grads, grads)
219
+
220
+ if step % grad_accum == 0:
221
+ if accum_grads is not None:
222
+ student.apply_grads(opt, tree_scale(accum_grads, 1.0 / grad_accum))
223
+ accum_grads = None
224
+
225
+ if step % cfg.train.log_every == 0 or step == 1 or step == total:
226
+ write_jsonl(
227
+ run.metrics_path,
228
+ [
229
+ {
230
+ "ts": now_ts(),
231
+ "step": step,
232
+ "kind": "distill_opd",
233
+ "loss": _to_float(lval),
234
+ "advantage": float(advantage),
235
+ "accel": backend.name,
236
+ }
237
+ ],
238
+ )
239
+
240
+ if step % cfg.train.save_every == 0 or step == total:
241
+ student.save_adapter(
242
+ str(run.adapter_dir),
243
+ metadata={
244
+ "base_model": base_student,
245
+ "source_adapter": str(adapter_path) if adapter_path else None,
246
+ "run": run.run_dir.name,
247
+ "kind": "distill_opd",
248
+ },
249
+ )
250
+ child = run
251
+ else:
252
+ rows = []
253
+ for prompt in prompts:
254
+ gen = teacher.generate(
255
+ prompt,
256
+ max_new_tokens=max_new_tokens,
257
+ temperature=temperature,
258
+ )
259
+ completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
260
+ rows.append({"prompt": prompt, "response": completion})
261
+ write_jsonl(train_path, rows)
262
+ console.print(f"[green]Distill dataset (offline)[/green] {train_path}")
263
+ child = run_sft(project_root, cfg, distill_dir, student_model, accel)
264
+
265
+ write_jsonl(
266
+ run.metrics_path,
267
+ [
268
+ {
269
+ "ts": now_ts(),
270
+ "kind": "distill",
271
+ "mode": mode,
272
+ "teacher": teacher_model,
273
+ "student": student_model,
274
+ "child_run": str(child.run_dir),
275
+ "samples": len(prompts),
276
+ }
277
+ ],
278
+ )
279
+ return run