superlinear 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- apps/__init__.py +4 -0
- apps/cli/__init__.py +8 -0
- apps/cli/bm25_rag.py +471 -0
- apps/cli/chat_repl.py +1497 -0
- apps/cli/client.py +195 -0
- apps/cli/docs_repl.py +2275 -0
- apps/cli/light_rag.py +729 -0
- apps/cli/local_snapshots.py +139 -0
- apps/cli/locks.py +214 -0
- apps/cli/main.py +457 -0
- apps/cli/output.py +32 -0
- apps/cli/server_cmds.py +516 -0
- apps/cli/session_cmds.py +491 -0
- apps/cli/snapshot_cmds.py +303 -0
- apps/cli/state.py +265 -0
- apps/server/__init__.py +4 -0
- apps/server/app.py +1363 -0
- apps/server/main.py +313 -0
- superlinear/__init__.py +114 -0
- superlinear/_version.py +3 -0
- superlinear/engine/__init__.py +10 -0
- superlinear/engine/adapters/__init__.py +12 -0
- superlinear/engine/adapters/base.py +91 -0
- superlinear/engine/adapters/superlinear.py +1233 -0
- superlinear/engine/chat_engine.py +1173 -0
- superlinear/engine/chat_types.py +130 -0
- superlinear/engine/registry.py +51 -0
- superlinear/engine/repetition.py +203 -0
- superlinear/engine/session_snapshots.py +451 -0
- superlinear/engine/tool_parser.py +83 -0
- superlinear/engine/types.py +42 -0
- superlinear/kernels/__init__.py +2 -0
- superlinear/kernels/common/__init__.py +21 -0
- superlinear/kernels/common/adjustment.py +106 -0
- superlinear/kernels/common/power.py +154 -0
- superlinear/kernels/superlinear/__init__.py +10 -0
- superlinear/kernels/superlinear/attention/__init__.py +78 -0
- superlinear/kernels/superlinear/attention/_prefill.py +940 -0
- superlinear/kernels/superlinear/attention/_sliding_window.py +1167 -0
- superlinear/kernels/superlinear/attention/api.py +433 -0
- superlinear/kernels/superlinear/search/__init__.py +33 -0
- superlinear/kernels/superlinear/search/_reference.py +204 -0
- superlinear/kernels/superlinear/search/_triton.py +488 -0
- superlinear/kernels/superlinear/search/_triton_gqa.py +534 -0
- superlinear/kernels/superlinear/search/api.py +200 -0
- superlinear/kernels/superlinear/span/__init__.py +41 -0
- superlinear/kernels/superlinear/span/_triton_bucketed_gqa.py +1461 -0
- superlinear/kernels/superlinear/span/_triton_forward.py +22 -0
- superlinear/kernels/superlinear/span/_triton_gqa.py +1226 -0
- superlinear/kernels/superlinear/span/_triton_impl.py +928 -0
- superlinear/kernels/superlinear/span/_triton_precomputed_sw.py +460 -0
- superlinear/kernels/superlinear/span/_triton_precomputed_sw_gqa.py +598 -0
- superlinear/kernels/superlinear/span/api.py +296 -0
- superlinear/kernels/superlinear/span/masks.py +187 -0
- superlinear/py.typed +0 -0
- superlinear/runtime.py +71 -0
- superlinear-0.1.0.dist-info/METADATA +469 -0
- superlinear-0.1.0.dist-info/RECORD +62 -0
- superlinear-0.1.0.dist-info/WHEEL +5 -0
- superlinear-0.1.0.dist-info/entry_points.txt +2 -0
- superlinear-0.1.0.dist-info/licenses/LICENSE +202 -0
- superlinear-0.1.0.dist-info/top_level.txt +2 -0
apps/server/app.py
ADDED
|
@@ -0,0 +1,1363 @@
|
|
|
1
|
+
"""FastAPI app for OpenAI-style Chat Completions.
|
|
2
|
+
|
|
3
|
+
The HTTP layer lives under `apps/` and can depend on heavier deps (FastAPI, uvicorn).
|
|
4
|
+
All model execution is delegated to the core engine (`superlinear/engine`).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import os
|
|
11
|
+
import threading
|
|
12
|
+
import json
|
|
13
|
+
import time
|
|
14
|
+
import uuid
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from typing import Any, AsyncIterator
|
|
17
|
+
|
|
18
|
+
from fastapi import FastAPI, HTTPException, Request
|
|
19
|
+
from starlette.responses import JSONResponse, StreamingResponse
|
|
20
|
+
|
|
21
|
+
from superlinear.engine.chat_engine import ChatEngine
|
|
22
|
+
from superlinear.engine.chat_types import ChatMessage, ChatRequest, StreamOptions, Timing, ToolCall, Usage
|
|
23
|
+
from superlinear.engine.session_snapshots import (
|
|
24
|
+
SnapshotCompatibilityError,
|
|
25
|
+
SnapshotStoreV1,
|
|
26
|
+
compute_model_compatibility,
|
|
27
|
+
export_hybrid_mamba_attention_static_cache,
|
|
28
|
+
import_hybrid_mamba_attention_static_cache,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def create_app(
|
|
33
|
+
*,
|
|
34
|
+
engine: ChatEngine,
|
|
35
|
+
model_id: str,
|
|
36
|
+
http_max_concurrency: int | None = None,
|
|
37
|
+
http_max_completion_tokens: int | None = None,
|
|
38
|
+
) -> FastAPI:
|
|
39
|
+
app = FastAPI(title="Superlinear Inference Server", version="0.1.0")
|
|
40
|
+
|
|
41
|
+
default_max_seq_len = 131_072
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class _HttpSession:
|
|
45
|
+
max_seq_len: int
|
|
46
|
+
messages: list[dict[str, Any]] = field(default_factory=list)
|
|
47
|
+
|
|
48
|
+
_sessions_lock = threading.Lock()
|
|
49
|
+
_sessions: dict[str, _HttpSession] = {}
|
|
50
|
+
|
|
51
|
+
_engine_lock = getattr(engine, "_lock", threading.Lock())
|
|
52
|
+
|
|
53
|
+
_snapshot_store_lock = threading.Lock()
|
|
54
|
+
_snapshot_store: SnapshotStoreV1 | None = None
|
|
55
|
+
|
|
56
|
+
http_semaphore: asyncio.Semaphore | None = None
|
|
57
|
+
if http_max_concurrency is not None:
|
|
58
|
+
try:
|
|
59
|
+
http_max_concurrency = int(http_max_concurrency)
|
|
60
|
+
except Exception as exc:
|
|
61
|
+
raise ValueError("http_max_concurrency must be an integer") from exc
|
|
62
|
+
if http_max_concurrency > 0:
|
|
63
|
+
http_semaphore = asyncio.Semaphore(http_max_concurrency)
|
|
64
|
+
elif http_max_concurrency < 0:
|
|
65
|
+
raise ValueError("http_max_concurrency must be >= 0")
|
|
66
|
+
|
|
67
|
+
if http_max_completion_tokens is not None:
|
|
68
|
+
try:
|
|
69
|
+
http_max_completion_tokens = int(http_max_completion_tokens)
|
|
70
|
+
except Exception as exc:
|
|
71
|
+
raise ValueError("http_max_completion_tokens must be an integer") from exc
|
|
72
|
+
if http_max_completion_tokens <= 0:
|
|
73
|
+
raise ValueError("http_max_completion_tokens must be > 0")
|
|
74
|
+
|
|
75
|
+
async def _wait_for_disconnect(request: Request, poll_s: float = 0.1) -> None:
|
|
76
|
+
while True:
|
|
77
|
+
if await request.is_disconnected():
|
|
78
|
+
return
|
|
79
|
+
await asyncio.sleep(poll_s)
|
|
80
|
+
|
|
81
|
+
async def _run_with_disconnect_cancellation(request: Request, coro: Any) -> Any:
|
|
82
|
+
task = asyncio.create_task(coro)
|
|
83
|
+
disconnect_task = asyncio.create_task(_wait_for_disconnect(request))
|
|
84
|
+
# Yield to let both tasks start (handles coroutines that return synchronously).
|
|
85
|
+
await asyncio.sleep(0)
|
|
86
|
+
done, pending = await asyncio.wait(
|
|
87
|
+
{task, disconnect_task},
|
|
88
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
89
|
+
)
|
|
90
|
+
if disconnect_task in done:
|
|
91
|
+
task.cancel()
|
|
92
|
+
try:
|
|
93
|
+
await task
|
|
94
|
+
except asyncio.CancelledError:
|
|
95
|
+
pass
|
|
96
|
+
raise HTTPException(status_code=499, detail="Client disconnected")
|
|
97
|
+
|
|
98
|
+
disconnect_task.cancel()
|
|
99
|
+
try:
|
|
100
|
+
await disconnect_task
|
|
101
|
+
except asyncio.CancelledError:
|
|
102
|
+
pass
|
|
103
|
+
return task.result()
|
|
104
|
+
|
|
105
|
+
async def _try_acquire_semaphore() -> None:
|
|
106
|
+
if http_semaphore is None:
|
|
107
|
+
return
|
|
108
|
+
try:
|
|
109
|
+
await asyncio.wait_for(http_semaphore.acquire(), timeout=0.001)
|
|
110
|
+
except TimeoutError as exc:
|
|
111
|
+
raise HTTPException(status_code=429, detail="Server is busy") from exc
|
|
112
|
+
|
|
113
|
+
def _get_snapshot_store() -> SnapshotStoreV1:
|
|
114
|
+
nonlocal _snapshot_store
|
|
115
|
+
with _snapshot_store_lock:
|
|
116
|
+
if _snapshot_store is not None:
|
|
117
|
+
return _snapshot_store
|
|
118
|
+
|
|
119
|
+
adapter = getattr(engine, "adapter", None)
|
|
120
|
+
if adapter is None:
|
|
121
|
+
raise HTTPException(status_code=500, detail="Engine does not expose an adapter.")
|
|
122
|
+
|
|
123
|
+
xdg_cache = os.environ.get("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache"))
|
|
124
|
+
default_snapshot_dir = os.path.join(xdg_cache, "spl", "snapshots")
|
|
125
|
+
root_dir = os.environ.get("SUPERLINEAR_SNAPSHOT_DIR", default_snapshot_dir)
|
|
126
|
+
compat = compute_model_compatibility(adapter=adapter, model_id=model_id)
|
|
127
|
+
_snapshot_store = SnapshotStoreV1(root_dir=root_dir, model_id=model_id, compat=compat)
|
|
128
|
+
return _snapshot_store
|
|
129
|
+
|
|
130
|
+
async def _json_dict_or_empty(request: Request) -> dict[str, Any]:
|
|
131
|
+
try:
|
|
132
|
+
payload = await request.json()
|
|
133
|
+
except Exception:
|
|
134
|
+
return {}
|
|
135
|
+
if isinstance(payload, dict):
|
|
136
|
+
return payload
|
|
137
|
+
raise HTTPException(status_code=400, detail="Request body must be a JSON object.")
|
|
138
|
+
|
|
139
|
+
# -------------------------------------------------------------------------
|
|
140
|
+
# Health & Models
|
|
141
|
+
# -------------------------------------------------------------------------
|
|
142
|
+
|
|
143
|
+
@app.get("/health")
|
|
144
|
+
async def health() -> dict[str, str]:
|
|
145
|
+
return {"status": "ok"}
|
|
146
|
+
|
|
147
|
+
@app.get("/v1/models")
|
|
148
|
+
async def list_models() -> dict[str, Any]:
|
|
149
|
+
now = int(time.time())
|
|
150
|
+
return {
|
|
151
|
+
"object": "list",
|
|
152
|
+
"data": [
|
|
153
|
+
{
|
|
154
|
+
"id": model_id,
|
|
155
|
+
"object": "model",
|
|
156
|
+
"created": now,
|
|
157
|
+
"owned_by": "superlinear",
|
|
158
|
+
}
|
|
159
|
+
],
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
# -------------------------------------------------------------------------
|
|
163
|
+
# Session Management
|
|
164
|
+
# -------------------------------------------------------------------------
|
|
165
|
+
|
|
166
|
+
@app.post("/v1/sessions")
|
|
167
|
+
async def create_session(request: Request) -> Any:
|
|
168
|
+
"""Create a new stateful session for multi-turn conversations."""
|
|
169
|
+
payload = await request.json()
|
|
170
|
+
session_id = payload.get("session_id")
|
|
171
|
+
if not session_id or not isinstance(session_id, str):
|
|
172
|
+
raise HTTPException(status_code=400, detail="'session_id' is required and must be a string.")
|
|
173
|
+
|
|
174
|
+
max_seq_len = payload.get("max_seq_len", default_max_seq_len)
|
|
175
|
+
try:
|
|
176
|
+
max_seq_len = int(max_seq_len)
|
|
177
|
+
except (ValueError, TypeError) as exc:
|
|
178
|
+
raise HTTPException(status_code=400, detail="'max_seq_len' must be an integer.") from exc
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
with _engine_lock:
|
|
182
|
+
engine.adapter.create_session(
|
|
183
|
+
cache_id=session_id,
|
|
184
|
+
max_seq_len=max_seq_len,
|
|
185
|
+
)
|
|
186
|
+
except ValueError as exc:
|
|
187
|
+
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
|
188
|
+
except Exception as exc:
|
|
189
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
190
|
+
|
|
191
|
+
with _sessions_lock:
|
|
192
|
+
_sessions[session_id] = _HttpSession(max_seq_len=max_seq_len)
|
|
193
|
+
|
|
194
|
+
return JSONResponse({"status": "created", "session_id": session_id})
|
|
195
|
+
|
|
196
|
+
@app.get("/v1/sessions")
|
|
197
|
+
async def list_sessions() -> Any:
|
|
198
|
+
"""List all active sessions."""
|
|
199
|
+
with _engine_lock:
|
|
200
|
+
sessions = engine.adapter.list_sessions()
|
|
201
|
+
return JSONResponse({"sessions": sessions})
|
|
202
|
+
|
|
203
|
+
@app.get("/v1/sessions/{session_id}")
|
|
204
|
+
async def get_session_info(session_id: str) -> Any:
|
|
205
|
+
"""Get information about a specific session."""
|
|
206
|
+
try:
|
|
207
|
+
with _engine_lock:
|
|
208
|
+
info = engine.adapter.get_session_info(session_id)
|
|
209
|
+
except KeyError as exc:
|
|
210
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
211
|
+
except Exception as exc:
|
|
212
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
213
|
+
|
|
214
|
+
# Convenience aliases for client code/tests.
|
|
215
|
+
info = dict(info)
|
|
216
|
+
info["cache_position"] = info.get("current_pos")
|
|
217
|
+
with _sessions_lock:
|
|
218
|
+
meta = _sessions.get(session_id)
|
|
219
|
+
if meta is not None:
|
|
220
|
+
info["message_count"] = len(meta.messages)
|
|
221
|
+
return JSONResponse(info)
|
|
222
|
+
|
|
223
|
+
@app.delete("/v1/sessions/{session_id}")
|
|
224
|
+
async def close_session(session_id: str) -> Any:
|
|
225
|
+
"""Close a session and free its resources."""
|
|
226
|
+
# Check if session exists
|
|
227
|
+
try:
|
|
228
|
+
with _engine_lock:
|
|
229
|
+
engine.adapter.get_session_info(session_id)
|
|
230
|
+
except KeyError as exc:
|
|
231
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
232
|
+
|
|
233
|
+
with _engine_lock:
|
|
234
|
+
engine.adapter.close_session(session_id)
|
|
235
|
+
with _sessions_lock:
|
|
236
|
+
_sessions.pop(session_id, None)
|
|
237
|
+
return JSONResponse({"status": "closed", "session_id": session_id})
|
|
238
|
+
|
|
239
|
+
@app.get("/v1/sessions/{session_id}/history")
|
|
240
|
+
async def get_session_history(session_id: str) -> Any:
|
|
241
|
+
"""Get the stored chat history for a session."""
|
|
242
|
+
try:
|
|
243
|
+
with _engine_lock:
|
|
244
|
+
engine.adapter.get_session_info(session_id)
|
|
245
|
+
except KeyError as exc:
|
|
246
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
247
|
+
|
|
248
|
+
with _sessions_lock:
|
|
249
|
+
meta = _sessions.get(session_id)
|
|
250
|
+
if meta is None:
|
|
251
|
+
return JSONResponse({"session_id": session_id, "messages": []})
|
|
252
|
+
return JSONResponse({"session_id": session_id, "messages": meta.messages})
|
|
253
|
+
|
|
254
|
+
@app.post("/v1/sessions/{session_id}/rollback")
|
|
255
|
+
async def rollback_session(session_id: str, request: Request) -> Any:
|
|
256
|
+
"""Rollback a session to an earlier message index by replaying history.
|
|
257
|
+
|
|
258
|
+
Body:
|
|
259
|
+
- keep_messages: int (number of messages to keep from the start)
|
|
260
|
+
"""
|
|
261
|
+
payload = await request.json()
|
|
262
|
+
keep_messages = payload.get("keep_messages")
|
|
263
|
+
try:
|
|
264
|
+
keep_messages = int(keep_messages)
|
|
265
|
+
except Exception as exc:
|
|
266
|
+
raise HTTPException(status_code=400, detail="'keep_messages' must be an integer.") from exc
|
|
267
|
+
if keep_messages < 0:
|
|
268
|
+
raise HTTPException(status_code=400, detail="'keep_messages' must be >= 0.")
|
|
269
|
+
|
|
270
|
+
# Ensure session exists and retrieve max_seq_len.
|
|
271
|
+
try:
|
|
272
|
+
with _engine_lock:
|
|
273
|
+
adapter_info = engine.adapter.get_session_info(session_id)
|
|
274
|
+
except KeyError as exc:
|
|
275
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
276
|
+
|
|
277
|
+
with _sessions_lock:
|
|
278
|
+
meta = _sessions.get(session_id)
|
|
279
|
+
if meta is None:
|
|
280
|
+
raise HTTPException(status_code=404, detail=f"No stored history for session: {session_id}")
|
|
281
|
+
meta.messages = meta.messages[:keep_messages]
|
|
282
|
+
max_seq_len = meta.max_seq_len
|
|
283
|
+
history_msgs = list(meta.messages)
|
|
284
|
+
|
|
285
|
+
# Recreate adapter session and replay history prompt.
|
|
286
|
+
with _engine_lock:
|
|
287
|
+
engine.adapter.close_session(session_id)
|
|
288
|
+
engine.adapter.create_session(cache_id=session_id, max_seq_len=max_seq_len)
|
|
289
|
+
|
|
290
|
+
if history_msgs:
|
|
291
|
+
# Build prompt from stored history WITHOUT adding a generation prompt.
|
|
292
|
+
# The next /chat/completions call will add user msg + generation prompt.
|
|
293
|
+
chat_req = _parse_chat_request({"messages": history_msgs, "max_tokens": 1})
|
|
294
|
+
tokenizer = getattr(engine.adapter, "tokenizer", None)
|
|
295
|
+
apply_chat_template = getattr(tokenizer, "apply_chat_template", None)
|
|
296
|
+
if not callable(apply_chat_template):
|
|
297
|
+
raise HTTPException(status_code=500, detail="Tokenizer does not support apply_chat_template().")
|
|
298
|
+
|
|
299
|
+
injected = engine._inject_tool_choice(list(chat_req.messages), chat_req.tool_choice) # type: ignore[attr-defined]
|
|
300
|
+
template_messages = engine._messages_for_template(injected) # type: ignore[attr-defined]
|
|
301
|
+
kwargs = engine._chat_template_kwargs( # type: ignore[attr-defined]
|
|
302
|
+
chat_req, enable_thinking=engine._effective_enable_thinking(chat_req) # type: ignore[attr-defined]
|
|
303
|
+
)
|
|
304
|
+
input_ids = apply_chat_template(
|
|
305
|
+
template_messages,
|
|
306
|
+
add_generation_prompt=False,
|
|
307
|
+
**kwargs,
|
|
308
|
+
)
|
|
309
|
+
with _engine_lock:
|
|
310
|
+
engine.adapter.append_to_session(cache_id=session_id, input_ids=input_ids)
|
|
311
|
+
|
|
312
|
+
with _engine_lock:
|
|
313
|
+
new_info = engine.adapter.get_session_info(session_id)
|
|
314
|
+
new_info = dict(new_info)
|
|
315
|
+
new_info["cache_position"] = new_info.get("current_pos")
|
|
316
|
+
new_info["message_count"] = keep_messages
|
|
317
|
+
return JSONResponse({"status": "ok", "session_id": session_id, "session": new_info})
|
|
318
|
+
|
|
319
|
+
def _parse_resize_strategy(payload: dict[str, Any]) -> str:
|
|
320
|
+
strategy = payload.get("strategy", "auto")
|
|
321
|
+
if strategy is None:
|
|
322
|
+
strategy = "auto"
|
|
323
|
+
if not isinstance(strategy, str):
|
|
324
|
+
raise HTTPException(status_code=400, detail="'strategy' must be a string.")
|
|
325
|
+
strategy = strategy.lower().strip()
|
|
326
|
+
if strategy not in {"auto", "gpu", "disk"}:
|
|
327
|
+
raise HTTPException(status_code=400, detail="'strategy' must be one of: auto, gpu, disk.")
|
|
328
|
+
return strategy
|
|
329
|
+
|
|
330
|
+
def _next_pow2_strictly_greater(n: int) -> int:
|
|
331
|
+
if n <= 0:
|
|
332
|
+
return 1
|
|
333
|
+
p = 1 << ((n - 1).bit_length())
|
|
334
|
+
if p == n:
|
|
335
|
+
p *= 2
|
|
336
|
+
return p
|
|
337
|
+
|
|
338
|
+
def _resize_session_to(*, session_id: str, target_max_seq_len: int, strategy: str) -> dict[str, Any]:
|
|
339
|
+
if target_max_seq_len <= 0:
|
|
340
|
+
raise HTTPException(status_code=400, detail="'max_seq_len' must be > 0.")
|
|
341
|
+
|
|
342
|
+
def _allocate_and_restore(*, close_first: bool) -> dict[str, Any]:
|
|
343
|
+
# Everything here runs under _engine_lock.
|
|
344
|
+
exported = engine.adapter.export_session(session_id)
|
|
345
|
+
current_pos = int(exported.get("current_pos") or 0)
|
|
346
|
+
old_max = int(exported.get("max_seq_len") or 0)
|
|
347
|
+
|
|
348
|
+
if current_pos < 0:
|
|
349
|
+
raise HTTPException(status_code=500, detail="Invalid session current_pos.")
|
|
350
|
+
if old_max <= 0:
|
|
351
|
+
raise HTTPException(status_code=500, detail="Invalid session max_seq_len.")
|
|
352
|
+
if target_max_seq_len < current_pos:
|
|
353
|
+
raise HTTPException(
|
|
354
|
+
status_code=400,
|
|
355
|
+
detail=f"'max_seq_len' ({target_max_seq_len}) must be >= current_pos ({current_pos}).",
|
|
356
|
+
)
|
|
357
|
+
if target_max_seq_len == old_max:
|
|
358
|
+
info = engine.adapter.get_session_info(session_id)
|
|
359
|
+
info = dict(info)
|
|
360
|
+
info["cache_position"] = info.get("current_pos")
|
|
361
|
+
return {"status": "noop", "session_id": session_id, "session": info}
|
|
362
|
+
|
|
363
|
+
cache_payload = export_hybrid_mamba_attention_static_cache(
|
|
364
|
+
cache=exported["past_key_values"],
|
|
365
|
+
current_pos=current_pos,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
model = getattr(engine.adapter, "model", None)
|
|
369
|
+
if model is None or not hasattr(model, "create_static_cache"):
|
|
370
|
+
raise HTTPException(status_code=500, detail="Adapter does not expose create_static_cache().")
|
|
371
|
+
|
|
372
|
+
if close_first:
|
|
373
|
+
# Free the old cache to reduce peak VRAM.
|
|
374
|
+
engine.adapter.close_session(session_id)
|
|
375
|
+
|
|
376
|
+
past_key_values = model.create_static_cache(batch_size=1, max_seq_len=target_max_seq_len)
|
|
377
|
+
restored_pos = import_hybrid_mamba_attention_static_cache(cache=past_key_values, payload=cache_payload)
|
|
378
|
+
engine.adapter.restore_session(
|
|
379
|
+
cache_id=session_id,
|
|
380
|
+
past_key_values=past_key_values,
|
|
381
|
+
current_pos=restored_pos,
|
|
382
|
+
max_seq_len=target_max_seq_len,
|
|
383
|
+
next_token_logits=None,
|
|
384
|
+
overwrite=True,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
info = engine.adapter.get_session_info(session_id)
|
|
388
|
+
info = dict(info)
|
|
389
|
+
info["cache_position"] = info.get("current_pos")
|
|
390
|
+
return {
|
|
391
|
+
"status": "resized",
|
|
392
|
+
"session_id": session_id,
|
|
393
|
+
"old_max_seq_len": old_max,
|
|
394
|
+
"max_seq_len": target_max_seq_len,
|
|
395
|
+
"current_pos": restored_pos,
|
|
396
|
+
"session": info,
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
try:
|
|
400
|
+
with _engine_lock:
|
|
401
|
+
if strategy == "gpu":
|
|
402
|
+
result = _allocate_and_restore(close_first=False)
|
|
403
|
+
elif strategy == "disk":
|
|
404
|
+
result = _allocate_and_restore(close_first=True)
|
|
405
|
+
else: # auto
|
|
406
|
+
try:
|
|
407
|
+
result = _allocate_and_restore(close_first=False)
|
|
408
|
+
except Exception as exc:
|
|
409
|
+
# Best-effort fallback on CUDA OOM by freeing old cache first.
|
|
410
|
+
try:
|
|
411
|
+
import torch # type: ignore
|
|
412
|
+
|
|
413
|
+
if isinstance(exc, torch.cuda.OutOfMemoryError):
|
|
414
|
+
result = _allocate_and_restore(close_first=True)
|
|
415
|
+
else:
|
|
416
|
+
raise
|
|
417
|
+
except HTTPException:
|
|
418
|
+
raise
|
|
419
|
+
except Exception:
|
|
420
|
+
raise
|
|
421
|
+
except HTTPException:
|
|
422
|
+
raise
|
|
423
|
+
except KeyError as exc:
|
|
424
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
425
|
+
except Exception as exc:
|
|
426
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
427
|
+
|
|
428
|
+
with _sessions_lock:
|
|
429
|
+
meta = _sessions.get(session_id)
|
|
430
|
+
if meta is not None:
|
|
431
|
+
meta.max_seq_len = int(target_max_seq_len)
|
|
432
|
+
if isinstance(result, dict) and isinstance(result.get("session"), dict):
|
|
433
|
+
result["session"]["message_count"] = len(meta.messages)
|
|
434
|
+
|
|
435
|
+
return result
|
|
436
|
+
|
|
437
|
+
@app.post("/v1/sessions/{session_id}/resize")
|
|
438
|
+
async def resize_session(session_id: str, request: Request) -> Any:
|
|
439
|
+
"""Resize a session KV cache to a new max sequence length.
|
|
440
|
+
|
|
441
|
+
Body:
|
|
442
|
+
- max_seq_len: int (required)
|
|
443
|
+
- strategy: "auto" | "gpu" | "disk" (optional; default: "auto")
|
|
444
|
+
|
|
445
|
+
Strategies:
|
|
446
|
+
- gpu: allocate new cache while old cache is still resident (higher peak VRAM)
|
|
447
|
+
- disk: free old cache before allocating new cache (lower peak VRAM)
|
|
448
|
+
- auto: try gpu, fall back to disk on OOM
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
payload = await _json_dict_or_empty(request)
|
|
452
|
+
raw_max = payload.get("max_seq_len")
|
|
453
|
+
if raw_max is None:
|
|
454
|
+
raise HTTPException(status_code=400, detail="'max_seq_len' is required.")
|
|
455
|
+
try:
|
|
456
|
+
new_max_seq_len = int(raw_max)
|
|
457
|
+
except Exception as exc:
|
|
458
|
+
raise HTTPException(status_code=400, detail="'max_seq_len' must be an integer.") from exc
|
|
459
|
+
|
|
460
|
+
strategy = _parse_resize_strategy(payload)
|
|
461
|
+
return JSONResponse(
|
|
462
|
+
_resize_session_to(session_id=session_id, target_max_seq_len=new_max_seq_len, strategy=strategy)
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
@app.post("/v1/sessions/{session_id}/resize/next_pow2")
|
|
466
|
+
async def resize_session_next_pow2(session_id: str, request: Request) -> Any:
|
|
467
|
+
"""Resize a session KV cache to the next power-of-two max sequence length.
|
|
468
|
+
|
|
469
|
+
Body (optional):
|
|
470
|
+
- strategy: "auto" | "gpu" | "disk" (default: "auto")
|
|
471
|
+
|
|
472
|
+
Example:
|
|
473
|
+
- 131072 -> 262144
|
|
474
|
+
- 262144 -> 524288
|
|
475
|
+
"""
|
|
476
|
+
|
|
477
|
+
payload = await _json_dict_or_empty(request)
|
|
478
|
+
strategy = _parse_resize_strategy(payload)
|
|
479
|
+
|
|
480
|
+
try:
|
|
481
|
+
with _engine_lock:
|
|
482
|
+
info = engine.adapter.get_session_info(session_id)
|
|
483
|
+
except KeyError as exc:
|
|
484
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
485
|
+
except Exception as exc:
|
|
486
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
487
|
+
|
|
488
|
+
old_max = int(info.get("max_seq_len") or 0)
|
|
489
|
+
new_max = _next_pow2_strictly_greater(old_max)
|
|
490
|
+
|
|
491
|
+
result = _resize_session_to(session_id=session_id, target_max_seq_len=new_max, strategy=strategy)
|
|
492
|
+
if isinstance(result, dict):
|
|
493
|
+
result = dict(result)
|
|
494
|
+
result["mode"] = "next_pow2"
|
|
495
|
+
return JSONResponse(result)
|
|
496
|
+
|
|
497
|
+
# -------------------------------------------------------------------------
|
|
498
|
+
# Snapshot Management (v1)
|
|
499
|
+
# -------------------------------------------------------------------------
|
|
500
|
+
|
|
501
|
+
@app.post("/v1/sessions/{session_id}/save")
|
|
502
|
+
async def save_session_snapshot(session_id: str, request: Request) -> Any:
|
|
503
|
+
"""Save a session to an immutable on-disk snapshot."""
|
|
504
|
+
payload = await _json_dict_or_empty(request)
|
|
505
|
+
|
|
506
|
+
title = payload.get("title")
|
|
507
|
+
description = payload.get("description")
|
|
508
|
+
tags = payload.get("tags")
|
|
509
|
+
if tags is not None and not isinstance(tags, list):
|
|
510
|
+
raise HTTPException(status_code=400, detail="'tags' must be a list of strings.")
|
|
511
|
+
|
|
512
|
+
transcript: list[dict[str, Any]] = []
|
|
513
|
+
try:
|
|
514
|
+
with _engine_lock:
|
|
515
|
+
exported = engine.adapter.export_session(session_id)
|
|
516
|
+
with _sessions_lock:
|
|
517
|
+
meta = _sessions.get(session_id)
|
|
518
|
+
transcript = list(meta.messages) if meta is not None else []
|
|
519
|
+
cache_payload = export_hybrid_mamba_attention_static_cache(
|
|
520
|
+
cache=exported["past_key_values"],
|
|
521
|
+
current_pos=int(exported["current_pos"]),
|
|
522
|
+
)
|
|
523
|
+
except KeyError as exc:
|
|
524
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
525
|
+
except Exception as exc:
|
|
526
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
527
|
+
|
|
528
|
+
store = _get_snapshot_store()
|
|
529
|
+
try:
|
|
530
|
+
manifest = store.create_snapshot(
|
|
531
|
+
transcript=transcript,
|
|
532
|
+
cache_payload=cache_payload,
|
|
533
|
+
session={"max_seq_len": int(exported["max_seq_len"]), "current_pos": int(exported["current_pos"])},
|
|
534
|
+
title=title if isinstance(title, str) else None,
|
|
535
|
+
description=description if isinstance(description, str) else None,
|
|
536
|
+
tags=[str(t) for t in tags] if isinstance(tags, list) else None,
|
|
537
|
+
)
|
|
538
|
+
except Exception as exc:
|
|
539
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
540
|
+
|
|
541
|
+
return JSONResponse({"status": "saved", "snapshot_id": manifest.snapshot_id})
|
|
542
|
+
|
|
543
|
+
@app.get("/v1/snapshots")
|
|
544
|
+
async def list_snapshots() -> Any:
|
|
545
|
+
store = _get_snapshot_store()
|
|
546
|
+
snaps = [m.to_dict() for m in store.list_snapshots()]
|
|
547
|
+
return JSONResponse({"snapshots": snaps})
|
|
548
|
+
|
|
549
|
+
@app.get("/v1/snapshots/{snapshot_id}")
|
|
550
|
+
async def get_snapshot(snapshot_id: str) -> Any:
|
|
551
|
+
store = _get_snapshot_store()
|
|
552
|
+
try:
|
|
553
|
+
manifest = store.get_manifest(snapshot_id)
|
|
554
|
+
except FileNotFoundError as exc:
|
|
555
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
556
|
+
except ValueError as exc:
|
|
557
|
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
558
|
+
return JSONResponse(manifest.to_dict())
|
|
559
|
+
|
|
560
|
+
@app.patch("/v1/snapshots/{snapshot_id}")
|
|
561
|
+
async def patch_snapshot(snapshot_id: str, request: Request) -> Any:
|
|
562
|
+
payload = await request.json()
|
|
563
|
+
if not isinstance(payload, dict):
|
|
564
|
+
raise HTTPException(status_code=400, detail="Request body must be a JSON object.")
|
|
565
|
+
|
|
566
|
+
title = payload.get("title")
|
|
567
|
+
description = payload.get("description")
|
|
568
|
+
tags = payload.get("tags")
|
|
569
|
+
if tags is not None and not isinstance(tags, list):
|
|
570
|
+
raise HTTPException(status_code=400, detail="'tags' must be a list of strings.")
|
|
571
|
+
|
|
572
|
+
store = _get_snapshot_store()
|
|
573
|
+
try:
|
|
574
|
+
updated = store.patch_metadata(
|
|
575
|
+
snapshot_id,
|
|
576
|
+
title=title if isinstance(title, str) else None,
|
|
577
|
+
description=description if isinstance(description, str) else None,
|
|
578
|
+
tags=[str(t) for t in tags] if isinstance(tags, list) else None,
|
|
579
|
+
)
|
|
580
|
+
except FileNotFoundError as exc:
|
|
581
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
582
|
+
except ValueError as exc:
|
|
583
|
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
584
|
+
return JSONResponse(updated.to_dict())
|
|
585
|
+
|
|
586
|
+
@app.delete("/v1/snapshots/{snapshot_id}")
|
|
587
|
+
async def delete_snapshot(snapshot_id: str) -> Any:
|
|
588
|
+
store = _get_snapshot_store()
|
|
589
|
+
try:
|
|
590
|
+
store.delete_snapshot(snapshot_id)
|
|
591
|
+
except FileNotFoundError as exc:
|
|
592
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
593
|
+
except ValueError as exc:
|
|
594
|
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
595
|
+
return JSONResponse({"status": "deleted", "snapshot_id": snapshot_id})
|
|
596
|
+
|
|
597
|
+
@app.post("/v1/snapshots/{snapshot_id}/load")
|
|
598
|
+
async def load_snapshot(snapshot_id: str, request: Request) -> Any:
|
|
599
|
+
payload = await _json_dict_or_empty(request)
|
|
600
|
+
|
|
601
|
+
target_session_id = payload.get("session_id")
|
|
602
|
+
if target_session_id is not None and not isinstance(target_session_id, str):
|
|
603
|
+
raise HTTPException(status_code=400, detail="'session_id' must be a string.")
|
|
604
|
+
force = bool(payload.get("force", False))
|
|
605
|
+
if not target_session_id:
|
|
606
|
+
target_session_id = f"sess_{uuid.uuid4().hex}"
|
|
607
|
+
|
|
608
|
+
store = _get_snapshot_store()
|
|
609
|
+
try:
|
|
610
|
+
manifest, transcript, cache_payload = store.load_snapshot_payload(snapshot_id)
|
|
611
|
+
except SnapshotCompatibilityError as exc:
|
|
612
|
+
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
|
613
|
+
except FileNotFoundError as exc:
|
|
614
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
615
|
+
except ValueError as exc:
|
|
616
|
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
617
|
+
except Exception as exc:
|
|
618
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
619
|
+
|
|
620
|
+
max_seq_len = int(manifest.session.get("max_seq_len") or default_max_seq_len)
|
|
621
|
+
expected_pos = int(manifest.session.get("current_pos") or 0)
|
|
622
|
+
restored_pos = expected_pos
|
|
623
|
+
|
|
624
|
+
try:
|
|
625
|
+
with _engine_lock:
|
|
626
|
+
# Avoid accidental overwrite unless explicitly forced.
|
|
627
|
+
try:
|
|
628
|
+
engine.adapter.get_session_info(target_session_id)
|
|
629
|
+
if not force:
|
|
630
|
+
raise HTTPException(
|
|
631
|
+
status_code=409,
|
|
632
|
+
detail=f"Session already exists: {target_session_id} (use force=true to overwrite).",
|
|
633
|
+
)
|
|
634
|
+
engine.adapter.close_session(target_session_id)
|
|
635
|
+
except KeyError:
|
|
636
|
+
pass
|
|
637
|
+
|
|
638
|
+
model = getattr(engine.adapter, "model", None)
|
|
639
|
+
if model is None or not hasattr(model, "create_static_cache"):
|
|
640
|
+
raise HTTPException(status_code=500, detail="Adapter does not expose create_static_cache().")
|
|
641
|
+
|
|
642
|
+
past_key_values = model.create_static_cache(batch_size=1, max_seq_len=max_seq_len)
|
|
643
|
+
restored_pos = import_hybrid_mamba_attention_static_cache(
|
|
644
|
+
cache=past_key_values, payload=cache_payload
|
|
645
|
+
)
|
|
646
|
+
engine.adapter.restore_session(
|
|
647
|
+
cache_id=target_session_id,
|
|
648
|
+
past_key_values=past_key_values,
|
|
649
|
+
current_pos=restored_pos,
|
|
650
|
+
max_seq_len=max_seq_len,
|
|
651
|
+
next_token_logits=None,
|
|
652
|
+
overwrite=False,
|
|
653
|
+
)
|
|
654
|
+
except HTTPException:
|
|
655
|
+
raise
|
|
656
|
+
except Exception as exc:
|
|
657
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
658
|
+
|
|
659
|
+
with _sessions_lock:
|
|
660
|
+
_sessions[target_session_id] = _HttpSession(max_seq_len=max_seq_len, messages=transcript)
|
|
661
|
+
|
|
662
|
+
return JSONResponse(
|
|
663
|
+
{
|
|
664
|
+
"status": "loaded",
|
|
665
|
+
"snapshot_id": snapshot_id,
|
|
666
|
+
"session_id": target_session_id,
|
|
667
|
+
"session": {
|
|
668
|
+
"current_pos": restored_pos,
|
|
669
|
+
"max_seq_len": max_seq_len,
|
|
670
|
+
"message_count": len(transcript),
|
|
671
|
+
},
|
|
672
|
+
}
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
# -------------------------------------------------------------------------
|
|
676
|
+
# Chat Completions
|
|
677
|
+
# -------------------------------------------------------------------------
|
|
678
|
+
|
|
679
|
+
@app.post("/v1/chat/completions")
|
|
680
|
+
async def chat_completions(request: Request) -> Any:
|
|
681
|
+
payload = await request.json()
|
|
682
|
+
|
|
683
|
+
req_model = payload.get("model")
|
|
684
|
+
if req_model is not None and req_model != model_id:
|
|
685
|
+
raise HTTPException(status_code=404, detail=f"Unknown model: {req_model}")
|
|
686
|
+
|
|
687
|
+
chat_req = _parse_chat_request(payload, http_max_completion_tokens=http_max_completion_tokens)
|
|
688
|
+
|
|
689
|
+
# Session chat: maintain server-side message history and only append delta tokens.
|
|
690
|
+
if chat_req.session_id:
|
|
691
|
+
session_id = chat_req.session_id
|
|
692
|
+
# Ensure session exists.
|
|
693
|
+
try:
|
|
694
|
+
with _engine_lock:
|
|
695
|
+
sess_info = engine.adapter.get_session_info(session_id)
|
|
696
|
+
except KeyError as exc:
|
|
697
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
698
|
+
|
|
699
|
+
try:
|
|
700
|
+
current_pos = int(sess_info.get("current_pos", 0) or 0)
|
|
701
|
+
except Exception:
|
|
702
|
+
current_pos = 0
|
|
703
|
+
|
|
704
|
+
with _sessions_lock:
|
|
705
|
+
meta = _sessions.get(session_id)
|
|
706
|
+
if meta is None:
|
|
707
|
+
# Session exists in adapter but not in HTTP store (e.g., server restarted).
|
|
708
|
+
meta = _HttpSession(max_seq_len=int(sess_info.get("max_seq_len", default_max_seq_len)))
|
|
709
|
+
_sessions[session_id] = meta
|
|
710
|
+
|
|
711
|
+
# Safety: a non-empty KV cache with an empty HTTP transcript means the server cannot
|
|
712
|
+
# correctly compute delta tokens to append. Proceeding would cause the model to ignore
|
|
713
|
+
# new user input or append mismatched tokens.
|
|
714
|
+
if current_pos > 0 and not meta.messages:
|
|
715
|
+
raise HTTPException(
|
|
716
|
+
status_code=409,
|
|
717
|
+
detail=(
|
|
718
|
+
"Session KV cache is non-empty but HTTP transcript is empty. "
|
|
719
|
+
"This indicates a corrupted/incomplete session state. "
|
|
720
|
+
"Start a new session or restore from a snapshot."
|
|
721
|
+
),
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
# Treat incoming messages as delta and append to stored history.
|
|
725
|
+
incoming_raw = payload.get("messages")
|
|
726
|
+
if isinstance(incoming_raw, list):
|
|
727
|
+
incoming_msgs = [m for m in incoming_raw if isinstance(m, dict)]
|
|
728
|
+
|
|
729
|
+
# Keep at most one leading system message in the stored transcript.
|
|
730
|
+
# The CLI may send the same system prompt every turn; accumulating
|
|
731
|
+
# duplicates wastes context and can degrade multi-turn coherence.
|
|
732
|
+
if incoming_msgs and incoming_msgs[0].get("role") == "system":
|
|
733
|
+
if meta.messages and isinstance(meta.messages[0], dict) and meta.messages[0].get("role") == "system":
|
|
734
|
+
# Never mutate the prompt prefix once the KV cache has advanced.
|
|
735
|
+
# Even a small change to the system message would invalidate the cached
|
|
736
|
+
# token stream and corrupt append-from slicing.
|
|
737
|
+
if current_pos <= 0:
|
|
738
|
+
meta.messages[0] = incoming_msgs[0]
|
|
739
|
+
# Drop the incoming system message to avoid duplicates.
|
|
740
|
+
incoming_msgs = incoming_msgs[1:]
|
|
741
|
+
|
|
742
|
+
meta.messages.extend(incoming_msgs)
|
|
743
|
+
|
|
744
|
+
full_messages = list(meta.messages)
|
|
745
|
+
|
|
746
|
+
# Build a new ChatRequest with full history messages.
|
|
747
|
+
# Append-from position is the current KV cache position.
|
|
748
|
+
chat_req = ChatRequest(
|
|
749
|
+
messages=_parse_chat_request({"messages": full_messages}).messages,
|
|
750
|
+
tools=chat_req.tools,
|
|
751
|
+
tool_choice=chat_req.tool_choice,
|
|
752
|
+
max_tokens=chat_req.max_tokens,
|
|
753
|
+
temperature=chat_req.temperature,
|
|
754
|
+
top_p=chat_req.top_p,
|
|
755
|
+
stop=chat_req.stop,
|
|
756
|
+
stream=chat_req.stream,
|
|
757
|
+
stream_options=chat_req.stream_options,
|
|
758
|
+
chat_template_kwargs=chat_req.chat_template_kwargs,
|
|
759
|
+
reasoning_budget=chat_req.reasoning_budget,
|
|
760
|
+
discard_thinking=chat_req.discard_thinking,
|
|
761
|
+
stream_thinking=chat_req.stream_thinking,
|
|
762
|
+
session_id=session_id,
|
|
763
|
+
session_append_from_pos=current_pos,
|
|
764
|
+
extra=chat_req.extra,
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
created = int(time.time())
|
|
768
|
+
chatcmpl_id = f"chatcmpl-{uuid.uuid4().hex}"
|
|
769
|
+
|
|
770
|
+
if chat_req.stream:
|
|
771
|
+
await _try_acquire_semaphore()
|
|
772
|
+
event_iter = _stream_chat_completions(
|
|
773
|
+
engine=engine,
|
|
774
|
+
chat_request=chat_req,
|
|
775
|
+
model_id=model_id,
|
|
776
|
+
created=created,
|
|
777
|
+
chatcmpl_id=chatcmpl_id,
|
|
778
|
+
sessions=_sessions,
|
|
779
|
+
sessions_lock=_sessions_lock,
|
|
780
|
+
request=request,
|
|
781
|
+
http_semaphore=http_semaphore,
|
|
782
|
+
semaphore_already_acquired=True,
|
|
783
|
+
)
|
|
784
|
+
return StreamingResponse(event_iter, media_type="text/event-stream")
|
|
785
|
+
|
|
786
|
+
await _try_acquire_semaphore()
|
|
787
|
+
try:
|
|
788
|
+
result = await _run_with_disconnect_cancellation(request, engine.generate_chat(chat_req))
|
|
789
|
+
except HTTPException:
|
|
790
|
+
raise
|
|
791
|
+
except ValueError as exc:
|
|
792
|
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
793
|
+
except Exception as exc:
|
|
794
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
795
|
+
finally:
|
|
796
|
+
if http_semaphore is not None:
|
|
797
|
+
http_semaphore.release()
|
|
798
|
+
|
|
799
|
+
usage = result.get("usage")
|
|
800
|
+
finish_reason = result.get("finish_reason") or "stop"
|
|
801
|
+
content = result.get("content")
|
|
802
|
+
tool_calls = result.get("tool_calls") or []
|
|
803
|
+
raw_content = result.get("raw_content")
|
|
804
|
+
|
|
805
|
+
message: dict[str, Any] = {"role": "assistant", "content": content}
|
|
806
|
+
if tool_calls:
|
|
807
|
+
message["content"] = None
|
|
808
|
+
message["tool_calls"] = [_openai_tool_call(tc) for tc in tool_calls]
|
|
809
|
+
|
|
810
|
+
# Persist assistant message to session history (non-streaming)
|
|
811
|
+
if chat_req.session_id:
|
|
812
|
+
# When discard_thinking=True, persist only the stripped content (no <think> blocks).
|
|
813
|
+
# Otherwise, persist raw_content (with thinking) if available.
|
|
814
|
+
if chat_req.discard_thinking:
|
|
815
|
+
history_content = content if content is not None else ""
|
|
816
|
+
else:
|
|
817
|
+
if raw_content is not None:
|
|
818
|
+
history_content = raw_content
|
|
819
|
+
else:
|
|
820
|
+
history_content = content if content is not None else ""
|
|
821
|
+
history_msg: dict[str, Any] = {"role": "assistant", "content": history_content}
|
|
822
|
+
if tool_calls:
|
|
823
|
+
history_msg["content"] = None
|
|
824
|
+
history_msg["tool_calls"] = [_openai_tool_call(tc) for tc in tool_calls]
|
|
825
|
+
|
|
826
|
+
# Persist empty-string assistant messages too (never null).
|
|
827
|
+
if tool_calls or isinstance(history_content, str):
|
|
828
|
+
with _sessions_lock:
|
|
829
|
+
meta = _sessions.get(chat_req.session_id)
|
|
830
|
+
if meta is not None:
|
|
831
|
+
meta.messages.append(history_msg)
|
|
832
|
+
|
|
833
|
+
resp: dict[str, Any] = {
|
|
834
|
+
"id": chatcmpl_id,
|
|
835
|
+
"object": "chat.completion",
|
|
836
|
+
"created": created,
|
|
837
|
+
"model": model_id,
|
|
838
|
+
"choices": [
|
|
839
|
+
{
|
|
840
|
+
"index": 0,
|
|
841
|
+
"message": message,
|
|
842
|
+
"finish_reason": finish_reason,
|
|
843
|
+
}
|
|
844
|
+
],
|
|
845
|
+
}
|
|
846
|
+
|
|
847
|
+
if isinstance(usage, Usage):
|
|
848
|
+
resp["usage"] = {
|
|
849
|
+
"prompt_tokens": usage.prompt_tokens,
|
|
850
|
+
"completion_tokens": usage.completion_tokens,
|
|
851
|
+
"total_tokens": usage.total_tokens,
|
|
852
|
+
}
|
|
853
|
+
|
|
854
|
+
return JSONResponse(resp)
|
|
855
|
+
|
|
856
|
+
return app
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
def _sse(data: str) -> str:
|
|
860
|
+
return f"data: {data}\n\n"
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
def _openai_tool_call(tc: ToolCall) -> dict[str, Any]:
|
|
864
|
+
return {
|
|
865
|
+
"id": tc.id,
|
|
866
|
+
"type": "function",
|
|
867
|
+
"function": {"name": tc.name, "arguments": json.dumps(tc.arguments, ensure_ascii=False)},
|
|
868
|
+
}
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
def _openai_stream_tool_call(tc: ToolCall, *, index: int) -> dict[str, Any]:
|
|
872
|
+
d = _openai_tool_call(tc)
|
|
873
|
+
d["index"] = index
|
|
874
|
+
return d
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
async def _stream_chat_completions(
|
|
878
|
+
*,
|
|
879
|
+
engine: ChatEngine,
|
|
880
|
+
chat_request: ChatRequest,
|
|
881
|
+
model_id: str,
|
|
882
|
+
created: int,
|
|
883
|
+
chatcmpl_id: str,
|
|
884
|
+
sessions: dict[str, Any],
|
|
885
|
+
sessions_lock: threading.Lock,
|
|
886
|
+
request: Request,
|
|
887
|
+
http_semaphore: asyncio.Semaphore | None,
|
|
888
|
+
semaphore_already_acquired: bool = False,
|
|
889
|
+
) -> AsyncIterator[str]:
|
|
890
|
+
if http_semaphore is not None and not semaphore_already_acquired:
|
|
891
|
+
try:
|
|
892
|
+
await asyncio.wait_for(http_semaphore.acquire(), timeout=0.001)
|
|
893
|
+
except TimeoutError as exc:
|
|
894
|
+
raise HTTPException(status_code=429, detail="Server is busy") from exc
|
|
895
|
+
|
|
896
|
+
# Initial chunk announces the role.
|
|
897
|
+
yield _sse(
|
|
898
|
+
json.dumps(
|
|
899
|
+
{
|
|
900
|
+
"id": chatcmpl_id,
|
|
901
|
+
"object": "chat.completion.chunk",
|
|
902
|
+
"created": created,
|
|
903
|
+
"model": model_id,
|
|
904
|
+
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
|
|
905
|
+
},
|
|
906
|
+
ensure_ascii=False,
|
|
907
|
+
)
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
final_finish_reason: str | None = None
|
|
911
|
+
final_usage: Usage | None = None
|
|
912
|
+
final_timing: Timing | None = None
|
|
913
|
+
assistant_text_parts: list[str] = []
|
|
914
|
+
assistant_tool_calls: list[ToolCall] = []
|
|
915
|
+
raw_content_for_history: str | None = None # Raw content with thinking for discard_thinking=False
|
|
916
|
+
cancelled = False
|
|
917
|
+
|
|
918
|
+
try:
|
|
919
|
+
async for event in engine.astream_chat(chat_request):
|
|
920
|
+
from superlinear.engine.chat_types import DeltaEvent, ThinkingDeltaEvent, ErrorEvent, FinalEvent, ToolCallEvent
|
|
921
|
+
|
|
922
|
+
# If the client disconnects mid-stream, stop consuming promptly.
|
|
923
|
+
# The engine stream is cancelled/closed when this generator unwinds.
|
|
924
|
+
if await request.is_disconnected():
|
|
925
|
+
cancelled = True
|
|
926
|
+
final_finish_reason = "cancelled"
|
|
927
|
+
break
|
|
928
|
+
|
|
929
|
+
if isinstance(event, DeltaEvent):
|
|
930
|
+
if not event.text:
|
|
931
|
+
continue
|
|
932
|
+
assistant_text_parts.append(event.text)
|
|
933
|
+
yield _sse(
|
|
934
|
+
json.dumps(
|
|
935
|
+
{
|
|
936
|
+
"id": chatcmpl_id,
|
|
937
|
+
"object": "chat.completion.chunk",
|
|
938
|
+
"created": created,
|
|
939
|
+
"model": model_id,
|
|
940
|
+
"choices": [
|
|
941
|
+
{
|
|
942
|
+
"index": 0,
|
|
943
|
+
"delta": {"content": event.text},
|
|
944
|
+
"finish_reason": None,
|
|
945
|
+
}
|
|
946
|
+
],
|
|
947
|
+
},
|
|
948
|
+
ensure_ascii=False,
|
|
949
|
+
)
|
|
950
|
+
)
|
|
951
|
+
continue
|
|
952
|
+
|
|
953
|
+
if isinstance(event, ThinkingDeltaEvent):
|
|
954
|
+
if not event.text:
|
|
955
|
+
continue
|
|
956
|
+
yield _sse(
|
|
957
|
+
json.dumps(
|
|
958
|
+
{
|
|
959
|
+
"id": chatcmpl_id,
|
|
960
|
+
"object": "chat.completion.chunk",
|
|
961
|
+
"created": created,
|
|
962
|
+
"model": model_id,
|
|
963
|
+
"choices": [
|
|
964
|
+
{
|
|
965
|
+
"index": 0,
|
|
966
|
+
"delta": {"thinking": event.text},
|
|
967
|
+
"finish_reason": None,
|
|
968
|
+
}
|
|
969
|
+
],
|
|
970
|
+
},
|
|
971
|
+
ensure_ascii=False,
|
|
972
|
+
)
|
|
973
|
+
)
|
|
974
|
+
continue
|
|
975
|
+
|
|
976
|
+
if isinstance(event, ToolCallEvent):
|
|
977
|
+
# Tool call detected - set finish reason to tool_calls
|
|
978
|
+
final_finish_reason = "tool_calls"
|
|
979
|
+
assistant_tool_calls = list(event.tool_calls)
|
|
980
|
+
yield _sse(
|
|
981
|
+
json.dumps(
|
|
982
|
+
{
|
|
983
|
+
"id": chatcmpl_id,
|
|
984
|
+
"object": "chat.completion.chunk",
|
|
985
|
+
"created": created,
|
|
986
|
+
"model": model_id,
|
|
987
|
+
"choices": [
|
|
988
|
+
{
|
|
989
|
+
"index": 0,
|
|
990
|
+
"delta": {
|
|
991
|
+
"tool_calls": [
|
|
992
|
+
_openai_stream_tool_call(tc, index=i)
|
|
993
|
+
for i, tc in enumerate(event.tool_calls)
|
|
994
|
+
]
|
|
995
|
+
},
|
|
996
|
+
"finish_reason": None,
|
|
997
|
+
}
|
|
998
|
+
],
|
|
999
|
+
},
|
|
1000
|
+
ensure_ascii=False,
|
|
1001
|
+
)
|
|
1002
|
+
)
|
|
1003
|
+
continue
|
|
1004
|
+
|
|
1005
|
+
if isinstance(event, FinalEvent):
|
|
1006
|
+
# Don't override tool_calls finish reason
|
|
1007
|
+
if final_finish_reason != "tool_calls":
|
|
1008
|
+
final_finish_reason = event.finish_reason
|
|
1009
|
+
# Capture raw content if provided (for discard_thinking=False sessions)
|
|
1010
|
+
if event.raw_content is not None:
|
|
1011
|
+
raw_content_for_history = event.raw_content
|
|
1012
|
+
final_usage = event.usage
|
|
1013
|
+
final_timing = event.timing
|
|
1014
|
+
continue
|
|
1015
|
+
|
|
1016
|
+
if isinstance(event, ErrorEvent):
|
|
1017
|
+
yield _sse(
|
|
1018
|
+
json.dumps(
|
|
1019
|
+
{
|
|
1020
|
+
"error": {
|
|
1021
|
+
"message": event.message,
|
|
1022
|
+
"type": "server_error",
|
|
1023
|
+
"param": None,
|
|
1024
|
+
"code": None,
|
|
1025
|
+
}
|
|
1026
|
+
},
|
|
1027
|
+
ensure_ascii=False,
|
|
1028
|
+
)
|
|
1029
|
+
)
|
|
1030
|
+
final_finish_reason = "error"
|
|
1031
|
+
break
|
|
1032
|
+
except asyncio.CancelledError:
|
|
1033
|
+
cancelled = True
|
|
1034
|
+
final_finish_reason = "cancelled"
|
|
1035
|
+
raise
|
|
1036
|
+
except ValueError as exc:
|
|
1037
|
+
yield _sse(
|
|
1038
|
+
json.dumps(
|
|
1039
|
+
{
|
|
1040
|
+
"error": {
|
|
1041
|
+
"message": str(exc),
|
|
1042
|
+
"type": "invalid_request_error",
|
|
1043
|
+
"param": None,
|
|
1044
|
+
"code": None,
|
|
1045
|
+
}
|
|
1046
|
+
},
|
|
1047
|
+
ensure_ascii=False,
|
|
1048
|
+
)
|
|
1049
|
+
)
|
|
1050
|
+
final_finish_reason = "error"
|
|
1051
|
+
except Exception as exc:
|
|
1052
|
+
yield _sse(
|
|
1053
|
+
json.dumps(
|
|
1054
|
+
{
|
|
1055
|
+
"error": {
|
|
1056
|
+
"message": str(exc),
|
|
1057
|
+
"type": "server_error",
|
|
1058
|
+
"param": None,
|
|
1059
|
+
"code": None,
|
|
1060
|
+
}
|
|
1061
|
+
},
|
|
1062
|
+
ensure_ascii=False,
|
|
1063
|
+
)
|
|
1064
|
+
)
|
|
1065
|
+
final_finish_reason = "error"
|
|
1066
|
+
finally:
|
|
1067
|
+
if http_semaphore is not None:
|
|
1068
|
+
http_semaphore.release()
|
|
1069
|
+
|
|
1070
|
+
# Persist assistant message to session history (best-effort).
|
|
1071
|
+
# Important: on cancelled/incomplete streams, persist an *empty string* (not null)
|
|
1072
|
+
# so HTTP transcript stays aligned with the adapter session KV state.
|
|
1073
|
+
if chat_request.session_id:
|
|
1074
|
+
if chat_request.discard_thinking:
|
|
1075
|
+
text = "".join(assistant_text_parts) if assistant_text_parts else ""
|
|
1076
|
+
history_content: str | None = text
|
|
1077
|
+
else:
|
|
1078
|
+
if raw_content_for_history is not None:
|
|
1079
|
+
history_content = raw_content_for_history
|
|
1080
|
+
else:
|
|
1081
|
+
history_content = "".join(assistant_text_parts) if assistant_text_parts else ""
|
|
1082
|
+
|
|
1083
|
+
msg: dict[str, Any] = {
|
|
1084
|
+
"role": "assistant",
|
|
1085
|
+
"content": history_content,
|
|
1086
|
+
}
|
|
1087
|
+
if assistant_tool_calls:
|
|
1088
|
+
msg["content"] = None
|
|
1089
|
+
msg["tool_calls"] = [_openai_tool_call(tc) for tc in assistant_tool_calls]
|
|
1090
|
+
|
|
1091
|
+
# Avoid persisting pure-null assistant messages.
|
|
1092
|
+
should_persist = True
|
|
1093
|
+
if not assistant_tool_calls and history_content is None:
|
|
1094
|
+
should_persist = False
|
|
1095
|
+
|
|
1096
|
+
if should_persist:
|
|
1097
|
+
with sessions_lock:
|
|
1098
|
+
meta = sessions.get(chat_request.session_id)
|
|
1099
|
+
if meta is not None and hasattr(meta, "messages"):
|
|
1100
|
+
meta.messages.append(msg) # type: ignore[attr-defined]
|
|
1101
|
+
|
|
1102
|
+
# Terminal chunk + DONE (skip if the request was cancelled/disconnected).
|
|
1103
|
+
if not cancelled:
|
|
1104
|
+
if final_finish_reason is None:
|
|
1105
|
+
final_finish_reason = "stop"
|
|
1106
|
+
|
|
1107
|
+
terminal: dict[str, Any] = {
|
|
1108
|
+
"id": chatcmpl_id,
|
|
1109
|
+
"object": "chat.completion.chunk",
|
|
1110
|
+
"created": created,
|
|
1111
|
+
"model": model_id,
|
|
1112
|
+
"choices": [{"index": 0, "delta": {}, "finish_reason": final_finish_reason}],
|
|
1113
|
+
}
|
|
1114
|
+
if isinstance(final_usage, Usage):
|
|
1115
|
+
terminal["usage"] = {
|
|
1116
|
+
"prompt_tokens": final_usage.prompt_tokens,
|
|
1117
|
+
"completion_tokens": final_usage.completion_tokens,
|
|
1118
|
+
"total_tokens": final_usage.total_tokens,
|
|
1119
|
+
}
|
|
1120
|
+
if isinstance(final_timing, Timing):
|
|
1121
|
+
terminal["x_superlinear_timing"] = {
|
|
1122
|
+
"prefill_s": final_timing.prefill_s,
|
|
1123
|
+
"decode_s": final_timing.decode_s,
|
|
1124
|
+
"total_s": final_timing.total_s,
|
|
1125
|
+
"tok_per_s": final_timing.tok_per_s,
|
|
1126
|
+
}
|
|
1127
|
+
yield _sse(json.dumps(terminal, ensure_ascii=False))
|
|
1128
|
+
yield "data: [DONE]\n\n"
|
|
1129
|
+
|
|
1130
|
+
|
|
1131
|
+
def _parse_chat_request(payload: Any, *, http_max_completion_tokens: int | None = None) -> ChatRequest:
|
|
1132
|
+
if not isinstance(payload, dict):
|
|
1133
|
+
raise HTTPException(status_code=400, detail="Request body must be a JSON object.")
|
|
1134
|
+
|
|
1135
|
+
raw_messages = payload.get("messages")
|
|
1136
|
+
if not isinstance(raw_messages, list) or not raw_messages:
|
|
1137
|
+
raise HTTPException(status_code=400, detail="'messages' must be a non-empty list.")
|
|
1138
|
+
|
|
1139
|
+
messages: list[ChatMessage] = []
|
|
1140
|
+
for msg in raw_messages:
|
|
1141
|
+
if not isinstance(msg, dict):
|
|
1142
|
+
raise HTTPException(status_code=400, detail="Each message must be an object.")
|
|
1143
|
+
|
|
1144
|
+
role = msg.get("role")
|
|
1145
|
+
if role not in {"system", "user", "assistant", "tool"}:
|
|
1146
|
+
raise HTTPException(status_code=400, detail=f"Invalid message role: {role!r}.")
|
|
1147
|
+
|
|
1148
|
+
content = _coerce_content(msg.get("content"))
|
|
1149
|
+
|
|
1150
|
+
tool_call_id = msg.get("tool_call_id") if role == "tool" else None
|
|
1151
|
+
tool_calls: list[ToolCall] = []
|
|
1152
|
+
|
|
1153
|
+
if role == "assistant" and msg.get("tool_calls") is not None:
|
|
1154
|
+
raw_tool_calls = msg.get("tool_calls")
|
|
1155
|
+
if not isinstance(raw_tool_calls, list):
|
|
1156
|
+
raise HTTPException(status_code=400, detail="'tool_calls' must be a list.")
|
|
1157
|
+
|
|
1158
|
+
for tc in raw_tool_calls:
|
|
1159
|
+
tool_calls.append(_parse_assistant_tool_call(tc))
|
|
1160
|
+
|
|
1161
|
+
messages.append(
|
|
1162
|
+
ChatMessage(
|
|
1163
|
+
role=role,
|
|
1164
|
+
content=content,
|
|
1165
|
+
tool_calls=tool_calls,
|
|
1166
|
+
tool_call_id=tool_call_id,
|
|
1167
|
+
)
|
|
1168
|
+
)
|
|
1169
|
+
|
|
1170
|
+
tools = payload.get("tools") or []
|
|
1171
|
+
if tools is None:
|
|
1172
|
+
tools = []
|
|
1173
|
+
if not isinstance(tools, list):
|
|
1174
|
+
raise HTTPException(status_code=400, detail="'tools' must be a list.")
|
|
1175
|
+
|
|
1176
|
+
tool_choice = payload.get("tool_choice")
|
|
1177
|
+
|
|
1178
|
+
max_tokens = payload.get("max_tokens")
|
|
1179
|
+
max_completion_tokens = payload.get("max_completion_tokens")
|
|
1180
|
+
|
|
1181
|
+
if max_tokens is None and max_completion_tokens is None:
|
|
1182
|
+
max_tokens = 4096
|
|
1183
|
+
elif max_tokens is not None and max_completion_tokens is not None:
|
|
1184
|
+
try:
|
|
1185
|
+
if int(max_tokens) != int(max_completion_tokens):
|
|
1186
|
+
raise HTTPException(
|
|
1187
|
+
status_code=400,
|
|
1188
|
+
detail="'max_tokens' and 'max_completion_tokens' must match when both are provided.",
|
|
1189
|
+
)
|
|
1190
|
+
except HTTPException:
|
|
1191
|
+
raise
|
|
1192
|
+
except Exception as exc:
|
|
1193
|
+
raise HTTPException(
|
|
1194
|
+
status_code=400,
|
|
1195
|
+
detail="'max_tokens' and 'max_completion_tokens' must be integers.",
|
|
1196
|
+
) from exc
|
|
1197
|
+
elif max_completion_tokens is not None:
|
|
1198
|
+
max_tokens = max_completion_tokens
|
|
1199
|
+
|
|
1200
|
+
try:
|
|
1201
|
+
max_tokens = int(max_tokens)
|
|
1202
|
+
except Exception as exc:
|
|
1203
|
+
raise HTTPException(status_code=400, detail="'max_tokens' must be an integer.") from exc
|
|
1204
|
+
if max_tokens <= 0:
|
|
1205
|
+
raise HTTPException(status_code=400, detail="'max_tokens' must be > 0.")
|
|
1206
|
+
|
|
1207
|
+
if http_max_completion_tokens is not None and max_tokens > http_max_completion_tokens:
|
|
1208
|
+
raise HTTPException(
|
|
1209
|
+
status_code=400,
|
|
1210
|
+
detail=f"'max_tokens' too large: {max_tokens} (cap={http_max_completion_tokens}).",
|
|
1211
|
+
)
|
|
1212
|
+
|
|
1213
|
+
try:
|
|
1214
|
+
temperature = float(payload.get("temperature", 0.1) or 0.1)
|
|
1215
|
+
except Exception as exc:
|
|
1216
|
+
raise HTTPException(status_code=400, detail="'temperature' must be a number.") from exc
|
|
1217
|
+
|
|
1218
|
+
try:
|
|
1219
|
+
top_p = float(payload.get("top_p", 0.95) or 0.95)
|
|
1220
|
+
except Exception as exc:
|
|
1221
|
+
raise HTTPException(status_code=400, detail="'top_p' must be a number.") from exc
|
|
1222
|
+
|
|
1223
|
+
stop = payload.get("stop") or []
|
|
1224
|
+
if isinstance(stop, str):
|
|
1225
|
+
stop = [stop]
|
|
1226
|
+
if not isinstance(stop, list):
|
|
1227
|
+
raise HTTPException(status_code=400, detail="'stop' must be a string or list of strings.")
|
|
1228
|
+
stop = [s for s in stop if isinstance(s, str)]
|
|
1229
|
+
|
|
1230
|
+
stream = bool(payload.get("stream", False))
|
|
1231
|
+
|
|
1232
|
+
stream_options = payload.get("stream_options") or {}
|
|
1233
|
+
if stream_options is None:
|
|
1234
|
+
stream_options = {}
|
|
1235
|
+
if not isinstance(stream_options, dict):
|
|
1236
|
+
raise HTTPException(status_code=400, detail="'stream_options' must be an object.")
|
|
1237
|
+
|
|
1238
|
+
try:
|
|
1239
|
+
flush_every_n_tokens = int(stream_options.get("flush_every_n_tokens", 8))
|
|
1240
|
+
flush_every_ms = int(stream_options.get("flush_every_ms", 50))
|
|
1241
|
+
except Exception as exc:
|
|
1242
|
+
raise HTTPException(
|
|
1243
|
+
status_code=400,
|
|
1244
|
+
detail="'stream_options.flush_every_n_tokens' and 'stream_options.flush_every_ms' must be integers.",
|
|
1245
|
+
) from exc
|
|
1246
|
+
|
|
1247
|
+
# Parse chat_template_kwargs (optional, vLLM-compatible)
|
|
1248
|
+
chat_template_kwargs = payload.get("chat_template_kwargs")
|
|
1249
|
+
if chat_template_kwargs is not None and not isinstance(chat_template_kwargs, dict):
|
|
1250
|
+
raise HTTPException(status_code=400, detail="'chat_template_kwargs' must be an object.")
|
|
1251
|
+
|
|
1252
|
+
# Parse reasoning_budget (optional, Superlinear-specific)
|
|
1253
|
+
reasoning_budget = payload.get("reasoning_budget")
|
|
1254
|
+
if reasoning_budget is not None:
|
|
1255
|
+
try:
|
|
1256
|
+
reasoning_budget = int(reasoning_budget)
|
|
1257
|
+
except (ValueError, TypeError) as exc:
|
|
1258
|
+
raise HTTPException(status_code=400, detail="'reasoning_budget' must be an integer.") from exc
|
|
1259
|
+
if reasoning_budget <= 0:
|
|
1260
|
+
raise HTTPException(status_code=400, detail="'reasoning_budget' must be > 0.")
|
|
1261
|
+
|
|
1262
|
+
# Parse discard_thinking (optional, Superlinear-specific)
|
|
1263
|
+
discard_thinking = payload.get("discard_thinking")
|
|
1264
|
+
if discard_thinking is not None and not isinstance(discard_thinking, bool):
|
|
1265
|
+
raise HTTPException(status_code=400, detail="'discard_thinking' must be a boolean.")
|
|
1266
|
+
|
|
1267
|
+
# Parse stream_thinking (optional, Superlinear-specific)
|
|
1268
|
+
stream_thinking = payload.get("stream_thinking")
|
|
1269
|
+
if stream_thinking is not None and not isinstance(stream_thinking, bool):
|
|
1270
|
+
raise HTTPException(status_code=400, detail="'stream_thinking' must be a boolean.")
|
|
1271
|
+
|
|
1272
|
+
# Parse session_id (optional, for stateful chat)
|
|
1273
|
+
session_id = payload.get("session_id")
|
|
1274
|
+
if session_id is not None and not isinstance(session_id, str):
|
|
1275
|
+
raise HTTPException(status_code=400, detail="'session_id' must be a string.")
|
|
1276
|
+
|
|
1277
|
+
# Parse extra (optional, engine-specific)
|
|
1278
|
+
extra = payload.get("extra")
|
|
1279
|
+
if extra is None:
|
|
1280
|
+
extra = {}
|
|
1281
|
+
if not isinstance(extra, dict):
|
|
1282
|
+
raise HTTPException(status_code=400, detail="'extra' must be an object.")
|
|
1283
|
+
|
|
1284
|
+
# Convenience alias: allow top-level repetition_detection to be passed through.
|
|
1285
|
+
repetition_detection = payload.get("repetition_detection")
|
|
1286
|
+
if repetition_detection is not None and "repetition_detection" not in extra:
|
|
1287
|
+
extra = dict(extra)
|
|
1288
|
+
extra["repetition_detection"] = repetition_detection
|
|
1289
|
+
|
|
1290
|
+
return ChatRequest(
|
|
1291
|
+
messages=messages,
|
|
1292
|
+
tools=tools,
|
|
1293
|
+
tool_choice=tool_choice,
|
|
1294
|
+
max_tokens=max_tokens,
|
|
1295
|
+
temperature=temperature,
|
|
1296
|
+
top_p=top_p,
|
|
1297
|
+
stop=stop,
|
|
1298
|
+
stream=stream,
|
|
1299
|
+
stream_options=StreamOptions(
|
|
1300
|
+
flush_every_n_tokens=flush_every_n_tokens,
|
|
1301
|
+
flush_every_ms=flush_every_ms,
|
|
1302
|
+
),
|
|
1303
|
+
chat_template_kwargs=chat_template_kwargs,
|
|
1304
|
+
reasoning_budget=reasoning_budget,
|
|
1305
|
+
discard_thinking=discard_thinking,
|
|
1306
|
+
stream_thinking=stream_thinking,
|
|
1307
|
+
session_id=session_id,
|
|
1308
|
+
extra=extra,
|
|
1309
|
+
)
|
|
1310
|
+
|
|
1311
|
+
|
|
1312
|
+
def _coerce_content(content: Any) -> str:
|
|
1313
|
+
if content is None:
|
|
1314
|
+
return ""
|
|
1315
|
+
if isinstance(content, str):
|
|
1316
|
+
return content
|
|
1317
|
+
|
|
1318
|
+
# Minimal support for OpenAI "content parts" format (text-only).
|
|
1319
|
+
if isinstance(content, list):
|
|
1320
|
+
parts: list[str] = []
|
|
1321
|
+
for part in content:
|
|
1322
|
+
if not isinstance(part, dict):
|
|
1323
|
+
continue
|
|
1324
|
+
if part.get("type") != "text":
|
|
1325
|
+
continue
|
|
1326
|
+
text = part.get("text")
|
|
1327
|
+
if isinstance(text, str):
|
|
1328
|
+
parts.append(text)
|
|
1329
|
+
return "".join(parts)
|
|
1330
|
+
|
|
1331
|
+
raise HTTPException(status_code=400, detail="Unsupported message content type.")
|
|
1332
|
+
|
|
1333
|
+
|
|
1334
|
+
def _parse_assistant_tool_call(tc: Any) -> ToolCall:
|
|
1335
|
+
if not isinstance(tc, dict):
|
|
1336
|
+
raise HTTPException(status_code=400, detail="Each tool_call must be an object.")
|
|
1337
|
+
|
|
1338
|
+
fn = tc.get("function")
|
|
1339
|
+
if not isinstance(fn, dict):
|
|
1340
|
+
raise HTTPException(status_code=400, detail="tool_call.function must be an object.")
|
|
1341
|
+
|
|
1342
|
+
name = fn.get("name")
|
|
1343
|
+
if not isinstance(name, str) or not name:
|
|
1344
|
+
raise HTTPException(status_code=400, detail="tool_call.function.name must be a string.")
|
|
1345
|
+
|
|
1346
|
+
arguments = fn.get("arguments")
|
|
1347
|
+
args_dict: dict[str, Any] = {}
|
|
1348
|
+
if isinstance(arguments, str) and arguments.strip():
|
|
1349
|
+
try:
|
|
1350
|
+
parsed = json.loads(arguments)
|
|
1351
|
+
if isinstance(parsed, dict):
|
|
1352
|
+
args_dict = parsed
|
|
1353
|
+
except Exception:
|
|
1354
|
+
# Best-effort fallback: preserve raw payload under a reserved key.
|
|
1355
|
+
args_dict = {"__raw__": arguments}
|
|
1356
|
+
elif isinstance(arguments, dict):
|
|
1357
|
+
args_dict = arguments
|
|
1358
|
+
|
|
1359
|
+
tool_call_id = tc.get("id")
|
|
1360
|
+
if not isinstance(tool_call_id, str) or not tool_call_id:
|
|
1361
|
+
tool_call_id = f"call_{uuid.uuid4().hex}"
|
|
1362
|
+
|
|
1363
|
+
return ToolCall(id=tool_call_id, name=name, arguments=args_dict)
|