superlinear 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. apps/__init__.py +4 -0
  2. apps/cli/__init__.py +8 -0
  3. apps/cli/bm25_rag.py +471 -0
  4. apps/cli/chat_repl.py +1497 -0
  5. apps/cli/client.py +195 -0
  6. apps/cli/docs_repl.py +2275 -0
  7. apps/cli/light_rag.py +729 -0
  8. apps/cli/local_snapshots.py +139 -0
  9. apps/cli/locks.py +214 -0
  10. apps/cli/main.py +457 -0
  11. apps/cli/output.py +32 -0
  12. apps/cli/server_cmds.py +516 -0
  13. apps/cli/session_cmds.py +491 -0
  14. apps/cli/snapshot_cmds.py +303 -0
  15. apps/cli/state.py +265 -0
  16. apps/server/__init__.py +4 -0
  17. apps/server/app.py +1363 -0
  18. apps/server/main.py +313 -0
  19. superlinear/__init__.py +114 -0
  20. superlinear/_version.py +3 -0
  21. superlinear/engine/__init__.py +10 -0
  22. superlinear/engine/adapters/__init__.py +12 -0
  23. superlinear/engine/adapters/base.py +91 -0
  24. superlinear/engine/adapters/superlinear.py +1233 -0
  25. superlinear/engine/chat_engine.py +1173 -0
  26. superlinear/engine/chat_types.py +130 -0
  27. superlinear/engine/registry.py +51 -0
  28. superlinear/engine/repetition.py +203 -0
  29. superlinear/engine/session_snapshots.py +451 -0
  30. superlinear/engine/tool_parser.py +83 -0
  31. superlinear/engine/types.py +42 -0
  32. superlinear/kernels/__init__.py +2 -0
  33. superlinear/kernels/common/__init__.py +21 -0
  34. superlinear/kernels/common/adjustment.py +106 -0
  35. superlinear/kernels/common/power.py +154 -0
  36. superlinear/kernels/superlinear/__init__.py +10 -0
  37. superlinear/kernels/superlinear/attention/__init__.py +78 -0
  38. superlinear/kernels/superlinear/attention/_prefill.py +940 -0
  39. superlinear/kernels/superlinear/attention/_sliding_window.py +1167 -0
  40. superlinear/kernels/superlinear/attention/api.py +433 -0
  41. superlinear/kernels/superlinear/search/__init__.py +33 -0
  42. superlinear/kernels/superlinear/search/_reference.py +204 -0
  43. superlinear/kernels/superlinear/search/_triton.py +488 -0
  44. superlinear/kernels/superlinear/search/_triton_gqa.py +534 -0
  45. superlinear/kernels/superlinear/search/api.py +200 -0
  46. superlinear/kernels/superlinear/span/__init__.py +41 -0
  47. superlinear/kernels/superlinear/span/_triton_bucketed_gqa.py +1461 -0
  48. superlinear/kernels/superlinear/span/_triton_forward.py +22 -0
  49. superlinear/kernels/superlinear/span/_triton_gqa.py +1226 -0
  50. superlinear/kernels/superlinear/span/_triton_impl.py +928 -0
  51. superlinear/kernels/superlinear/span/_triton_precomputed_sw.py +460 -0
  52. superlinear/kernels/superlinear/span/_triton_precomputed_sw_gqa.py +598 -0
  53. superlinear/kernels/superlinear/span/api.py +296 -0
  54. superlinear/kernels/superlinear/span/masks.py +187 -0
  55. superlinear/py.typed +0 -0
  56. superlinear/runtime.py +71 -0
  57. superlinear-0.1.0.dist-info/METADATA +469 -0
  58. superlinear-0.1.0.dist-info/RECORD +62 -0
  59. superlinear-0.1.0.dist-info/WHEEL +5 -0
  60. superlinear-0.1.0.dist-info/entry_points.txt +2 -0
  61. superlinear-0.1.0.dist-info/licenses/LICENSE +202 -0
  62. superlinear-0.1.0.dist-info/top_level.txt +2 -0
apps/server/app.py ADDED
@@ -0,0 +1,1363 @@
1
+ """FastAPI app for OpenAI-style Chat Completions.
2
+
3
+ The HTTP layer lives under `apps/` and can depend on heavier deps (FastAPI, uvicorn).
4
+ All model execution is delegated to the core engine (`superlinear/engine`).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import asyncio
10
+ import os
11
+ import threading
12
+ import json
13
+ import time
14
+ import uuid
15
+ from dataclasses import dataclass, field
16
+ from typing import Any, AsyncIterator
17
+
18
+ from fastapi import FastAPI, HTTPException, Request
19
+ from starlette.responses import JSONResponse, StreamingResponse
20
+
21
+ from superlinear.engine.chat_engine import ChatEngine
22
+ from superlinear.engine.chat_types import ChatMessage, ChatRequest, StreamOptions, Timing, ToolCall, Usage
23
+ from superlinear.engine.session_snapshots import (
24
+ SnapshotCompatibilityError,
25
+ SnapshotStoreV1,
26
+ compute_model_compatibility,
27
+ export_hybrid_mamba_attention_static_cache,
28
+ import_hybrid_mamba_attention_static_cache,
29
+ )
30
+
31
+
32
+ def create_app(
33
+ *,
34
+ engine: ChatEngine,
35
+ model_id: str,
36
+ http_max_concurrency: int | None = None,
37
+ http_max_completion_tokens: int | None = None,
38
+ ) -> FastAPI:
39
+ app = FastAPI(title="Superlinear Inference Server", version="0.1.0")
40
+
41
+ default_max_seq_len = 131_072
42
+
43
+ @dataclass
44
+ class _HttpSession:
45
+ max_seq_len: int
46
+ messages: list[dict[str, Any]] = field(default_factory=list)
47
+
48
+ _sessions_lock = threading.Lock()
49
+ _sessions: dict[str, _HttpSession] = {}
50
+
51
+ _engine_lock = getattr(engine, "_lock", threading.Lock())
52
+
53
+ _snapshot_store_lock = threading.Lock()
54
+ _snapshot_store: SnapshotStoreV1 | None = None
55
+
56
+ http_semaphore: asyncio.Semaphore | None = None
57
+ if http_max_concurrency is not None:
58
+ try:
59
+ http_max_concurrency = int(http_max_concurrency)
60
+ except Exception as exc:
61
+ raise ValueError("http_max_concurrency must be an integer") from exc
62
+ if http_max_concurrency > 0:
63
+ http_semaphore = asyncio.Semaphore(http_max_concurrency)
64
+ elif http_max_concurrency < 0:
65
+ raise ValueError("http_max_concurrency must be >= 0")
66
+
67
+ if http_max_completion_tokens is not None:
68
+ try:
69
+ http_max_completion_tokens = int(http_max_completion_tokens)
70
+ except Exception as exc:
71
+ raise ValueError("http_max_completion_tokens must be an integer") from exc
72
+ if http_max_completion_tokens <= 0:
73
+ raise ValueError("http_max_completion_tokens must be > 0")
74
+
75
+ async def _wait_for_disconnect(request: Request, poll_s: float = 0.1) -> None:
76
+ while True:
77
+ if await request.is_disconnected():
78
+ return
79
+ await asyncio.sleep(poll_s)
80
+
81
+ async def _run_with_disconnect_cancellation(request: Request, coro: Any) -> Any:
82
+ task = asyncio.create_task(coro)
83
+ disconnect_task = asyncio.create_task(_wait_for_disconnect(request))
84
+ # Yield to let both tasks start (handles coroutines that return synchronously).
85
+ await asyncio.sleep(0)
86
+ done, pending = await asyncio.wait(
87
+ {task, disconnect_task},
88
+ return_when=asyncio.FIRST_COMPLETED,
89
+ )
90
+ if disconnect_task in done:
91
+ task.cancel()
92
+ try:
93
+ await task
94
+ except asyncio.CancelledError:
95
+ pass
96
+ raise HTTPException(status_code=499, detail="Client disconnected")
97
+
98
+ disconnect_task.cancel()
99
+ try:
100
+ await disconnect_task
101
+ except asyncio.CancelledError:
102
+ pass
103
+ return task.result()
104
+
105
+ async def _try_acquire_semaphore() -> None:
106
+ if http_semaphore is None:
107
+ return
108
+ try:
109
+ await asyncio.wait_for(http_semaphore.acquire(), timeout=0.001)
110
+ except TimeoutError as exc:
111
+ raise HTTPException(status_code=429, detail="Server is busy") from exc
112
+
113
+ def _get_snapshot_store() -> SnapshotStoreV1:
114
+ nonlocal _snapshot_store
115
+ with _snapshot_store_lock:
116
+ if _snapshot_store is not None:
117
+ return _snapshot_store
118
+
119
+ adapter = getattr(engine, "adapter", None)
120
+ if adapter is None:
121
+ raise HTTPException(status_code=500, detail="Engine does not expose an adapter.")
122
+
123
+ xdg_cache = os.environ.get("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache"))
124
+ default_snapshot_dir = os.path.join(xdg_cache, "spl", "snapshots")
125
+ root_dir = os.environ.get("SUPERLINEAR_SNAPSHOT_DIR", default_snapshot_dir)
126
+ compat = compute_model_compatibility(adapter=adapter, model_id=model_id)
127
+ _snapshot_store = SnapshotStoreV1(root_dir=root_dir, model_id=model_id, compat=compat)
128
+ return _snapshot_store
129
+
130
+ async def _json_dict_or_empty(request: Request) -> dict[str, Any]:
131
+ try:
132
+ payload = await request.json()
133
+ except Exception:
134
+ return {}
135
+ if isinstance(payload, dict):
136
+ return payload
137
+ raise HTTPException(status_code=400, detail="Request body must be a JSON object.")
138
+
139
+ # -------------------------------------------------------------------------
140
+ # Health & Models
141
+ # -------------------------------------------------------------------------
142
+
143
+ @app.get("/health")
144
+ async def health() -> dict[str, str]:
145
+ return {"status": "ok"}
146
+
147
+ @app.get("/v1/models")
148
+ async def list_models() -> dict[str, Any]:
149
+ now = int(time.time())
150
+ return {
151
+ "object": "list",
152
+ "data": [
153
+ {
154
+ "id": model_id,
155
+ "object": "model",
156
+ "created": now,
157
+ "owned_by": "superlinear",
158
+ }
159
+ ],
160
+ }
161
+
162
+ # -------------------------------------------------------------------------
163
+ # Session Management
164
+ # -------------------------------------------------------------------------
165
+
166
+ @app.post("/v1/sessions")
167
+ async def create_session(request: Request) -> Any:
168
+ """Create a new stateful session for multi-turn conversations."""
169
+ payload = await request.json()
170
+ session_id = payload.get("session_id")
171
+ if not session_id or not isinstance(session_id, str):
172
+ raise HTTPException(status_code=400, detail="'session_id' is required and must be a string.")
173
+
174
+ max_seq_len = payload.get("max_seq_len", default_max_seq_len)
175
+ try:
176
+ max_seq_len = int(max_seq_len)
177
+ except (ValueError, TypeError) as exc:
178
+ raise HTTPException(status_code=400, detail="'max_seq_len' must be an integer.") from exc
179
+
180
+ try:
181
+ with _engine_lock:
182
+ engine.adapter.create_session(
183
+ cache_id=session_id,
184
+ max_seq_len=max_seq_len,
185
+ )
186
+ except ValueError as exc:
187
+ raise HTTPException(status_code=409, detail=str(exc)) from exc
188
+ except Exception as exc:
189
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
190
+
191
+ with _sessions_lock:
192
+ _sessions[session_id] = _HttpSession(max_seq_len=max_seq_len)
193
+
194
+ return JSONResponse({"status": "created", "session_id": session_id})
195
+
196
+ @app.get("/v1/sessions")
197
+ async def list_sessions() -> Any:
198
+ """List all active sessions."""
199
+ with _engine_lock:
200
+ sessions = engine.adapter.list_sessions()
201
+ return JSONResponse({"sessions": sessions})
202
+
203
+ @app.get("/v1/sessions/{session_id}")
204
+ async def get_session_info(session_id: str) -> Any:
205
+ """Get information about a specific session."""
206
+ try:
207
+ with _engine_lock:
208
+ info = engine.adapter.get_session_info(session_id)
209
+ except KeyError as exc:
210
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
211
+ except Exception as exc:
212
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
213
+
214
+ # Convenience aliases for client code/tests.
215
+ info = dict(info)
216
+ info["cache_position"] = info.get("current_pos")
217
+ with _sessions_lock:
218
+ meta = _sessions.get(session_id)
219
+ if meta is not None:
220
+ info["message_count"] = len(meta.messages)
221
+ return JSONResponse(info)
222
+
223
+ @app.delete("/v1/sessions/{session_id}")
224
+ async def close_session(session_id: str) -> Any:
225
+ """Close a session and free its resources."""
226
+ # Check if session exists
227
+ try:
228
+ with _engine_lock:
229
+ engine.adapter.get_session_info(session_id)
230
+ except KeyError as exc:
231
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
232
+
233
+ with _engine_lock:
234
+ engine.adapter.close_session(session_id)
235
+ with _sessions_lock:
236
+ _sessions.pop(session_id, None)
237
+ return JSONResponse({"status": "closed", "session_id": session_id})
238
+
239
+ @app.get("/v1/sessions/{session_id}/history")
240
+ async def get_session_history(session_id: str) -> Any:
241
+ """Get the stored chat history for a session."""
242
+ try:
243
+ with _engine_lock:
244
+ engine.adapter.get_session_info(session_id)
245
+ except KeyError as exc:
246
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
247
+
248
+ with _sessions_lock:
249
+ meta = _sessions.get(session_id)
250
+ if meta is None:
251
+ return JSONResponse({"session_id": session_id, "messages": []})
252
+ return JSONResponse({"session_id": session_id, "messages": meta.messages})
253
+
254
+ @app.post("/v1/sessions/{session_id}/rollback")
255
+ async def rollback_session(session_id: str, request: Request) -> Any:
256
+ """Rollback a session to an earlier message index by replaying history.
257
+
258
+ Body:
259
+ - keep_messages: int (number of messages to keep from the start)
260
+ """
261
+ payload = await request.json()
262
+ keep_messages = payload.get("keep_messages")
263
+ try:
264
+ keep_messages = int(keep_messages)
265
+ except Exception as exc:
266
+ raise HTTPException(status_code=400, detail="'keep_messages' must be an integer.") from exc
267
+ if keep_messages < 0:
268
+ raise HTTPException(status_code=400, detail="'keep_messages' must be >= 0.")
269
+
270
+ # Ensure session exists and retrieve max_seq_len.
271
+ try:
272
+ with _engine_lock:
273
+ adapter_info = engine.adapter.get_session_info(session_id)
274
+ except KeyError as exc:
275
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
276
+
277
+ with _sessions_lock:
278
+ meta = _sessions.get(session_id)
279
+ if meta is None:
280
+ raise HTTPException(status_code=404, detail=f"No stored history for session: {session_id}")
281
+ meta.messages = meta.messages[:keep_messages]
282
+ max_seq_len = meta.max_seq_len
283
+ history_msgs = list(meta.messages)
284
+
285
+ # Recreate adapter session and replay history prompt.
286
+ with _engine_lock:
287
+ engine.adapter.close_session(session_id)
288
+ engine.adapter.create_session(cache_id=session_id, max_seq_len=max_seq_len)
289
+
290
+ if history_msgs:
291
+ # Build prompt from stored history WITHOUT adding a generation prompt.
292
+ # The next /chat/completions call will add user msg + generation prompt.
293
+ chat_req = _parse_chat_request({"messages": history_msgs, "max_tokens": 1})
294
+ tokenizer = getattr(engine.adapter, "tokenizer", None)
295
+ apply_chat_template = getattr(tokenizer, "apply_chat_template", None)
296
+ if not callable(apply_chat_template):
297
+ raise HTTPException(status_code=500, detail="Tokenizer does not support apply_chat_template().")
298
+
299
+ injected = engine._inject_tool_choice(list(chat_req.messages), chat_req.tool_choice) # type: ignore[attr-defined]
300
+ template_messages = engine._messages_for_template(injected) # type: ignore[attr-defined]
301
+ kwargs = engine._chat_template_kwargs( # type: ignore[attr-defined]
302
+ chat_req, enable_thinking=engine._effective_enable_thinking(chat_req) # type: ignore[attr-defined]
303
+ )
304
+ input_ids = apply_chat_template(
305
+ template_messages,
306
+ add_generation_prompt=False,
307
+ **kwargs,
308
+ )
309
+ with _engine_lock:
310
+ engine.adapter.append_to_session(cache_id=session_id, input_ids=input_ids)
311
+
312
+ with _engine_lock:
313
+ new_info = engine.adapter.get_session_info(session_id)
314
+ new_info = dict(new_info)
315
+ new_info["cache_position"] = new_info.get("current_pos")
316
+ new_info["message_count"] = keep_messages
317
+ return JSONResponse({"status": "ok", "session_id": session_id, "session": new_info})
318
+
319
+ def _parse_resize_strategy(payload: dict[str, Any]) -> str:
320
+ strategy = payload.get("strategy", "auto")
321
+ if strategy is None:
322
+ strategy = "auto"
323
+ if not isinstance(strategy, str):
324
+ raise HTTPException(status_code=400, detail="'strategy' must be a string.")
325
+ strategy = strategy.lower().strip()
326
+ if strategy not in {"auto", "gpu", "disk"}:
327
+ raise HTTPException(status_code=400, detail="'strategy' must be one of: auto, gpu, disk.")
328
+ return strategy
329
+
330
+ def _next_pow2_strictly_greater(n: int) -> int:
331
+ if n <= 0:
332
+ return 1
333
+ p = 1 << ((n - 1).bit_length())
334
+ if p == n:
335
+ p *= 2
336
+ return p
337
+
338
+ def _resize_session_to(*, session_id: str, target_max_seq_len: int, strategy: str) -> dict[str, Any]:
339
+ if target_max_seq_len <= 0:
340
+ raise HTTPException(status_code=400, detail="'max_seq_len' must be > 0.")
341
+
342
+ def _allocate_and_restore(*, close_first: bool) -> dict[str, Any]:
343
+ # Everything here runs under _engine_lock.
344
+ exported = engine.adapter.export_session(session_id)
345
+ current_pos = int(exported.get("current_pos") or 0)
346
+ old_max = int(exported.get("max_seq_len") or 0)
347
+
348
+ if current_pos < 0:
349
+ raise HTTPException(status_code=500, detail="Invalid session current_pos.")
350
+ if old_max <= 0:
351
+ raise HTTPException(status_code=500, detail="Invalid session max_seq_len.")
352
+ if target_max_seq_len < current_pos:
353
+ raise HTTPException(
354
+ status_code=400,
355
+ detail=f"'max_seq_len' ({target_max_seq_len}) must be >= current_pos ({current_pos}).",
356
+ )
357
+ if target_max_seq_len == old_max:
358
+ info = engine.adapter.get_session_info(session_id)
359
+ info = dict(info)
360
+ info["cache_position"] = info.get("current_pos")
361
+ return {"status": "noop", "session_id": session_id, "session": info}
362
+
363
+ cache_payload = export_hybrid_mamba_attention_static_cache(
364
+ cache=exported["past_key_values"],
365
+ current_pos=current_pos,
366
+ )
367
+
368
+ model = getattr(engine.adapter, "model", None)
369
+ if model is None or not hasattr(model, "create_static_cache"):
370
+ raise HTTPException(status_code=500, detail="Adapter does not expose create_static_cache().")
371
+
372
+ if close_first:
373
+ # Free the old cache to reduce peak VRAM.
374
+ engine.adapter.close_session(session_id)
375
+
376
+ past_key_values = model.create_static_cache(batch_size=1, max_seq_len=target_max_seq_len)
377
+ restored_pos = import_hybrid_mamba_attention_static_cache(cache=past_key_values, payload=cache_payload)
378
+ engine.adapter.restore_session(
379
+ cache_id=session_id,
380
+ past_key_values=past_key_values,
381
+ current_pos=restored_pos,
382
+ max_seq_len=target_max_seq_len,
383
+ next_token_logits=None,
384
+ overwrite=True,
385
+ )
386
+
387
+ info = engine.adapter.get_session_info(session_id)
388
+ info = dict(info)
389
+ info["cache_position"] = info.get("current_pos")
390
+ return {
391
+ "status": "resized",
392
+ "session_id": session_id,
393
+ "old_max_seq_len": old_max,
394
+ "max_seq_len": target_max_seq_len,
395
+ "current_pos": restored_pos,
396
+ "session": info,
397
+ }
398
+
399
+ try:
400
+ with _engine_lock:
401
+ if strategy == "gpu":
402
+ result = _allocate_and_restore(close_first=False)
403
+ elif strategy == "disk":
404
+ result = _allocate_and_restore(close_first=True)
405
+ else: # auto
406
+ try:
407
+ result = _allocate_and_restore(close_first=False)
408
+ except Exception as exc:
409
+ # Best-effort fallback on CUDA OOM by freeing old cache first.
410
+ try:
411
+ import torch # type: ignore
412
+
413
+ if isinstance(exc, torch.cuda.OutOfMemoryError):
414
+ result = _allocate_and_restore(close_first=True)
415
+ else:
416
+ raise
417
+ except HTTPException:
418
+ raise
419
+ except Exception:
420
+ raise
421
+ except HTTPException:
422
+ raise
423
+ except KeyError as exc:
424
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
425
+ except Exception as exc:
426
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
427
+
428
+ with _sessions_lock:
429
+ meta = _sessions.get(session_id)
430
+ if meta is not None:
431
+ meta.max_seq_len = int(target_max_seq_len)
432
+ if isinstance(result, dict) and isinstance(result.get("session"), dict):
433
+ result["session"]["message_count"] = len(meta.messages)
434
+
435
+ return result
436
+
437
+ @app.post("/v1/sessions/{session_id}/resize")
438
+ async def resize_session(session_id: str, request: Request) -> Any:
439
+ """Resize a session KV cache to a new max sequence length.
440
+
441
+ Body:
442
+ - max_seq_len: int (required)
443
+ - strategy: "auto" | "gpu" | "disk" (optional; default: "auto")
444
+
445
+ Strategies:
446
+ - gpu: allocate new cache while old cache is still resident (higher peak VRAM)
447
+ - disk: free old cache before allocating new cache (lower peak VRAM)
448
+ - auto: try gpu, fall back to disk on OOM
449
+ """
450
+
451
+ payload = await _json_dict_or_empty(request)
452
+ raw_max = payload.get("max_seq_len")
453
+ if raw_max is None:
454
+ raise HTTPException(status_code=400, detail="'max_seq_len' is required.")
455
+ try:
456
+ new_max_seq_len = int(raw_max)
457
+ except Exception as exc:
458
+ raise HTTPException(status_code=400, detail="'max_seq_len' must be an integer.") from exc
459
+
460
+ strategy = _parse_resize_strategy(payload)
461
+ return JSONResponse(
462
+ _resize_session_to(session_id=session_id, target_max_seq_len=new_max_seq_len, strategy=strategy)
463
+ )
464
+
465
+ @app.post("/v1/sessions/{session_id}/resize/next_pow2")
466
+ async def resize_session_next_pow2(session_id: str, request: Request) -> Any:
467
+ """Resize a session KV cache to the next power-of-two max sequence length.
468
+
469
+ Body (optional):
470
+ - strategy: "auto" | "gpu" | "disk" (default: "auto")
471
+
472
+ Example:
473
+ - 131072 -> 262144
474
+ - 262144 -> 524288
475
+ """
476
+
477
+ payload = await _json_dict_or_empty(request)
478
+ strategy = _parse_resize_strategy(payload)
479
+
480
+ try:
481
+ with _engine_lock:
482
+ info = engine.adapter.get_session_info(session_id)
483
+ except KeyError as exc:
484
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
485
+ except Exception as exc:
486
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
487
+
488
+ old_max = int(info.get("max_seq_len") or 0)
489
+ new_max = _next_pow2_strictly_greater(old_max)
490
+
491
+ result = _resize_session_to(session_id=session_id, target_max_seq_len=new_max, strategy=strategy)
492
+ if isinstance(result, dict):
493
+ result = dict(result)
494
+ result["mode"] = "next_pow2"
495
+ return JSONResponse(result)
496
+
497
+ # -------------------------------------------------------------------------
498
+ # Snapshot Management (v1)
499
+ # -------------------------------------------------------------------------
500
+
501
+ @app.post("/v1/sessions/{session_id}/save")
502
+ async def save_session_snapshot(session_id: str, request: Request) -> Any:
503
+ """Save a session to an immutable on-disk snapshot."""
504
+ payload = await _json_dict_or_empty(request)
505
+
506
+ title = payload.get("title")
507
+ description = payload.get("description")
508
+ tags = payload.get("tags")
509
+ if tags is not None and not isinstance(tags, list):
510
+ raise HTTPException(status_code=400, detail="'tags' must be a list of strings.")
511
+
512
+ transcript: list[dict[str, Any]] = []
513
+ try:
514
+ with _engine_lock:
515
+ exported = engine.adapter.export_session(session_id)
516
+ with _sessions_lock:
517
+ meta = _sessions.get(session_id)
518
+ transcript = list(meta.messages) if meta is not None else []
519
+ cache_payload = export_hybrid_mamba_attention_static_cache(
520
+ cache=exported["past_key_values"],
521
+ current_pos=int(exported["current_pos"]),
522
+ )
523
+ except KeyError as exc:
524
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
525
+ except Exception as exc:
526
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
527
+
528
+ store = _get_snapshot_store()
529
+ try:
530
+ manifest = store.create_snapshot(
531
+ transcript=transcript,
532
+ cache_payload=cache_payload,
533
+ session={"max_seq_len": int(exported["max_seq_len"]), "current_pos": int(exported["current_pos"])},
534
+ title=title if isinstance(title, str) else None,
535
+ description=description if isinstance(description, str) else None,
536
+ tags=[str(t) for t in tags] if isinstance(tags, list) else None,
537
+ )
538
+ except Exception as exc:
539
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
540
+
541
+ return JSONResponse({"status": "saved", "snapshot_id": manifest.snapshot_id})
542
+
543
+ @app.get("/v1/snapshots")
544
+ async def list_snapshots() -> Any:
545
+ store = _get_snapshot_store()
546
+ snaps = [m.to_dict() for m in store.list_snapshots()]
547
+ return JSONResponse({"snapshots": snaps})
548
+
549
+ @app.get("/v1/snapshots/{snapshot_id}")
550
+ async def get_snapshot(snapshot_id: str) -> Any:
551
+ store = _get_snapshot_store()
552
+ try:
553
+ manifest = store.get_manifest(snapshot_id)
554
+ except FileNotFoundError as exc:
555
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
556
+ except ValueError as exc:
557
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
558
+ return JSONResponse(manifest.to_dict())
559
+
560
+ @app.patch("/v1/snapshots/{snapshot_id}")
561
+ async def patch_snapshot(snapshot_id: str, request: Request) -> Any:
562
+ payload = await request.json()
563
+ if not isinstance(payload, dict):
564
+ raise HTTPException(status_code=400, detail="Request body must be a JSON object.")
565
+
566
+ title = payload.get("title")
567
+ description = payload.get("description")
568
+ tags = payload.get("tags")
569
+ if tags is not None and not isinstance(tags, list):
570
+ raise HTTPException(status_code=400, detail="'tags' must be a list of strings.")
571
+
572
+ store = _get_snapshot_store()
573
+ try:
574
+ updated = store.patch_metadata(
575
+ snapshot_id,
576
+ title=title if isinstance(title, str) else None,
577
+ description=description if isinstance(description, str) else None,
578
+ tags=[str(t) for t in tags] if isinstance(tags, list) else None,
579
+ )
580
+ except FileNotFoundError as exc:
581
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
582
+ except ValueError as exc:
583
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
584
+ return JSONResponse(updated.to_dict())
585
+
586
+ @app.delete("/v1/snapshots/{snapshot_id}")
587
+ async def delete_snapshot(snapshot_id: str) -> Any:
588
+ store = _get_snapshot_store()
589
+ try:
590
+ store.delete_snapshot(snapshot_id)
591
+ except FileNotFoundError as exc:
592
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
593
+ except ValueError as exc:
594
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
595
+ return JSONResponse({"status": "deleted", "snapshot_id": snapshot_id})
596
+
597
+ @app.post("/v1/snapshots/{snapshot_id}/load")
598
+ async def load_snapshot(snapshot_id: str, request: Request) -> Any:
599
+ payload = await _json_dict_or_empty(request)
600
+
601
+ target_session_id = payload.get("session_id")
602
+ if target_session_id is not None and not isinstance(target_session_id, str):
603
+ raise HTTPException(status_code=400, detail="'session_id' must be a string.")
604
+ force = bool(payload.get("force", False))
605
+ if not target_session_id:
606
+ target_session_id = f"sess_{uuid.uuid4().hex}"
607
+
608
+ store = _get_snapshot_store()
609
+ try:
610
+ manifest, transcript, cache_payload = store.load_snapshot_payload(snapshot_id)
611
+ except SnapshotCompatibilityError as exc:
612
+ raise HTTPException(status_code=409, detail=str(exc)) from exc
613
+ except FileNotFoundError as exc:
614
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
615
+ except ValueError as exc:
616
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
617
+ except Exception as exc:
618
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
619
+
620
+ max_seq_len = int(manifest.session.get("max_seq_len") or default_max_seq_len)
621
+ expected_pos = int(manifest.session.get("current_pos") or 0)
622
+ restored_pos = expected_pos
623
+
624
+ try:
625
+ with _engine_lock:
626
+ # Avoid accidental overwrite unless explicitly forced.
627
+ try:
628
+ engine.adapter.get_session_info(target_session_id)
629
+ if not force:
630
+ raise HTTPException(
631
+ status_code=409,
632
+ detail=f"Session already exists: {target_session_id} (use force=true to overwrite).",
633
+ )
634
+ engine.adapter.close_session(target_session_id)
635
+ except KeyError:
636
+ pass
637
+
638
+ model = getattr(engine.adapter, "model", None)
639
+ if model is None or not hasattr(model, "create_static_cache"):
640
+ raise HTTPException(status_code=500, detail="Adapter does not expose create_static_cache().")
641
+
642
+ past_key_values = model.create_static_cache(batch_size=1, max_seq_len=max_seq_len)
643
+ restored_pos = import_hybrid_mamba_attention_static_cache(
644
+ cache=past_key_values, payload=cache_payload
645
+ )
646
+ engine.adapter.restore_session(
647
+ cache_id=target_session_id,
648
+ past_key_values=past_key_values,
649
+ current_pos=restored_pos,
650
+ max_seq_len=max_seq_len,
651
+ next_token_logits=None,
652
+ overwrite=False,
653
+ )
654
+ except HTTPException:
655
+ raise
656
+ except Exception as exc:
657
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
658
+
659
+ with _sessions_lock:
660
+ _sessions[target_session_id] = _HttpSession(max_seq_len=max_seq_len, messages=transcript)
661
+
662
+ return JSONResponse(
663
+ {
664
+ "status": "loaded",
665
+ "snapshot_id": snapshot_id,
666
+ "session_id": target_session_id,
667
+ "session": {
668
+ "current_pos": restored_pos,
669
+ "max_seq_len": max_seq_len,
670
+ "message_count": len(transcript),
671
+ },
672
+ }
673
+ )
674
+
675
+ # -------------------------------------------------------------------------
676
+ # Chat Completions
677
+ # -------------------------------------------------------------------------
678
+
679
+ @app.post("/v1/chat/completions")
680
+ async def chat_completions(request: Request) -> Any:
681
+ payload = await request.json()
682
+
683
+ req_model = payload.get("model")
684
+ if req_model is not None and req_model != model_id:
685
+ raise HTTPException(status_code=404, detail=f"Unknown model: {req_model}")
686
+
687
+ chat_req = _parse_chat_request(payload, http_max_completion_tokens=http_max_completion_tokens)
688
+
689
+ # Session chat: maintain server-side message history and only append delta tokens.
690
+ if chat_req.session_id:
691
+ session_id = chat_req.session_id
692
+ # Ensure session exists.
693
+ try:
694
+ with _engine_lock:
695
+ sess_info = engine.adapter.get_session_info(session_id)
696
+ except KeyError as exc:
697
+ raise HTTPException(status_code=404, detail=str(exc)) from exc
698
+
699
+ try:
700
+ current_pos = int(sess_info.get("current_pos", 0) or 0)
701
+ except Exception:
702
+ current_pos = 0
703
+
704
+ with _sessions_lock:
705
+ meta = _sessions.get(session_id)
706
+ if meta is None:
707
+ # Session exists in adapter but not in HTTP store (e.g., server restarted).
708
+ meta = _HttpSession(max_seq_len=int(sess_info.get("max_seq_len", default_max_seq_len)))
709
+ _sessions[session_id] = meta
710
+
711
+ # Safety: a non-empty KV cache with an empty HTTP transcript means the server cannot
712
+ # correctly compute delta tokens to append. Proceeding would cause the model to ignore
713
+ # new user input or append mismatched tokens.
714
+ if current_pos > 0 and not meta.messages:
715
+ raise HTTPException(
716
+ status_code=409,
717
+ detail=(
718
+ "Session KV cache is non-empty but HTTP transcript is empty. "
719
+ "This indicates a corrupted/incomplete session state. "
720
+ "Start a new session or restore from a snapshot."
721
+ ),
722
+ )
723
+
724
+ # Treat incoming messages as delta and append to stored history.
725
+ incoming_raw = payload.get("messages")
726
+ if isinstance(incoming_raw, list):
727
+ incoming_msgs = [m for m in incoming_raw if isinstance(m, dict)]
728
+
729
+ # Keep at most one leading system message in the stored transcript.
730
+ # The CLI may send the same system prompt every turn; accumulating
731
+ # duplicates wastes context and can degrade multi-turn coherence.
732
+ if incoming_msgs and incoming_msgs[0].get("role") == "system":
733
+ if meta.messages and isinstance(meta.messages[0], dict) and meta.messages[0].get("role") == "system":
734
+ # Never mutate the prompt prefix once the KV cache has advanced.
735
+ # Even a small change to the system message would invalidate the cached
736
+ # token stream and corrupt append-from slicing.
737
+ if current_pos <= 0:
738
+ meta.messages[0] = incoming_msgs[0]
739
+ # Drop the incoming system message to avoid duplicates.
740
+ incoming_msgs = incoming_msgs[1:]
741
+
742
+ meta.messages.extend(incoming_msgs)
743
+
744
+ full_messages = list(meta.messages)
745
+
746
+ # Build a new ChatRequest with full history messages.
747
+ # Append-from position is the current KV cache position.
748
+ chat_req = ChatRequest(
749
+ messages=_parse_chat_request({"messages": full_messages}).messages,
750
+ tools=chat_req.tools,
751
+ tool_choice=chat_req.tool_choice,
752
+ max_tokens=chat_req.max_tokens,
753
+ temperature=chat_req.temperature,
754
+ top_p=chat_req.top_p,
755
+ stop=chat_req.stop,
756
+ stream=chat_req.stream,
757
+ stream_options=chat_req.stream_options,
758
+ chat_template_kwargs=chat_req.chat_template_kwargs,
759
+ reasoning_budget=chat_req.reasoning_budget,
760
+ discard_thinking=chat_req.discard_thinking,
761
+ stream_thinking=chat_req.stream_thinking,
762
+ session_id=session_id,
763
+ session_append_from_pos=current_pos,
764
+ extra=chat_req.extra,
765
+ )
766
+
767
+ created = int(time.time())
768
+ chatcmpl_id = f"chatcmpl-{uuid.uuid4().hex}"
769
+
770
+ if chat_req.stream:
771
+ await _try_acquire_semaphore()
772
+ event_iter = _stream_chat_completions(
773
+ engine=engine,
774
+ chat_request=chat_req,
775
+ model_id=model_id,
776
+ created=created,
777
+ chatcmpl_id=chatcmpl_id,
778
+ sessions=_sessions,
779
+ sessions_lock=_sessions_lock,
780
+ request=request,
781
+ http_semaphore=http_semaphore,
782
+ semaphore_already_acquired=True,
783
+ )
784
+ return StreamingResponse(event_iter, media_type="text/event-stream")
785
+
786
+ await _try_acquire_semaphore()
787
+ try:
788
+ result = await _run_with_disconnect_cancellation(request, engine.generate_chat(chat_req))
789
+ except HTTPException:
790
+ raise
791
+ except ValueError as exc:
792
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
793
+ except Exception as exc:
794
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
795
+ finally:
796
+ if http_semaphore is not None:
797
+ http_semaphore.release()
798
+
799
+ usage = result.get("usage")
800
+ finish_reason = result.get("finish_reason") or "stop"
801
+ content = result.get("content")
802
+ tool_calls = result.get("tool_calls") or []
803
+ raw_content = result.get("raw_content")
804
+
805
+ message: dict[str, Any] = {"role": "assistant", "content": content}
806
+ if tool_calls:
807
+ message["content"] = None
808
+ message["tool_calls"] = [_openai_tool_call(tc) for tc in tool_calls]
809
+
810
+ # Persist assistant message to session history (non-streaming)
811
+ if chat_req.session_id:
812
+ # When discard_thinking=True, persist only the stripped content (no <think> blocks).
813
+ # Otherwise, persist raw_content (with thinking) if available.
814
+ if chat_req.discard_thinking:
815
+ history_content = content if content is not None else ""
816
+ else:
817
+ if raw_content is not None:
818
+ history_content = raw_content
819
+ else:
820
+ history_content = content if content is not None else ""
821
+ history_msg: dict[str, Any] = {"role": "assistant", "content": history_content}
822
+ if tool_calls:
823
+ history_msg["content"] = None
824
+ history_msg["tool_calls"] = [_openai_tool_call(tc) for tc in tool_calls]
825
+
826
+ # Persist empty-string assistant messages too (never null).
827
+ if tool_calls or isinstance(history_content, str):
828
+ with _sessions_lock:
829
+ meta = _sessions.get(chat_req.session_id)
830
+ if meta is not None:
831
+ meta.messages.append(history_msg)
832
+
833
+ resp: dict[str, Any] = {
834
+ "id": chatcmpl_id,
835
+ "object": "chat.completion",
836
+ "created": created,
837
+ "model": model_id,
838
+ "choices": [
839
+ {
840
+ "index": 0,
841
+ "message": message,
842
+ "finish_reason": finish_reason,
843
+ }
844
+ ],
845
+ }
846
+
847
+ if isinstance(usage, Usage):
848
+ resp["usage"] = {
849
+ "prompt_tokens": usage.prompt_tokens,
850
+ "completion_tokens": usage.completion_tokens,
851
+ "total_tokens": usage.total_tokens,
852
+ }
853
+
854
+ return JSONResponse(resp)
855
+
856
+ return app
857
+
858
+
859
+ def _sse(data: str) -> str:
860
+ return f"data: {data}\n\n"
861
+
862
+
863
+ def _openai_tool_call(tc: ToolCall) -> dict[str, Any]:
864
+ return {
865
+ "id": tc.id,
866
+ "type": "function",
867
+ "function": {"name": tc.name, "arguments": json.dumps(tc.arguments, ensure_ascii=False)},
868
+ }
869
+
870
+
871
+ def _openai_stream_tool_call(tc: ToolCall, *, index: int) -> dict[str, Any]:
872
+ d = _openai_tool_call(tc)
873
+ d["index"] = index
874
+ return d
875
+
876
+
877
+ async def _stream_chat_completions(
878
+ *,
879
+ engine: ChatEngine,
880
+ chat_request: ChatRequest,
881
+ model_id: str,
882
+ created: int,
883
+ chatcmpl_id: str,
884
+ sessions: dict[str, Any],
885
+ sessions_lock: threading.Lock,
886
+ request: Request,
887
+ http_semaphore: asyncio.Semaphore | None,
888
+ semaphore_already_acquired: bool = False,
889
+ ) -> AsyncIterator[str]:
890
+ if http_semaphore is not None and not semaphore_already_acquired:
891
+ try:
892
+ await asyncio.wait_for(http_semaphore.acquire(), timeout=0.001)
893
+ except TimeoutError as exc:
894
+ raise HTTPException(status_code=429, detail="Server is busy") from exc
895
+
896
+ # Initial chunk announces the role.
897
+ yield _sse(
898
+ json.dumps(
899
+ {
900
+ "id": chatcmpl_id,
901
+ "object": "chat.completion.chunk",
902
+ "created": created,
903
+ "model": model_id,
904
+ "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
905
+ },
906
+ ensure_ascii=False,
907
+ )
908
+ )
909
+
910
+ final_finish_reason: str | None = None
911
+ final_usage: Usage | None = None
912
+ final_timing: Timing | None = None
913
+ assistant_text_parts: list[str] = []
914
+ assistant_tool_calls: list[ToolCall] = []
915
+ raw_content_for_history: str | None = None # Raw content with thinking for discard_thinking=False
916
+ cancelled = False
917
+
918
+ try:
919
+ async for event in engine.astream_chat(chat_request):
920
+ from superlinear.engine.chat_types import DeltaEvent, ThinkingDeltaEvent, ErrorEvent, FinalEvent, ToolCallEvent
921
+
922
+ # If the client disconnects mid-stream, stop consuming promptly.
923
+ # The engine stream is cancelled/closed when this generator unwinds.
924
+ if await request.is_disconnected():
925
+ cancelled = True
926
+ final_finish_reason = "cancelled"
927
+ break
928
+
929
+ if isinstance(event, DeltaEvent):
930
+ if not event.text:
931
+ continue
932
+ assistant_text_parts.append(event.text)
933
+ yield _sse(
934
+ json.dumps(
935
+ {
936
+ "id": chatcmpl_id,
937
+ "object": "chat.completion.chunk",
938
+ "created": created,
939
+ "model": model_id,
940
+ "choices": [
941
+ {
942
+ "index": 0,
943
+ "delta": {"content": event.text},
944
+ "finish_reason": None,
945
+ }
946
+ ],
947
+ },
948
+ ensure_ascii=False,
949
+ )
950
+ )
951
+ continue
952
+
953
+ if isinstance(event, ThinkingDeltaEvent):
954
+ if not event.text:
955
+ continue
956
+ yield _sse(
957
+ json.dumps(
958
+ {
959
+ "id": chatcmpl_id,
960
+ "object": "chat.completion.chunk",
961
+ "created": created,
962
+ "model": model_id,
963
+ "choices": [
964
+ {
965
+ "index": 0,
966
+ "delta": {"thinking": event.text},
967
+ "finish_reason": None,
968
+ }
969
+ ],
970
+ },
971
+ ensure_ascii=False,
972
+ )
973
+ )
974
+ continue
975
+
976
+ if isinstance(event, ToolCallEvent):
977
+ # Tool call detected - set finish reason to tool_calls
978
+ final_finish_reason = "tool_calls"
979
+ assistant_tool_calls = list(event.tool_calls)
980
+ yield _sse(
981
+ json.dumps(
982
+ {
983
+ "id": chatcmpl_id,
984
+ "object": "chat.completion.chunk",
985
+ "created": created,
986
+ "model": model_id,
987
+ "choices": [
988
+ {
989
+ "index": 0,
990
+ "delta": {
991
+ "tool_calls": [
992
+ _openai_stream_tool_call(tc, index=i)
993
+ for i, tc in enumerate(event.tool_calls)
994
+ ]
995
+ },
996
+ "finish_reason": None,
997
+ }
998
+ ],
999
+ },
1000
+ ensure_ascii=False,
1001
+ )
1002
+ )
1003
+ continue
1004
+
1005
+ if isinstance(event, FinalEvent):
1006
+ # Don't override tool_calls finish reason
1007
+ if final_finish_reason != "tool_calls":
1008
+ final_finish_reason = event.finish_reason
1009
+ # Capture raw content if provided (for discard_thinking=False sessions)
1010
+ if event.raw_content is not None:
1011
+ raw_content_for_history = event.raw_content
1012
+ final_usage = event.usage
1013
+ final_timing = event.timing
1014
+ continue
1015
+
1016
+ if isinstance(event, ErrorEvent):
1017
+ yield _sse(
1018
+ json.dumps(
1019
+ {
1020
+ "error": {
1021
+ "message": event.message,
1022
+ "type": "server_error",
1023
+ "param": None,
1024
+ "code": None,
1025
+ }
1026
+ },
1027
+ ensure_ascii=False,
1028
+ )
1029
+ )
1030
+ final_finish_reason = "error"
1031
+ break
1032
+ except asyncio.CancelledError:
1033
+ cancelled = True
1034
+ final_finish_reason = "cancelled"
1035
+ raise
1036
+ except ValueError as exc:
1037
+ yield _sse(
1038
+ json.dumps(
1039
+ {
1040
+ "error": {
1041
+ "message": str(exc),
1042
+ "type": "invalid_request_error",
1043
+ "param": None,
1044
+ "code": None,
1045
+ }
1046
+ },
1047
+ ensure_ascii=False,
1048
+ )
1049
+ )
1050
+ final_finish_reason = "error"
1051
+ except Exception as exc:
1052
+ yield _sse(
1053
+ json.dumps(
1054
+ {
1055
+ "error": {
1056
+ "message": str(exc),
1057
+ "type": "server_error",
1058
+ "param": None,
1059
+ "code": None,
1060
+ }
1061
+ },
1062
+ ensure_ascii=False,
1063
+ )
1064
+ )
1065
+ final_finish_reason = "error"
1066
+ finally:
1067
+ if http_semaphore is not None:
1068
+ http_semaphore.release()
1069
+
1070
+ # Persist assistant message to session history (best-effort).
1071
+ # Important: on cancelled/incomplete streams, persist an *empty string* (not null)
1072
+ # so HTTP transcript stays aligned with the adapter session KV state.
1073
+ if chat_request.session_id:
1074
+ if chat_request.discard_thinking:
1075
+ text = "".join(assistant_text_parts) if assistant_text_parts else ""
1076
+ history_content: str | None = text
1077
+ else:
1078
+ if raw_content_for_history is not None:
1079
+ history_content = raw_content_for_history
1080
+ else:
1081
+ history_content = "".join(assistant_text_parts) if assistant_text_parts else ""
1082
+
1083
+ msg: dict[str, Any] = {
1084
+ "role": "assistant",
1085
+ "content": history_content,
1086
+ }
1087
+ if assistant_tool_calls:
1088
+ msg["content"] = None
1089
+ msg["tool_calls"] = [_openai_tool_call(tc) for tc in assistant_tool_calls]
1090
+
1091
+ # Avoid persisting pure-null assistant messages.
1092
+ should_persist = True
1093
+ if not assistant_tool_calls and history_content is None:
1094
+ should_persist = False
1095
+
1096
+ if should_persist:
1097
+ with sessions_lock:
1098
+ meta = sessions.get(chat_request.session_id)
1099
+ if meta is not None and hasattr(meta, "messages"):
1100
+ meta.messages.append(msg) # type: ignore[attr-defined]
1101
+
1102
+ # Terminal chunk + DONE (skip if the request was cancelled/disconnected).
1103
+ if not cancelled:
1104
+ if final_finish_reason is None:
1105
+ final_finish_reason = "stop"
1106
+
1107
+ terminal: dict[str, Any] = {
1108
+ "id": chatcmpl_id,
1109
+ "object": "chat.completion.chunk",
1110
+ "created": created,
1111
+ "model": model_id,
1112
+ "choices": [{"index": 0, "delta": {}, "finish_reason": final_finish_reason}],
1113
+ }
1114
+ if isinstance(final_usage, Usage):
1115
+ terminal["usage"] = {
1116
+ "prompt_tokens": final_usage.prompt_tokens,
1117
+ "completion_tokens": final_usage.completion_tokens,
1118
+ "total_tokens": final_usage.total_tokens,
1119
+ }
1120
+ if isinstance(final_timing, Timing):
1121
+ terminal["x_superlinear_timing"] = {
1122
+ "prefill_s": final_timing.prefill_s,
1123
+ "decode_s": final_timing.decode_s,
1124
+ "total_s": final_timing.total_s,
1125
+ "tok_per_s": final_timing.tok_per_s,
1126
+ }
1127
+ yield _sse(json.dumps(terminal, ensure_ascii=False))
1128
+ yield "data: [DONE]\n\n"
1129
+
1130
+
1131
+ def _parse_chat_request(payload: Any, *, http_max_completion_tokens: int | None = None) -> ChatRequest:
1132
+ if not isinstance(payload, dict):
1133
+ raise HTTPException(status_code=400, detail="Request body must be a JSON object.")
1134
+
1135
+ raw_messages = payload.get("messages")
1136
+ if not isinstance(raw_messages, list) or not raw_messages:
1137
+ raise HTTPException(status_code=400, detail="'messages' must be a non-empty list.")
1138
+
1139
+ messages: list[ChatMessage] = []
1140
+ for msg in raw_messages:
1141
+ if not isinstance(msg, dict):
1142
+ raise HTTPException(status_code=400, detail="Each message must be an object.")
1143
+
1144
+ role = msg.get("role")
1145
+ if role not in {"system", "user", "assistant", "tool"}:
1146
+ raise HTTPException(status_code=400, detail=f"Invalid message role: {role!r}.")
1147
+
1148
+ content = _coerce_content(msg.get("content"))
1149
+
1150
+ tool_call_id = msg.get("tool_call_id") if role == "tool" else None
1151
+ tool_calls: list[ToolCall] = []
1152
+
1153
+ if role == "assistant" and msg.get("tool_calls") is not None:
1154
+ raw_tool_calls = msg.get("tool_calls")
1155
+ if not isinstance(raw_tool_calls, list):
1156
+ raise HTTPException(status_code=400, detail="'tool_calls' must be a list.")
1157
+
1158
+ for tc in raw_tool_calls:
1159
+ tool_calls.append(_parse_assistant_tool_call(tc))
1160
+
1161
+ messages.append(
1162
+ ChatMessage(
1163
+ role=role,
1164
+ content=content,
1165
+ tool_calls=tool_calls,
1166
+ tool_call_id=tool_call_id,
1167
+ )
1168
+ )
1169
+
1170
+ tools = payload.get("tools") or []
1171
+ if tools is None:
1172
+ tools = []
1173
+ if not isinstance(tools, list):
1174
+ raise HTTPException(status_code=400, detail="'tools' must be a list.")
1175
+
1176
+ tool_choice = payload.get("tool_choice")
1177
+
1178
+ max_tokens = payload.get("max_tokens")
1179
+ max_completion_tokens = payload.get("max_completion_tokens")
1180
+
1181
+ if max_tokens is None and max_completion_tokens is None:
1182
+ max_tokens = 4096
1183
+ elif max_tokens is not None and max_completion_tokens is not None:
1184
+ try:
1185
+ if int(max_tokens) != int(max_completion_tokens):
1186
+ raise HTTPException(
1187
+ status_code=400,
1188
+ detail="'max_tokens' and 'max_completion_tokens' must match when both are provided.",
1189
+ )
1190
+ except HTTPException:
1191
+ raise
1192
+ except Exception as exc:
1193
+ raise HTTPException(
1194
+ status_code=400,
1195
+ detail="'max_tokens' and 'max_completion_tokens' must be integers.",
1196
+ ) from exc
1197
+ elif max_completion_tokens is not None:
1198
+ max_tokens = max_completion_tokens
1199
+
1200
+ try:
1201
+ max_tokens = int(max_tokens)
1202
+ except Exception as exc:
1203
+ raise HTTPException(status_code=400, detail="'max_tokens' must be an integer.") from exc
1204
+ if max_tokens <= 0:
1205
+ raise HTTPException(status_code=400, detail="'max_tokens' must be > 0.")
1206
+
1207
+ if http_max_completion_tokens is not None and max_tokens > http_max_completion_tokens:
1208
+ raise HTTPException(
1209
+ status_code=400,
1210
+ detail=f"'max_tokens' too large: {max_tokens} (cap={http_max_completion_tokens}).",
1211
+ )
1212
+
1213
+ try:
1214
+ temperature = float(payload.get("temperature", 0.1) or 0.1)
1215
+ except Exception as exc:
1216
+ raise HTTPException(status_code=400, detail="'temperature' must be a number.") from exc
1217
+
1218
+ try:
1219
+ top_p = float(payload.get("top_p", 0.95) or 0.95)
1220
+ except Exception as exc:
1221
+ raise HTTPException(status_code=400, detail="'top_p' must be a number.") from exc
1222
+
1223
+ stop = payload.get("stop") or []
1224
+ if isinstance(stop, str):
1225
+ stop = [stop]
1226
+ if not isinstance(stop, list):
1227
+ raise HTTPException(status_code=400, detail="'stop' must be a string or list of strings.")
1228
+ stop = [s for s in stop if isinstance(s, str)]
1229
+
1230
+ stream = bool(payload.get("stream", False))
1231
+
1232
+ stream_options = payload.get("stream_options") or {}
1233
+ if stream_options is None:
1234
+ stream_options = {}
1235
+ if not isinstance(stream_options, dict):
1236
+ raise HTTPException(status_code=400, detail="'stream_options' must be an object.")
1237
+
1238
+ try:
1239
+ flush_every_n_tokens = int(stream_options.get("flush_every_n_tokens", 8))
1240
+ flush_every_ms = int(stream_options.get("flush_every_ms", 50))
1241
+ except Exception as exc:
1242
+ raise HTTPException(
1243
+ status_code=400,
1244
+ detail="'stream_options.flush_every_n_tokens' and 'stream_options.flush_every_ms' must be integers.",
1245
+ ) from exc
1246
+
1247
+ # Parse chat_template_kwargs (optional, vLLM-compatible)
1248
+ chat_template_kwargs = payload.get("chat_template_kwargs")
1249
+ if chat_template_kwargs is not None and not isinstance(chat_template_kwargs, dict):
1250
+ raise HTTPException(status_code=400, detail="'chat_template_kwargs' must be an object.")
1251
+
1252
+ # Parse reasoning_budget (optional, Superlinear-specific)
1253
+ reasoning_budget = payload.get("reasoning_budget")
1254
+ if reasoning_budget is not None:
1255
+ try:
1256
+ reasoning_budget = int(reasoning_budget)
1257
+ except (ValueError, TypeError) as exc:
1258
+ raise HTTPException(status_code=400, detail="'reasoning_budget' must be an integer.") from exc
1259
+ if reasoning_budget <= 0:
1260
+ raise HTTPException(status_code=400, detail="'reasoning_budget' must be > 0.")
1261
+
1262
+ # Parse discard_thinking (optional, Superlinear-specific)
1263
+ discard_thinking = payload.get("discard_thinking")
1264
+ if discard_thinking is not None and not isinstance(discard_thinking, bool):
1265
+ raise HTTPException(status_code=400, detail="'discard_thinking' must be a boolean.")
1266
+
1267
+ # Parse stream_thinking (optional, Superlinear-specific)
1268
+ stream_thinking = payload.get("stream_thinking")
1269
+ if stream_thinking is not None and not isinstance(stream_thinking, bool):
1270
+ raise HTTPException(status_code=400, detail="'stream_thinking' must be a boolean.")
1271
+
1272
+ # Parse session_id (optional, for stateful chat)
1273
+ session_id = payload.get("session_id")
1274
+ if session_id is not None and not isinstance(session_id, str):
1275
+ raise HTTPException(status_code=400, detail="'session_id' must be a string.")
1276
+
1277
+ # Parse extra (optional, engine-specific)
1278
+ extra = payload.get("extra")
1279
+ if extra is None:
1280
+ extra = {}
1281
+ if not isinstance(extra, dict):
1282
+ raise HTTPException(status_code=400, detail="'extra' must be an object.")
1283
+
1284
+ # Convenience alias: allow top-level repetition_detection to be passed through.
1285
+ repetition_detection = payload.get("repetition_detection")
1286
+ if repetition_detection is not None and "repetition_detection" not in extra:
1287
+ extra = dict(extra)
1288
+ extra["repetition_detection"] = repetition_detection
1289
+
1290
+ return ChatRequest(
1291
+ messages=messages,
1292
+ tools=tools,
1293
+ tool_choice=tool_choice,
1294
+ max_tokens=max_tokens,
1295
+ temperature=temperature,
1296
+ top_p=top_p,
1297
+ stop=stop,
1298
+ stream=stream,
1299
+ stream_options=StreamOptions(
1300
+ flush_every_n_tokens=flush_every_n_tokens,
1301
+ flush_every_ms=flush_every_ms,
1302
+ ),
1303
+ chat_template_kwargs=chat_template_kwargs,
1304
+ reasoning_budget=reasoning_budget,
1305
+ discard_thinking=discard_thinking,
1306
+ stream_thinking=stream_thinking,
1307
+ session_id=session_id,
1308
+ extra=extra,
1309
+ )
1310
+
1311
+
1312
+ def _coerce_content(content: Any) -> str:
1313
+ if content is None:
1314
+ return ""
1315
+ if isinstance(content, str):
1316
+ return content
1317
+
1318
+ # Minimal support for OpenAI "content parts" format (text-only).
1319
+ if isinstance(content, list):
1320
+ parts: list[str] = []
1321
+ for part in content:
1322
+ if not isinstance(part, dict):
1323
+ continue
1324
+ if part.get("type") != "text":
1325
+ continue
1326
+ text = part.get("text")
1327
+ if isinstance(text, str):
1328
+ parts.append(text)
1329
+ return "".join(parts)
1330
+
1331
+ raise HTTPException(status_code=400, detail="Unsupported message content type.")
1332
+
1333
+
1334
+ def _parse_assistant_tool_call(tc: Any) -> ToolCall:
1335
+ if not isinstance(tc, dict):
1336
+ raise HTTPException(status_code=400, detail="Each tool_call must be an object.")
1337
+
1338
+ fn = tc.get("function")
1339
+ if not isinstance(fn, dict):
1340
+ raise HTTPException(status_code=400, detail="tool_call.function must be an object.")
1341
+
1342
+ name = fn.get("name")
1343
+ if not isinstance(name, str) or not name:
1344
+ raise HTTPException(status_code=400, detail="tool_call.function.name must be a string.")
1345
+
1346
+ arguments = fn.get("arguments")
1347
+ args_dict: dict[str, Any] = {}
1348
+ if isinstance(arguments, str) and arguments.strip():
1349
+ try:
1350
+ parsed = json.loads(arguments)
1351
+ if isinstance(parsed, dict):
1352
+ args_dict = parsed
1353
+ except Exception:
1354
+ # Best-effort fallback: preserve raw payload under a reserved key.
1355
+ args_dict = {"__raw__": arguments}
1356
+ elif isinstance(arguments, dict):
1357
+ args_dict = arguments
1358
+
1359
+ tool_call_id = tc.get("id")
1360
+ if not isinstance(tool_call_id, str) or not tool_call_id:
1361
+ tool_call_id = f"call_{uuid.uuid4().hex}"
1362
+
1363
+ return ToolCall(id=tool_call_id, name=name, arguments=args_dict)