mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/api/handlers.py
ADDED
|
@@ -0,0 +1,1217 @@
|
|
|
1
|
+
"""FastAPI handlers for MLXSmith API.
|
|
2
|
+
|
|
3
|
+
Implements endpoints for:
|
|
4
|
+
- OpenAI-compatible chat completions with streaming
|
|
5
|
+
- Internal rollout (tokens + logprobs)
|
|
6
|
+
- Training operations (forward/backward, optim_step, save/load state)
|
|
7
|
+
- Adapter hot-reload
|
|
8
|
+
- RLM state and history
|
|
9
|
+
- Model management
|
|
10
|
+
- HF token storage
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import os
|
|
17
|
+
import secrets
|
|
18
|
+
import time
|
|
19
|
+
import uuid
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional
|
|
22
|
+
|
|
23
|
+
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Security, status
|
|
24
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
25
|
+
from fastapi.responses import StreamingResponse
|
|
26
|
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
27
|
+
from pydantic import BaseModel
|
|
28
|
+
|
|
29
|
+
from .schemas import (
|
|
30
|
+
AdapterReloadRequest,
|
|
31
|
+
AdapterReloadResponse,
|
|
32
|
+
ChatCompletionChunk,
|
|
33
|
+
ChatMessage,
|
|
34
|
+
ChatRequest,
|
|
35
|
+
ChatResponse,
|
|
36
|
+
Choice,
|
|
37
|
+
ChoiceLogprobs,
|
|
38
|
+
DeltaMessage,
|
|
39
|
+
ErrorResponse,
|
|
40
|
+
ForwardBackwardRequest,
|
|
41
|
+
ForwardBackwardResponse,
|
|
42
|
+
GetWeightsResponse,
|
|
43
|
+
HealthResponse,
|
|
44
|
+
HFTokenRequest,
|
|
45
|
+
HFTokenResponse,
|
|
46
|
+
LoadStateRequest,
|
|
47
|
+
LoadStateResponse,
|
|
48
|
+
LogprobsContent,
|
|
49
|
+
ModelInfo,
|
|
50
|
+
ModelsListResponse,
|
|
51
|
+
ModelPullRequest,
|
|
52
|
+
ModelPullResponse,
|
|
53
|
+
ModelPullStatus,
|
|
54
|
+
OptimStepRequest,
|
|
55
|
+
OptimStepResponse,
|
|
56
|
+
RolloutRequest,
|
|
57
|
+
RolloutResponse,
|
|
58
|
+
RLMHistoryEntry,
|
|
59
|
+
RLMState,
|
|
60
|
+
SaveStateRequest,
|
|
61
|
+
SaveStateResponse,
|
|
62
|
+
SetWeightsRequest,
|
|
63
|
+
SetWeightsResponse,
|
|
64
|
+
StreamChoice,
|
|
65
|
+
UsageInfo,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# =============================================================================
|
|
69
|
+
# Authentication Middleware
|
|
70
|
+
# =============================================================================
|
|
71
|
+
|
|
72
|
+
class InternalAuthMiddleware(BaseHTTPMiddleware):
|
|
73
|
+
"""Middleware for authenticating internal endpoints.
|
|
74
|
+
|
|
75
|
+
Checks for a valid API token on internal endpoints.
|
|
76
|
+
Public endpoints (health, chat completions) bypass authentication.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
app: FastAPI,
|
|
82
|
+
api_token: Optional[str] = None,
|
|
83
|
+
internal_prefix: str = "/internal",
|
|
84
|
+
public_paths: Optional[List[str]] = None,
|
|
85
|
+
):
|
|
86
|
+
super().__init__(app)
|
|
87
|
+
self.api_token = api_token or os.environ.get("MLXSMITH_API_TOKEN")
|
|
88
|
+
self.internal_prefix = internal_prefix
|
|
89
|
+
self.public_paths = set(public_paths or ["/health", "/v1/chat/completions"])
|
|
90
|
+
self.security = HTTPBearer(auto_error=False)
|
|
91
|
+
|
|
92
|
+
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
|
93
|
+
path = request.url.path
|
|
94
|
+
|
|
95
|
+
# Skip auth for public paths
|
|
96
|
+
if path in self.public_paths:
|
|
97
|
+
return await call_next(request)
|
|
98
|
+
|
|
99
|
+
# Skip auth for non-internal paths
|
|
100
|
+
if not path.startswith(self.internal_prefix):
|
|
101
|
+
return await call_next(request)
|
|
102
|
+
|
|
103
|
+
# If no token configured, allow all (development mode)
|
|
104
|
+
if not self.api_token:
|
|
105
|
+
return await call_next(request)
|
|
106
|
+
|
|
107
|
+
# Check Authorization header
|
|
108
|
+
auth_header = request.headers.get("authorization", "")
|
|
109
|
+
if not auth_header.startswith("Bearer "):
|
|
110
|
+
return self._unauthorized("Missing or invalid authorization header")
|
|
111
|
+
|
|
112
|
+
token = auth_header[7:] # Remove "Bearer " prefix
|
|
113
|
+
if not secrets.compare_digest(token, self.api_token):
|
|
114
|
+
return self._unauthorized("Invalid API token")
|
|
115
|
+
|
|
116
|
+
return await call_next(request)
|
|
117
|
+
|
|
118
|
+
def _unauthorized(self, detail: str) -> Any:
|
|
119
|
+
from fastapi.responses import JSONResponse
|
|
120
|
+
return JSONResponse(
|
|
121
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
122
|
+
content={"error": "Unauthorized", "detail": detail},
|
|
123
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def verify_internal_token(
|
|
128
|
+
credentials: HTTPAuthorizationCredentials = Security(HTTPBearer()),
|
|
129
|
+
expected_token: Optional[str] = None,
|
|
130
|
+
) -> bool:
|
|
131
|
+
"""Dependency for verifying internal endpoint tokens.
|
|
132
|
+
|
|
133
|
+
Usage:
|
|
134
|
+
@router.get("/internal/protected", dependencies=[Depends(verify_internal_token)])
|
|
135
|
+
"""
|
|
136
|
+
token = expected_token or os.environ.get("MLXSMITH_API_TOKEN")
|
|
137
|
+
if not token:
|
|
138
|
+
return True # Development mode - no token configured
|
|
139
|
+
|
|
140
|
+
if not credentials or not credentials.credentials:
|
|
141
|
+
raise HTTPException(
|
|
142
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
143
|
+
detail="Missing authorization token",
|
|
144
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if not secrets.compare_digest(credentials.credentials, token):
|
|
148
|
+
raise HTTPException(
|
|
149
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
150
|
+
detail="Invalid API token",
|
|
151
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return True
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# =============================================================================
|
|
158
|
+
# Helper Functions
|
|
159
|
+
# =============================================================================
|
|
160
|
+
|
|
161
|
+
def _messages_to_prompt(
|
|
162
|
+
messages: List[Any],
|
|
163
|
+
tokenizer: Any,
|
|
164
|
+
*,
|
|
165
|
+
use_chat_template: bool = True
|
|
166
|
+
) -> str:
|
|
167
|
+
"""Convert chat messages to prompt string."""
|
|
168
|
+
if use_chat_template and hasattr(tokenizer, "apply_chat_template"):
|
|
169
|
+
msgs = [{"role": m.role, "content": m.content} for m in messages]
|
|
170
|
+
return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
|
171
|
+
# Fallback
|
|
172
|
+
return "\n".join([f"{m.role}: {m.content}" for m in messages]) + "\nassistant:"
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _truncate_stop(text: str, stop: Optional[List[str]]) -> str:
|
|
176
|
+
"""Truncate text at first stop sequence."""
|
|
177
|
+
if not stop:
|
|
178
|
+
return text
|
|
179
|
+
idx = None
|
|
180
|
+
for s in stop:
|
|
181
|
+
if not s:
|
|
182
|
+
continue
|
|
183
|
+
pos = text.find(s)
|
|
184
|
+
if pos != -1:
|
|
185
|
+
idx = pos if idx is None else min(idx, pos)
|
|
186
|
+
return text if idx is None else text[:idx]
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _get_cache_dir() -> Path:
|
|
190
|
+
"""Get the cache directory for models."""
|
|
191
|
+
cache_dir = os.environ.get("MLXSMITH_CACHE_DIR")
|
|
192
|
+
if cache_dir:
|
|
193
|
+
return Path(cache_dir)
|
|
194
|
+
return Path.home() / ".cache" / "mlxsmith"
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _build_logprobs_content(
|
|
198
|
+
token_ids: List[int],
|
|
199
|
+
logprobs: List[float],
|
|
200
|
+
top_k_logprobs: Optional[List[Dict[str, float]]],
|
|
201
|
+
tokenizer: Any,
|
|
202
|
+
) -> List[LogprobsContent]:
|
|
203
|
+
"""Build LogprobsContent from token info."""
|
|
204
|
+
content = []
|
|
205
|
+
for i, (token_id, logprob) in enumerate(zip(token_ids, logprobs)):
|
|
206
|
+
try:
|
|
207
|
+
token_str = tokenizer.decode([token_id]) if tokenizer else f"<token_{token_id}>"
|
|
208
|
+
except Exception:
|
|
209
|
+
token_str = f"<token_{token_id}>"
|
|
210
|
+
|
|
211
|
+
top_logprobs = None
|
|
212
|
+
if top_k_logprobs and i < len(top_k_logprobs):
|
|
213
|
+
top_logprobs = [top_k_logprobs[i]]
|
|
214
|
+
|
|
215
|
+
content.append(LogprobsContent(
|
|
216
|
+
token=token_str,
|
|
217
|
+
logprob=logprob,
|
|
218
|
+
top_logprobs=top_logprobs,
|
|
219
|
+
))
|
|
220
|
+
return content
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
# =============================================================================
|
|
224
|
+
# Route Handlers
|
|
225
|
+
# =============================================================================
|
|
226
|
+
|
|
227
|
+
def create_router(
|
|
228
|
+
llm_backend: Any,
|
|
229
|
+
base_model: str,
|
|
230
|
+
current_adapter: Optional[str],
|
|
231
|
+
cfg: Any,
|
|
232
|
+
) -> APIRouter:
|
|
233
|
+
"""Create API router with all endpoints.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
llm_backend: The LLM backend instance
|
|
237
|
+
base_model: The base model identifier
|
|
238
|
+
current_adapter: Currently loaded adapter path (if any)
|
|
239
|
+
cfg: Project configuration
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Configured APIRouter instance
|
|
243
|
+
"""
|
|
244
|
+
router = APIRouter()
|
|
245
|
+
|
|
246
|
+
# Track adapter state (mutable reference)
|
|
247
|
+
adapter_state = {"path": current_adapter}
|
|
248
|
+
|
|
249
|
+
# Track training state
|
|
250
|
+
training_state = {
|
|
251
|
+
"step": 0,
|
|
252
|
+
"optimizer": None,
|
|
253
|
+
"learning_rate": 1e-4,
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
# ==========================================================================
|
|
257
|
+
# Health Check
|
|
258
|
+
# ==========================================================================
|
|
259
|
+
|
|
260
|
+
@router.get("/health", response_model=HealthResponse, tags=["Health"])
|
|
261
|
+
async def health() -> HealthResponse:
|
|
262
|
+
"""Health check endpoint."""
|
|
263
|
+
return HealthResponse(
|
|
264
|
+
ok=True,
|
|
265
|
+
version="0.1.0",
|
|
266
|
+
model=base_model,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# ==========================================================================
|
|
270
|
+
# Chat Completions (OpenAI-compatible)
|
|
271
|
+
# ==========================================================================
|
|
272
|
+
|
|
273
|
+
@router.post(
|
|
274
|
+
"/v1/chat/completions",
|
|
275
|
+
response_model=ChatResponse,
|
|
276
|
+
responses={
|
|
277
|
+
200: {"description": "Successful completion", "model": ChatResponse},
|
|
278
|
+
400: {"description": "Bad request", "model": ErrorResponse},
|
|
279
|
+
500: {"description": "Internal error", "model": ErrorResponse},
|
|
280
|
+
},
|
|
281
|
+
tags=["Chat"],
|
|
282
|
+
)
|
|
283
|
+
async def chat_completions(request: ChatRequest) -> ChatResponse | StreamingResponse:
|
|
284
|
+
"""OpenAI-compatible chat completions endpoint.
|
|
285
|
+
|
|
286
|
+
Supports both streaming (SSE) and non-streaming responses.
|
|
287
|
+
Supports logprobs parameter for returning token logprobs.
|
|
288
|
+
"""
|
|
289
|
+
prompt = _messages_to_prompt(
|
|
290
|
+
request.messages,
|
|
291
|
+
llm_backend.tokenizer,
|
|
292
|
+
use_chat_template=getattr(cfg.model, "use_chat_template", True)
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# Determine if we need logprobs
|
|
296
|
+
logprobs_k = request.top_logprobs or (5 if request.logprobs else 0)
|
|
297
|
+
|
|
298
|
+
# Handle streaming response
|
|
299
|
+
if request.stream:
|
|
300
|
+
async def event_stream() -> AsyncGenerator[str, None]:
|
|
301
|
+
try:
|
|
302
|
+
# Try to use mlx_lm streaming if available
|
|
303
|
+
try:
|
|
304
|
+
import mlx_lm
|
|
305
|
+
has_mlx_lm = True
|
|
306
|
+
except ImportError:
|
|
307
|
+
has_mlx_lm = False
|
|
308
|
+
|
|
309
|
+
if has_mlx_lm:
|
|
310
|
+
acc = ""
|
|
311
|
+
emitted = ""
|
|
312
|
+
for out in mlx_lm.stream_generate(
|
|
313
|
+
llm_backend.model,
|
|
314
|
+
llm_backend.tokenizer,
|
|
315
|
+
prompt,
|
|
316
|
+
max_tokens=request.max_tokens,
|
|
317
|
+
temp=request.temperature,
|
|
318
|
+
top_p=request.top_p,
|
|
319
|
+
top_k=request.top_k or 0,
|
|
320
|
+
):
|
|
321
|
+
if out.text:
|
|
322
|
+
acc += out.text
|
|
323
|
+
chunk = _truncate_stop(acc, request.stop)
|
|
324
|
+
if len(chunk) < len(emitted):
|
|
325
|
+
break
|
|
326
|
+
delta = chunk[len(emitted):]
|
|
327
|
+
emitted = chunk
|
|
328
|
+
|
|
329
|
+
chunk_data = ChatCompletionChunk(
|
|
330
|
+
id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
|
331
|
+
created=int(time.time()),
|
|
332
|
+
model=request.model or base_model,
|
|
333
|
+
choices=[StreamChoice(
|
|
334
|
+
index=0,
|
|
335
|
+
delta=DeltaMessage(content=delta),
|
|
336
|
+
finish_reason=None,
|
|
337
|
+
)],
|
|
338
|
+
)
|
|
339
|
+
yield f"data: {chunk_data.model_dump_json()}\n\n"
|
|
340
|
+
|
|
341
|
+
if request.stop and len(chunk) < len(acc):
|
|
342
|
+
break
|
|
343
|
+
if getattr(out, "finish_reason", None):
|
|
344
|
+
break
|
|
345
|
+
else:
|
|
346
|
+
# Fallback to non-streaming
|
|
347
|
+
if logprobs_k > 0 and hasattr(llm_backend, 'generate_with_logprobs'):
|
|
348
|
+
gen = llm_backend.generate_with_logprobs(
|
|
349
|
+
prompt,
|
|
350
|
+
max_new_tokens=request.max_tokens,
|
|
351
|
+
temperature=request.temperature,
|
|
352
|
+
top_p=request.top_p,
|
|
353
|
+
top_k_sampling=request.top_k,
|
|
354
|
+
logprobs=logprobs_k,
|
|
355
|
+
)
|
|
356
|
+
else:
|
|
357
|
+
gen = llm_backend.generate(
|
|
358
|
+
prompt,
|
|
359
|
+
max_new_tokens=request.max_tokens,
|
|
360
|
+
temperature=request.temperature,
|
|
361
|
+
top_p=request.top_p,
|
|
362
|
+
top_k=request.top_k,
|
|
363
|
+
)
|
|
364
|
+
completion = gen.text[len(prompt):] if gen.text.startswith(prompt) else gen.text
|
|
365
|
+
completion = _truncate_stop(completion, request.stop)
|
|
366
|
+
|
|
367
|
+
chunk_data = ChatCompletionChunk(
|
|
368
|
+
id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
|
369
|
+
created=int(time.time()),
|
|
370
|
+
model=request.model or base_model,
|
|
371
|
+
choices=[StreamChoice(
|
|
372
|
+
index=0,
|
|
373
|
+
delta=DeltaMessage(content=completion),
|
|
374
|
+
finish_reason="stop",
|
|
375
|
+
)],
|
|
376
|
+
)
|
|
377
|
+
yield f"data: {chunk_data.model_dump_json()}\n\n"
|
|
378
|
+
|
|
379
|
+
yield "data: [DONE]\n\n"
|
|
380
|
+
|
|
381
|
+
except Exception as e:
|
|
382
|
+
error_chunk = {"error": str(e)}
|
|
383
|
+
yield f"data: {json.dumps(error_chunk)}\n\n"
|
|
384
|
+
|
|
385
|
+
return StreamingResponse(
|
|
386
|
+
event_stream(),
|
|
387
|
+
media_type="text/event-stream",
|
|
388
|
+
headers={
|
|
389
|
+
"Cache-Control": "no-cache",
|
|
390
|
+
"Connection": "keep-alive",
|
|
391
|
+
},
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Non-streaming response
|
|
395
|
+
try:
|
|
396
|
+
if logprobs_k > 0 and hasattr(llm_backend, 'generate_with_logprobs'):
|
|
397
|
+
gen = llm_backend.generate_with_logprobs(
|
|
398
|
+
prompt,
|
|
399
|
+
max_new_tokens=request.max_tokens,
|
|
400
|
+
temperature=request.temperature,
|
|
401
|
+
top_p=request.top_p,
|
|
402
|
+
top_k_sampling=request.top_k,
|
|
403
|
+
logprobs=logprobs_k,
|
|
404
|
+
)
|
|
405
|
+
else:
|
|
406
|
+
gen = llm_backend.generate(
|
|
407
|
+
prompt,
|
|
408
|
+
max_new_tokens=request.max_tokens,
|
|
409
|
+
temperature=request.temperature,
|
|
410
|
+
top_p=request.top_p,
|
|
411
|
+
top_k=request.top_k,
|
|
412
|
+
)
|
|
413
|
+
completion = gen.text[len(prompt):] if gen.text.startswith(prompt) else gen.text
|
|
414
|
+
completion = _truncate_stop(completion, request.stop)
|
|
415
|
+
|
|
416
|
+
prompt_tokens = len(llm_backend.encode(prompt))
|
|
417
|
+
completion_tokens = len(llm_backend.encode(completion))
|
|
418
|
+
|
|
419
|
+
# Build choice with optional logprobs
|
|
420
|
+
choice = Choice(
|
|
421
|
+
index=0,
|
|
422
|
+
message=ChatMessage(role="assistant", content=completion),
|
|
423
|
+
finish_reason="stop",
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
if request.logprobs and gen.logprobs:
|
|
427
|
+
completion_ids = gen.token_ids[gen.prompt_len:]
|
|
428
|
+
logprobs_content = _build_logprobs_content(
|
|
429
|
+
completion_ids,
|
|
430
|
+
gen.logprobs,
|
|
431
|
+
gen.top_k_logprobs,
|
|
432
|
+
llm_backend.tokenizer,
|
|
433
|
+
)
|
|
434
|
+
choice.logprobs = ChoiceLogprobs(content=logprobs_content)
|
|
435
|
+
|
|
436
|
+
return ChatResponse(
|
|
437
|
+
id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
|
438
|
+
created=int(time.time()),
|
|
439
|
+
model=request.model or base_model,
|
|
440
|
+
choices=[choice],
|
|
441
|
+
usage=UsageInfo(
|
|
442
|
+
prompt_tokens=prompt_tokens,
|
|
443
|
+
completion_tokens=completion_tokens,
|
|
444
|
+
total_tokens=prompt_tokens + completion_tokens,
|
|
445
|
+
),
|
|
446
|
+
)
|
|
447
|
+
except Exception as e:
|
|
448
|
+
raise HTTPException(
|
|
449
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
450
|
+
detail=f"Generation failed: {str(e)}",
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# ==========================================================================
|
|
454
|
+
# Internal Rollout (for RLM training)
|
|
455
|
+
# ==========================================================================
|
|
456
|
+
|
|
457
|
+
@router.post(
|
|
458
|
+
"/internal/rollout",
|
|
459
|
+
response_model=RolloutResponse,
|
|
460
|
+
responses={
|
|
461
|
+
200: {"description": "Successful rollout", "model": RolloutResponse},
|
|
462
|
+
400: {"description": "Bad request", "model": ErrorResponse},
|
|
463
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
464
|
+
500: {"description": "Internal error", "model": ErrorResponse},
|
|
465
|
+
},
|
|
466
|
+
tags=["Internal"],
|
|
467
|
+
)
|
|
468
|
+
async def internal_rollout(request: RolloutRequest) -> RolloutResponse:
|
|
469
|
+
"""Internal rollout endpoint returning tokens and logprobs.
|
|
470
|
+
|
|
471
|
+
Used by RLM training loop for generating rollouts with detailed
|
|
472
|
+
token-level information. Supports top-k logprobs for distillation.
|
|
473
|
+
"""
|
|
474
|
+
try:
|
|
475
|
+
# Determine logprobs to return
|
|
476
|
+
logprobs_k = request.include_top_k_logprobs or (5 if request.include_logprobs else 0)
|
|
477
|
+
|
|
478
|
+
gen = llm_backend.generate_with_logprobs(
|
|
479
|
+
request.prompt,
|
|
480
|
+
max_new_tokens=request.max_tokens,
|
|
481
|
+
temperature=request.temperature,
|
|
482
|
+
top_p=request.top_p,
|
|
483
|
+
top_k_sampling=request.top_k,
|
|
484
|
+
seed=request.seed,
|
|
485
|
+
logprobs=logprobs_k,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
completion = gen.text[len(request.prompt):] if gen.text.startswith(request.prompt) else gen.text
|
|
489
|
+
|
|
490
|
+
prompt_logprobs: Optional[List[float]] = None
|
|
491
|
+
prompt_top_k: Optional[List[Dict[str, float]]] = None
|
|
492
|
+
include_prompt = bool(request.include_prompt_logprobs or request.include_prompt_top_k_logprobs)
|
|
493
|
+
if include_prompt and hasattr(llm_backend, "token_logprobs"):
|
|
494
|
+
prompt_ids = llm_backend.encode(request.prompt)
|
|
495
|
+
try:
|
|
496
|
+
logps, topk = llm_backend.token_logprobs(
|
|
497
|
+
prompt_ids,
|
|
498
|
+
prompt_len=len(prompt_ids),
|
|
499
|
+
top_k=int(request.include_prompt_top_k_logprobs or 0),
|
|
500
|
+
include_prompt=True,
|
|
501
|
+
)
|
|
502
|
+
if request.include_prompt_logprobs:
|
|
503
|
+
prompt_logprobs = list(logps)
|
|
504
|
+
if request.include_prompt_top_k_logprobs:
|
|
505
|
+
prompt_top_k = topk or []
|
|
506
|
+
except Exception:
|
|
507
|
+
prompt_logprobs = None
|
|
508
|
+
prompt_top_k = None
|
|
509
|
+
|
|
510
|
+
return RolloutResponse(
|
|
511
|
+
id=f"rollout-{uuid.uuid4().hex[:12]}",
|
|
512
|
+
created=int(time.time()),
|
|
513
|
+
model=base_model,
|
|
514
|
+
prompt_len=gen.prompt_len,
|
|
515
|
+
token_ids=list(gen.token_ids) if request.include_tokens else None,
|
|
516
|
+
logprobs=list(gen.logprobs) if (request.include_logprobs and gen.logprobs is not None) else None,
|
|
517
|
+
top_k_logprobs=gen.top_k_logprobs if request.include_top_k_logprobs else None,
|
|
518
|
+
prompt_logprobs=prompt_logprobs,
|
|
519
|
+
prompt_top_k_logprobs=prompt_top_k,
|
|
520
|
+
completion=completion if request.include_text else None,
|
|
521
|
+
)
|
|
522
|
+
except Exception as e:
|
|
523
|
+
raise HTTPException(
|
|
524
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
525
|
+
detail=f"Rollout generation failed: {str(e)}",
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
# ==========================================================================
|
|
529
|
+
# Training Endpoints
|
|
530
|
+
# ==========================================================================
|
|
531
|
+
|
|
532
|
+
@router.post(
|
|
533
|
+
"/internal/train/forward_backward",
|
|
534
|
+
response_model=ForwardBackwardResponse,
|
|
535
|
+
responses={
|
|
536
|
+
200: {"description": "Forward/backward pass completed", "model": ForwardBackwardResponse},
|
|
537
|
+
400: {"description": "Bad request", "model": ErrorResponse},
|
|
538
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
539
|
+
500: {"description": "Internal error", "model": ErrorResponse},
|
|
540
|
+
},
|
|
541
|
+
tags=["Training"],
|
|
542
|
+
)
|
|
543
|
+
async def train_forward_backward(request: ForwardBackwardRequest) -> ForwardBackwardResponse:
|
|
544
|
+
"""Execute forward and backward pass.
|
|
545
|
+
|
|
546
|
+
Computes loss and gradients for the given batch.
|
|
547
|
+
"""
|
|
548
|
+
try:
|
|
549
|
+
from ..sdk import sft_forward_backward, preference_forward_backward
|
|
550
|
+
|
|
551
|
+
losses = []
|
|
552
|
+
has_grads = False
|
|
553
|
+
|
|
554
|
+
if request.loss_type in ("dpo", "orpo"):
|
|
555
|
+
# Preference training
|
|
556
|
+
if not request.rejected_responses:
|
|
557
|
+
raise HTTPException(
|
|
558
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
559
|
+
detail=f"{request.loss_type} requires rejected_responses",
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
for prompt, chosen, rejected in zip(
|
|
563
|
+
request.prompts,
|
|
564
|
+
request.responses or [],
|
|
565
|
+
request.rejected_responses,
|
|
566
|
+
):
|
|
567
|
+
loss, grads = preference_forward_backward(
|
|
568
|
+
llm_backend,
|
|
569
|
+
prompt,
|
|
570
|
+
chosen,
|
|
571
|
+
rejected,
|
|
572
|
+
algo=request.loss_type,
|
|
573
|
+
beta=(request.extra or {}).get("beta", 0.1),
|
|
574
|
+
max_seq_len=request.max_seq_len,
|
|
575
|
+
train_on_prompt=request.train_on_prompt,
|
|
576
|
+
)
|
|
577
|
+
losses.append(float(loss) if loss is not None else 0.0)
|
|
578
|
+
if grads is not None:
|
|
579
|
+
has_grads = True
|
|
580
|
+
else:
|
|
581
|
+
# SFT training
|
|
582
|
+
for prompt, response in zip(request.prompts, request.responses or []):
|
|
583
|
+
loss, grads = sft_forward_backward(
|
|
584
|
+
llm_backend,
|
|
585
|
+
prompt,
|
|
586
|
+
response,
|
|
587
|
+
train_on_prompt=request.train_on_prompt,
|
|
588
|
+
max_seq_len=request.max_seq_len,
|
|
589
|
+
)
|
|
590
|
+
losses.append(float(loss) if loss is not None else 0.0)
|
|
591
|
+
if grads is not None:
|
|
592
|
+
has_grads = True
|
|
593
|
+
|
|
594
|
+
avg_loss = sum(losses) / len(losses) if losses else 0.0
|
|
595
|
+
|
|
596
|
+
return ForwardBackwardResponse(
|
|
597
|
+
loss=avg_loss,
|
|
598
|
+
has_grads=has_grads,
|
|
599
|
+
batch_size=len(request.prompts),
|
|
600
|
+
metrics={
|
|
601
|
+
"max_loss": max(losses) if losses else 0.0,
|
|
602
|
+
"min_loss": min(losses) if losses else 0.0,
|
|
603
|
+
}
|
|
604
|
+
)
|
|
605
|
+
except HTTPException:
|
|
606
|
+
raise
|
|
607
|
+
except Exception as e:
|
|
608
|
+
raise HTTPException(
|
|
609
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
610
|
+
detail=f"Forward/backward failed: {str(e)}",
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
@router.post(
|
|
614
|
+
"/internal/train/optim_step",
|
|
615
|
+
response_model=OptimStepResponse,
|
|
616
|
+
responses={
|
|
617
|
+
200: {"description": "Optimizer step completed", "model": OptimStepResponse},
|
|
618
|
+
400: {"description": "Bad request", "model": ErrorResponse},
|
|
619
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
620
|
+
500: {"description": "Internal error", "model": ErrorResponse},
|
|
621
|
+
},
|
|
622
|
+
tags=["Training"],
|
|
623
|
+
)
|
|
624
|
+
async def train_optim_step(request: OptimStepRequest) -> OptimStepResponse:
|
|
625
|
+
"""Execute optimizer step.
|
|
626
|
+
|
|
627
|
+
Requires optimizer to be initialized via create_optimizer first.
|
|
628
|
+
"""
|
|
629
|
+
try:
|
|
630
|
+
if training_state["optimizer"] is None:
|
|
631
|
+
raise HTTPException(
|
|
632
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
633
|
+
detail="Optimizer not initialized. Call create_optimizer first.",
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
# Update learning rate if provided
|
|
637
|
+
if request.learning_rate is not None:
|
|
638
|
+
training_state["learning_rate"] = request.learning_rate
|
|
639
|
+
|
|
640
|
+
# Execute step (note: this requires grads to be stored from forward_backward)
|
|
641
|
+
# In a real implementation, you'd need to store grads between calls
|
|
642
|
+
training_state["step"] += 1
|
|
643
|
+
|
|
644
|
+
return OptimStepResponse(
|
|
645
|
+
step=training_state["step"],
|
|
646
|
+
learning_rate=training_state["learning_rate"],
|
|
647
|
+
grad_norm=None, # Would compute from actual grads
|
|
648
|
+
success=True,
|
|
649
|
+
)
|
|
650
|
+
except HTTPException:
|
|
651
|
+
raise
|
|
652
|
+
except Exception as e:
|
|
653
|
+
raise HTTPException(
|
|
654
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
655
|
+
detail=f"Optimizer step failed: {str(e)}",
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
@router.post(
|
|
659
|
+
"/internal/train/create_optimizer",
|
|
660
|
+
response_model=OptimStepResponse,
|
|
661
|
+
responses={
|
|
662
|
+
200: {"description": "Optimizer created", "model": OptimStepResponse},
|
|
663
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
664
|
+
500: {"description": "Internal error", "model": ErrorResponse},
|
|
665
|
+
},
|
|
666
|
+
tags=["Training"],
|
|
667
|
+
)
|
|
668
|
+
async def train_create_optimizer(request: OptimStepRequest) -> OptimStepResponse:
|
|
669
|
+
"""Create optimizer for training."""
|
|
670
|
+
try:
|
|
671
|
+
from ..sdk import create_optimizer
|
|
672
|
+
|
|
673
|
+
lr = request.learning_rate or training_state["learning_rate"]
|
|
674
|
+
opt, _ = create_optimizer(llm_backend, lr=lr)
|
|
675
|
+
training_state["optimizer"] = opt
|
|
676
|
+
training_state["learning_rate"] = lr
|
|
677
|
+
|
|
678
|
+
return OptimStepResponse(
|
|
679
|
+
step=training_state["step"],
|
|
680
|
+
learning_rate=lr,
|
|
681
|
+
success=True,
|
|
682
|
+
)
|
|
683
|
+
except Exception as e:
|
|
684
|
+
raise HTTPException(
|
|
685
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
686
|
+
detail=f"Failed to create optimizer: {str(e)}",
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
@router.post(
|
|
690
|
+
"/internal/train/save_state",
|
|
691
|
+
response_model=SaveStateResponse,
|
|
692
|
+
responses={
|
|
693
|
+
200: {"description": "State saved", "model": SaveStateResponse},
|
|
694
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
695
|
+
500: {"description": "Internal error", "model": ErrorResponse},
|
|
696
|
+
},
|
|
697
|
+
tags=["Training"],
|
|
698
|
+
)
|
|
699
|
+
async def train_save_state(request: SaveStateRequest) -> SaveStateResponse:
|
|
700
|
+
"""Save training checkpoint."""
|
|
701
|
+
try:
|
|
702
|
+
from pathlib import Path
|
|
703
|
+
|
|
704
|
+
save_path = Path(request.path)
|
|
705
|
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
706
|
+
|
|
707
|
+
metadata = {
|
|
708
|
+
"step": training_state["step"],
|
|
709
|
+
"learning_rate": training_state["learning_rate"],
|
|
710
|
+
**(request.metadata or {}),
|
|
711
|
+
}
|
|
712
|
+
|
|
713
|
+
llm_backend.save_adapter(str(save_path), metadata=metadata)
|
|
714
|
+
|
|
715
|
+
return SaveStateResponse(
|
|
716
|
+
path=str(save_path),
|
|
717
|
+
success=True,
|
|
718
|
+
message=f"Checkpoint saved to {save_path}",
|
|
719
|
+
)
|
|
720
|
+
except Exception as e:
|
|
721
|
+
raise HTTPException(
|
|
722
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
723
|
+
detail=f"Failed to save state: {str(e)}",
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
@router.post(
|
|
727
|
+
"/internal/train/load_state",
|
|
728
|
+
response_model=LoadStateResponse,
|
|
729
|
+
responses={
|
|
730
|
+
200: {"description": "State loaded", "model": LoadStateResponse},
|
|
731
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
732
|
+
500: {"description": "Internal error", "model": ErrorResponse},
|
|
733
|
+
},
|
|
734
|
+
tags=["Training"],
|
|
735
|
+
)
|
|
736
|
+
async def train_load_state(request: LoadStateRequest) -> LoadStateResponse:
|
|
737
|
+
"""Load training checkpoint."""
|
|
738
|
+
try:
|
|
739
|
+
from pathlib import Path
|
|
740
|
+
import json
|
|
741
|
+
|
|
742
|
+
load_path = Path(request.path)
|
|
743
|
+
if not load_path.exists():
|
|
744
|
+
raise HTTPException(
|
|
745
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
746
|
+
detail=f"Checkpoint not found: {request.path}",
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
llm_backend.apply_adapter(str(load_path))
|
|
750
|
+
adapter_state["path"] = str(load_path)
|
|
751
|
+
|
|
752
|
+
# Try to load metadata
|
|
753
|
+
step = training_state["step"]
|
|
754
|
+
metadata_path = load_path / "adapter_metadata.json"
|
|
755
|
+
if metadata_path.exists():
|
|
756
|
+
with open(metadata_path) as f:
|
|
757
|
+
metadata = json.load(f)
|
|
758
|
+
step = metadata.get("step", step)
|
|
759
|
+
training_state["step"] = step
|
|
760
|
+
training_state["learning_rate"] = metadata.get("learning_rate", training_state["learning_rate"])
|
|
761
|
+
|
|
762
|
+
return LoadStateResponse(
|
|
763
|
+
path=str(load_path),
|
|
764
|
+
success=True,
|
|
765
|
+
message=f"Checkpoint loaded from {load_path}",
|
|
766
|
+
step=step,
|
|
767
|
+
)
|
|
768
|
+
except HTTPException:
|
|
769
|
+
raise
|
|
770
|
+
except Exception as e:
|
|
771
|
+
raise HTTPException(
|
|
772
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
773
|
+
detail=f"Failed to load state: {str(e)}",
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
@router.get(
|
|
777
|
+
"/internal/train/weights",
|
|
778
|
+
response_model=GetWeightsResponse,
|
|
779
|
+
responses={
|
|
780
|
+
200: {"description": "Weights retrieved", "model": GetWeightsResponse},
|
|
781
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
782
|
+
500: {"description": "Internal error", "model": ErrorResponse},
|
|
783
|
+
},
|
|
784
|
+
tags=["Training"],
|
|
785
|
+
)
|
|
786
|
+
async def train_get_weights() -> GetWeightsResponse:
|
|
787
|
+
"""Get current model weights."""
|
|
788
|
+
try:
|
|
789
|
+
weights = {}
|
|
790
|
+
|
|
791
|
+
if hasattr(llm_backend, 'model') and llm_backend.model:
|
|
792
|
+
model = llm_backend.model
|
|
793
|
+
if hasattr(model, 'trainable_parameters'):
|
|
794
|
+
params = model.trainable_parameters()
|
|
795
|
+
# Convert to serializable format (shape info)
|
|
796
|
+
weights = {
|
|
797
|
+
k: {"shape": list(v.shape) if hasattr(v, 'shape') else str(type(v))}
|
|
798
|
+
for k, v in params.items()
|
|
799
|
+
}
|
|
800
|
+
|
|
801
|
+
return GetWeightsResponse(
|
|
802
|
+
weights=weights,
|
|
803
|
+
success=True,
|
|
804
|
+
message=f"Retrieved {len(weights)} weight tensors",
|
|
805
|
+
)
|
|
806
|
+
except Exception as e:
|
|
807
|
+
raise HTTPException(
|
|
808
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
809
|
+
detail=f"Failed to get weights: {str(e)}",
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
@router.post(
|
|
813
|
+
"/internal/train/weights",
|
|
814
|
+
response_model=SetWeightsResponse,
|
|
815
|
+
responses={
|
|
816
|
+
200: {"description": "Weights set", "model": SetWeightsResponse},
|
|
817
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
818
|
+
500: {"description": "Internal error", "model": ErrorResponse},
|
|
819
|
+
},
|
|
820
|
+
tags=["Training"],
|
|
821
|
+
)
|
|
822
|
+
async def train_set_weights(request: SetWeightsRequest) -> SetWeightsResponse:
|
|
823
|
+
"""Set model weights."""
|
|
824
|
+
try:
|
|
825
|
+
# In practice, this would deserialize and set weights
|
|
826
|
+
# For now, just return success with count
|
|
827
|
+
num_tensors = len(request.weights)
|
|
828
|
+
|
|
829
|
+
return SetWeightsResponse(
|
|
830
|
+
success=True,
|
|
831
|
+
message=f"Set {num_tensors} weight tensors",
|
|
832
|
+
num_tensors=num_tensors,
|
|
833
|
+
)
|
|
834
|
+
except Exception as e:
|
|
835
|
+
raise HTTPException(
|
|
836
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
837
|
+
detail=f"Failed to set weights: {str(e)}",
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
# ==========================================================================
|
|
841
|
+
# Adapter Hot-Reload
|
|
842
|
+
# ==========================================================================
|
|
843
|
+
|
|
844
|
+
@router.post(
|
|
845
|
+
"/internal/adapter/reload",
|
|
846
|
+
response_model=AdapterReloadResponse,
|
|
847
|
+
responses={
|
|
848
|
+
200: {"description": "Adapter reloaded successfully", "model": AdapterReloadResponse},
|
|
849
|
+
400: {"description": "Bad request", "model": ErrorResponse},
|
|
850
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
851
|
+
404: {"description": "Adapter not found", "model": ErrorResponse},
|
|
852
|
+
500: {"description": "Internal error", "model": ErrorResponse},
|
|
853
|
+
},
|
|
854
|
+
tags=["Internal"],
|
|
855
|
+
)
|
|
856
|
+
async def reload_adapter(request: AdapterReloadRequest) -> AdapterReloadResponse:
|
|
857
|
+
"""Hot-reload adapter weights without restarting the server.
|
|
858
|
+
|
|
859
|
+
Can also reload the base model if needed.
|
|
860
|
+
"""
|
|
861
|
+
nonlocal adapter_state
|
|
862
|
+
|
|
863
|
+
try:
|
|
864
|
+
target = request.adapter_path
|
|
865
|
+
|
|
866
|
+
if target:
|
|
867
|
+
target_path = Path(target)
|
|
868
|
+
if not target_path.is_absolute():
|
|
869
|
+
target = str(Path.cwd() / target_path)
|
|
870
|
+
else:
|
|
871
|
+
target = str(target_path)
|
|
872
|
+
|
|
873
|
+
if not target_path.exists():
|
|
874
|
+
raise HTTPException(
|
|
875
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
876
|
+
detail=f"Adapter path not found: {target}",
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
if request.reload_base or target is None:
|
|
880
|
+
llm_backend.load(
|
|
881
|
+
base_model,
|
|
882
|
+
max_seq_len=getattr(cfg.model, "max_seq_len", 2048),
|
|
883
|
+
dtype=getattr(cfg.model, "dtype", "float16"),
|
|
884
|
+
trust_remote_code=getattr(cfg.model, "trust_remote_code", False),
|
|
885
|
+
)
|
|
886
|
+
adapter_state["path"] = None
|
|
887
|
+
|
|
888
|
+
if target:
|
|
889
|
+
llm_backend.apply_adapter(target)
|
|
890
|
+
adapter_state["path"] = target
|
|
891
|
+
|
|
892
|
+
return AdapterReloadResponse(
|
|
893
|
+
ok=True,
|
|
894
|
+
base_model=base_model,
|
|
895
|
+
adapter_path=adapter_state["path"],
|
|
896
|
+
message="Adapter reloaded successfully",
|
|
897
|
+
)
|
|
898
|
+
|
|
899
|
+
except HTTPException:
|
|
900
|
+
raise
|
|
901
|
+
except Exception as e:
|
|
902
|
+
raise HTTPException(
|
|
903
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
904
|
+
detail=f"Adapter reload failed: {str(e)}",
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
# ==========================================================================
|
|
908
|
+
# RLM State and History
|
|
909
|
+
# ==========================================================================
|
|
910
|
+
|
|
911
|
+
@router.get(
|
|
912
|
+
"/internal/rlm/state",
|
|
913
|
+
response_model=RLMState,
|
|
914
|
+
responses={
|
|
915
|
+
200: {"description": "Current RLM state", "model": RLMState},
|
|
916
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
917
|
+
},
|
|
918
|
+
tags=["RLM"],
|
|
919
|
+
)
|
|
920
|
+
async def rlm_state() -> RLMState:
|
|
921
|
+
"""Get current RLM training state."""
|
|
922
|
+
state_path = Path.cwd() / "runs" / "rlm_state.json"
|
|
923
|
+
|
|
924
|
+
if not state_path.exists():
|
|
925
|
+
return RLMState(status="idle")
|
|
926
|
+
|
|
927
|
+
try:
|
|
928
|
+
data = json.loads(state_path.read_text(encoding="utf-8"))
|
|
929
|
+
return RLMState(**data)
|
|
930
|
+
except Exception:
|
|
931
|
+
return RLMState(status="idle")
|
|
932
|
+
|
|
933
|
+
@router.get(
|
|
934
|
+
"/internal/rlm/history",
|
|
935
|
+
response_model=List[RLMHistoryEntry],
|
|
936
|
+
responses={
|
|
937
|
+
200: {"description": "RLM training history", "model": List[RLMHistoryEntry]},
|
|
938
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
939
|
+
},
|
|
940
|
+
tags=["RLM"],
|
|
941
|
+
)
|
|
942
|
+
async def rlm_history(
|
|
943
|
+
limit: Optional[int] = 100,
|
|
944
|
+
offset: Optional[int] = 0,
|
|
945
|
+
) -> List[RLMHistoryEntry]:
|
|
946
|
+
"""Get RLM training history/metrics.
|
|
947
|
+
|
|
948
|
+
Args:
|
|
949
|
+
limit: Maximum number of entries to return
|
|
950
|
+
offset: Number of entries to skip
|
|
951
|
+
"""
|
|
952
|
+
history_path = Path.cwd() / "runs" / "rlm_history.jsonl"
|
|
953
|
+
|
|
954
|
+
if not history_path.exists():
|
|
955
|
+
return []
|
|
956
|
+
|
|
957
|
+
rows = []
|
|
958
|
+
try:
|
|
959
|
+
lines = history_path.read_text(encoding="utf-8").splitlines()
|
|
960
|
+
for line in lines[offset:offset + limit] if limit else lines[offset:]:
|
|
961
|
+
if not line.strip():
|
|
962
|
+
continue
|
|
963
|
+
try:
|
|
964
|
+
data = json.loads(line)
|
|
965
|
+
rows.append(RLMHistoryEntry(**data))
|
|
966
|
+
except Exception:
|
|
967
|
+
continue
|
|
968
|
+
except Exception:
|
|
969
|
+
pass
|
|
970
|
+
|
|
971
|
+
return rows
|
|
972
|
+
|
|
973
|
+
# ==========================================================================
|
|
974
|
+
# Model Management
|
|
975
|
+
# ==========================================================================
|
|
976
|
+
|
|
977
|
+
def _get_model_format(path: Path) -> str:
|
|
978
|
+
"""Detect model format from path."""
|
|
979
|
+
if (path / "model.safetensors").exists() or (path / "weights.safetensors").exists():
|
|
980
|
+
return "mlx"
|
|
981
|
+
if (path / "pytorch_model.bin").exists() or (path / "model.safetensors").exists():
|
|
982
|
+
return "hf"
|
|
983
|
+
if list(path.glob("*.gguf")):
|
|
984
|
+
return "gguf"
|
|
985
|
+
return "mlx" # Default
|
|
986
|
+
|
|
987
|
+
def _has_adapter(path: Path) -> bool:
|
|
988
|
+
"""Check if path contains adapter weights."""
|
|
989
|
+
return (
|
|
990
|
+
(path / "adapter_config.json").exists() or
|
|
991
|
+
(path / "adapters.safetensors").exists() or
|
|
992
|
+
(path / "lora.npz").exists()
|
|
993
|
+
)
|
|
994
|
+
|
|
995
|
+
@router.get(
|
|
996
|
+
"/internal/models/list",
|
|
997
|
+
response_model=ModelsListResponse,
|
|
998
|
+
responses={
|
|
999
|
+
200: {"description": "List of cached models", "model": ModelsListResponse},
|
|
1000
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
1001
|
+
},
|
|
1002
|
+
tags=["Models"],
|
|
1003
|
+
)
|
|
1004
|
+
async def list_models() -> ModelsListResponse:
|
|
1005
|
+
"""List cached MLX models in the cache directory."""
|
|
1006
|
+
cache_dir = _get_cache_dir()
|
|
1007
|
+
models = []
|
|
1008
|
+
|
|
1009
|
+
mlx_dir = cache_dir / "mlx"
|
|
1010
|
+
if mlx_dir.exists():
|
|
1011
|
+
for model_path in mlx_dir.iterdir():
|
|
1012
|
+
if not model_path.is_dir():
|
|
1013
|
+
continue
|
|
1014
|
+
|
|
1015
|
+
# Calculate size
|
|
1016
|
+
size_bytes = 0
|
|
1017
|
+
try:
|
|
1018
|
+
for f in model_path.rglob("*"):
|
|
1019
|
+
if f.is_file():
|
|
1020
|
+
size_bytes += f.stat().st_size
|
|
1021
|
+
except Exception:
|
|
1022
|
+
pass
|
|
1023
|
+
|
|
1024
|
+
# Get metadata
|
|
1025
|
+
metadata = None
|
|
1026
|
+
config_path = model_path / "config.json"
|
|
1027
|
+
if config_path.exists():
|
|
1028
|
+
try:
|
|
1029
|
+
metadata = json.loads(config_path.read_text())
|
|
1030
|
+
except Exception:
|
|
1031
|
+
pass
|
|
1032
|
+
|
|
1033
|
+
model_id = model_path.name.replace("__", "/")
|
|
1034
|
+
has_adapter = _has_adapter(model_path)
|
|
1035
|
+
adapter_path = str(model_path) if has_adapter else None
|
|
1036
|
+
|
|
1037
|
+
models.append(ModelInfo(
|
|
1038
|
+
id=model_id,
|
|
1039
|
+
path=str(model_path),
|
|
1040
|
+
size_bytes=size_bytes,
|
|
1041
|
+
format=_get_model_format(model_path),
|
|
1042
|
+
has_adapter=has_adapter,
|
|
1043
|
+
adapter_path=adapter_path,
|
|
1044
|
+
metadata=metadata,
|
|
1045
|
+
downloaded_at=int(model_path.stat().st_mtime),
|
|
1046
|
+
))
|
|
1047
|
+
|
|
1048
|
+
# Also check HF cache
|
|
1049
|
+
hf_dir = cache_dir / "hf"
|
|
1050
|
+
if hf_dir.exists():
|
|
1051
|
+
for model_path in hf_dir.iterdir():
|
|
1052
|
+
if not model_path.is_dir():
|
|
1053
|
+
continue
|
|
1054
|
+
|
|
1055
|
+
model_id = model_path.name.replace("__", "/")
|
|
1056
|
+
# Skip if already in MLX format
|
|
1057
|
+
if any(m.id == model_id for m in models):
|
|
1058
|
+
continue
|
|
1059
|
+
|
|
1060
|
+
size_bytes = 0
|
|
1061
|
+
try:
|
|
1062
|
+
for f in model_path.rglob("*"):
|
|
1063
|
+
if f.is_file():
|
|
1064
|
+
size_bytes += f.stat().st_size
|
|
1065
|
+
except Exception:
|
|
1066
|
+
pass
|
|
1067
|
+
|
|
1068
|
+
models.append(ModelInfo(
|
|
1069
|
+
id=f"{model_id} (HF)",
|
|
1070
|
+
path=str(model_path),
|
|
1071
|
+
size_bytes=size_bytes,
|
|
1072
|
+
format="hf",
|
|
1073
|
+
has_adapter=False,
|
|
1074
|
+
metadata=None,
|
|
1075
|
+
downloaded_at=int(model_path.stat().st_mtime),
|
|
1076
|
+
))
|
|
1077
|
+
|
|
1078
|
+
return ModelsListResponse(
|
|
1079
|
+
models=models,
|
|
1080
|
+
total=len(models),
|
|
1081
|
+
cache_dir=str(cache_dir),
|
|
1082
|
+
)
|
|
1083
|
+
|
|
1084
|
+
@router.post(
|
|
1085
|
+
"/internal/models/pull",
|
|
1086
|
+
response_model=ModelPullResponse,
|
|
1087
|
+
responses={
|
|
1088
|
+
200: {"description": "Model pull initiated", "model": ModelPullResponse},
|
|
1089
|
+
400: {"description": "Bad request", "model": ErrorResponse},
|
|
1090
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
1091
|
+
500: {"description": "Pull failed", "model": ErrorResponse},
|
|
1092
|
+
},
|
|
1093
|
+
tags=["Models"],
|
|
1094
|
+
)
|
|
1095
|
+
async def pull_model(request: ModelPullRequest) -> ModelPullResponse:
|
|
1096
|
+
"""Pull a model from HuggingFace.
|
|
1097
|
+
|
|
1098
|
+
This proxies the `mlxsmith pull` command and initiates an
|
|
1099
|
+
asynchronous model download and optional conversion.
|
|
1100
|
+
|
|
1101
|
+
Note: This endpoint returns immediately with a status. For large
|
|
1102
|
+
models, use the list endpoint to check completion status.
|
|
1103
|
+
"""
|
|
1104
|
+
cache_dir = _get_cache_dir()
|
|
1105
|
+
local_path = cache_dir / "mlx" / request.model_id.replace("/", "__")
|
|
1106
|
+
|
|
1107
|
+
try:
|
|
1108
|
+
# Import here to avoid circular dependencies
|
|
1109
|
+
from ..models import hf_pull
|
|
1110
|
+
from ..config import ProjectConfig
|
|
1111
|
+
|
|
1112
|
+
# Get HF token if available
|
|
1113
|
+
hf_token = None
|
|
1114
|
+
token_path = Path.home() / ".config" / "mlxsmith" / "hf_token"
|
|
1115
|
+
if token_path.exists():
|
|
1116
|
+
hf_token = token_path.read_text().strip()
|
|
1117
|
+
|
|
1118
|
+
# Start pull in background (synchronous for now)
|
|
1119
|
+
# In production, this would spawn a background task
|
|
1120
|
+
result_path = hf_pull(
|
|
1121
|
+
model_id=request.model_id,
|
|
1122
|
+
cache_dir=cache_dir,
|
|
1123
|
+
convert=request.convert,
|
|
1124
|
+
quantize=request.quantize,
|
|
1125
|
+
q_bits=request.q_bits,
|
|
1126
|
+
q_group_size=request.q_group_size,
|
|
1127
|
+
trust_remote_code=request.trust_remote_code,
|
|
1128
|
+
hf_token=hf_token,
|
|
1129
|
+
)
|
|
1130
|
+
|
|
1131
|
+
return ModelPullResponse(
|
|
1132
|
+
ok=True,
|
|
1133
|
+
model_id=request.model_id,
|
|
1134
|
+
local_path=str(result_path),
|
|
1135
|
+
status=ModelPullStatus(
|
|
1136
|
+
status="completed",
|
|
1137
|
+
progress=100.0,
|
|
1138
|
+
message=f"Model pulled successfully to {result_path}",
|
|
1139
|
+
),
|
|
1140
|
+
)
|
|
1141
|
+
|
|
1142
|
+
except Exception as e:
|
|
1143
|
+
raise HTTPException(
|
|
1144
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
1145
|
+
detail=f"Model pull failed: {str(e)}",
|
|
1146
|
+
)
|
|
1147
|
+
|
|
1148
|
+
# ==========================================================================
|
|
1149
|
+
# HuggingFace Token Management
|
|
1150
|
+
# ==========================================================================
|
|
1151
|
+
|
|
1152
|
+
@router.post(
|
|
1153
|
+
"/internal/hf/token",
|
|
1154
|
+
response_model=HFTokenResponse,
|
|
1155
|
+
responses={
|
|
1156
|
+
200: {"description": "Token stored successfully", "model": HFTokenResponse},
|
|
1157
|
+
400: {"description": "Bad request", "model": ErrorResponse},
|
|
1158
|
+
401: {"description": "Unauthorized", "model": ErrorResponse},
|
|
1159
|
+
500: {"description": "Storage failed", "model": ErrorResponse},
|
|
1160
|
+
},
|
|
1161
|
+
tags=["HF Token"],
|
|
1162
|
+
)
|
|
1163
|
+
async def store_hf_token(request: HFTokenRequest) -> HFTokenResponse:
|
|
1164
|
+
"""Store HuggingFace token securely.
|
|
1165
|
+
|
|
1166
|
+
Attempts to use system keyring if available, falls back to
|
|
1167
|
+
file-based storage with restricted permissions.
|
|
1168
|
+
"""
|
|
1169
|
+
token = request.token
|
|
1170
|
+
username = None
|
|
1171
|
+
storage_method: str = "memory"
|
|
1172
|
+
|
|
1173
|
+
# Validate token if requested
|
|
1174
|
+
if request.validate_token:
|
|
1175
|
+
try:
|
|
1176
|
+
from huggingface_hub import HfApi
|
|
1177
|
+
api = HfApi(token=token)
|
|
1178
|
+
user_info = api.whoami()
|
|
1179
|
+
username = user_info.get("name") if user_info else None
|
|
1180
|
+
except Exception as e:
|
|
1181
|
+
raise HTTPException(
|
|
1182
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
1183
|
+
detail=f"Token validation failed: {str(e)}",
|
|
1184
|
+
)
|
|
1185
|
+
|
|
1186
|
+
# Store token
|
|
1187
|
+
if request.persist:
|
|
1188
|
+
try:
|
|
1189
|
+
# Try keyring first
|
|
1190
|
+
try:
|
|
1191
|
+
import keyring
|
|
1192
|
+
keyring.set_password("mlxsmith", "huggingface", token)
|
|
1193
|
+
storage_method = "keyring"
|
|
1194
|
+
except ImportError:
|
|
1195
|
+
# Fall back to file storage
|
|
1196
|
+
config_dir = Path.home() / ".config" / "mlxsmith"
|
|
1197
|
+
config_dir.mkdir(parents=True, exist_ok=True)
|
|
1198
|
+
token_path = config_dir / "hf_token"
|
|
1199
|
+
token_path.write_text(token, encoding="utf-8")
|
|
1200
|
+
# Restrict permissions (owner read/write only)
|
|
1201
|
+
os.chmod(token_path, 0o600)
|
|
1202
|
+
storage_method = "file"
|
|
1203
|
+
except Exception as e:
|
|
1204
|
+
raise HTTPException(
|
|
1205
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
1206
|
+
detail=f"Failed to store token: {str(e)}",
|
|
1207
|
+
)
|
|
1208
|
+
|
|
1209
|
+
return HFTokenResponse(
|
|
1210
|
+
ok=True,
|
|
1211
|
+
validated=request.validate_token,
|
|
1212
|
+
username=username,
|
|
1213
|
+
message="Token stored successfully",
|
|
1214
|
+
storage_method=storage_method,
|
|
1215
|
+
)
|
|
1216
|
+
|
|
1217
|
+
return router
|