superlinear 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- apps/__init__.py +4 -0
- apps/cli/__init__.py +8 -0
- apps/cli/bm25_rag.py +471 -0
- apps/cli/chat_repl.py +1497 -0
- apps/cli/client.py +195 -0
- apps/cli/docs_repl.py +2275 -0
- apps/cli/light_rag.py +729 -0
- apps/cli/local_snapshots.py +139 -0
- apps/cli/locks.py +214 -0
- apps/cli/main.py +457 -0
- apps/cli/output.py +32 -0
- apps/cli/server_cmds.py +516 -0
- apps/cli/session_cmds.py +491 -0
- apps/cli/snapshot_cmds.py +303 -0
- apps/cli/state.py +265 -0
- apps/server/__init__.py +4 -0
- apps/server/app.py +1363 -0
- apps/server/main.py +313 -0
- superlinear/__init__.py +114 -0
- superlinear/_version.py +3 -0
- superlinear/engine/__init__.py +10 -0
- superlinear/engine/adapters/__init__.py +12 -0
- superlinear/engine/adapters/base.py +91 -0
- superlinear/engine/adapters/superlinear.py +1233 -0
- superlinear/engine/chat_engine.py +1173 -0
- superlinear/engine/chat_types.py +130 -0
- superlinear/engine/registry.py +51 -0
- superlinear/engine/repetition.py +203 -0
- superlinear/engine/session_snapshots.py +451 -0
- superlinear/engine/tool_parser.py +83 -0
- superlinear/engine/types.py +42 -0
- superlinear/kernels/__init__.py +2 -0
- superlinear/kernels/common/__init__.py +21 -0
- superlinear/kernels/common/adjustment.py +106 -0
- superlinear/kernels/common/power.py +154 -0
- superlinear/kernels/superlinear/__init__.py +10 -0
- superlinear/kernels/superlinear/attention/__init__.py +78 -0
- superlinear/kernels/superlinear/attention/_prefill.py +940 -0
- superlinear/kernels/superlinear/attention/_sliding_window.py +1167 -0
- superlinear/kernels/superlinear/attention/api.py +433 -0
- superlinear/kernels/superlinear/search/__init__.py +33 -0
- superlinear/kernels/superlinear/search/_reference.py +204 -0
- superlinear/kernels/superlinear/search/_triton.py +488 -0
- superlinear/kernels/superlinear/search/_triton_gqa.py +534 -0
- superlinear/kernels/superlinear/search/api.py +200 -0
- superlinear/kernels/superlinear/span/__init__.py +41 -0
- superlinear/kernels/superlinear/span/_triton_bucketed_gqa.py +1461 -0
- superlinear/kernels/superlinear/span/_triton_forward.py +22 -0
- superlinear/kernels/superlinear/span/_triton_gqa.py +1226 -0
- superlinear/kernels/superlinear/span/_triton_impl.py +928 -0
- superlinear/kernels/superlinear/span/_triton_precomputed_sw.py +460 -0
- superlinear/kernels/superlinear/span/_triton_precomputed_sw_gqa.py +598 -0
- superlinear/kernels/superlinear/span/api.py +296 -0
- superlinear/kernels/superlinear/span/masks.py +187 -0
- superlinear/py.typed +0 -0
- superlinear/runtime.py +71 -0
- superlinear-0.1.0.dist-info/METADATA +469 -0
- superlinear-0.1.0.dist-info/RECORD +62 -0
- superlinear-0.1.0.dist-info/WHEEL +5 -0
- superlinear-0.1.0.dist-info/entry_points.txt +2 -0
- superlinear-0.1.0.dist-info/licenses/LICENSE +202 -0
- superlinear-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,1173 @@
|
|
|
1
|
+
"""Chunked async chat inference engine (single-GPU, single-flight).
|
|
2
|
+
|
|
3
|
+
This module provides the core, reusable engine:
|
|
4
|
+
- request normalization -> tokenizer prompt
|
|
5
|
+
- serialized adapter execution
|
|
6
|
+
- chunk-level async streaming
|
|
7
|
+
- tool call detection + parsing
|
|
8
|
+
|
|
9
|
+
It deliberately contains no HTTP/FastAPI code.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import asyncio
|
|
15
|
+
from collections import deque
|
|
16
|
+
import logging
|
|
17
|
+
import threading
|
|
18
|
+
import time
|
|
19
|
+
import uuid
|
|
20
|
+
from dataclasses import dataclass, field
|
|
21
|
+
from typing import Any, AsyncIterator, Sequence
|
|
22
|
+
|
|
23
|
+
from .chat_types import (
|
|
24
|
+
ChatMessage,
|
|
25
|
+
ChatRequest,
|
|
26
|
+
DeltaEvent,
|
|
27
|
+
ErrorEvent,
|
|
28
|
+
FinalEvent,
|
|
29
|
+
ThinkingDeltaEvent,
|
|
30
|
+
StreamEvent,
|
|
31
|
+
Timing,
|
|
32
|
+
ToolCall,
|
|
33
|
+
ToolCallEvent,
|
|
34
|
+
Usage,
|
|
35
|
+
)
|
|
36
|
+
from .repetition import RepetitionDetectionConfig, detect_repetition_kmp_tail
|
|
37
|
+
from .tool_parser import ToolCallParseError, parse_tool_call_block
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True)
|
|
43
|
+
class EngineConfig:
|
|
44
|
+
"""Engine-wide defaults and limits."""
|
|
45
|
+
|
|
46
|
+
default_backend: str = "custom"
|
|
47
|
+
enable_thinking: bool = True
|
|
48
|
+
discard_thinking: bool = True
|
|
49
|
+
max_prompt_tokens: int = 262_144
|
|
50
|
+
max_tool_calls_per_turn: int = 8
|
|
51
|
+
repetition_detection: RepetitionDetectionConfig = field(
|
|
52
|
+
default_factory=RepetitionDetectionConfig
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class _ModelOutputParser:
|
|
57
|
+
"""Incremental parser for model output.
|
|
58
|
+
|
|
59
|
+
Responsibilities:
|
|
60
|
+
- Remove <think>...</think> blocks (if the model emits them).
|
|
61
|
+
- Detect and buffer a complete <tool_call>...</tool_call>.
|
|
62
|
+
- Apply stop sequences (string-based) to normal text.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
_THINK_OPEN = "<think>"
|
|
66
|
+
_THINK_CLOSE = "</think>"
|
|
67
|
+
_TOOL_OPEN = "<tool_call>"
|
|
68
|
+
_TOOL_CLOSE = "</tool_call>"
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
*,
|
|
73
|
+
stop_sequences: Sequence[str],
|
|
74
|
+
valid_tool_names: set[str],
|
|
75
|
+
max_tool_calls: int,
|
|
76
|
+
allow_tool_calls: bool,
|
|
77
|
+
start_in_think: bool = False,
|
|
78
|
+
emit_thinking: bool = False,
|
|
79
|
+
) -> None:
|
|
80
|
+
self._stop_sequences = [s for s in stop_sequences if s]
|
|
81
|
+
self._max_stop_len = max((len(s) for s in self._stop_sequences), default=0)
|
|
82
|
+
|
|
83
|
+
self._valid_tool_names = valid_tool_names
|
|
84
|
+
self._max_tool_calls = max_tool_calls
|
|
85
|
+
self._allow_tool_calls = allow_tool_calls
|
|
86
|
+
|
|
87
|
+
self._buffer = ""
|
|
88
|
+
self._tool_buffer = ""
|
|
89
|
+
# When enable_thinking=True, the generation prompt already includes <think>,
|
|
90
|
+
# so we start inside the think block and wait for </think> before emitting.
|
|
91
|
+
self._in_think = start_in_think
|
|
92
|
+
self._in_tool = False
|
|
93
|
+
|
|
94
|
+
self._emit_thinking = bool(emit_thinking)
|
|
95
|
+
self._emitted_think_open = False
|
|
96
|
+
|
|
97
|
+
self.stopped: bool = False
|
|
98
|
+
self.stop_reason: str | None = None # "stop" | "tool_calls"
|
|
99
|
+
self.tool_calls: list[ToolCall] = []
|
|
100
|
+
|
|
101
|
+
self._tail_keep = max(
|
|
102
|
+
len(self._THINK_OPEN) - 1,
|
|
103
|
+
len(self._THINK_CLOSE) - 1,
|
|
104
|
+
len(self._TOOL_OPEN) - 1,
|
|
105
|
+
len(self._TOOL_CLOSE) - 1,
|
|
106
|
+
max(self._max_stop_len - 1, 0),
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def feed(self, text: str) -> list[StreamEvent]:
|
|
110
|
+
if not text or self.stopped:
|
|
111
|
+
return []
|
|
112
|
+
|
|
113
|
+
self._buffer += text
|
|
114
|
+
events: list[StreamEvent] = []
|
|
115
|
+
|
|
116
|
+
if self._emit_thinking and self._in_think and not self._emitted_think_open:
|
|
117
|
+
events.append(ThinkingDeltaEvent(self._THINK_OPEN))
|
|
118
|
+
self._emitted_think_open = True
|
|
119
|
+
|
|
120
|
+
while self._buffer and not self.stopped:
|
|
121
|
+
if self._in_think:
|
|
122
|
+
end = self._buffer.find(self._THINK_CLOSE)
|
|
123
|
+
if end == -1:
|
|
124
|
+
if not self._emit_thinking:
|
|
125
|
+
# Drop accumulated thinking content; keep small tail in case the close tag is split.
|
|
126
|
+
if self._tail_keep:
|
|
127
|
+
self._buffer = self._buffer[-self._tail_keep :]
|
|
128
|
+
else:
|
|
129
|
+
self._buffer = ""
|
|
130
|
+
break
|
|
131
|
+
|
|
132
|
+
# Emit thinking content, keeping a small tail in case the close tag is split.
|
|
133
|
+
if self._tail_keep and len(self._buffer) > self._tail_keep:
|
|
134
|
+
emit_text = self._buffer[: -self._tail_keep]
|
|
135
|
+
self._buffer = self._buffer[-self._tail_keep :]
|
|
136
|
+
else:
|
|
137
|
+
emit_text = self._buffer
|
|
138
|
+
self._buffer = ""
|
|
139
|
+
|
|
140
|
+
if emit_text:
|
|
141
|
+
events.append(ThinkingDeltaEvent(emit_text))
|
|
142
|
+
break
|
|
143
|
+
|
|
144
|
+
if self._emit_thinking:
|
|
145
|
+
think_text = self._buffer[:end]
|
|
146
|
+
if think_text:
|
|
147
|
+
events.append(ThinkingDeltaEvent(think_text))
|
|
148
|
+
events.append(ThinkingDeltaEvent(self._THINK_CLOSE))
|
|
149
|
+
self._buffer = self._buffer[end + len(self._THINK_CLOSE) :]
|
|
150
|
+
self._in_think = False
|
|
151
|
+
self._emitted_think_open = False
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
# Drop everything through the closing tag.
|
|
155
|
+
self._buffer = self._buffer[end + len(self._THINK_CLOSE) :]
|
|
156
|
+
self._in_think = False
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
if self._in_tool:
|
|
160
|
+
self._tool_buffer += self._buffer
|
|
161
|
+
self._buffer = ""
|
|
162
|
+
|
|
163
|
+
end = self._tool_buffer.find(self._TOOL_CLOSE)
|
|
164
|
+
if end == -1:
|
|
165
|
+
break
|
|
166
|
+
|
|
167
|
+
block = self._tool_buffer[: end + len(self._TOOL_CLOSE)]
|
|
168
|
+
trailing = self._tool_buffer[end + len(self._TOOL_CLOSE) :]
|
|
169
|
+
self._tool_buffer = ""
|
|
170
|
+
self._in_tool = False
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
parsed = parse_tool_call_block(block)
|
|
174
|
+
except ToolCallParseError as exc:
|
|
175
|
+
self.stopped = True
|
|
176
|
+
self.stop_reason = "error"
|
|
177
|
+
events.append(ErrorEvent(f"Failed to parse tool call: {exc}"))
|
|
178
|
+
break
|
|
179
|
+
|
|
180
|
+
if not self._allow_tool_calls:
|
|
181
|
+
self.stopped = True
|
|
182
|
+
self.stop_reason = "error"
|
|
183
|
+
events.append(ErrorEvent("Tool calls are disabled for this request (tool_choice='none')."))
|
|
184
|
+
break
|
|
185
|
+
|
|
186
|
+
if self._valid_tool_names and parsed.name not in self._valid_tool_names:
|
|
187
|
+
self.stopped = True
|
|
188
|
+
self.stop_reason = "error"
|
|
189
|
+
events.append(
|
|
190
|
+
ErrorEvent(
|
|
191
|
+
f"Model called unknown tool {parsed.name!r}. "
|
|
192
|
+
"Ensure the tool is present in the request 'tools' list."
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
break
|
|
196
|
+
|
|
197
|
+
if len(self.tool_calls) >= self._max_tool_calls:
|
|
198
|
+
self.stopped = True
|
|
199
|
+
self.stop_reason = "error"
|
|
200
|
+
events.append(
|
|
201
|
+
ErrorEvent(
|
|
202
|
+
f"Too many tool calls in one turn (max={self._max_tool_calls})."
|
|
203
|
+
)
|
|
204
|
+
)
|
|
205
|
+
break
|
|
206
|
+
|
|
207
|
+
self.tool_calls.append(
|
|
208
|
+
ToolCall(id=f"call_{uuid.uuid4().hex}", name=parsed.name, arguments=parsed.arguments)
|
|
209
|
+
)
|
|
210
|
+
events.append(ToolCallEvent(tool_calls=list(self.tool_calls)))
|
|
211
|
+
|
|
212
|
+
# Tool call completes the turn.
|
|
213
|
+
self.stopped = True
|
|
214
|
+
self.stop_reason = "tool_calls"
|
|
215
|
+
|
|
216
|
+
# Any trailing content after </tool_call> is ignored for v0.
|
|
217
|
+
_ = trailing
|
|
218
|
+
break
|
|
219
|
+
|
|
220
|
+
# Normal (non-think, non-tool) mode.
|
|
221
|
+
next_think = self._buffer.find(self._THINK_OPEN)
|
|
222
|
+
next_tool = self._buffer.find(self._TOOL_OPEN)
|
|
223
|
+
|
|
224
|
+
next_special = -1
|
|
225
|
+
if next_think != -1 and next_tool != -1:
|
|
226
|
+
next_special = min(next_think, next_tool)
|
|
227
|
+
elif next_think != -1:
|
|
228
|
+
next_special = next_think
|
|
229
|
+
elif next_tool != -1:
|
|
230
|
+
next_special = next_tool
|
|
231
|
+
|
|
232
|
+
if next_special != -1:
|
|
233
|
+
# Emit content before the special tag.
|
|
234
|
+
before = self._buffer[:next_special]
|
|
235
|
+
self._buffer = self._buffer[next_special:]
|
|
236
|
+
events.extend(self._emit_text(before))
|
|
237
|
+
|
|
238
|
+
if self.stopped:
|
|
239
|
+
break
|
|
240
|
+
|
|
241
|
+
if self._buffer.startswith(self._THINK_OPEN):
|
|
242
|
+
self._buffer = self._buffer[len(self._THINK_OPEN) :]
|
|
243
|
+
self._in_think = True
|
|
244
|
+
if self._emit_thinking:
|
|
245
|
+
events.append(ThinkingDeltaEvent(self._THINK_OPEN))
|
|
246
|
+
self._emitted_think_open = True
|
|
247
|
+
continue
|
|
248
|
+
if self._buffer.startswith(self._TOOL_OPEN):
|
|
249
|
+
self._tool_buffer = self._TOOL_OPEN
|
|
250
|
+
self._buffer = self._buffer[len(self._TOOL_OPEN) :]
|
|
251
|
+
self._in_tool = True
|
|
252
|
+
continue
|
|
253
|
+
|
|
254
|
+
# No special tags found in the current buffer.
|
|
255
|
+
events.extend(self._emit_available_text())
|
|
256
|
+
break
|
|
257
|
+
|
|
258
|
+
return [e for e in events if not isinstance(e, DeltaEvent) or e.text]
|
|
259
|
+
|
|
260
|
+
def finish(self) -> list[StreamEvent]:
|
|
261
|
+
"""Flush any remaining buffered content at end-of-generation."""
|
|
262
|
+
if self.stopped:
|
|
263
|
+
return []
|
|
264
|
+
|
|
265
|
+
if self._in_tool:
|
|
266
|
+
self.stopped = True
|
|
267
|
+
self.stop_reason = "error"
|
|
268
|
+
return [ErrorEvent("Incomplete <tool_call> block in model output.")]
|
|
269
|
+
|
|
270
|
+
# If we're still inside a think block, drop it.
|
|
271
|
+
if self._in_think:
|
|
272
|
+
self._buffer = ""
|
|
273
|
+
return []
|
|
274
|
+
|
|
275
|
+
# Emit remaining buffer (apply stop sequences if needed).
|
|
276
|
+
remaining = self._buffer
|
|
277
|
+
self._buffer = ""
|
|
278
|
+
return self._emit_text(remaining)
|
|
279
|
+
|
|
280
|
+
def _emit_text(self, text: str) -> list[StreamEvent]:
|
|
281
|
+
if not text:
|
|
282
|
+
return []
|
|
283
|
+
|
|
284
|
+
if self._stop_sequences:
|
|
285
|
+
idx = self._find_earliest_stop(text)
|
|
286
|
+
if idx is not None:
|
|
287
|
+
before = text[:idx]
|
|
288
|
+
self.stopped = True
|
|
289
|
+
self.stop_reason = "stop"
|
|
290
|
+
return [DeltaEvent(before)] if before else []
|
|
291
|
+
|
|
292
|
+
return [DeltaEvent(text)]
|
|
293
|
+
|
|
294
|
+
def _emit_available_text(self) -> list[StreamEvent]:
|
|
295
|
+
if not self._buffer:
|
|
296
|
+
return []
|
|
297
|
+
|
|
298
|
+
if self._tail_keep <= 0:
|
|
299
|
+
text = self._buffer
|
|
300
|
+
self._buffer = ""
|
|
301
|
+
return self._emit_text(text)
|
|
302
|
+
|
|
303
|
+
if len(self._buffer) <= self._tail_keep:
|
|
304
|
+
return []
|
|
305
|
+
|
|
306
|
+
safe_end = len(self._buffer) - self._tail_keep
|
|
307
|
+
safe = self._buffer[:safe_end]
|
|
308
|
+
self._buffer = self._buffer[safe_end:]
|
|
309
|
+
return self._emit_text(safe)
|
|
310
|
+
|
|
311
|
+
def _find_earliest_stop(self, text: str) -> int | None:
|
|
312
|
+
earliest: int | None = None
|
|
313
|
+
for s in self._stop_sequences:
|
|
314
|
+
idx = text.find(s)
|
|
315
|
+
if idx == -1:
|
|
316
|
+
continue
|
|
317
|
+
if earliest is None or idx < earliest:
|
|
318
|
+
earliest = idx
|
|
319
|
+
return earliest
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class ChatEngine:
|
|
323
|
+
"""Core chat inference engine.
|
|
324
|
+
|
|
325
|
+
Thread-safety:
|
|
326
|
+
The underlying adapter is not thread-safe. This engine serializes access
|
|
327
|
+
with a global lock (single-flight).
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
def __init__(self, adapter: Any, *, config: EngineConfig | None = None) -> None:
|
|
331
|
+
self._adapter = adapter
|
|
332
|
+
self._config = config or EngineConfig()
|
|
333
|
+
self._lock = threading.Lock()
|
|
334
|
+
|
|
335
|
+
@property
|
|
336
|
+
def adapter(self) -> Any:
|
|
337
|
+
"""Access to the underlying adapter for session management."""
|
|
338
|
+
return self._adapter
|
|
339
|
+
|
|
340
|
+
@property
|
|
341
|
+
def model_info(self) -> dict[str, Any]:
|
|
342
|
+
return getattr(self._adapter, "model_info", {})
|
|
343
|
+
|
|
344
|
+
def shutdown(self) -> None:
|
|
345
|
+
unload = getattr(self._adapter, "unload", None)
|
|
346
|
+
if callable(unload):
|
|
347
|
+
unload()
|
|
348
|
+
|
|
349
|
+
def _tool_names(self, tools: Sequence[dict[str, Any]]) -> set[str]:
|
|
350
|
+
names: set[str] = set()
|
|
351
|
+
for tool in tools:
|
|
352
|
+
if not isinstance(tool, dict):
|
|
353
|
+
continue
|
|
354
|
+
if tool.get("type") != "function":
|
|
355
|
+
continue
|
|
356
|
+
fn = tool.get("function")
|
|
357
|
+
if isinstance(fn, dict):
|
|
358
|
+
name = fn.get("name")
|
|
359
|
+
if isinstance(name, str) and name:
|
|
360
|
+
names.add(name)
|
|
361
|
+
return names
|
|
362
|
+
|
|
363
|
+
def _inject_tool_choice(self, messages: list[ChatMessage], tool_choice: Any) -> list[ChatMessage]:
|
|
364
|
+
if tool_choice is None:
|
|
365
|
+
return messages
|
|
366
|
+
|
|
367
|
+
instruction: str | None = None
|
|
368
|
+
if tool_choice == "none":
|
|
369
|
+
instruction = "Tool choice: do not call any tools."
|
|
370
|
+
elif tool_choice == "required":
|
|
371
|
+
instruction = "Tool choice: you must call a tool for this response."
|
|
372
|
+
elif isinstance(tool_choice, dict):
|
|
373
|
+
# OpenAI format: {"type":"function","function":{"name":"..."}}
|
|
374
|
+
fn = tool_choice.get("function") if tool_choice.get("type") == "function" else None
|
|
375
|
+
name = fn.get("name") if isinstance(fn, dict) else None
|
|
376
|
+
if isinstance(name, str) and name:
|
|
377
|
+
instruction = f"Tool choice: call only the tool named {name!r}."
|
|
378
|
+
|
|
379
|
+
if instruction is None:
|
|
380
|
+
return messages
|
|
381
|
+
|
|
382
|
+
if messages and messages[0].role == "system":
|
|
383
|
+
updated = ChatMessage(
|
|
384
|
+
role="system",
|
|
385
|
+
content=((messages[0].content or "").rstrip() + "\n\n" + instruction).strip(),
|
|
386
|
+
tool_calls=messages[0].tool_calls,
|
|
387
|
+
tool_call_id=messages[0].tool_call_id,
|
|
388
|
+
)
|
|
389
|
+
return [updated, *messages[1:]]
|
|
390
|
+
|
|
391
|
+
return [ChatMessage(role="system", content=instruction), *messages]
|
|
392
|
+
|
|
393
|
+
def _messages_for_template(self, messages: list[ChatMessage]) -> list[dict[str, Any]]:
|
|
394
|
+
out: list[dict[str, Any]] = []
|
|
395
|
+
for m in messages:
|
|
396
|
+
msg: dict[str, Any] = {"role": m.role}
|
|
397
|
+
if m.content is not None:
|
|
398
|
+
msg["content"] = m.content
|
|
399
|
+
else:
|
|
400
|
+
msg["content"] = ""
|
|
401
|
+
|
|
402
|
+
if m.role == "assistant" and m.tool_calls:
|
|
403
|
+
# The model chat template expects tool_call.function.arguments as a mapping (not a JSON string).
|
|
404
|
+
tool_calls: list[dict[str, Any]] = []
|
|
405
|
+
for tc in m.tool_calls:
|
|
406
|
+
tool_calls.append(
|
|
407
|
+
{
|
|
408
|
+
"id": tc.id,
|
|
409
|
+
"type": "function",
|
|
410
|
+
"function": {"name": tc.name, "arguments": tc.arguments},
|
|
411
|
+
}
|
|
412
|
+
)
|
|
413
|
+
msg["tool_calls"] = tool_calls
|
|
414
|
+
|
|
415
|
+
if m.role == "tool" and m.tool_call_id:
|
|
416
|
+
msg["tool_call_id"] = m.tool_call_id
|
|
417
|
+
|
|
418
|
+
out.append(msg)
|
|
419
|
+
return out
|
|
420
|
+
|
|
421
|
+
def _validate_tool_choice(self, request: ChatRequest) -> None:
|
|
422
|
+
tc = request.tool_choice
|
|
423
|
+
if tc is None or tc == "auto":
|
|
424
|
+
return
|
|
425
|
+
|
|
426
|
+
tool_names = self._tool_names(request.tools)
|
|
427
|
+
if tc == "none":
|
|
428
|
+
return
|
|
429
|
+
|
|
430
|
+
if not request.tools:
|
|
431
|
+
raise ValueError("tool_choice was provided but no tools were supplied in the request.")
|
|
432
|
+
|
|
433
|
+
if tc == "required":
|
|
434
|
+
return
|
|
435
|
+
|
|
436
|
+
if isinstance(tc, dict) and tc.get("type") == "function":
|
|
437
|
+
fn = tc.get("function")
|
|
438
|
+
name = fn.get("name") if isinstance(fn, dict) else None
|
|
439
|
+
if isinstance(name, str) and name and tool_names and name not in tool_names:
|
|
440
|
+
raise ValueError(f"tool_choice requested unknown tool {name!r}.")
|
|
441
|
+
|
|
442
|
+
def _single_token_stop_ids(self, stop: Sequence[str]) -> tuple[list[int], list[str]]:
|
|
443
|
+
"""Split stops into (single-token stop ids, string stops).
|
|
444
|
+
|
|
445
|
+
We only early-stop on single-token sequences. Anything else is enforced
|
|
446
|
+
by the output parser (string scan).
|
|
447
|
+
"""
|
|
448
|
+
stop_token_ids: list[int] = []
|
|
449
|
+
stop_strings: list[str] = []
|
|
450
|
+
tokenizer = getattr(self._adapter, "tokenizer", None)
|
|
451
|
+
|
|
452
|
+
# Dynamically get special tokens from the tokenizer to filter from output.
|
|
453
|
+
# This keeps the engine model-agnostic.
|
|
454
|
+
if tokenizer is not None:
|
|
455
|
+
all_special = getattr(tokenizer, "all_special_tokens", None)
|
|
456
|
+
if all_special:
|
|
457
|
+
stop_strings.extend(s for s in all_special if s)
|
|
458
|
+
|
|
459
|
+
encode = getattr(tokenizer, "encode", None)
|
|
460
|
+
if callable(encode):
|
|
461
|
+
for s in stop:
|
|
462
|
+
if not s:
|
|
463
|
+
continue
|
|
464
|
+
try:
|
|
465
|
+
ids = encode(s, add_special_tokens=False)
|
|
466
|
+
except Exception:
|
|
467
|
+
stop_strings.append(s)
|
|
468
|
+
continue
|
|
469
|
+
if isinstance(ids, list) and len(ids) == 1 and isinstance(ids[0], int):
|
|
470
|
+
stop_token_ids.append(ids[0])
|
|
471
|
+
else:
|
|
472
|
+
stop_strings.append(s)
|
|
473
|
+
return stop_token_ids, stop_strings
|
|
474
|
+
|
|
475
|
+
return [], [s for s in stop if s]
|
|
476
|
+
|
|
477
|
+
def _build_input_ids(self, request: ChatRequest) -> Any:
|
|
478
|
+
apply_chat_template, template_messages, kwargs = self._prepare_chat_template(request)
|
|
479
|
+
return apply_chat_template(template_messages, add_generation_prompt=True, **kwargs)
|
|
480
|
+
|
|
481
|
+
def _effective_enable_thinking(self, request: ChatRequest) -> bool:
|
|
482
|
+
enable_thinking = bool(self._config.enable_thinking)
|
|
483
|
+
if request.reasoning_budget is not None:
|
|
484
|
+
enable_thinking = True
|
|
485
|
+
if request.chat_template_kwargs and "enable_thinking" in request.chat_template_kwargs:
|
|
486
|
+
enable_thinking = bool(request.chat_template_kwargs["enable_thinking"])
|
|
487
|
+
return enable_thinking
|
|
488
|
+
|
|
489
|
+
def _effective_discard_thinking(self, request: ChatRequest) -> bool:
|
|
490
|
+
discard_thinking = bool(self._config.discard_thinking)
|
|
491
|
+
if request.discard_thinking is not None:
|
|
492
|
+
discard_thinking = bool(request.discard_thinking)
|
|
493
|
+
return discard_thinking
|
|
494
|
+
|
|
495
|
+
def _chat_template_kwargs(self, request: ChatRequest, *, enable_thinking: bool) -> dict[str, Any]:
|
|
496
|
+
"""Build kwargs for tokenizer.apply_chat_template with strict parity.
|
|
497
|
+
|
|
498
|
+
Important: This preserves the distinction between "tools omitted" vs `tools=[]`.
|
|
499
|
+
"""
|
|
500
|
+
kwargs: dict[str, Any] = {
|
|
501
|
+
"return_tensors": "pt",
|
|
502
|
+
"enable_thinking": bool(enable_thinking),
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
# Merge request-level chat_template_kwargs (can override enable_thinking)
|
|
506
|
+
if request.chat_template_kwargs:
|
|
507
|
+
kwargs.update(request.chat_template_kwargs)
|
|
508
|
+
# The engine controls add_generation_prompt explicitly.
|
|
509
|
+
kwargs.pop("add_generation_prompt", None)
|
|
510
|
+
|
|
511
|
+
# For tool_choice="none", omit tool definitions from the prompt to reduce the chance of
|
|
512
|
+
# accidental tool calls (in addition to the injected instruction).
|
|
513
|
+
tools = request.tools
|
|
514
|
+
if request.tool_choice == "none":
|
|
515
|
+
tools = []
|
|
516
|
+
|
|
517
|
+
# Preserve "tools omitted" when tools==[].
|
|
518
|
+
if tools:
|
|
519
|
+
kwargs["tools"] = tools
|
|
520
|
+
|
|
521
|
+
return kwargs
|
|
522
|
+
|
|
523
|
+
def _prepare_chat_template(self, request: ChatRequest) -> tuple[Any, list[dict[str, Any]], dict[str, Any]]:
|
|
524
|
+
tokenizer = getattr(self._adapter, "tokenizer", None)
|
|
525
|
+
if tokenizer is None:
|
|
526
|
+
raise RuntimeError("Adapter has no tokenizer loaded.")
|
|
527
|
+
|
|
528
|
+
self._validate_tool_choice(request)
|
|
529
|
+
|
|
530
|
+
apply_chat_template = getattr(tokenizer, "apply_chat_template", None)
|
|
531
|
+
if not callable(apply_chat_template):
|
|
532
|
+
raise RuntimeError("Tokenizer does not support apply_chat_template().")
|
|
533
|
+
|
|
534
|
+
messages = self._inject_tool_choice(list(request.messages), request.tool_choice)
|
|
535
|
+
template_messages = self._messages_for_template(messages)
|
|
536
|
+
kwargs = self._chat_template_kwargs(request, enable_thinking=self._effective_enable_thinking(request))
|
|
537
|
+
return apply_chat_template, template_messages, kwargs
|
|
538
|
+
|
|
539
|
+
def _compute_generation_boundary(
|
|
540
|
+
self,
|
|
541
|
+
*,
|
|
542
|
+
apply_chat_template: Any,
|
|
543
|
+
template_messages: list[dict[str, Any]],
|
|
544
|
+
kwargs: dict[str, Any],
|
|
545
|
+
) -> tuple[int, Any, Any, Any]:
|
|
546
|
+
"""Compute the (end-of-user) boundary and generation prompt suffix.
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
(boundary_pos, ids_no_gen, ids_with_gen, gen_prompt_ids)
|
|
550
|
+
|
|
551
|
+
Raises:
|
|
552
|
+
ValueError if the strict-prefix invariant is violated.
|
|
553
|
+
"""
|
|
554
|
+
ids_no_gen = apply_chat_template(template_messages, add_generation_prompt=False, **kwargs)
|
|
555
|
+
ids_with_gen = apply_chat_template(template_messages, add_generation_prompt=True, **kwargs)
|
|
556
|
+
|
|
557
|
+
try:
|
|
558
|
+
boundary_pos = int(getattr(ids_no_gen, "shape")[1])
|
|
559
|
+
with_len = int(getattr(ids_with_gen, "shape")[1])
|
|
560
|
+
except Exception as exc: # pragma: no cover
|
|
561
|
+
raise ValueError("apply_chat_template() must return a tensor with shape (1, seq_len).") from exc
|
|
562
|
+
|
|
563
|
+
if boundary_pos < 0 or with_len < 0 or with_len < boundary_pos:
|
|
564
|
+
raise ValueError("Invalid chat template boundary lengths.")
|
|
565
|
+
if with_len == boundary_pos:
|
|
566
|
+
raise ValueError("Chat template boundary is not a strict prefix (no generation prompt suffix).")
|
|
567
|
+
|
|
568
|
+
prefix_ok = False
|
|
569
|
+
try:
|
|
570
|
+
import torch
|
|
571
|
+
|
|
572
|
+
if (
|
|
573
|
+
isinstance(ids_no_gen, torch.Tensor)
|
|
574
|
+
and isinstance(ids_with_gen, torch.Tensor)
|
|
575
|
+
and ids_no_gen.ndim == 2
|
|
576
|
+
and ids_with_gen.ndim == 2
|
|
577
|
+
and ids_no_gen.shape[0] == 1
|
|
578
|
+
and ids_with_gen.shape[0] == 1
|
|
579
|
+
):
|
|
580
|
+
prefix_ok = torch.equal(ids_with_gen[:, :boundary_pos], ids_no_gen)
|
|
581
|
+
except Exception:
|
|
582
|
+
prefix_ok = False
|
|
583
|
+
|
|
584
|
+
if not prefix_ok:
|
|
585
|
+
# Best-effort diagnostic suffix to help debug template drift.
|
|
586
|
+
suffix = None
|
|
587
|
+
try:
|
|
588
|
+
tokenizer = getattr(self._adapter, "tokenizer", None)
|
|
589
|
+
decode = getattr(tokenizer, "decode", None)
|
|
590
|
+
if callable(decode):
|
|
591
|
+
suffix_ids = ids_with_gen[0, max(boundary_pos - 16, 0) : boundary_pos + 16].tolist()
|
|
592
|
+
suffix = decode(suffix_ids, skip_special_tokens=False)
|
|
593
|
+
except Exception:
|
|
594
|
+
suffix = None
|
|
595
|
+
|
|
596
|
+
if suffix:
|
|
597
|
+
logger.warning("apply_chat_template prefix invariant failed near boundary: %r", suffix)
|
|
598
|
+
raise ValueError("Chat template strict-prefix boundary invariant failed.")
|
|
599
|
+
|
|
600
|
+
gen_prompt_ids = ids_with_gen[:, boundary_pos:]
|
|
601
|
+
return boundary_pos, ids_no_gen, ids_with_gen, gen_prompt_ids
|
|
602
|
+
|
|
603
|
+
async def astream_chat(self, request: ChatRequest) -> AsyncIterator[StreamEvent]:
|
|
604
|
+
"""Async iterator streaming internal events."""
|
|
605
|
+
loop = asyncio.get_running_loop()
|
|
606
|
+
queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue()
|
|
607
|
+
cancel = threading.Event()
|
|
608
|
+
|
|
609
|
+
repetition_cfg = self._config.repetition_detection.merged(
|
|
610
|
+
request.extra.get("repetition_detection") if request.extra else None
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
enable_thinking = self._effective_enable_thinking(request)
|
|
614
|
+
discard_thinking = self._effective_discard_thinking(request)
|
|
615
|
+
emit_thinking = bool(request.stream_thinking) and enable_thinking
|
|
616
|
+
|
|
617
|
+
apply_chat_template, template_messages, kwargs = self._prepare_chat_template(request)
|
|
618
|
+
|
|
619
|
+
boundary_pos: int | None = None
|
|
620
|
+
ids_no_gen = None
|
|
621
|
+
gen_prompt_ids = None
|
|
622
|
+
|
|
623
|
+
if request.session_id and enable_thinking and discard_thinking:
|
|
624
|
+
try:
|
|
625
|
+
boundary_pos, ids_no_gen, input_ids, gen_prompt_ids = self._compute_generation_boundary(
|
|
626
|
+
apply_chat_template=apply_chat_template,
|
|
627
|
+
template_messages=template_messages,
|
|
628
|
+
kwargs=kwargs,
|
|
629
|
+
)
|
|
630
|
+
except ValueError as exc:
|
|
631
|
+
# Safety-over-performance fallback: keep discard-thinking enabled but use
|
|
632
|
+
# checkpoint-before-user (Option A) to avoid relying on a potentially-wrong boundary.
|
|
633
|
+
logger.warning("Falling back to checkpoint-before-user: %s", exc)
|
|
634
|
+
input_ids = apply_chat_template(template_messages, add_generation_prompt=True, **kwargs)
|
|
635
|
+
else:
|
|
636
|
+
input_ids = apply_chat_template(template_messages, add_generation_prompt=True, **kwargs)
|
|
637
|
+
|
|
638
|
+
try:
|
|
639
|
+
prompt_tokens = int(getattr(input_ids, "shape")[1])
|
|
640
|
+
except Exception:
|
|
641
|
+
prompt_tokens = 0
|
|
642
|
+
|
|
643
|
+
if prompt_tokens and prompt_tokens > self._config.max_prompt_tokens:
|
|
644
|
+
yield ErrorEvent(
|
|
645
|
+
f"Prompt too long: {prompt_tokens} tokens (max={self._config.max_prompt_tokens})."
|
|
646
|
+
)
|
|
647
|
+
return
|
|
648
|
+
|
|
649
|
+
stop_token_ids, stop_strings = self._single_token_stop_ids(request.stop)
|
|
650
|
+
tool_names = self._tool_names(request.tools)
|
|
651
|
+
parser = _ModelOutputParser(
|
|
652
|
+
stop_sequences=stop_strings,
|
|
653
|
+
valid_tool_names=tool_names,
|
|
654
|
+
max_tool_calls=self._config.max_tool_calls_per_turn,
|
|
655
|
+
allow_tool_calls=request.tool_choice != "none",
|
|
656
|
+
start_in_think=enable_thinking,
|
|
657
|
+
emit_thinking=emit_thinking,
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
def worker() -> None:
|
|
661
|
+
started = time.monotonic()
|
|
662
|
+
first_token_at: float | None = None
|
|
663
|
+
completion_tokens = 0
|
|
664
|
+
finish_reason: str = "stop"
|
|
665
|
+
|
|
666
|
+
normalized_raw_content_for_history: str | None = None
|
|
667
|
+
|
|
668
|
+
# Per-request stream policy.
|
|
669
|
+
flush_n = max(int(request.stream_options.flush_every_n_tokens), 1)
|
|
670
|
+
flush_ms = max(int(request.stream_options.flush_every_ms), 1)
|
|
671
|
+
flush_s = flush_ms / 1000.0
|
|
672
|
+
|
|
673
|
+
token_buffer: list[int] = []
|
|
674
|
+
last_flush = time.monotonic()
|
|
675
|
+
|
|
676
|
+
assistant_text_parts: list[str] = []
|
|
677
|
+
assistant_raw_text_parts: list[str] = [] # Raw text including <think> blocks
|
|
678
|
+
assistant_tool_calls: list[ToolCall] = []
|
|
679
|
+
|
|
680
|
+
repetition_tail: deque[int] | None = None
|
|
681
|
+
if repetition_cfg.enabled:
|
|
682
|
+
repetition_tail = deque(maxlen=repetition_cfg.tail_len)
|
|
683
|
+
|
|
684
|
+
def _emit_event(event: StreamEvent) -> None:
|
|
685
|
+
nonlocal assistant_tool_calls
|
|
686
|
+
if isinstance(event, DeltaEvent):
|
|
687
|
+
if event.text:
|
|
688
|
+
assistant_text_parts.append(event.text)
|
|
689
|
+
elif isinstance(event, ThinkingDeltaEvent):
|
|
690
|
+
# Thinking deltas are streamed but never counted as assistant output content.
|
|
691
|
+
pass
|
|
692
|
+
elif isinstance(event, ToolCallEvent):
|
|
693
|
+
assistant_tool_calls = list(event.tool_calls)
|
|
694
|
+
loop.call_soon_threadsafe(queue.put_nowait, event)
|
|
695
|
+
|
|
696
|
+
try:
|
|
697
|
+
with self._lock:
|
|
698
|
+
|
|
699
|
+
# Session-based generation path
|
|
700
|
+
if request.session_id:
|
|
701
|
+
append_from = request.session_append_from_pos
|
|
702
|
+
if append_from is None:
|
|
703
|
+
append_from = 0
|
|
704
|
+
try:
|
|
705
|
+
append_from = int(append_from)
|
|
706
|
+
except Exception:
|
|
707
|
+
append_from = 0
|
|
708
|
+
if append_from < 0:
|
|
709
|
+
append_from = 0
|
|
710
|
+
|
|
711
|
+
discard_session_thinking = enable_thinking and discard_thinking
|
|
712
|
+
|
|
713
|
+
checkpoint = None
|
|
714
|
+
commit_from_pos = None
|
|
715
|
+
fallback_checkpoint = None
|
|
716
|
+
fallback_commit_from_pos = None
|
|
717
|
+
|
|
718
|
+
if discard_session_thinking:
|
|
719
|
+
if not hasattr(self._adapter, "checkpoint_session") or not hasattr(
|
|
720
|
+
self._adapter, "restore_session_checkpoint"
|
|
721
|
+
):
|
|
722
|
+
raise RuntimeError(
|
|
723
|
+
"Adapter does not support checkpoint/restore required for discard_thinking."
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
# Option B (fast path): checkpoint after user boundary.
|
|
727
|
+
if boundary_pos is not None and ids_no_gen is not None and gen_prompt_ids is not None:
|
|
728
|
+
# Keep a fallback checkpoint in case boundary invariants drift mid-flight.
|
|
729
|
+
fallback_checkpoint = self._adapter.checkpoint_session(request.session_id)
|
|
730
|
+
fallback_commit_from_pos = int(append_from)
|
|
731
|
+
|
|
732
|
+
if append_from > int(boundary_pos):
|
|
733
|
+
raise ValueError(
|
|
734
|
+
f"session_append_from_pos={append_from} exceeds boundary_pos={boundary_pos}."
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
try:
|
|
738
|
+
delta_user_ids = ids_no_gen[:, append_from:boundary_pos]
|
|
739
|
+
except Exception:
|
|
740
|
+
delta_user_ids = ids_no_gen
|
|
741
|
+
|
|
742
|
+
if getattr(delta_user_ids, "numel", lambda: 0)() > 0:
|
|
743
|
+
self._adapter.append_to_session(
|
|
744
|
+
cache_id=request.session_id,
|
|
745
|
+
input_ids=delta_user_ids,
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
sess_info = self._adapter.get_session_info(request.session_id)
|
|
749
|
+
cur = int(sess_info.get("current_pos", -1))
|
|
750
|
+
if cur != int(boundary_pos):
|
|
751
|
+
raise ValueError(
|
|
752
|
+
f"Session cursor mismatch after user prefill: current_pos={cur} "
|
|
753
|
+
f"!= boundary_pos={boundary_pos}."
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
checkpoint = self._adapter.checkpoint_session(request.session_id)
|
|
757
|
+
commit_from_pos = int(boundary_pos)
|
|
758
|
+
|
|
759
|
+
if getattr(gen_prompt_ids, "numel", lambda: 0)() > 0:
|
|
760
|
+
self._adapter.append_to_session(
|
|
761
|
+
cache_id=request.session_id,
|
|
762
|
+
input_ids=gen_prompt_ids,
|
|
763
|
+
)
|
|
764
|
+
else:
|
|
765
|
+
# Option A (fallback): checkpoint before appending user.
|
|
766
|
+
checkpoint = self._adapter.checkpoint_session(request.session_id)
|
|
767
|
+
commit_from_pos = int(append_from)
|
|
768
|
+
|
|
769
|
+
# Append full delta prompt tokens (includes user + gen prompt).
|
|
770
|
+
try:
|
|
771
|
+
delta_input_ids = input_ids[:, append_from:]
|
|
772
|
+
except Exception:
|
|
773
|
+
delta_input_ids = input_ids
|
|
774
|
+
|
|
775
|
+
if getattr(delta_input_ids, "numel", lambda: 0)() > 0:
|
|
776
|
+
self._adapter.append_to_session(
|
|
777
|
+
cache_id=request.session_id,
|
|
778
|
+
input_ids=delta_input_ids,
|
|
779
|
+
)
|
|
780
|
+
else:
|
|
781
|
+
# Default: append full delta prompt tokens (server may provide full-history messages).
|
|
782
|
+
try:
|
|
783
|
+
delta_input_ids = input_ids[:, append_from:]
|
|
784
|
+
except Exception:
|
|
785
|
+
delta_input_ids = input_ids
|
|
786
|
+
|
|
787
|
+
if getattr(delta_input_ids, "numel", lambda: 0)() > 0:
|
|
788
|
+
self._adapter.append_to_session(
|
|
789
|
+
cache_id=request.session_id,
|
|
790
|
+
input_ids=delta_input_ids,
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
# Generate from the session
|
|
794
|
+
token_iter = self._adapter.stream_generate_session(
|
|
795
|
+
cache_id=request.session_id,
|
|
796
|
+
max_new_tokens=int(request.max_tokens),
|
|
797
|
+
temperature=float(request.temperature or 0.0),
|
|
798
|
+
stop_token_ids=stop_token_ids or None,
|
|
799
|
+
)
|
|
800
|
+
else:
|
|
801
|
+
# Stateless generation path
|
|
802
|
+
token_iter = self._adapter.stream_generate(
|
|
803
|
+
input_ids,
|
|
804
|
+
max_new_tokens=int(request.max_tokens),
|
|
805
|
+
temperature=float(request.temperature or 0.0),
|
|
806
|
+
stop_token_ids=stop_token_ids or None,
|
|
807
|
+
backend=self._config.default_backend,
|
|
808
|
+
reasoning_budget=request.reasoning_budget,
|
|
809
|
+
enable_thinking=enable_thinking,
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
# Track if we've bailed out of thinking due to repetition
|
|
813
|
+
thinking_bailout_done = False
|
|
814
|
+
|
|
815
|
+
try:
|
|
816
|
+
for token in token_iter:
|
|
817
|
+
if cancel.is_set():
|
|
818
|
+
finish_reason = "cancelled"
|
|
819
|
+
break
|
|
820
|
+
|
|
821
|
+
completion_tokens += 1
|
|
822
|
+
if first_token_at is None:
|
|
823
|
+
first_token_at = time.monotonic()
|
|
824
|
+
|
|
825
|
+
try:
|
|
826
|
+
token_id = int(token.item())
|
|
827
|
+
except Exception:
|
|
828
|
+
# Fall back to best-effort stringification.
|
|
829
|
+
token_id = int(token) # type: ignore[arg-type]
|
|
830
|
+
|
|
831
|
+
token_buffer.append(token_id)
|
|
832
|
+
|
|
833
|
+
if repetition_tail is not None:
|
|
834
|
+
repetition_tail.append(token_id)
|
|
835
|
+
if (
|
|
836
|
+
completion_tokens >= repetition_cfg.min_generated_tokens
|
|
837
|
+
and (completion_tokens % repetition_cfg.check_every) == 0
|
|
838
|
+
):
|
|
839
|
+
hit = detect_repetition_kmp_tail(
|
|
840
|
+
list(repetition_tail),
|
|
841
|
+
tail_len=repetition_cfg.tail_len,
|
|
842
|
+
min_generated_tokens=0,
|
|
843
|
+
min_repeats=repetition_cfg.min_repeats,
|
|
844
|
+
max_period=repetition_cfg.max_period,
|
|
845
|
+
min_unique_tokens=repetition_cfg.min_unique_tokens,
|
|
846
|
+
)
|
|
847
|
+
if hit is not None:
|
|
848
|
+
# Flush buffer before checking parser state
|
|
849
|
+
if token_buffer:
|
|
850
|
+
_flush_token_buffer(token_buffer, parser, _emit_event, assistant_raw_text_parts)
|
|
851
|
+
token_buffer.clear()
|
|
852
|
+
|
|
853
|
+
# If we're in thinking mode and haven't bailed out yet,
|
|
854
|
+
# inject </think> and continue instead of stopping
|
|
855
|
+
if parser._in_think and not thinking_bailout_done:
|
|
856
|
+
logger.debug(
|
|
857
|
+
"Repetition in thinking - injecting </think>: period=%d repeats=%d completion_tokens=%d",
|
|
858
|
+
hit.period,
|
|
859
|
+
hit.repeats,
|
|
860
|
+
completion_tokens,
|
|
861
|
+
)
|
|
862
|
+
# Feed </think> to parser to exit thinking mode
|
|
863
|
+
for event in parser.feed("</think>"):
|
|
864
|
+
_emit_event(event)
|
|
865
|
+
# Clear repetition tail to give fresh start
|
|
866
|
+
repetition_tail.clear()
|
|
867
|
+
thinking_bailout_done = True
|
|
868
|
+
# Continue generating (don't break)
|
|
869
|
+
continue
|
|
870
|
+
|
|
871
|
+
logger.debug(
|
|
872
|
+
"Repetition early-stop: period=%d repeats=%d checked_tail_len=%d completion_tokens=%d",
|
|
873
|
+
hit.period,
|
|
874
|
+
hit.repeats,
|
|
875
|
+
hit.checked_tail_len,
|
|
876
|
+
completion_tokens,
|
|
877
|
+
)
|
|
878
|
+
finish_reason = "repetition"
|
|
879
|
+
break
|
|
880
|
+
|
|
881
|
+
now = time.monotonic()
|
|
882
|
+
if len(token_buffer) < flush_n and (now - last_flush) < flush_s:
|
|
883
|
+
continue
|
|
884
|
+
|
|
885
|
+
last_flush = now
|
|
886
|
+
_flush_token_buffer(token_buffer, parser, _emit_event, assistant_raw_text_parts)
|
|
887
|
+
token_buffer.clear()
|
|
888
|
+
|
|
889
|
+
if parser.stopped and parser.stop_reason == "tool_calls":
|
|
890
|
+
finish_reason = "tool_calls"
|
|
891
|
+
break
|
|
892
|
+
if parser.stopped and parser.stop_reason == "stop":
|
|
893
|
+
finish_reason = "stop"
|
|
894
|
+
break
|
|
895
|
+
if parser.stopped and parser.stop_reason == "error":
|
|
896
|
+
finish_reason = "error"
|
|
897
|
+
break
|
|
898
|
+
finally:
|
|
899
|
+
# Explicitly close the generator to ensure its finally block runs.
|
|
900
|
+
# This is critical for session-based generation where the generator's
|
|
901
|
+
# finally block persists the KV cache state.
|
|
902
|
+
if hasattr(token_iter, 'close'):
|
|
903
|
+
token_iter.close()
|
|
904
|
+
|
|
905
|
+
# Final flush.
|
|
906
|
+
if token_buffer and not parser.stopped:
|
|
907
|
+
_flush_token_buffer(token_buffer, parser, _emit_event, assistant_raw_text_parts)
|
|
908
|
+
token_buffer.clear()
|
|
909
|
+
|
|
910
|
+
# Flush parser tail.
|
|
911
|
+
if not parser.stopped:
|
|
912
|
+
for event in parser.finish():
|
|
913
|
+
_emit_event(event)
|
|
914
|
+
if isinstance(event, ErrorEvent):
|
|
915
|
+
finish_reason = "error"
|
|
916
|
+
|
|
917
|
+
# Discard-thinking commit: restore to a checkpoint and append tokens for persisted history.
|
|
918
|
+
if request.session_id and enable_thinking and discard_thinking:
|
|
919
|
+
if checkpoint is None or commit_from_pos is None:
|
|
920
|
+
raise RuntimeError("Discard-thinking flow missing checkpoint state.")
|
|
921
|
+
|
|
922
|
+
self._adapter.restore_session_checkpoint(
|
|
923
|
+
cache_id=request.session_id,
|
|
924
|
+
checkpoint=checkpoint,
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
# Persist tool calls as a structured assistant message.
|
|
928
|
+
assistant_msg = None
|
|
929
|
+
if assistant_tool_calls:
|
|
930
|
+
assistant_msg = ChatMessage(
|
|
931
|
+
role="assistant",
|
|
932
|
+
content=None,
|
|
933
|
+
tool_calls=list(assistant_tool_calls),
|
|
934
|
+
)
|
|
935
|
+
else:
|
|
936
|
+
assistant_msg = ChatMessage(
|
|
937
|
+
role="assistant",
|
|
938
|
+
content="".join(assistant_text_parts),
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
persisted_messages = self._inject_tool_choice(
|
|
942
|
+
[*list(request.messages), assistant_msg],
|
|
943
|
+
request.tool_choice,
|
|
944
|
+
)
|
|
945
|
+
template_persisted = self._messages_for_template(persisted_messages)
|
|
946
|
+
|
|
947
|
+
ids_persisted = apply_chat_template(
|
|
948
|
+
template_persisted,
|
|
949
|
+
add_generation_prompt=False,
|
|
950
|
+
**kwargs,
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
# Sanity check: persisted prompt should start with the end-of-user prefix.
|
|
954
|
+
if boundary_pos is not None and ids_no_gen is not None:
|
|
955
|
+
import torch
|
|
956
|
+
|
|
957
|
+
if (
|
|
958
|
+
isinstance(ids_persisted, torch.Tensor)
|
|
959
|
+
and isinstance(ids_no_gen, torch.Tensor)
|
|
960
|
+
and ids_persisted.ndim == 2
|
|
961
|
+
and ids_no_gen.ndim == 2
|
|
962
|
+
and ids_persisted.shape[0] == 1
|
|
963
|
+
and ids_no_gen.shape[0] == 1
|
|
964
|
+
and ids_persisted.shape[1] >= int(boundary_pos)
|
|
965
|
+
and not torch.equal(ids_persisted[:, : int(boundary_pos)], ids_no_gen)
|
|
966
|
+
):
|
|
967
|
+
if fallback_checkpoint is not None and fallback_commit_from_pos is not None:
|
|
968
|
+
logger.warning(
|
|
969
|
+
"Persisted prompt prefix mismatch; falling back to checkpoint-before-user."
|
|
970
|
+
)
|
|
971
|
+
self._adapter.restore_session_checkpoint(
|
|
972
|
+
cache_id=request.session_id,
|
|
973
|
+
checkpoint=fallback_checkpoint,
|
|
974
|
+
)
|
|
975
|
+
commit_from_pos = int(fallback_commit_from_pos)
|
|
976
|
+
else:
|
|
977
|
+
raise ValueError(
|
|
978
|
+
"Persisted prompt no longer matches the end-of-user prefix."
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
try:
|
|
982
|
+
delta_commit_ids = ids_persisted[:, int(commit_from_pos) :]
|
|
983
|
+
except Exception:
|
|
984
|
+
delta_commit_ids = ids_persisted
|
|
985
|
+
|
|
986
|
+
if getattr(delta_commit_ids, "numel", lambda: 0)() > 0:
|
|
987
|
+
self._adapter.append_to_session(
|
|
988
|
+
cache_id=request.session_id,
|
|
989
|
+
input_ids=delta_commit_ids,
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
# discard_thinking=False session path: the model often stops *before* emitting <|im_end|>,
|
|
993
|
+
# but apply_chat_template(add_generation_prompt=False) will include it for assistant messages.
|
|
994
|
+
# To keep KV/history in sync, append any missing tail tokens after generation.
|
|
995
|
+
if request.session_id and enable_thinking and not discard_thinking and not assistant_tool_calls:
|
|
996
|
+
# Build normalized raw content that matches the template's generation prompt prefix.
|
|
997
|
+
normalized_raw_content_for_history = (
|
|
998
|
+
"".join(assistant_raw_text_parts) if assistant_raw_text_parts else None
|
|
999
|
+
)
|
|
1000
|
+
if normalized_raw_content_for_history:
|
|
1001
|
+
if (
|
|
1002
|
+
_ModelOutputParser._THINK_CLOSE in normalized_raw_content_for_history
|
|
1003
|
+
and _ModelOutputParser._THINK_OPEN not in normalized_raw_content_for_history[:64]
|
|
1004
|
+
):
|
|
1005
|
+
normalized_raw_content_for_history = (
|
|
1006
|
+
_ModelOutputParser._THINK_OPEN + "\n" + normalized_raw_content_for_history
|
|
1007
|
+
)
|
|
1008
|
+
|
|
1009
|
+
assistant_msg = ChatMessage(
|
|
1010
|
+
role="assistant",
|
|
1011
|
+
content=normalized_raw_content_for_history,
|
|
1012
|
+
)
|
|
1013
|
+
persisted_messages = self._inject_tool_choice(
|
|
1014
|
+
[*list(request.messages), assistant_msg],
|
|
1015
|
+
request.tool_choice,
|
|
1016
|
+
)
|
|
1017
|
+
template_persisted = self._messages_for_template(persisted_messages)
|
|
1018
|
+
ids_persisted = apply_chat_template(
|
|
1019
|
+
template_persisted,
|
|
1020
|
+
add_generation_prompt=False,
|
|
1021
|
+
**kwargs,
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
sess_info = self._adapter.get_session_info(request.session_id)
|
|
1025
|
+
cur = int(sess_info.get("current_pos", 0))
|
|
1026
|
+
try:
|
|
1027
|
+
expected_total = int(getattr(ids_persisted, "shape")[1])
|
|
1028
|
+
except Exception:
|
|
1029
|
+
expected_total = cur
|
|
1030
|
+
|
|
1031
|
+
if expected_total > cur:
|
|
1032
|
+
try:
|
|
1033
|
+
delta_tail_ids = ids_persisted[:, cur:expected_total]
|
|
1034
|
+
except Exception:
|
|
1035
|
+
delta_tail_ids = ids_persisted
|
|
1036
|
+
if getattr(delta_tail_ids, "numel", lambda: 0)() > 0:
|
|
1037
|
+
self._adapter.append_to_session(
|
|
1038
|
+
cache_id=request.session_id,
|
|
1039
|
+
input_ids=delta_tail_ids,
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
# If we exhausted the token budget without an explicit stop/tool call/cancel,
|
|
1043
|
+
# report a length stop (best-effort; adapter doesn't expose stop reason).
|
|
1044
|
+
if (
|
|
1045
|
+
finish_reason == "stop"
|
|
1046
|
+
and not parser.stopped
|
|
1047
|
+
and completion_tokens >= int(request.max_tokens)
|
|
1048
|
+
):
|
|
1049
|
+
finish_reason = "length"
|
|
1050
|
+
|
|
1051
|
+
except Exception as exc:
|
|
1052
|
+
finish_reason = "error"
|
|
1053
|
+
loop.call_soon_threadsafe(queue.put_nowait, ErrorEvent(f"Generation failed: {exc}"))
|
|
1054
|
+
finally:
|
|
1055
|
+
ended = time.monotonic()
|
|
1056
|
+
prefill_s = None if first_token_at is None else max(first_token_at - started, 0.0)
|
|
1057
|
+
decode_s = None
|
|
1058
|
+
if first_token_at is not None:
|
|
1059
|
+
decode_s = max(ended - first_token_at, 0.0)
|
|
1060
|
+
|
|
1061
|
+
tok_per_s = None
|
|
1062
|
+
if decode_s and decode_s > 0 and completion_tokens > 0:
|
|
1063
|
+
tok_per_s = completion_tokens / decode_s
|
|
1064
|
+
|
|
1065
|
+
# Include raw content (with thinking) when discard_thinking=False for sessions
|
|
1066
|
+
raw_content = None
|
|
1067
|
+
if request.session_id and enable_thinking and not discard_thinking:
|
|
1068
|
+
raw_content = normalized_raw_content_for_history
|
|
1069
|
+
if raw_content is None:
|
|
1070
|
+
raw_content = "".join(assistant_raw_text_parts) if assistant_raw_text_parts else None
|
|
1071
|
+
# When enable_thinking=True, the generation prompt already includes <think>.
|
|
1072
|
+
# The model output stream may therefore omit the opening tag and begin directly
|
|
1073
|
+
# with the thinking text, later emitting only </think>. For history/KV sync,
|
|
1074
|
+
# normalize by re-introducing the opening tag when needed.
|
|
1075
|
+
if raw_content:
|
|
1076
|
+
if (
|
|
1077
|
+
_ModelOutputParser._THINK_CLOSE in raw_content
|
|
1078
|
+
and _ModelOutputParser._THINK_OPEN not in raw_content[:64]
|
|
1079
|
+
):
|
|
1080
|
+
raw_content = _ModelOutputParser._THINK_OPEN + "\n" + raw_content
|
|
1081
|
+
|
|
1082
|
+
final = FinalEvent(
|
|
1083
|
+
finish_reason=finish_reason
|
|
1084
|
+
if finish_reason in {"stop", "length", "tool_calls", "cancelled", "error", "repetition"}
|
|
1085
|
+
else "stop",
|
|
1086
|
+
usage=Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
|
|
1087
|
+
timing=Timing(
|
|
1088
|
+
prefill_s=prefill_s,
|
|
1089
|
+
decode_s=decode_s,
|
|
1090
|
+
total_s=max(ended - started, 0.0),
|
|
1091
|
+
tok_per_s=tok_per_s,
|
|
1092
|
+
),
|
|
1093
|
+
raw_content=raw_content,
|
|
1094
|
+
)
|
|
1095
|
+
loop.call_soon_threadsafe(queue.put_nowait, final)
|
|
1096
|
+
loop.call_soon_threadsafe(queue.put_nowait, None)
|
|
1097
|
+
|
|
1098
|
+
def _flush_token_buffer(
|
|
1099
|
+
token_ids: list[int],
|
|
1100
|
+
parser_: _ModelOutputParser,
|
|
1101
|
+
emit: Any,
|
|
1102
|
+
raw_text_parts: list[str] | None = None,
|
|
1103
|
+
) -> None:
|
|
1104
|
+
tokenizer = getattr(self._adapter, "tokenizer", None)
|
|
1105
|
+
decode = getattr(tokenizer, "decode", None)
|
|
1106
|
+
if not callable(decode):
|
|
1107
|
+
raise RuntimeError("Tokenizer does not support decode().")
|
|
1108
|
+
|
|
1109
|
+
text = decode(token_ids, skip_special_tokens=False)
|
|
1110
|
+
if raw_text_parts is not None:
|
|
1111
|
+
raw_text_parts.append(text)
|
|
1112
|
+
for event in parser_.feed(text):
|
|
1113
|
+
emit(event)
|
|
1114
|
+
|
|
1115
|
+
thread = threading.Thread(target=worker, name=f"superlinear-gen-{uuid.uuid4().hex}", daemon=True)
|
|
1116
|
+
thread.start()
|
|
1117
|
+
|
|
1118
|
+
try:
|
|
1119
|
+
while True:
|
|
1120
|
+
event = await queue.get()
|
|
1121
|
+
if event is None:
|
|
1122
|
+
break
|
|
1123
|
+
yield event
|
|
1124
|
+
except asyncio.CancelledError:
|
|
1125
|
+
cancel.set()
|
|
1126
|
+
raise
|
|
1127
|
+
finally:
|
|
1128
|
+
# If the consumer stops early (disconnect / generator close), cancel generation promptly.
|
|
1129
|
+
cancel.set()
|
|
1130
|
+
|
|
1131
|
+
async def generate_chat(self, request: ChatRequest) -> dict[str, Any]:
|
|
1132
|
+
"""Non-streaming chat completion.
|
|
1133
|
+
|
|
1134
|
+
Returns:
|
|
1135
|
+
Dict containing:
|
|
1136
|
+
- content: str | None
|
|
1137
|
+
- tool_calls: list[ToolCall]
|
|
1138
|
+
- finish_reason: str
|
|
1139
|
+
- usage: Usage
|
|
1140
|
+
- timing: Timing
|
|
1141
|
+
- raw_content: str | None (if discard_thinking=False)
|
|
1142
|
+
"""
|
|
1143
|
+
content_parts: list[str] = []
|
|
1144
|
+
tool_calls: list[ToolCall] = []
|
|
1145
|
+
usage: Usage | None = None
|
|
1146
|
+
timing: Timing | None = None
|
|
1147
|
+
finish_reason = "stop"
|
|
1148
|
+
raw_content: str | None = None
|
|
1149
|
+
|
|
1150
|
+
async for event in self.astream_chat(request):
|
|
1151
|
+
if isinstance(event, DeltaEvent):
|
|
1152
|
+
content_parts.append(event.text)
|
|
1153
|
+
elif isinstance(event, ToolCallEvent):
|
|
1154
|
+
tool_calls = event.tool_calls
|
|
1155
|
+
finish_reason = "tool_calls"
|
|
1156
|
+
elif isinstance(event, FinalEvent):
|
|
1157
|
+
usage = event.usage
|
|
1158
|
+
timing = event.timing
|
|
1159
|
+
raw_content = event.raw_content
|
|
1160
|
+
# Don't override tool_calls finish reason
|
|
1161
|
+
if finish_reason != "tool_calls":
|
|
1162
|
+
finish_reason = event.finish_reason
|
|
1163
|
+
elif isinstance(event, ErrorEvent):
|
|
1164
|
+
raise RuntimeError(event.message)
|
|
1165
|
+
|
|
1166
|
+
return {
|
|
1167
|
+
"content": "".join(content_parts) if content_parts else None,
|
|
1168
|
+
"tool_calls": tool_calls,
|
|
1169
|
+
"finish_reason": finish_reason,
|
|
1170
|
+
"usage": usage,
|
|
1171
|
+
"timing": timing,
|
|
1172
|
+
"raw_content": raw_content,
|
|
1173
|
+
}
|