mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/server.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
"""MLXSmith FastAPI server with OpenAI-compatible chat completions.
|
|
2
|
+
|
|
3
|
+
This module provides:
|
|
4
|
+
- OpenAI-compatible chat completions with streaming support
|
|
5
|
+
- Internal rollout endpoint (tokens + logprobs)
|
|
6
|
+
- Adapter hot-reload
|
|
7
|
+
- RLM state and history endpoints
|
|
8
|
+
- Web UI and RLM monitor (when enabled)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import time
|
|
15
|
+
import uuid
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any, Dict, List, Optional
|
|
18
|
+
|
|
19
|
+
from fastapi import FastAPI
|
|
20
|
+
from fastapi.responses import StreamingResponse, HTMLResponse
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
|
|
23
|
+
from .config import ProjectConfig
|
|
24
|
+
from .models import resolve_model_spec
|
|
25
|
+
from .llm.registry import get_llm_backend
|
|
26
|
+
|
|
27
|
+
# Import new API handlers
|
|
28
|
+
from .api.handlers import create_router, InternalAuthMiddleware
|
|
29
|
+
from .api.schemas import (
|
|
30
|
+
ChatMessage,
|
|
31
|
+
ChatRequest,
|
|
32
|
+
ChatResponse,
|
|
33
|
+
RolloutRequest,
|
|
34
|
+
RolloutResponse,
|
|
35
|
+
AdapterReloadRequest,
|
|
36
|
+
AdapterReloadResponse,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# Re-export schemas for backward compatibility
|
|
40
|
+
__all__ = [
|
|
41
|
+
"create_app",
|
|
42
|
+
"ChatMessage",
|
|
43
|
+
"ChatRequest",
|
|
44
|
+
"ChatResponse",
|
|
45
|
+
"RolloutRequest",
|
|
46
|
+
"RolloutResponse",
|
|
47
|
+
"AdapterReloadRequest",
|
|
48
|
+
"AdapterReloadResponse",
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _ui_html() -> str:
|
|
53
|
+
return """<!doctype html>
|
|
54
|
+
<html lang="en">
|
|
55
|
+
<head>
|
|
56
|
+
<meta charset="utf-8" />
|
|
57
|
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
|
58
|
+
<title>mlxsmith serve</title>
|
|
59
|
+
<style>
|
|
60
|
+
:root {
|
|
61
|
+
--bg: #f7f2ea;
|
|
62
|
+
--ink: #231f20;
|
|
63
|
+
--accent: #0b6e4f;
|
|
64
|
+
--accent-2: #b44a1d;
|
|
65
|
+
--card: #fff8ee;
|
|
66
|
+
}
|
|
67
|
+
* { box-sizing: border-box; }
|
|
68
|
+
body {
|
|
69
|
+
margin: 0;
|
|
70
|
+
font-family: "Space Grotesk", "Avenir Next", "Segoe UI", sans-serif;
|
|
71
|
+
color: var(--ink);
|
|
72
|
+
background:
|
|
73
|
+
radial-gradient(circle at 10% 10%, rgba(11,110,79,0.12), transparent 35%),
|
|
74
|
+
radial-gradient(circle at 90% 20%, rgba(180,74,29,0.12), transparent 40%),
|
|
75
|
+
var(--bg);
|
|
76
|
+
}
|
|
77
|
+
header {
|
|
78
|
+
padding: 28px 32px;
|
|
79
|
+
border-bottom: 1px solid rgba(35,31,32,0.1);
|
|
80
|
+
display: flex;
|
|
81
|
+
align-items: center;
|
|
82
|
+
justify-content: space-between;
|
|
83
|
+
gap: 16px;
|
|
84
|
+
}
|
|
85
|
+
.title {
|
|
86
|
+
font-size: 20px;
|
|
87
|
+
letter-spacing: 0.02em;
|
|
88
|
+
text-transform: uppercase;
|
|
89
|
+
}
|
|
90
|
+
.wrap {
|
|
91
|
+
display: grid;
|
|
92
|
+
grid-template-columns: 1.2fr 0.8fr;
|
|
93
|
+
gap: 24px;
|
|
94
|
+
padding: 24px 32px 40px;
|
|
95
|
+
}
|
|
96
|
+
.card {
|
|
97
|
+
background: var(--card);
|
|
98
|
+
border: 1px solid rgba(35,31,32,0.08);
|
|
99
|
+
border-radius: 16px;
|
|
100
|
+
padding: 18px;
|
|
101
|
+
box-shadow: 0 6px 20px rgba(35,31,32,0.06);
|
|
102
|
+
}
|
|
103
|
+
h2 { margin: 0 0 12px; font-size: 18px; }
|
|
104
|
+
textarea {
|
|
105
|
+
width: 100%;
|
|
106
|
+
min-height: 120px;
|
|
107
|
+
border: 1px solid rgba(35,31,32,0.2);
|
|
108
|
+
border-radius: 12px;
|
|
109
|
+
padding: 12px;
|
|
110
|
+
font: inherit;
|
|
111
|
+
background: #fffdf9;
|
|
112
|
+
}
|
|
113
|
+
.row { display: flex; gap: 12px; margin-top: 10px; }
|
|
114
|
+
button {
|
|
115
|
+
border: 0;
|
|
116
|
+
border-radius: 10px;
|
|
117
|
+
padding: 10px 14px;
|
|
118
|
+
font: inherit;
|
|
119
|
+
cursor: pointer;
|
|
120
|
+
}
|
|
121
|
+
.btn-primary { background: var(--accent); color: #fff; }
|
|
122
|
+
.btn-secondary { background: var(--accent-2); color: #fff; }
|
|
123
|
+
.log {
|
|
124
|
+
background: #14110f;
|
|
125
|
+
color: #fdf7f0;
|
|
126
|
+
border-radius: 12px;
|
|
127
|
+
padding: 12px;
|
|
128
|
+
height: 300px;
|
|
129
|
+
overflow: auto;
|
|
130
|
+
font-family: "Iosevka", "Menlo", monospace;
|
|
131
|
+
font-size: 13px;
|
|
132
|
+
}
|
|
133
|
+
.muted { color: rgba(35,31,32,0.6); font-size: 12px; }
|
|
134
|
+
@media (max-width: 900px) {
|
|
135
|
+
.wrap { grid-template-columns: 1fr; }
|
|
136
|
+
}
|
|
137
|
+
</style>
|
|
138
|
+
</head>
|
|
139
|
+
<body>
|
|
140
|
+
<header>
|
|
141
|
+
<div class="title">mlxsmith serve</div>
|
|
142
|
+
<div class="muted">OpenAI-compatible API + RLM monitor</div>
|
|
143
|
+
</header>
|
|
144
|
+
<div class="wrap">
|
|
145
|
+
<section class="card">
|
|
146
|
+
<h2>Chat</h2>
|
|
147
|
+
<textarea id="prompt" placeholder="Ask the model anything..."></textarea>
|
|
148
|
+
<div class="row">
|
|
149
|
+
<button class="btn-primary" id="send">Send</button>
|
|
150
|
+
<button class="btn-secondary" id="stream">Stream</button>
|
|
151
|
+
<a class="btn-secondary" href="/rlm/monitor" style="text-decoration:none;">RLM Monitor</a>
|
|
152
|
+
</div>
|
|
153
|
+
<div class="log" id="output"></div>
|
|
154
|
+
</section>
|
|
155
|
+
<section class="card">
|
|
156
|
+
<h2>Status</h2>
|
|
157
|
+
<div id="status" class="muted">Loading…</div>
|
|
158
|
+
<div class="muted" style="margin-top:12px;">Tip: enable <code>serve.ui</code> in mlxsmith.yaml to keep this page on by default.</div>
|
|
159
|
+
</section>
|
|
160
|
+
</div>
|
|
161
|
+
<script>
|
|
162
|
+
const output = document.getElementById('output');
|
|
163
|
+
const statusEl = document.getElementById('status');
|
|
164
|
+
const promptEl = document.getElementById('prompt');
|
|
165
|
+
const baseBody = () => ({
|
|
166
|
+
messages: [{role: 'user', content: promptEl.value}],
|
|
167
|
+
max_tokens: 256
|
|
168
|
+
});
|
|
169
|
+
|
|
170
|
+
async function refreshStatus() {
|
|
171
|
+
try {
|
|
172
|
+
const res = await fetch('/internal/rlm/state');
|
|
173
|
+
if (!res.ok) return;
|
|
174
|
+
const data = await res.json();
|
|
175
|
+
statusEl.textContent = JSON.stringify(data, null, 2);
|
|
176
|
+
} catch (err) {
|
|
177
|
+
statusEl.textContent = 'Status unavailable';
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
refreshStatus();
|
|
181
|
+
setInterval(refreshStatus, 5000);
|
|
182
|
+
|
|
183
|
+
document.getElementById('send').onclick = async () => {
|
|
184
|
+
output.textContent = '';
|
|
185
|
+
const res = await fetch('/v1/chat/completions', {
|
|
186
|
+
method: 'POST',
|
|
187
|
+
headers: {'Content-Type': 'application/json'},
|
|
188
|
+
body: JSON.stringify(baseBody())
|
|
189
|
+
});
|
|
190
|
+
const data = await res.json();
|
|
191
|
+
output.textContent = data.choices?.[0]?.message?.content || '';
|
|
192
|
+
};
|
|
193
|
+
|
|
194
|
+
document.getElementById('stream').onclick = async () => {
|
|
195
|
+
output.textContent = '';
|
|
196
|
+
const body = baseBody();
|
|
197
|
+
body.stream = true;
|
|
198
|
+
const res = await fetch('/v1/chat/completions', {
|
|
199
|
+
method: 'POST',
|
|
200
|
+
headers: {'Content-Type': 'application/json'},
|
|
201
|
+
body: JSON.stringify(body)
|
|
202
|
+
});
|
|
203
|
+
const reader = res.body.getReader();
|
|
204
|
+
const decoder = new TextDecoder('utf-8');
|
|
205
|
+
let buf = '';
|
|
206
|
+
while (true) {
|
|
207
|
+
const { value, done } = await reader.read();
|
|
208
|
+
if (done) break;
|
|
209
|
+
buf += decoder.decode(value, {stream: true});
|
|
210
|
+
const parts = buf.split('\\n\\n');
|
|
211
|
+
buf = parts.pop() || '';
|
|
212
|
+
for (const part of parts) {
|
|
213
|
+
const line = part.trim();
|
|
214
|
+
if (!line.startsWith('data:')) continue;
|
|
215
|
+
const payload = line.replace('data:', '').trim();
|
|
216
|
+
if (payload === '[DONE]') return;
|
|
217
|
+
try {
|
|
218
|
+
const obj = JSON.parse(payload);
|
|
219
|
+
const delta = obj.choices?.[0]?.delta?.content || '';
|
|
220
|
+
output.textContent += delta;
|
|
221
|
+
} catch (e) {}
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
};
|
|
225
|
+
</script>
|
|
226
|
+
</body>
|
|
227
|
+
</html>
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _monitor_html() -> str:
|
|
232
|
+
return """<!doctype html>
|
|
233
|
+
<html lang="en">
|
|
234
|
+
<head>
|
|
235
|
+
<meta charset="utf-8" />
|
|
236
|
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
|
237
|
+
<title>RLM Monitor</title>
|
|
238
|
+
<style>
|
|
239
|
+
body {
|
|
240
|
+
margin: 0;
|
|
241
|
+
font-family: "Space Grotesk", "Avenir Next", "Segoe UI", sans-serif;
|
|
242
|
+
background: #f4f6f1;
|
|
243
|
+
color: #222;
|
|
244
|
+
}
|
|
245
|
+
header {
|
|
246
|
+
padding: 24px 32px;
|
|
247
|
+
background: #1f3b2c;
|
|
248
|
+
color: #f4f6f1;
|
|
249
|
+
display: flex;
|
|
250
|
+
justify-content: space-between;
|
|
251
|
+
align-items: center;
|
|
252
|
+
}
|
|
253
|
+
.wrap { padding: 24px 32px; display: grid; gap: 16px; }
|
|
254
|
+
.card {
|
|
255
|
+
background: #fff;
|
|
256
|
+
border-radius: 14px;
|
|
257
|
+
padding: 16px;
|
|
258
|
+
box-shadow: 0 8px 20px rgba(0,0,0,0.08);
|
|
259
|
+
}
|
|
260
|
+
canvas { width: 100%; height: 220px; }
|
|
261
|
+
pre { margin: 0; font-size: 12px; }
|
|
262
|
+
</style>
|
|
263
|
+
</head>
|
|
264
|
+
<body>
|
|
265
|
+
<header>
|
|
266
|
+
<div>RLM Monitor</div>
|
|
267
|
+
<a href="/" style="color:#f4f6f1;text-decoration:none;">Back to Serve</a>
|
|
268
|
+
</header>
|
|
269
|
+
<div class="wrap">
|
|
270
|
+
<section class="card">
|
|
271
|
+
<canvas id="chart" width="900" height="220"></canvas>
|
|
272
|
+
</section>
|
|
273
|
+
<section class="card">
|
|
274
|
+
<pre id="state">Loading…</pre>
|
|
275
|
+
</section>
|
|
276
|
+
</div>
|
|
277
|
+
<script>
|
|
278
|
+
const canvas = document.getElementById('chart');
|
|
279
|
+
const ctx = canvas.getContext('2d');
|
|
280
|
+
const stateEl = document.getElementById('state');
|
|
281
|
+
|
|
282
|
+
function draw(history) {
|
|
283
|
+
ctx.clearRect(0,0,canvas.width,canvas.height);
|
|
284
|
+
if (!history.length) return;
|
|
285
|
+
const scores = history.map(h => h.adapter_score || 0);
|
|
286
|
+
const max = Math.max(...scores, 1);
|
|
287
|
+
const min = Math.min(...scores, 0);
|
|
288
|
+
const pad = 20;
|
|
289
|
+
const w = canvas.width - pad * 2;
|
|
290
|
+
const h = canvas.height - pad * 2;
|
|
291
|
+
ctx.strokeStyle = '#1f3b2c';
|
|
292
|
+
ctx.lineWidth = 2;
|
|
293
|
+
ctx.beginPath();
|
|
294
|
+
scores.forEach((s, i) => {
|
|
295
|
+
const x = pad + (i / Math.max(1, scores.length - 1)) * w;
|
|
296
|
+
const y = pad + (1 - (s - min) / (max - min || 1)) * h;
|
|
297
|
+
if (i === 0) ctx.moveTo(x, y); else ctx.lineTo(x, y);
|
|
298
|
+
});
|
|
299
|
+
ctx.stroke();
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
async function refresh() {
|
|
303
|
+
try {
|
|
304
|
+
const h = await fetch('/internal/rlm/history').then(r => r.json());
|
|
305
|
+
draw(h);
|
|
306
|
+
const s = await fetch('/internal/rlm/state').then(r => r.json());
|
|
307
|
+
stateEl.textContent = JSON.stringify(s, null, 2);
|
|
308
|
+
} catch (e) {
|
|
309
|
+
stateEl.textContent = 'History unavailable';
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
refresh();
|
|
313
|
+
setInterval(refresh, 5000);
|
|
314
|
+
</script>
|
|
315
|
+
</body>
|
|
316
|
+
</html>
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def create_app(model_spec: str, cfg: ProjectConfig) -> FastAPI:
|
|
321
|
+
"""Create and configure the FastAPI application.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
model_spec: Model specification (path or HF repo ID)
|
|
325
|
+
cfg: Project configuration
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
Configured FastAPI application
|
|
329
|
+
"""
|
|
330
|
+
app = FastAPI(
|
|
331
|
+
title="mlxsmith",
|
|
332
|
+
description="MLXSmith API server for local LLM inference and RLM training",
|
|
333
|
+
version="0.1.0",
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# Load LLM backend
|
|
337
|
+
llm = get_llm_backend(cfg.model.backend)
|
|
338
|
+
base_model, adapter_path, _meta = resolve_model_spec(Path.cwd(), model_spec, cfg)
|
|
339
|
+
llm.load(
|
|
340
|
+
base_model,
|
|
341
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
342
|
+
dtype=cfg.model.dtype,
|
|
343
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
344
|
+
)
|
|
345
|
+
if adapter_path:
|
|
346
|
+
llm.apply_adapter(str(adapter_path))
|
|
347
|
+
current_adapter = str(adapter_path) if adapter_path else None
|
|
348
|
+
|
|
349
|
+
# Add authentication middleware for internal endpoints
|
|
350
|
+
app.add_middleware(
|
|
351
|
+
InternalAuthMiddleware,
|
|
352
|
+
api_token=None, # Set via MLXSMITH_API_TOKEN env var
|
|
353
|
+
internal_prefix="/internal",
|
|
354
|
+
public_paths=["/health", "/v1/chat/completions", "/docs", "/openapi.json"],
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Create and include the API router
|
|
358
|
+
router = create_router(
|
|
359
|
+
llm_backend=llm,
|
|
360
|
+
base_model=base_model,
|
|
361
|
+
current_adapter=current_adapter,
|
|
362
|
+
cfg=cfg,
|
|
363
|
+
)
|
|
364
|
+
app.include_router(router)
|
|
365
|
+
|
|
366
|
+
# Add UI routes if enabled
|
|
367
|
+
if cfg.serve.ui:
|
|
368
|
+
@app.get("/")
|
|
369
|
+
def ui_root():
|
|
370
|
+
return HTMLResponse(_ui_html())
|
|
371
|
+
|
|
372
|
+
@app.get("/rlm/monitor")
|
|
373
|
+
def ui_monitor():
|
|
374
|
+
return HTMLResponse(_monitor_html())
|
|
375
|
+
|
|
376
|
+
return app
|
|
File without changes
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import random
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Iterable
|
|
7
|
+
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
|
|
10
|
+
from ..accel import get_backend
|
|
11
|
+
from ..config import ProjectConfig
|
|
12
|
+
from ..llm.interface import compute_logprobs
|
|
13
|
+
from ..models import resolve_model_spec
|
|
14
|
+
from ..runs import RunPaths, new_run, snapshot_config
|
|
15
|
+
from ..llm.registry import get_llm_backend
|
|
16
|
+
from ..llm.backend import BackendNotAvailable
|
|
17
|
+
from ..sdk.losses import importance_sampling_loss
|
|
18
|
+
from ..train.lora import LoRAConfig
|
|
19
|
+
from ..util import ensure_dir, write_jsonl, now_ts, tree_add, tree_scale
|
|
20
|
+
from .sft import run_sft
|
|
21
|
+
|
|
22
|
+
console = Console()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _iter_prompts(path: Path) -> Iterable[str]:
|
|
26
|
+
for line in path.read_text(encoding="utf-8").splitlines():
|
|
27
|
+
if not line.strip():
|
|
28
|
+
continue
|
|
29
|
+
row = json.loads(line)
|
|
30
|
+
prompt = row.get("prompt") or row.get("instruction") or row.get("input") or ""
|
|
31
|
+
if not prompt and "messages" in row:
|
|
32
|
+
msgs = row.get("messages") or []
|
|
33
|
+
if msgs:
|
|
34
|
+
prompt = "\n".join([m.get("content", "") for m in msgs])
|
|
35
|
+
if prompt:
|
|
36
|
+
yield str(prompt)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def run_distill(
|
|
40
|
+
project_root: Path,
|
|
41
|
+
cfg: ProjectConfig,
|
|
42
|
+
data_path: Path,
|
|
43
|
+
*,
|
|
44
|
+
teacher_model: str,
|
|
45
|
+
student_model: str,
|
|
46
|
+
accel: str,
|
|
47
|
+
mode: str = "offline",
|
|
48
|
+
max_new_tokens: int = 256,
|
|
49
|
+
temperature: float = 0.7,
|
|
50
|
+
) -> RunPaths:
|
|
51
|
+
run = new_run(project_root, "distill")
|
|
52
|
+
snapshot_config(cfg.model_dump(), run.config_snapshot_path)
|
|
53
|
+
|
|
54
|
+
mode = mode.lower()
|
|
55
|
+
prompts = list(_iter_prompts(data_path))
|
|
56
|
+
if not prompts:
|
|
57
|
+
raise RuntimeError("No prompts found in distillation dataset")
|
|
58
|
+
|
|
59
|
+
backend = get_backend(accel)
|
|
60
|
+
backend.patch()
|
|
61
|
+
console.print(f"[bold]DISTILL[/bold] run: {run.run_dir.name} mode={mode} accel={backend.name}")
|
|
62
|
+
|
|
63
|
+
teacher = get_llm_backend(cfg.model.backend)
|
|
64
|
+
student = get_llm_backend(cfg.model.backend)
|
|
65
|
+
|
|
66
|
+
base_teacher, _, _meta_t = resolve_model_spec(project_root, teacher_model, cfg)
|
|
67
|
+
base_student, adapter_path, _meta_s = resolve_model_spec(project_root, student_model, cfg)
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
teacher.load(
|
|
71
|
+
base_teacher,
|
|
72
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
73
|
+
dtype=cfg.model.dtype,
|
|
74
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
75
|
+
)
|
|
76
|
+
student.load(
|
|
77
|
+
base_student,
|
|
78
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
79
|
+
dtype=cfg.model.dtype,
|
|
80
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
81
|
+
)
|
|
82
|
+
except BackendNotAvailable as e:
|
|
83
|
+
console.print(f"[yellow]MLX backend unavailable[/yellow]: {e}")
|
|
84
|
+
(run.adapter_dir / "ADAPTER.txt").write_text(
|
|
85
|
+
f"Backend unavailable in this environment.\nteacher={teacher_model}\nstudent={student_model}\n",
|
|
86
|
+
encoding="utf-8",
|
|
87
|
+
)
|
|
88
|
+
return run
|
|
89
|
+
|
|
90
|
+
distill_dir = ensure_dir(run.artifacts_dir / "distill_data")
|
|
91
|
+
train_path = distill_dir / "train.jsonl"
|
|
92
|
+
|
|
93
|
+
if mode == "opd":
|
|
94
|
+
rows = [{"prompt": prompt} for prompt in prompts]
|
|
95
|
+
write_jsonl(train_path, rows)
|
|
96
|
+
console.print(f"[green]Distill dataset (OPD)[/green] {train_path}")
|
|
97
|
+
|
|
98
|
+
if adapter_path:
|
|
99
|
+
student.apply_adapter(str(adapter_path))
|
|
100
|
+
else:
|
|
101
|
+
lora_cfg = LoRAConfig(
|
|
102
|
+
r=cfg.lora.r,
|
|
103
|
+
alpha=cfg.lora.alpha,
|
|
104
|
+
dropout=cfg.lora.dropout,
|
|
105
|
+
target_modules=list(cfg.lora.target_modules or []),
|
|
106
|
+
num_layers=cfg.lora.num_layers,
|
|
107
|
+
scale=cfg.lora.scale,
|
|
108
|
+
fine_tune_type=cfg.lora.fine_tune_type,
|
|
109
|
+
)
|
|
110
|
+
student.apply_lora_from_config(lora_cfg)
|
|
111
|
+
|
|
112
|
+
opt, _params = student.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
|
|
113
|
+
|
|
114
|
+
rng = random.Random(cfg.train.seed)
|
|
115
|
+
total = int(cfg.train.iters)
|
|
116
|
+
grad_accum = max(1, int(cfg.train.grad_accum))
|
|
117
|
+
max_len = int(cfg.model.max_seq_len)
|
|
118
|
+
accum_grads = None
|
|
119
|
+
|
|
120
|
+
def _to_float(val):
|
|
121
|
+
if hasattr(val, "item"):
|
|
122
|
+
try:
|
|
123
|
+
return float(val.item())
|
|
124
|
+
except Exception:
|
|
125
|
+
pass
|
|
126
|
+
return float(val)
|
|
127
|
+
|
|
128
|
+
for step in range(1, total + 1):
|
|
129
|
+
prompt = rng.choice(prompts)
|
|
130
|
+
gen = student.generate_with_logprobs(
|
|
131
|
+
prompt,
|
|
132
|
+
max_new_tokens=max_new_tokens,
|
|
133
|
+
temperature=temperature,
|
|
134
|
+
logprobs=0,
|
|
135
|
+
)
|
|
136
|
+
completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
|
|
137
|
+
behavior_logprobs = list(gen.logprobs) if gen.logprobs else []
|
|
138
|
+
if not behavior_logprobs:
|
|
139
|
+
continue
|
|
140
|
+
|
|
141
|
+
teacher_logprobs: list[float] = []
|
|
142
|
+
teacher_avg = None
|
|
143
|
+
teacher_res = compute_logprobs(
|
|
144
|
+
teacher,
|
|
145
|
+
prompt,
|
|
146
|
+
completion,
|
|
147
|
+
top_k=0,
|
|
148
|
+
max_seq_len=max_len,
|
|
149
|
+
)
|
|
150
|
+
if teacher_res.token_logprobs:
|
|
151
|
+
teacher_logprobs = [float(lp) for lp in teacher_res.token_logprobs]
|
|
152
|
+
else:
|
|
153
|
+
prompt_ids = teacher.encode(prompt)
|
|
154
|
+
ids = teacher.encode(prompt + completion)
|
|
155
|
+
if max_len and len(ids) > max_len:
|
|
156
|
+
overflow = len(ids) - max_len
|
|
157
|
+
ids = ids[overflow:]
|
|
158
|
+
prompt_len = max(0, len(prompt_ids) - overflow)
|
|
159
|
+
else:
|
|
160
|
+
prompt_len = len(prompt_ids)
|
|
161
|
+
teacher_len = max(0, len(ids) - prompt_len)
|
|
162
|
+
if teacher_len > 0:
|
|
163
|
+
teacher_total = teacher.sequence_logprob(ids, prompt_len=prompt_len)
|
|
164
|
+
teacher_avg = _to_float(teacher_total) / float(teacher_len)
|
|
165
|
+
|
|
166
|
+
ids = list(gen.token_ids)
|
|
167
|
+
prompt_len = gen.prompt_len
|
|
168
|
+
if max_len and len(ids) > max_len:
|
|
169
|
+
overflow = len(ids) - max_len
|
|
170
|
+
ids = ids[overflow:]
|
|
171
|
+
if overflow >= prompt_len:
|
|
172
|
+
removed_completion = overflow - prompt_len
|
|
173
|
+
prompt_len = 0
|
|
174
|
+
if removed_completion > 0:
|
|
175
|
+
behavior_logprobs = behavior_logprobs[removed_completion:]
|
|
176
|
+
if teacher_logprobs:
|
|
177
|
+
teacher_logprobs = teacher_logprobs[removed_completion:]
|
|
178
|
+
else:
|
|
179
|
+
prompt_len = max(0, prompt_len - overflow)
|
|
180
|
+
|
|
181
|
+
completion_len = max(0, len(ids) - prompt_len)
|
|
182
|
+
if completion_len == 0 or not behavior_logprobs:
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
if teacher_logprobs:
|
|
186
|
+
n = min(len(teacher_logprobs), len(behavior_logprobs), completion_len)
|
|
187
|
+
if n <= 0:
|
|
188
|
+
continue
|
|
189
|
+
if completion_len != n:
|
|
190
|
+
ids = ids[: prompt_len + n]
|
|
191
|
+
completion_len = n
|
|
192
|
+
behavior_logprobs = behavior_logprobs[:n]
|
|
193
|
+
teacher_logprobs = teacher_logprobs[:n]
|
|
194
|
+
behavior_logprob = sum(behavior_logprobs)
|
|
195
|
+
advantage = sum(t - b for t, b in zip(teacher_logprobs, behavior_logprobs)) / float(n)
|
|
196
|
+
else:
|
|
197
|
+
if teacher_avg is None:
|
|
198
|
+
continue
|
|
199
|
+
if len(behavior_logprobs) > completion_len:
|
|
200
|
+
behavior_logprobs = behavior_logprobs[:completion_len]
|
|
201
|
+
behavior_logprob = sum(behavior_logprobs)
|
|
202
|
+
if not behavior_logprobs:
|
|
203
|
+
continue
|
|
204
|
+
behavior_avg = behavior_logprob / float(len(behavior_logprobs))
|
|
205
|
+
advantage = teacher_avg - behavior_avg
|
|
206
|
+
|
|
207
|
+
def loss_fn(_model):
|
|
208
|
+
return importance_sampling_loss(
|
|
209
|
+
student,
|
|
210
|
+
ids,
|
|
211
|
+
prompt_len=prompt_len,
|
|
212
|
+
advantage=advantage,
|
|
213
|
+
behavior_logprob=behavior_logprob,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
lval, grads = student.value_and_grad(loss_fn)
|
|
217
|
+
if grads is not None:
|
|
218
|
+
accum_grads = tree_add(accum_grads, grads)
|
|
219
|
+
|
|
220
|
+
if step % grad_accum == 0:
|
|
221
|
+
if accum_grads is not None:
|
|
222
|
+
student.apply_grads(opt, tree_scale(accum_grads, 1.0 / grad_accum))
|
|
223
|
+
accum_grads = None
|
|
224
|
+
|
|
225
|
+
if step % cfg.train.log_every == 0 or step == 1 or step == total:
|
|
226
|
+
write_jsonl(
|
|
227
|
+
run.metrics_path,
|
|
228
|
+
[
|
|
229
|
+
{
|
|
230
|
+
"ts": now_ts(),
|
|
231
|
+
"step": step,
|
|
232
|
+
"kind": "distill_opd",
|
|
233
|
+
"loss": _to_float(lval),
|
|
234
|
+
"advantage": float(advantage),
|
|
235
|
+
"accel": backend.name,
|
|
236
|
+
}
|
|
237
|
+
],
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
if step % cfg.train.save_every == 0 or step == total:
|
|
241
|
+
student.save_adapter(
|
|
242
|
+
str(run.adapter_dir),
|
|
243
|
+
metadata={
|
|
244
|
+
"base_model": base_student,
|
|
245
|
+
"source_adapter": str(adapter_path) if adapter_path else None,
|
|
246
|
+
"run": run.run_dir.name,
|
|
247
|
+
"kind": "distill_opd",
|
|
248
|
+
},
|
|
249
|
+
)
|
|
250
|
+
child = run
|
|
251
|
+
else:
|
|
252
|
+
rows = []
|
|
253
|
+
for prompt in prompts:
|
|
254
|
+
gen = teacher.generate(
|
|
255
|
+
prompt,
|
|
256
|
+
max_new_tokens=max_new_tokens,
|
|
257
|
+
temperature=temperature,
|
|
258
|
+
)
|
|
259
|
+
completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
|
|
260
|
+
rows.append({"prompt": prompt, "response": completion})
|
|
261
|
+
write_jsonl(train_path, rows)
|
|
262
|
+
console.print(f"[green]Distill dataset (offline)[/green] {train_path}")
|
|
263
|
+
child = run_sft(project_root, cfg, distill_dir, student_model, accel)
|
|
264
|
+
|
|
265
|
+
write_jsonl(
|
|
266
|
+
run.metrics_path,
|
|
267
|
+
[
|
|
268
|
+
{
|
|
269
|
+
"ts": now_ts(),
|
|
270
|
+
"kind": "distill",
|
|
271
|
+
"mode": mode,
|
|
272
|
+
"teacher": teacher_model,
|
|
273
|
+
"student": student_model,
|
|
274
|
+
"child_run": str(child.run_dir),
|
|
275
|
+
"samples": len(prompts),
|
|
276
|
+
}
|
|
277
|
+
],
|
|
278
|
+
)
|
|
279
|
+
return run
|