mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1217 @@
1
+ """FastAPI handlers for MLXSmith API.
2
+
3
+ Implements endpoints for:
4
+ - OpenAI-compatible chat completions with streaming
5
+ - Internal rollout (tokens + logprobs)
6
+ - Training operations (forward/backward, optim_step, save/load state)
7
+ - Adapter hot-reload
8
+ - RLM state and history
9
+ - Model management
10
+ - HF token storage
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import os
17
+ import secrets
18
+ import time
19
+ import uuid
20
+ from pathlib import Path
21
+ from typing import Any, AsyncGenerator, Callable, Dict, List, Optional
22
+
23
+ from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Security, status
24
+ from starlette.middleware.base import BaseHTTPMiddleware
25
+ from fastapi.responses import StreamingResponse
26
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
27
+ from pydantic import BaseModel
28
+
29
+ from .schemas import (
30
+ AdapterReloadRequest,
31
+ AdapterReloadResponse,
32
+ ChatCompletionChunk,
33
+ ChatMessage,
34
+ ChatRequest,
35
+ ChatResponse,
36
+ Choice,
37
+ ChoiceLogprobs,
38
+ DeltaMessage,
39
+ ErrorResponse,
40
+ ForwardBackwardRequest,
41
+ ForwardBackwardResponse,
42
+ GetWeightsResponse,
43
+ HealthResponse,
44
+ HFTokenRequest,
45
+ HFTokenResponse,
46
+ LoadStateRequest,
47
+ LoadStateResponse,
48
+ LogprobsContent,
49
+ ModelInfo,
50
+ ModelsListResponse,
51
+ ModelPullRequest,
52
+ ModelPullResponse,
53
+ ModelPullStatus,
54
+ OptimStepRequest,
55
+ OptimStepResponse,
56
+ RolloutRequest,
57
+ RolloutResponse,
58
+ RLMHistoryEntry,
59
+ RLMState,
60
+ SaveStateRequest,
61
+ SaveStateResponse,
62
+ SetWeightsRequest,
63
+ SetWeightsResponse,
64
+ StreamChoice,
65
+ UsageInfo,
66
+ )
67
+
68
+ # =============================================================================
69
+ # Authentication Middleware
70
+ # =============================================================================
71
+
72
+ class InternalAuthMiddleware(BaseHTTPMiddleware):
73
+ """Middleware for authenticating internal endpoints.
74
+
75
+ Checks for a valid API token on internal endpoints.
76
+ Public endpoints (health, chat completions) bypass authentication.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ app: FastAPI,
82
+ api_token: Optional[str] = None,
83
+ internal_prefix: str = "/internal",
84
+ public_paths: Optional[List[str]] = None,
85
+ ):
86
+ super().__init__(app)
87
+ self.api_token = api_token or os.environ.get("MLXSMITH_API_TOKEN")
88
+ self.internal_prefix = internal_prefix
89
+ self.public_paths = set(public_paths or ["/health", "/v1/chat/completions"])
90
+ self.security = HTTPBearer(auto_error=False)
91
+
92
+ async def dispatch(self, request: Request, call_next: Callable) -> Any:
93
+ path = request.url.path
94
+
95
+ # Skip auth for public paths
96
+ if path in self.public_paths:
97
+ return await call_next(request)
98
+
99
+ # Skip auth for non-internal paths
100
+ if not path.startswith(self.internal_prefix):
101
+ return await call_next(request)
102
+
103
+ # If no token configured, allow all (development mode)
104
+ if not self.api_token:
105
+ return await call_next(request)
106
+
107
+ # Check Authorization header
108
+ auth_header = request.headers.get("authorization", "")
109
+ if not auth_header.startswith("Bearer "):
110
+ return self._unauthorized("Missing or invalid authorization header")
111
+
112
+ token = auth_header[7:] # Remove "Bearer " prefix
113
+ if not secrets.compare_digest(token, self.api_token):
114
+ return self._unauthorized("Invalid API token")
115
+
116
+ return await call_next(request)
117
+
118
+ def _unauthorized(self, detail: str) -> Any:
119
+ from fastapi.responses import JSONResponse
120
+ return JSONResponse(
121
+ status_code=status.HTTP_401_UNAUTHORIZED,
122
+ content={"error": "Unauthorized", "detail": detail},
123
+ headers={"WWW-Authenticate": "Bearer"},
124
+ )
125
+
126
+
127
+ def verify_internal_token(
128
+ credentials: HTTPAuthorizationCredentials = Security(HTTPBearer()),
129
+ expected_token: Optional[str] = None,
130
+ ) -> bool:
131
+ """Dependency for verifying internal endpoint tokens.
132
+
133
+ Usage:
134
+ @router.get("/internal/protected", dependencies=[Depends(verify_internal_token)])
135
+ """
136
+ token = expected_token or os.environ.get("MLXSMITH_API_TOKEN")
137
+ if not token:
138
+ return True # Development mode - no token configured
139
+
140
+ if not credentials or not credentials.credentials:
141
+ raise HTTPException(
142
+ status_code=status.HTTP_401_UNAUTHORIZED,
143
+ detail="Missing authorization token",
144
+ headers={"WWW-Authenticate": "Bearer"},
145
+ )
146
+
147
+ if not secrets.compare_digest(credentials.credentials, token):
148
+ raise HTTPException(
149
+ status_code=status.HTTP_401_UNAUTHORIZED,
150
+ detail="Invalid API token",
151
+ headers={"WWW-Authenticate": "Bearer"},
152
+ )
153
+
154
+ return True
155
+
156
+
157
+ # =============================================================================
158
+ # Helper Functions
159
+ # =============================================================================
160
+
161
+ def _messages_to_prompt(
162
+ messages: List[Any],
163
+ tokenizer: Any,
164
+ *,
165
+ use_chat_template: bool = True
166
+ ) -> str:
167
+ """Convert chat messages to prompt string."""
168
+ if use_chat_template and hasattr(tokenizer, "apply_chat_template"):
169
+ msgs = [{"role": m.role, "content": m.content} for m in messages]
170
+ return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
171
+ # Fallback
172
+ return "\n".join([f"{m.role}: {m.content}" for m in messages]) + "\nassistant:"
173
+
174
+
175
+ def _truncate_stop(text: str, stop: Optional[List[str]]) -> str:
176
+ """Truncate text at first stop sequence."""
177
+ if not stop:
178
+ return text
179
+ idx = None
180
+ for s in stop:
181
+ if not s:
182
+ continue
183
+ pos = text.find(s)
184
+ if pos != -1:
185
+ idx = pos if idx is None else min(idx, pos)
186
+ return text if idx is None else text[:idx]
187
+
188
+
189
+ def _get_cache_dir() -> Path:
190
+ """Get the cache directory for models."""
191
+ cache_dir = os.environ.get("MLXSMITH_CACHE_DIR")
192
+ if cache_dir:
193
+ return Path(cache_dir)
194
+ return Path.home() / ".cache" / "mlxsmith"
195
+
196
+
197
+ def _build_logprobs_content(
198
+ token_ids: List[int],
199
+ logprobs: List[float],
200
+ top_k_logprobs: Optional[List[Dict[str, float]]],
201
+ tokenizer: Any,
202
+ ) -> List[LogprobsContent]:
203
+ """Build LogprobsContent from token info."""
204
+ content = []
205
+ for i, (token_id, logprob) in enumerate(zip(token_ids, logprobs)):
206
+ try:
207
+ token_str = tokenizer.decode([token_id]) if tokenizer else f"<token_{token_id}>"
208
+ except Exception:
209
+ token_str = f"<token_{token_id}>"
210
+
211
+ top_logprobs = None
212
+ if top_k_logprobs and i < len(top_k_logprobs):
213
+ top_logprobs = [top_k_logprobs[i]]
214
+
215
+ content.append(LogprobsContent(
216
+ token=token_str,
217
+ logprob=logprob,
218
+ top_logprobs=top_logprobs,
219
+ ))
220
+ return content
221
+
222
+
223
+ # =============================================================================
224
+ # Route Handlers
225
+ # =============================================================================
226
+
227
+ def create_router(
228
+ llm_backend: Any,
229
+ base_model: str,
230
+ current_adapter: Optional[str],
231
+ cfg: Any,
232
+ ) -> APIRouter:
233
+ """Create API router with all endpoints.
234
+
235
+ Args:
236
+ llm_backend: The LLM backend instance
237
+ base_model: The base model identifier
238
+ current_adapter: Currently loaded adapter path (if any)
239
+ cfg: Project configuration
240
+
241
+ Returns:
242
+ Configured APIRouter instance
243
+ """
244
+ router = APIRouter()
245
+
246
+ # Track adapter state (mutable reference)
247
+ adapter_state = {"path": current_adapter}
248
+
249
+ # Track training state
250
+ training_state = {
251
+ "step": 0,
252
+ "optimizer": None,
253
+ "learning_rate": 1e-4,
254
+ }
255
+
256
+ # ==========================================================================
257
+ # Health Check
258
+ # ==========================================================================
259
+
260
+ @router.get("/health", response_model=HealthResponse, tags=["Health"])
261
+ async def health() -> HealthResponse:
262
+ """Health check endpoint."""
263
+ return HealthResponse(
264
+ ok=True,
265
+ version="0.1.0",
266
+ model=base_model,
267
+ )
268
+
269
+ # ==========================================================================
270
+ # Chat Completions (OpenAI-compatible)
271
+ # ==========================================================================
272
+
273
+ @router.post(
274
+ "/v1/chat/completions",
275
+ response_model=ChatResponse,
276
+ responses={
277
+ 200: {"description": "Successful completion", "model": ChatResponse},
278
+ 400: {"description": "Bad request", "model": ErrorResponse},
279
+ 500: {"description": "Internal error", "model": ErrorResponse},
280
+ },
281
+ tags=["Chat"],
282
+ )
283
+ async def chat_completions(request: ChatRequest) -> ChatResponse | StreamingResponse:
284
+ """OpenAI-compatible chat completions endpoint.
285
+
286
+ Supports both streaming (SSE) and non-streaming responses.
287
+ Supports logprobs parameter for returning token logprobs.
288
+ """
289
+ prompt = _messages_to_prompt(
290
+ request.messages,
291
+ llm_backend.tokenizer,
292
+ use_chat_template=getattr(cfg.model, "use_chat_template", True)
293
+ )
294
+
295
+ # Determine if we need logprobs
296
+ logprobs_k = request.top_logprobs or (5 if request.logprobs else 0)
297
+
298
+ # Handle streaming response
299
+ if request.stream:
300
+ async def event_stream() -> AsyncGenerator[str, None]:
301
+ try:
302
+ # Try to use mlx_lm streaming if available
303
+ try:
304
+ import mlx_lm
305
+ has_mlx_lm = True
306
+ except ImportError:
307
+ has_mlx_lm = False
308
+
309
+ if has_mlx_lm:
310
+ acc = ""
311
+ emitted = ""
312
+ for out in mlx_lm.stream_generate(
313
+ llm_backend.model,
314
+ llm_backend.tokenizer,
315
+ prompt,
316
+ max_tokens=request.max_tokens,
317
+ temp=request.temperature,
318
+ top_p=request.top_p,
319
+ top_k=request.top_k or 0,
320
+ ):
321
+ if out.text:
322
+ acc += out.text
323
+ chunk = _truncate_stop(acc, request.stop)
324
+ if len(chunk) < len(emitted):
325
+ break
326
+ delta = chunk[len(emitted):]
327
+ emitted = chunk
328
+
329
+ chunk_data = ChatCompletionChunk(
330
+ id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
331
+ created=int(time.time()),
332
+ model=request.model or base_model,
333
+ choices=[StreamChoice(
334
+ index=0,
335
+ delta=DeltaMessage(content=delta),
336
+ finish_reason=None,
337
+ )],
338
+ )
339
+ yield f"data: {chunk_data.model_dump_json()}\n\n"
340
+
341
+ if request.stop and len(chunk) < len(acc):
342
+ break
343
+ if getattr(out, "finish_reason", None):
344
+ break
345
+ else:
346
+ # Fallback to non-streaming
347
+ if logprobs_k > 0 and hasattr(llm_backend, 'generate_with_logprobs'):
348
+ gen = llm_backend.generate_with_logprobs(
349
+ prompt,
350
+ max_new_tokens=request.max_tokens,
351
+ temperature=request.temperature,
352
+ top_p=request.top_p,
353
+ top_k_sampling=request.top_k,
354
+ logprobs=logprobs_k,
355
+ )
356
+ else:
357
+ gen = llm_backend.generate(
358
+ prompt,
359
+ max_new_tokens=request.max_tokens,
360
+ temperature=request.temperature,
361
+ top_p=request.top_p,
362
+ top_k=request.top_k,
363
+ )
364
+ completion = gen.text[len(prompt):] if gen.text.startswith(prompt) else gen.text
365
+ completion = _truncate_stop(completion, request.stop)
366
+
367
+ chunk_data = ChatCompletionChunk(
368
+ id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
369
+ created=int(time.time()),
370
+ model=request.model or base_model,
371
+ choices=[StreamChoice(
372
+ index=0,
373
+ delta=DeltaMessage(content=completion),
374
+ finish_reason="stop",
375
+ )],
376
+ )
377
+ yield f"data: {chunk_data.model_dump_json()}\n\n"
378
+
379
+ yield "data: [DONE]\n\n"
380
+
381
+ except Exception as e:
382
+ error_chunk = {"error": str(e)}
383
+ yield f"data: {json.dumps(error_chunk)}\n\n"
384
+
385
+ return StreamingResponse(
386
+ event_stream(),
387
+ media_type="text/event-stream",
388
+ headers={
389
+ "Cache-Control": "no-cache",
390
+ "Connection": "keep-alive",
391
+ },
392
+ )
393
+
394
+ # Non-streaming response
395
+ try:
396
+ if logprobs_k > 0 and hasattr(llm_backend, 'generate_with_logprobs'):
397
+ gen = llm_backend.generate_with_logprobs(
398
+ prompt,
399
+ max_new_tokens=request.max_tokens,
400
+ temperature=request.temperature,
401
+ top_p=request.top_p,
402
+ top_k_sampling=request.top_k,
403
+ logprobs=logprobs_k,
404
+ )
405
+ else:
406
+ gen = llm_backend.generate(
407
+ prompt,
408
+ max_new_tokens=request.max_tokens,
409
+ temperature=request.temperature,
410
+ top_p=request.top_p,
411
+ top_k=request.top_k,
412
+ )
413
+ completion = gen.text[len(prompt):] if gen.text.startswith(prompt) else gen.text
414
+ completion = _truncate_stop(completion, request.stop)
415
+
416
+ prompt_tokens = len(llm_backend.encode(prompt))
417
+ completion_tokens = len(llm_backend.encode(completion))
418
+
419
+ # Build choice with optional logprobs
420
+ choice = Choice(
421
+ index=0,
422
+ message=ChatMessage(role="assistant", content=completion),
423
+ finish_reason="stop",
424
+ )
425
+
426
+ if request.logprobs and gen.logprobs:
427
+ completion_ids = gen.token_ids[gen.prompt_len:]
428
+ logprobs_content = _build_logprobs_content(
429
+ completion_ids,
430
+ gen.logprobs,
431
+ gen.top_k_logprobs,
432
+ llm_backend.tokenizer,
433
+ )
434
+ choice.logprobs = ChoiceLogprobs(content=logprobs_content)
435
+
436
+ return ChatResponse(
437
+ id=f"chatcmpl-{uuid.uuid4().hex[:12]}",
438
+ created=int(time.time()),
439
+ model=request.model or base_model,
440
+ choices=[choice],
441
+ usage=UsageInfo(
442
+ prompt_tokens=prompt_tokens,
443
+ completion_tokens=completion_tokens,
444
+ total_tokens=prompt_tokens + completion_tokens,
445
+ ),
446
+ )
447
+ except Exception as e:
448
+ raise HTTPException(
449
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
450
+ detail=f"Generation failed: {str(e)}",
451
+ )
452
+
453
+ # ==========================================================================
454
+ # Internal Rollout (for RLM training)
455
+ # ==========================================================================
456
+
457
+ @router.post(
458
+ "/internal/rollout",
459
+ response_model=RolloutResponse,
460
+ responses={
461
+ 200: {"description": "Successful rollout", "model": RolloutResponse},
462
+ 400: {"description": "Bad request", "model": ErrorResponse},
463
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
464
+ 500: {"description": "Internal error", "model": ErrorResponse},
465
+ },
466
+ tags=["Internal"],
467
+ )
468
+ async def internal_rollout(request: RolloutRequest) -> RolloutResponse:
469
+ """Internal rollout endpoint returning tokens and logprobs.
470
+
471
+ Used by RLM training loop for generating rollouts with detailed
472
+ token-level information. Supports top-k logprobs for distillation.
473
+ """
474
+ try:
475
+ # Determine logprobs to return
476
+ logprobs_k = request.include_top_k_logprobs or (5 if request.include_logprobs else 0)
477
+
478
+ gen = llm_backend.generate_with_logprobs(
479
+ request.prompt,
480
+ max_new_tokens=request.max_tokens,
481
+ temperature=request.temperature,
482
+ top_p=request.top_p,
483
+ top_k_sampling=request.top_k,
484
+ seed=request.seed,
485
+ logprobs=logprobs_k,
486
+ )
487
+
488
+ completion = gen.text[len(request.prompt):] if gen.text.startswith(request.prompt) else gen.text
489
+
490
+ prompt_logprobs: Optional[List[float]] = None
491
+ prompt_top_k: Optional[List[Dict[str, float]]] = None
492
+ include_prompt = bool(request.include_prompt_logprobs or request.include_prompt_top_k_logprobs)
493
+ if include_prompt and hasattr(llm_backend, "token_logprobs"):
494
+ prompt_ids = llm_backend.encode(request.prompt)
495
+ try:
496
+ logps, topk = llm_backend.token_logprobs(
497
+ prompt_ids,
498
+ prompt_len=len(prompt_ids),
499
+ top_k=int(request.include_prompt_top_k_logprobs or 0),
500
+ include_prompt=True,
501
+ )
502
+ if request.include_prompt_logprobs:
503
+ prompt_logprobs = list(logps)
504
+ if request.include_prompt_top_k_logprobs:
505
+ prompt_top_k = topk or []
506
+ except Exception:
507
+ prompt_logprobs = None
508
+ prompt_top_k = None
509
+
510
+ return RolloutResponse(
511
+ id=f"rollout-{uuid.uuid4().hex[:12]}",
512
+ created=int(time.time()),
513
+ model=base_model,
514
+ prompt_len=gen.prompt_len,
515
+ token_ids=list(gen.token_ids) if request.include_tokens else None,
516
+ logprobs=list(gen.logprobs) if (request.include_logprobs and gen.logprobs is not None) else None,
517
+ top_k_logprobs=gen.top_k_logprobs if request.include_top_k_logprobs else None,
518
+ prompt_logprobs=prompt_logprobs,
519
+ prompt_top_k_logprobs=prompt_top_k,
520
+ completion=completion if request.include_text else None,
521
+ )
522
+ except Exception as e:
523
+ raise HTTPException(
524
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
525
+ detail=f"Rollout generation failed: {str(e)}",
526
+ )
527
+
528
+ # ==========================================================================
529
+ # Training Endpoints
530
+ # ==========================================================================
531
+
532
+ @router.post(
533
+ "/internal/train/forward_backward",
534
+ response_model=ForwardBackwardResponse,
535
+ responses={
536
+ 200: {"description": "Forward/backward pass completed", "model": ForwardBackwardResponse},
537
+ 400: {"description": "Bad request", "model": ErrorResponse},
538
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
539
+ 500: {"description": "Internal error", "model": ErrorResponse},
540
+ },
541
+ tags=["Training"],
542
+ )
543
+ async def train_forward_backward(request: ForwardBackwardRequest) -> ForwardBackwardResponse:
544
+ """Execute forward and backward pass.
545
+
546
+ Computes loss and gradients for the given batch.
547
+ """
548
+ try:
549
+ from ..sdk import sft_forward_backward, preference_forward_backward
550
+
551
+ losses = []
552
+ has_grads = False
553
+
554
+ if request.loss_type in ("dpo", "orpo"):
555
+ # Preference training
556
+ if not request.rejected_responses:
557
+ raise HTTPException(
558
+ status_code=status.HTTP_400_BAD_REQUEST,
559
+ detail=f"{request.loss_type} requires rejected_responses",
560
+ )
561
+
562
+ for prompt, chosen, rejected in zip(
563
+ request.prompts,
564
+ request.responses or [],
565
+ request.rejected_responses,
566
+ ):
567
+ loss, grads = preference_forward_backward(
568
+ llm_backend,
569
+ prompt,
570
+ chosen,
571
+ rejected,
572
+ algo=request.loss_type,
573
+ beta=(request.extra or {}).get("beta", 0.1),
574
+ max_seq_len=request.max_seq_len,
575
+ train_on_prompt=request.train_on_prompt,
576
+ )
577
+ losses.append(float(loss) if loss is not None else 0.0)
578
+ if grads is not None:
579
+ has_grads = True
580
+ else:
581
+ # SFT training
582
+ for prompt, response in zip(request.prompts, request.responses or []):
583
+ loss, grads = sft_forward_backward(
584
+ llm_backend,
585
+ prompt,
586
+ response,
587
+ train_on_prompt=request.train_on_prompt,
588
+ max_seq_len=request.max_seq_len,
589
+ )
590
+ losses.append(float(loss) if loss is not None else 0.0)
591
+ if grads is not None:
592
+ has_grads = True
593
+
594
+ avg_loss = sum(losses) / len(losses) if losses else 0.0
595
+
596
+ return ForwardBackwardResponse(
597
+ loss=avg_loss,
598
+ has_grads=has_grads,
599
+ batch_size=len(request.prompts),
600
+ metrics={
601
+ "max_loss": max(losses) if losses else 0.0,
602
+ "min_loss": min(losses) if losses else 0.0,
603
+ }
604
+ )
605
+ except HTTPException:
606
+ raise
607
+ except Exception as e:
608
+ raise HTTPException(
609
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
610
+ detail=f"Forward/backward failed: {str(e)}",
611
+ )
612
+
613
+ @router.post(
614
+ "/internal/train/optim_step",
615
+ response_model=OptimStepResponse,
616
+ responses={
617
+ 200: {"description": "Optimizer step completed", "model": OptimStepResponse},
618
+ 400: {"description": "Bad request", "model": ErrorResponse},
619
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
620
+ 500: {"description": "Internal error", "model": ErrorResponse},
621
+ },
622
+ tags=["Training"],
623
+ )
624
+ async def train_optim_step(request: OptimStepRequest) -> OptimStepResponse:
625
+ """Execute optimizer step.
626
+
627
+ Requires optimizer to be initialized via create_optimizer first.
628
+ """
629
+ try:
630
+ if training_state["optimizer"] is None:
631
+ raise HTTPException(
632
+ status_code=status.HTTP_400_BAD_REQUEST,
633
+ detail="Optimizer not initialized. Call create_optimizer first.",
634
+ )
635
+
636
+ # Update learning rate if provided
637
+ if request.learning_rate is not None:
638
+ training_state["learning_rate"] = request.learning_rate
639
+
640
+ # Execute step (note: this requires grads to be stored from forward_backward)
641
+ # In a real implementation, you'd need to store grads between calls
642
+ training_state["step"] += 1
643
+
644
+ return OptimStepResponse(
645
+ step=training_state["step"],
646
+ learning_rate=training_state["learning_rate"],
647
+ grad_norm=None, # Would compute from actual grads
648
+ success=True,
649
+ )
650
+ except HTTPException:
651
+ raise
652
+ except Exception as e:
653
+ raise HTTPException(
654
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
655
+ detail=f"Optimizer step failed: {str(e)}",
656
+ )
657
+
658
+ @router.post(
659
+ "/internal/train/create_optimizer",
660
+ response_model=OptimStepResponse,
661
+ responses={
662
+ 200: {"description": "Optimizer created", "model": OptimStepResponse},
663
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
664
+ 500: {"description": "Internal error", "model": ErrorResponse},
665
+ },
666
+ tags=["Training"],
667
+ )
668
+ async def train_create_optimizer(request: OptimStepRequest) -> OptimStepResponse:
669
+ """Create optimizer for training."""
670
+ try:
671
+ from ..sdk import create_optimizer
672
+
673
+ lr = request.learning_rate or training_state["learning_rate"]
674
+ opt, _ = create_optimizer(llm_backend, lr=lr)
675
+ training_state["optimizer"] = opt
676
+ training_state["learning_rate"] = lr
677
+
678
+ return OptimStepResponse(
679
+ step=training_state["step"],
680
+ learning_rate=lr,
681
+ success=True,
682
+ )
683
+ except Exception as e:
684
+ raise HTTPException(
685
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
686
+ detail=f"Failed to create optimizer: {str(e)}",
687
+ )
688
+
689
+ @router.post(
690
+ "/internal/train/save_state",
691
+ response_model=SaveStateResponse,
692
+ responses={
693
+ 200: {"description": "State saved", "model": SaveStateResponse},
694
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
695
+ 500: {"description": "Internal error", "model": ErrorResponse},
696
+ },
697
+ tags=["Training"],
698
+ )
699
+ async def train_save_state(request: SaveStateRequest) -> SaveStateResponse:
700
+ """Save training checkpoint."""
701
+ try:
702
+ from pathlib import Path
703
+
704
+ save_path = Path(request.path)
705
+ save_path.parent.mkdir(parents=True, exist_ok=True)
706
+
707
+ metadata = {
708
+ "step": training_state["step"],
709
+ "learning_rate": training_state["learning_rate"],
710
+ **(request.metadata or {}),
711
+ }
712
+
713
+ llm_backend.save_adapter(str(save_path), metadata=metadata)
714
+
715
+ return SaveStateResponse(
716
+ path=str(save_path),
717
+ success=True,
718
+ message=f"Checkpoint saved to {save_path}",
719
+ )
720
+ except Exception as e:
721
+ raise HTTPException(
722
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
723
+ detail=f"Failed to save state: {str(e)}",
724
+ )
725
+
726
+ @router.post(
727
+ "/internal/train/load_state",
728
+ response_model=LoadStateResponse,
729
+ responses={
730
+ 200: {"description": "State loaded", "model": LoadStateResponse},
731
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
732
+ 500: {"description": "Internal error", "model": ErrorResponse},
733
+ },
734
+ tags=["Training"],
735
+ )
736
+ async def train_load_state(request: LoadStateRequest) -> LoadStateResponse:
737
+ """Load training checkpoint."""
738
+ try:
739
+ from pathlib import Path
740
+ import json
741
+
742
+ load_path = Path(request.path)
743
+ if not load_path.exists():
744
+ raise HTTPException(
745
+ status_code=status.HTTP_400_BAD_REQUEST,
746
+ detail=f"Checkpoint not found: {request.path}",
747
+ )
748
+
749
+ llm_backend.apply_adapter(str(load_path))
750
+ adapter_state["path"] = str(load_path)
751
+
752
+ # Try to load metadata
753
+ step = training_state["step"]
754
+ metadata_path = load_path / "adapter_metadata.json"
755
+ if metadata_path.exists():
756
+ with open(metadata_path) as f:
757
+ metadata = json.load(f)
758
+ step = metadata.get("step", step)
759
+ training_state["step"] = step
760
+ training_state["learning_rate"] = metadata.get("learning_rate", training_state["learning_rate"])
761
+
762
+ return LoadStateResponse(
763
+ path=str(load_path),
764
+ success=True,
765
+ message=f"Checkpoint loaded from {load_path}",
766
+ step=step,
767
+ )
768
+ except HTTPException:
769
+ raise
770
+ except Exception as e:
771
+ raise HTTPException(
772
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
773
+ detail=f"Failed to load state: {str(e)}",
774
+ )
775
+
776
+ @router.get(
777
+ "/internal/train/weights",
778
+ response_model=GetWeightsResponse,
779
+ responses={
780
+ 200: {"description": "Weights retrieved", "model": GetWeightsResponse},
781
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
782
+ 500: {"description": "Internal error", "model": ErrorResponse},
783
+ },
784
+ tags=["Training"],
785
+ )
786
+ async def train_get_weights() -> GetWeightsResponse:
787
+ """Get current model weights."""
788
+ try:
789
+ weights = {}
790
+
791
+ if hasattr(llm_backend, 'model') and llm_backend.model:
792
+ model = llm_backend.model
793
+ if hasattr(model, 'trainable_parameters'):
794
+ params = model.trainable_parameters()
795
+ # Convert to serializable format (shape info)
796
+ weights = {
797
+ k: {"shape": list(v.shape) if hasattr(v, 'shape') else str(type(v))}
798
+ for k, v in params.items()
799
+ }
800
+
801
+ return GetWeightsResponse(
802
+ weights=weights,
803
+ success=True,
804
+ message=f"Retrieved {len(weights)} weight tensors",
805
+ )
806
+ except Exception as e:
807
+ raise HTTPException(
808
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
809
+ detail=f"Failed to get weights: {str(e)}",
810
+ )
811
+
812
+ @router.post(
813
+ "/internal/train/weights",
814
+ response_model=SetWeightsResponse,
815
+ responses={
816
+ 200: {"description": "Weights set", "model": SetWeightsResponse},
817
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
818
+ 500: {"description": "Internal error", "model": ErrorResponse},
819
+ },
820
+ tags=["Training"],
821
+ )
822
+ async def train_set_weights(request: SetWeightsRequest) -> SetWeightsResponse:
823
+ """Set model weights."""
824
+ try:
825
+ # In practice, this would deserialize and set weights
826
+ # For now, just return success with count
827
+ num_tensors = len(request.weights)
828
+
829
+ return SetWeightsResponse(
830
+ success=True,
831
+ message=f"Set {num_tensors} weight tensors",
832
+ num_tensors=num_tensors,
833
+ )
834
+ except Exception as e:
835
+ raise HTTPException(
836
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
837
+ detail=f"Failed to set weights: {str(e)}",
838
+ )
839
+
840
+ # ==========================================================================
841
+ # Adapter Hot-Reload
842
+ # ==========================================================================
843
+
844
+ @router.post(
845
+ "/internal/adapter/reload",
846
+ response_model=AdapterReloadResponse,
847
+ responses={
848
+ 200: {"description": "Adapter reloaded successfully", "model": AdapterReloadResponse},
849
+ 400: {"description": "Bad request", "model": ErrorResponse},
850
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
851
+ 404: {"description": "Adapter not found", "model": ErrorResponse},
852
+ 500: {"description": "Internal error", "model": ErrorResponse},
853
+ },
854
+ tags=["Internal"],
855
+ )
856
+ async def reload_adapter(request: AdapterReloadRequest) -> AdapterReloadResponse:
857
+ """Hot-reload adapter weights without restarting the server.
858
+
859
+ Can also reload the base model if needed.
860
+ """
861
+ nonlocal adapter_state
862
+
863
+ try:
864
+ target = request.adapter_path
865
+
866
+ if target:
867
+ target_path = Path(target)
868
+ if not target_path.is_absolute():
869
+ target = str(Path.cwd() / target_path)
870
+ else:
871
+ target = str(target_path)
872
+
873
+ if not target_path.exists():
874
+ raise HTTPException(
875
+ status_code=status.HTTP_404_NOT_FOUND,
876
+ detail=f"Adapter path not found: {target}",
877
+ )
878
+
879
+ if request.reload_base or target is None:
880
+ llm_backend.load(
881
+ base_model,
882
+ max_seq_len=getattr(cfg.model, "max_seq_len", 2048),
883
+ dtype=getattr(cfg.model, "dtype", "float16"),
884
+ trust_remote_code=getattr(cfg.model, "trust_remote_code", False),
885
+ )
886
+ adapter_state["path"] = None
887
+
888
+ if target:
889
+ llm_backend.apply_adapter(target)
890
+ adapter_state["path"] = target
891
+
892
+ return AdapterReloadResponse(
893
+ ok=True,
894
+ base_model=base_model,
895
+ adapter_path=adapter_state["path"],
896
+ message="Adapter reloaded successfully",
897
+ )
898
+
899
+ except HTTPException:
900
+ raise
901
+ except Exception as e:
902
+ raise HTTPException(
903
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
904
+ detail=f"Adapter reload failed: {str(e)}",
905
+ )
906
+
907
+ # ==========================================================================
908
+ # RLM State and History
909
+ # ==========================================================================
910
+
911
+ @router.get(
912
+ "/internal/rlm/state",
913
+ response_model=RLMState,
914
+ responses={
915
+ 200: {"description": "Current RLM state", "model": RLMState},
916
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
917
+ },
918
+ tags=["RLM"],
919
+ )
920
+ async def rlm_state() -> RLMState:
921
+ """Get current RLM training state."""
922
+ state_path = Path.cwd() / "runs" / "rlm_state.json"
923
+
924
+ if not state_path.exists():
925
+ return RLMState(status="idle")
926
+
927
+ try:
928
+ data = json.loads(state_path.read_text(encoding="utf-8"))
929
+ return RLMState(**data)
930
+ except Exception:
931
+ return RLMState(status="idle")
932
+
933
+ @router.get(
934
+ "/internal/rlm/history",
935
+ response_model=List[RLMHistoryEntry],
936
+ responses={
937
+ 200: {"description": "RLM training history", "model": List[RLMHistoryEntry]},
938
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
939
+ },
940
+ tags=["RLM"],
941
+ )
942
+ async def rlm_history(
943
+ limit: Optional[int] = 100,
944
+ offset: Optional[int] = 0,
945
+ ) -> List[RLMHistoryEntry]:
946
+ """Get RLM training history/metrics.
947
+
948
+ Args:
949
+ limit: Maximum number of entries to return
950
+ offset: Number of entries to skip
951
+ """
952
+ history_path = Path.cwd() / "runs" / "rlm_history.jsonl"
953
+
954
+ if not history_path.exists():
955
+ return []
956
+
957
+ rows = []
958
+ try:
959
+ lines = history_path.read_text(encoding="utf-8").splitlines()
960
+ for line in lines[offset:offset + limit] if limit else lines[offset:]:
961
+ if not line.strip():
962
+ continue
963
+ try:
964
+ data = json.loads(line)
965
+ rows.append(RLMHistoryEntry(**data))
966
+ except Exception:
967
+ continue
968
+ except Exception:
969
+ pass
970
+
971
+ return rows
972
+
973
+ # ==========================================================================
974
+ # Model Management
975
+ # ==========================================================================
976
+
977
+ def _get_model_format(path: Path) -> str:
978
+ """Detect model format from path."""
979
+ if (path / "model.safetensors").exists() or (path / "weights.safetensors").exists():
980
+ return "mlx"
981
+ if (path / "pytorch_model.bin").exists() or (path / "model.safetensors").exists():
982
+ return "hf"
983
+ if list(path.glob("*.gguf")):
984
+ return "gguf"
985
+ return "mlx" # Default
986
+
987
+ def _has_adapter(path: Path) -> bool:
988
+ """Check if path contains adapter weights."""
989
+ return (
990
+ (path / "adapter_config.json").exists() or
991
+ (path / "adapters.safetensors").exists() or
992
+ (path / "lora.npz").exists()
993
+ )
994
+
995
+ @router.get(
996
+ "/internal/models/list",
997
+ response_model=ModelsListResponse,
998
+ responses={
999
+ 200: {"description": "List of cached models", "model": ModelsListResponse},
1000
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
1001
+ },
1002
+ tags=["Models"],
1003
+ )
1004
+ async def list_models() -> ModelsListResponse:
1005
+ """List cached MLX models in the cache directory."""
1006
+ cache_dir = _get_cache_dir()
1007
+ models = []
1008
+
1009
+ mlx_dir = cache_dir / "mlx"
1010
+ if mlx_dir.exists():
1011
+ for model_path in mlx_dir.iterdir():
1012
+ if not model_path.is_dir():
1013
+ continue
1014
+
1015
+ # Calculate size
1016
+ size_bytes = 0
1017
+ try:
1018
+ for f in model_path.rglob("*"):
1019
+ if f.is_file():
1020
+ size_bytes += f.stat().st_size
1021
+ except Exception:
1022
+ pass
1023
+
1024
+ # Get metadata
1025
+ metadata = None
1026
+ config_path = model_path / "config.json"
1027
+ if config_path.exists():
1028
+ try:
1029
+ metadata = json.loads(config_path.read_text())
1030
+ except Exception:
1031
+ pass
1032
+
1033
+ model_id = model_path.name.replace("__", "/")
1034
+ has_adapter = _has_adapter(model_path)
1035
+ adapter_path = str(model_path) if has_adapter else None
1036
+
1037
+ models.append(ModelInfo(
1038
+ id=model_id,
1039
+ path=str(model_path),
1040
+ size_bytes=size_bytes,
1041
+ format=_get_model_format(model_path),
1042
+ has_adapter=has_adapter,
1043
+ adapter_path=adapter_path,
1044
+ metadata=metadata,
1045
+ downloaded_at=int(model_path.stat().st_mtime),
1046
+ ))
1047
+
1048
+ # Also check HF cache
1049
+ hf_dir = cache_dir / "hf"
1050
+ if hf_dir.exists():
1051
+ for model_path in hf_dir.iterdir():
1052
+ if not model_path.is_dir():
1053
+ continue
1054
+
1055
+ model_id = model_path.name.replace("__", "/")
1056
+ # Skip if already in MLX format
1057
+ if any(m.id == model_id for m in models):
1058
+ continue
1059
+
1060
+ size_bytes = 0
1061
+ try:
1062
+ for f in model_path.rglob("*"):
1063
+ if f.is_file():
1064
+ size_bytes += f.stat().st_size
1065
+ except Exception:
1066
+ pass
1067
+
1068
+ models.append(ModelInfo(
1069
+ id=f"{model_id} (HF)",
1070
+ path=str(model_path),
1071
+ size_bytes=size_bytes,
1072
+ format="hf",
1073
+ has_adapter=False,
1074
+ metadata=None,
1075
+ downloaded_at=int(model_path.stat().st_mtime),
1076
+ ))
1077
+
1078
+ return ModelsListResponse(
1079
+ models=models,
1080
+ total=len(models),
1081
+ cache_dir=str(cache_dir),
1082
+ )
1083
+
1084
+ @router.post(
1085
+ "/internal/models/pull",
1086
+ response_model=ModelPullResponse,
1087
+ responses={
1088
+ 200: {"description": "Model pull initiated", "model": ModelPullResponse},
1089
+ 400: {"description": "Bad request", "model": ErrorResponse},
1090
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
1091
+ 500: {"description": "Pull failed", "model": ErrorResponse},
1092
+ },
1093
+ tags=["Models"],
1094
+ )
1095
+ async def pull_model(request: ModelPullRequest) -> ModelPullResponse:
1096
+ """Pull a model from HuggingFace.
1097
+
1098
+ This proxies the `mlxsmith pull` command and initiates an
1099
+ asynchronous model download and optional conversion.
1100
+
1101
+ Note: This endpoint returns immediately with a status. For large
1102
+ models, use the list endpoint to check completion status.
1103
+ """
1104
+ cache_dir = _get_cache_dir()
1105
+ local_path = cache_dir / "mlx" / request.model_id.replace("/", "__")
1106
+
1107
+ try:
1108
+ # Import here to avoid circular dependencies
1109
+ from ..models import hf_pull
1110
+ from ..config import ProjectConfig
1111
+
1112
+ # Get HF token if available
1113
+ hf_token = None
1114
+ token_path = Path.home() / ".config" / "mlxsmith" / "hf_token"
1115
+ if token_path.exists():
1116
+ hf_token = token_path.read_text().strip()
1117
+
1118
+ # Start pull in background (synchronous for now)
1119
+ # In production, this would spawn a background task
1120
+ result_path = hf_pull(
1121
+ model_id=request.model_id,
1122
+ cache_dir=cache_dir,
1123
+ convert=request.convert,
1124
+ quantize=request.quantize,
1125
+ q_bits=request.q_bits,
1126
+ q_group_size=request.q_group_size,
1127
+ trust_remote_code=request.trust_remote_code,
1128
+ hf_token=hf_token,
1129
+ )
1130
+
1131
+ return ModelPullResponse(
1132
+ ok=True,
1133
+ model_id=request.model_id,
1134
+ local_path=str(result_path),
1135
+ status=ModelPullStatus(
1136
+ status="completed",
1137
+ progress=100.0,
1138
+ message=f"Model pulled successfully to {result_path}",
1139
+ ),
1140
+ )
1141
+
1142
+ except Exception as e:
1143
+ raise HTTPException(
1144
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1145
+ detail=f"Model pull failed: {str(e)}",
1146
+ )
1147
+
1148
+ # ==========================================================================
1149
+ # HuggingFace Token Management
1150
+ # ==========================================================================
1151
+
1152
+ @router.post(
1153
+ "/internal/hf/token",
1154
+ response_model=HFTokenResponse,
1155
+ responses={
1156
+ 200: {"description": "Token stored successfully", "model": HFTokenResponse},
1157
+ 400: {"description": "Bad request", "model": ErrorResponse},
1158
+ 401: {"description": "Unauthorized", "model": ErrorResponse},
1159
+ 500: {"description": "Storage failed", "model": ErrorResponse},
1160
+ },
1161
+ tags=["HF Token"],
1162
+ )
1163
+ async def store_hf_token(request: HFTokenRequest) -> HFTokenResponse:
1164
+ """Store HuggingFace token securely.
1165
+
1166
+ Attempts to use system keyring if available, falls back to
1167
+ file-based storage with restricted permissions.
1168
+ """
1169
+ token = request.token
1170
+ username = None
1171
+ storage_method: str = "memory"
1172
+
1173
+ # Validate token if requested
1174
+ if request.validate_token:
1175
+ try:
1176
+ from huggingface_hub import HfApi
1177
+ api = HfApi(token=token)
1178
+ user_info = api.whoami()
1179
+ username = user_info.get("name") if user_info else None
1180
+ except Exception as e:
1181
+ raise HTTPException(
1182
+ status_code=status.HTTP_400_BAD_REQUEST,
1183
+ detail=f"Token validation failed: {str(e)}",
1184
+ )
1185
+
1186
+ # Store token
1187
+ if request.persist:
1188
+ try:
1189
+ # Try keyring first
1190
+ try:
1191
+ import keyring
1192
+ keyring.set_password("mlxsmith", "huggingface", token)
1193
+ storage_method = "keyring"
1194
+ except ImportError:
1195
+ # Fall back to file storage
1196
+ config_dir = Path.home() / ".config" / "mlxsmith"
1197
+ config_dir.mkdir(parents=True, exist_ok=True)
1198
+ token_path = config_dir / "hf_token"
1199
+ token_path.write_text(token, encoding="utf-8")
1200
+ # Restrict permissions (owner read/write only)
1201
+ os.chmod(token_path, 0o600)
1202
+ storage_method = "file"
1203
+ except Exception as e:
1204
+ raise HTTPException(
1205
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1206
+ detail=f"Failed to store token: {str(e)}",
1207
+ )
1208
+
1209
+ return HFTokenResponse(
1210
+ ok=True,
1211
+ validated=request.validate_token,
1212
+ username=username,
1213
+ message="Token stored successfully",
1214
+ storage_method=storage_method,
1215
+ )
1216
+
1217
+ return router