superlinear 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. apps/__init__.py +4 -0
  2. apps/cli/__init__.py +8 -0
  3. apps/cli/bm25_rag.py +471 -0
  4. apps/cli/chat_repl.py +1497 -0
  5. apps/cli/client.py +195 -0
  6. apps/cli/docs_repl.py +2275 -0
  7. apps/cli/light_rag.py +729 -0
  8. apps/cli/local_snapshots.py +139 -0
  9. apps/cli/locks.py +214 -0
  10. apps/cli/main.py +457 -0
  11. apps/cli/output.py +32 -0
  12. apps/cli/server_cmds.py +516 -0
  13. apps/cli/session_cmds.py +491 -0
  14. apps/cli/snapshot_cmds.py +303 -0
  15. apps/cli/state.py +265 -0
  16. apps/server/__init__.py +4 -0
  17. apps/server/app.py +1363 -0
  18. apps/server/main.py +313 -0
  19. superlinear/__init__.py +114 -0
  20. superlinear/_version.py +3 -0
  21. superlinear/engine/__init__.py +10 -0
  22. superlinear/engine/adapters/__init__.py +12 -0
  23. superlinear/engine/adapters/base.py +91 -0
  24. superlinear/engine/adapters/superlinear.py +1233 -0
  25. superlinear/engine/chat_engine.py +1173 -0
  26. superlinear/engine/chat_types.py +130 -0
  27. superlinear/engine/registry.py +51 -0
  28. superlinear/engine/repetition.py +203 -0
  29. superlinear/engine/session_snapshots.py +451 -0
  30. superlinear/engine/tool_parser.py +83 -0
  31. superlinear/engine/types.py +42 -0
  32. superlinear/kernels/__init__.py +2 -0
  33. superlinear/kernels/common/__init__.py +21 -0
  34. superlinear/kernels/common/adjustment.py +106 -0
  35. superlinear/kernels/common/power.py +154 -0
  36. superlinear/kernels/superlinear/__init__.py +10 -0
  37. superlinear/kernels/superlinear/attention/__init__.py +78 -0
  38. superlinear/kernels/superlinear/attention/_prefill.py +940 -0
  39. superlinear/kernels/superlinear/attention/_sliding_window.py +1167 -0
  40. superlinear/kernels/superlinear/attention/api.py +433 -0
  41. superlinear/kernels/superlinear/search/__init__.py +33 -0
  42. superlinear/kernels/superlinear/search/_reference.py +204 -0
  43. superlinear/kernels/superlinear/search/_triton.py +488 -0
  44. superlinear/kernels/superlinear/search/_triton_gqa.py +534 -0
  45. superlinear/kernels/superlinear/search/api.py +200 -0
  46. superlinear/kernels/superlinear/span/__init__.py +41 -0
  47. superlinear/kernels/superlinear/span/_triton_bucketed_gqa.py +1461 -0
  48. superlinear/kernels/superlinear/span/_triton_forward.py +22 -0
  49. superlinear/kernels/superlinear/span/_triton_gqa.py +1226 -0
  50. superlinear/kernels/superlinear/span/_triton_impl.py +928 -0
  51. superlinear/kernels/superlinear/span/_triton_precomputed_sw.py +460 -0
  52. superlinear/kernels/superlinear/span/_triton_precomputed_sw_gqa.py +598 -0
  53. superlinear/kernels/superlinear/span/api.py +296 -0
  54. superlinear/kernels/superlinear/span/masks.py +187 -0
  55. superlinear/py.typed +0 -0
  56. superlinear/runtime.py +71 -0
  57. superlinear-0.1.0.dist-info/METADATA +469 -0
  58. superlinear-0.1.0.dist-info/RECORD +62 -0
  59. superlinear-0.1.0.dist-info/WHEEL +5 -0
  60. superlinear-0.1.0.dist-info/entry_points.txt +2 -0
  61. superlinear-0.1.0.dist-info/licenses/LICENSE +202 -0
  62. superlinear-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1173 @@
1
+ """Chunked async chat inference engine (single-GPU, single-flight).
2
+
3
+ This module provides the core, reusable engine:
4
+ - request normalization -> tokenizer prompt
5
+ - serialized adapter execution
6
+ - chunk-level async streaming
7
+ - tool call detection + parsing
8
+
9
+ It deliberately contains no HTTP/FastAPI code.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import asyncio
15
+ from collections import deque
16
+ import logging
17
+ import threading
18
+ import time
19
+ import uuid
20
+ from dataclasses import dataclass, field
21
+ from typing import Any, AsyncIterator, Sequence
22
+
23
+ from .chat_types import (
24
+ ChatMessage,
25
+ ChatRequest,
26
+ DeltaEvent,
27
+ ErrorEvent,
28
+ FinalEvent,
29
+ ThinkingDeltaEvent,
30
+ StreamEvent,
31
+ Timing,
32
+ ToolCall,
33
+ ToolCallEvent,
34
+ Usage,
35
+ )
36
+ from .repetition import RepetitionDetectionConfig, detect_repetition_kmp_tail
37
+ from .tool_parser import ToolCallParseError, parse_tool_call_block
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class EngineConfig:
44
+ """Engine-wide defaults and limits."""
45
+
46
+ default_backend: str = "custom"
47
+ enable_thinking: bool = True
48
+ discard_thinking: bool = True
49
+ max_prompt_tokens: int = 262_144
50
+ max_tool_calls_per_turn: int = 8
51
+ repetition_detection: RepetitionDetectionConfig = field(
52
+ default_factory=RepetitionDetectionConfig
53
+ )
54
+
55
+
56
+ class _ModelOutputParser:
57
+ """Incremental parser for model output.
58
+
59
+ Responsibilities:
60
+ - Remove <think>...</think> blocks (if the model emits them).
61
+ - Detect and buffer a complete <tool_call>...</tool_call>.
62
+ - Apply stop sequences (string-based) to normal text.
63
+ """
64
+
65
+ _THINK_OPEN = "<think>"
66
+ _THINK_CLOSE = "</think>"
67
+ _TOOL_OPEN = "<tool_call>"
68
+ _TOOL_CLOSE = "</tool_call>"
69
+
70
+ def __init__(
71
+ self,
72
+ *,
73
+ stop_sequences: Sequence[str],
74
+ valid_tool_names: set[str],
75
+ max_tool_calls: int,
76
+ allow_tool_calls: bool,
77
+ start_in_think: bool = False,
78
+ emit_thinking: bool = False,
79
+ ) -> None:
80
+ self._stop_sequences = [s for s in stop_sequences if s]
81
+ self._max_stop_len = max((len(s) for s in self._stop_sequences), default=0)
82
+
83
+ self._valid_tool_names = valid_tool_names
84
+ self._max_tool_calls = max_tool_calls
85
+ self._allow_tool_calls = allow_tool_calls
86
+
87
+ self._buffer = ""
88
+ self._tool_buffer = ""
89
+ # When enable_thinking=True, the generation prompt already includes <think>,
90
+ # so we start inside the think block and wait for </think> before emitting.
91
+ self._in_think = start_in_think
92
+ self._in_tool = False
93
+
94
+ self._emit_thinking = bool(emit_thinking)
95
+ self._emitted_think_open = False
96
+
97
+ self.stopped: bool = False
98
+ self.stop_reason: str | None = None # "stop" | "tool_calls"
99
+ self.tool_calls: list[ToolCall] = []
100
+
101
+ self._tail_keep = max(
102
+ len(self._THINK_OPEN) - 1,
103
+ len(self._THINK_CLOSE) - 1,
104
+ len(self._TOOL_OPEN) - 1,
105
+ len(self._TOOL_CLOSE) - 1,
106
+ max(self._max_stop_len - 1, 0),
107
+ )
108
+
109
+ def feed(self, text: str) -> list[StreamEvent]:
110
+ if not text or self.stopped:
111
+ return []
112
+
113
+ self._buffer += text
114
+ events: list[StreamEvent] = []
115
+
116
+ if self._emit_thinking and self._in_think and not self._emitted_think_open:
117
+ events.append(ThinkingDeltaEvent(self._THINK_OPEN))
118
+ self._emitted_think_open = True
119
+
120
+ while self._buffer and not self.stopped:
121
+ if self._in_think:
122
+ end = self._buffer.find(self._THINK_CLOSE)
123
+ if end == -1:
124
+ if not self._emit_thinking:
125
+ # Drop accumulated thinking content; keep small tail in case the close tag is split.
126
+ if self._tail_keep:
127
+ self._buffer = self._buffer[-self._tail_keep :]
128
+ else:
129
+ self._buffer = ""
130
+ break
131
+
132
+ # Emit thinking content, keeping a small tail in case the close tag is split.
133
+ if self._tail_keep and len(self._buffer) > self._tail_keep:
134
+ emit_text = self._buffer[: -self._tail_keep]
135
+ self._buffer = self._buffer[-self._tail_keep :]
136
+ else:
137
+ emit_text = self._buffer
138
+ self._buffer = ""
139
+
140
+ if emit_text:
141
+ events.append(ThinkingDeltaEvent(emit_text))
142
+ break
143
+
144
+ if self._emit_thinking:
145
+ think_text = self._buffer[:end]
146
+ if think_text:
147
+ events.append(ThinkingDeltaEvent(think_text))
148
+ events.append(ThinkingDeltaEvent(self._THINK_CLOSE))
149
+ self._buffer = self._buffer[end + len(self._THINK_CLOSE) :]
150
+ self._in_think = False
151
+ self._emitted_think_open = False
152
+ continue
153
+
154
+ # Drop everything through the closing tag.
155
+ self._buffer = self._buffer[end + len(self._THINK_CLOSE) :]
156
+ self._in_think = False
157
+ continue
158
+
159
+ if self._in_tool:
160
+ self._tool_buffer += self._buffer
161
+ self._buffer = ""
162
+
163
+ end = self._tool_buffer.find(self._TOOL_CLOSE)
164
+ if end == -1:
165
+ break
166
+
167
+ block = self._tool_buffer[: end + len(self._TOOL_CLOSE)]
168
+ trailing = self._tool_buffer[end + len(self._TOOL_CLOSE) :]
169
+ self._tool_buffer = ""
170
+ self._in_tool = False
171
+
172
+ try:
173
+ parsed = parse_tool_call_block(block)
174
+ except ToolCallParseError as exc:
175
+ self.stopped = True
176
+ self.stop_reason = "error"
177
+ events.append(ErrorEvent(f"Failed to parse tool call: {exc}"))
178
+ break
179
+
180
+ if not self._allow_tool_calls:
181
+ self.stopped = True
182
+ self.stop_reason = "error"
183
+ events.append(ErrorEvent("Tool calls are disabled for this request (tool_choice='none')."))
184
+ break
185
+
186
+ if self._valid_tool_names and parsed.name not in self._valid_tool_names:
187
+ self.stopped = True
188
+ self.stop_reason = "error"
189
+ events.append(
190
+ ErrorEvent(
191
+ f"Model called unknown tool {parsed.name!r}. "
192
+ "Ensure the tool is present in the request 'tools' list."
193
+ )
194
+ )
195
+ break
196
+
197
+ if len(self.tool_calls) >= self._max_tool_calls:
198
+ self.stopped = True
199
+ self.stop_reason = "error"
200
+ events.append(
201
+ ErrorEvent(
202
+ f"Too many tool calls in one turn (max={self._max_tool_calls})."
203
+ )
204
+ )
205
+ break
206
+
207
+ self.tool_calls.append(
208
+ ToolCall(id=f"call_{uuid.uuid4().hex}", name=parsed.name, arguments=parsed.arguments)
209
+ )
210
+ events.append(ToolCallEvent(tool_calls=list(self.tool_calls)))
211
+
212
+ # Tool call completes the turn.
213
+ self.stopped = True
214
+ self.stop_reason = "tool_calls"
215
+
216
+ # Any trailing content after </tool_call> is ignored for v0.
217
+ _ = trailing
218
+ break
219
+
220
+ # Normal (non-think, non-tool) mode.
221
+ next_think = self._buffer.find(self._THINK_OPEN)
222
+ next_tool = self._buffer.find(self._TOOL_OPEN)
223
+
224
+ next_special = -1
225
+ if next_think != -1 and next_tool != -1:
226
+ next_special = min(next_think, next_tool)
227
+ elif next_think != -1:
228
+ next_special = next_think
229
+ elif next_tool != -1:
230
+ next_special = next_tool
231
+
232
+ if next_special != -1:
233
+ # Emit content before the special tag.
234
+ before = self._buffer[:next_special]
235
+ self._buffer = self._buffer[next_special:]
236
+ events.extend(self._emit_text(before))
237
+
238
+ if self.stopped:
239
+ break
240
+
241
+ if self._buffer.startswith(self._THINK_OPEN):
242
+ self._buffer = self._buffer[len(self._THINK_OPEN) :]
243
+ self._in_think = True
244
+ if self._emit_thinking:
245
+ events.append(ThinkingDeltaEvent(self._THINK_OPEN))
246
+ self._emitted_think_open = True
247
+ continue
248
+ if self._buffer.startswith(self._TOOL_OPEN):
249
+ self._tool_buffer = self._TOOL_OPEN
250
+ self._buffer = self._buffer[len(self._TOOL_OPEN) :]
251
+ self._in_tool = True
252
+ continue
253
+
254
+ # No special tags found in the current buffer.
255
+ events.extend(self._emit_available_text())
256
+ break
257
+
258
+ return [e for e in events if not isinstance(e, DeltaEvent) or e.text]
259
+
260
+ def finish(self) -> list[StreamEvent]:
261
+ """Flush any remaining buffered content at end-of-generation."""
262
+ if self.stopped:
263
+ return []
264
+
265
+ if self._in_tool:
266
+ self.stopped = True
267
+ self.stop_reason = "error"
268
+ return [ErrorEvent("Incomplete <tool_call> block in model output.")]
269
+
270
+ # If we're still inside a think block, drop it.
271
+ if self._in_think:
272
+ self._buffer = ""
273
+ return []
274
+
275
+ # Emit remaining buffer (apply stop sequences if needed).
276
+ remaining = self._buffer
277
+ self._buffer = ""
278
+ return self._emit_text(remaining)
279
+
280
+ def _emit_text(self, text: str) -> list[StreamEvent]:
281
+ if not text:
282
+ return []
283
+
284
+ if self._stop_sequences:
285
+ idx = self._find_earliest_stop(text)
286
+ if idx is not None:
287
+ before = text[:idx]
288
+ self.stopped = True
289
+ self.stop_reason = "stop"
290
+ return [DeltaEvent(before)] if before else []
291
+
292
+ return [DeltaEvent(text)]
293
+
294
+ def _emit_available_text(self) -> list[StreamEvent]:
295
+ if not self._buffer:
296
+ return []
297
+
298
+ if self._tail_keep <= 0:
299
+ text = self._buffer
300
+ self._buffer = ""
301
+ return self._emit_text(text)
302
+
303
+ if len(self._buffer) <= self._tail_keep:
304
+ return []
305
+
306
+ safe_end = len(self._buffer) - self._tail_keep
307
+ safe = self._buffer[:safe_end]
308
+ self._buffer = self._buffer[safe_end:]
309
+ return self._emit_text(safe)
310
+
311
+ def _find_earliest_stop(self, text: str) -> int | None:
312
+ earliest: int | None = None
313
+ for s in self._stop_sequences:
314
+ idx = text.find(s)
315
+ if idx == -1:
316
+ continue
317
+ if earliest is None or idx < earliest:
318
+ earliest = idx
319
+ return earliest
320
+
321
+
322
+ class ChatEngine:
323
+ """Core chat inference engine.
324
+
325
+ Thread-safety:
326
+ The underlying adapter is not thread-safe. This engine serializes access
327
+ with a global lock (single-flight).
328
+ """
329
+
330
+ def __init__(self, adapter: Any, *, config: EngineConfig | None = None) -> None:
331
+ self._adapter = adapter
332
+ self._config = config or EngineConfig()
333
+ self._lock = threading.Lock()
334
+
335
+ @property
336
+ def adapter(self) -> Any:
337
+ """Access to the underlying adapter for session management."""
338
+ return self._adapter
339
+
340
+ @property
341
+ def model_info(self) -> dict[str, Any]:
342
+ return getattr(self._adapter, "model_info", {})
343
+
344
+ def shutdown(self) -> None:
345
+ unload = getattr(self._adapter, "unload", None)
346
+ if callable(unload):
347
+ unload()
348
+
349
+ def _tool_names(self, tools: Sequence[dict[str, Any]]) -> set[str]:
350
+ names: set[str] = set()
351
+ for tool in tools:
352
+ if not isinstance(tool, dict):
353
+ continue
354
+ if tool.get("type") != "function":
355
+ continue
356
+ fn = tool.get("function")
357
+ if isinstance(fn, dict):
358
+ name = fn.get("name")
359
+ if isinstance(name, str) and name:
360
+ names.add(name)
361
+ return names
362
+
363
+ def _inject_tool_choice(self, messages: list[ChatMessage], tool_choice: Any) -> list[ChatMessage]:
364
+ if tool_choice is None:
365
+ return messages
366
+
367
+ instruction: str | None = None
368
+ if tool_choice == "none":
369
+ instruction = "Tool choice: do not call any tools."
370
+ elif tool_choice == "required":
371
+ instruction = "Tool choice: you must call a tool for this response."
372
+ elif isinstance(tool_choice, dict):
373
+ # OpenAI format: {"type":"function","function":{"name":"..."}}
374
+ fn = tool_choice.get("function") if tool_choice.get("type") == "function" else None
375
+ name = fn.get("name") if isinstance(fn, dict) else None
376
+ if isinstance(name, str) and name:
377
+ instruction = f"Tool choice: call only the tool named {name!r}."
378
+
379
+ if instruction is None:
380
+ return messages
381
+
382
+ if messages and messages[0].role == "system":
383
+ updated = ChatMessage(
384
+ role="system",
385
+ content=((messages[0].content or "").rstrip() + "\n\n" + instruction).strip(),
386
+ tool_calls=messages[0].tool_calls,
387
+ tool_call_id=messages[0].tool_call_id,
388
+ )
389
+ return [updated, *messages[1:]]
390
+
391
+ return [ChatMessage(role="system", content=instruction), *messages]
392
+
393
+ def _messages_for_template(self, messages: list[ChatMessage]) -> list[dict[str, Any]]:
394
+ out: list[dict[str, Any]] = []
395
+ for m in messages:
396
+ msg: dict[str, Any] = {"role": m.role}
397
+ if m.content is not None:
398
+ msg["content"] = m.content
399
+ else:
400
+ msg["content"] = ""
401
+
402
+ if m.role == "assistant" and m.tool_calls:
403
+ # The model chat template expects tool_call.function.arguments as a mapping (not a JSON string).
404
+ tool_calls: list[dict[str, Any]] = []
405
+ for tc in m.tool_calls:
406
+ tool_calls.append(
407
+ {
408
+ "id": tc.id,
409
+ "type": "function",
410
+ "function": {"name": tc.name, "arguments": tc.arguments},
411
+ }
412
+ )
413
+ msg["tool_calls"] = tool_calls
414
+
415
+ if m.role == "tool" and m.tool_call_id:
416
+ msg["tool_call_id"] = m.tool_call_id
417
+
418
+ out.append(msg)
419
+ return out
420
+
421
+ def _validate_tool_choice(self, request: ChatRequest) -> None:
422
+ tc = request.tool_choice
423
+ if tc is None or tc == "auto":
424
+ return
425
+
426
+ tool_names = self._tool_names(request.tools)
427
+ if tc == "none":
428
+ return
429
+
430
+ if not request.tools:
431
+ raise ValueError("tool_choice was provided but no tools were supplied in the request.")
432
+
433
+ if tc == "required":
434
+ return
435
+
436
+ if isinstance(tc, dict) and tc.get("type") == "function":
437
+ fn = tc.get("function")
438
+ name = fn.get("name") if isinstance(fn, dict) else None
439
+ if isinstance(name, str) and name and tool_names and name not in tool_names:
440
+ raise ValueError(f"tool_choice requested unknown tool {name!r}.")
441
+
442
+ def _single_token_stop_ids(self, stop: Sequence[str]) -> tuple[list[int], list[str]]:
443
+ """Split stops into (single-token stop ids, string stops).
444
+
445
+ We only early-stop on single-token sequences. Anything else is enforced
446
+ by the output parser (string scan).
447
+ """
448
+ stop_token_ids: list[int] = []
449
+ stop_strings: list[str] = []
450
+ tokenizer = getattr(self._adapter, "tokenizer", None)
451
+
452
+ # Dynamically get special tokens from the tokenizer to filter from output.
453
+ # This keeps the engine model-agnostic.
454
+ if tokenizer is not None:
455
+ all_special = getattr(tokenizer, "all_special_tokens", None)
456
+ if all_special:
457
+ stop_strings.extend(s for s in all_special if s)
458
+
459
+ encode = getattr(tokenizer, "encode", None)
460
+ if callable(encode):
461
+ for s in stop:
462
+ if not s:
463
+ continue
464
+ try:
465
+ ids = encode(s, add_special_tokens=False)
466
+ except Exception:
467
+ stop_strings.append(s)
468
+ continue
469
+ if isinstance(ids, list) and len(ids) == 1 and isinstance(ids[0], int):
470
+ stop_token_ids.append(ids[0])
471
+ else:
472
+ stop_strings.append(s)
473
+ return stop_token_ids, stop_strings
474
+
475
+ return [], [s for s in stop if s]
476
+
477
+ def _build_input_ids(self, request: ChatRequest) -> Any:
478
+ apply_chat_template, template_messages, kwargs = self._prepare_chat_template(request)
479
+ return apply_chat_template(template_messages, add_generation_prompt=True, **kwargs)
480
+
481
+ def _effective_enable_thinking(self, request: ChatRequest) -> bool:
482
+ enable_thinking = bool(self._config.enable_thinking)
483
+ if request.reasoning_budget is not None:
484
+ enable_thinking = True
485
+ if request.chat_template_kwargs and "enable_thinking" in request.chat_template_kwargs:
486
+ enable_thinking = bool(request.chat_template_kwargs["enable_thinking"])
487
+ return enable_thinking
488
+
489
+ def _effective_discard_thinking(self, request: ChatRequest) -> bool:
490
+ discard_thinking = bool(self._config.discard_thinking)
491
+ if request.discard_thinking is not None:
492
+ discard_thinking = bool(request.discard_thinking)
493
+ return discard_thinking
494
+
495
+ def _chat_template_kwargs(self, request: ChatRequest, *, enable_thinking: bool) -> dict[str, Any]:
496
+ """Build kwargs for tokenizer.apply_chat_template with strict parity.
497
+
498
+ Important: This preserves the distinction between "tools omitted" vs `tools=[]`.
499
+ """
500
+ kwargs: dict[str, Any] = {
501
+ "return_tensors": "pt",
502
+ "enable_thinking": bool(enable_thinking),
503
+ }
504
+
505
+ # Merge request-level chat_template_kwargs (can override enable_thinking)
506
+ if request.chat_template_kwargs:
507
+ kwargs.update(request.chat_template_kwargs)
508
+ # The engine controls add_generation_prompt explicitly.
509
+ kwargs.pop("add_generation_prompt", None)
510
+
511
+ # For tool_choice="none", omit tool definitions from the prompt to reduce the chance of
512
+ # accidental tool calls (in addition to the injected instruction).
513
+ tools = request.tools
514
+ if request.tool_choice == "none":
515
+ tools = []
516
+
517
+ # Preserve "tools omitted" when tools==[].
518
+ if tools:
519
+ kwargs["tools"] = tools
520
+
521
+ return kwargs
522
+
523
+ def _prepare_chat_template(self, request: ChatRequest) -> tuple[Any, list[dict[str, Any]], dict[str, Any]]:
524
+ tokenizer = getattr(self._adapter, "tokenizer", None)
525
+ if tokenizer is None:
526
+ raise RuntimeError("Adapter has no tokenizer loaded.")
527
+
528
+ self._validate_tool_choice(request)
529
+
530
+ apply_chat_template = getattr(tokenizer, "apply_chat_template", None)
531
+ if not callable(apply_chat_template):
532
+ raise RuntimeError("Tokenizer does not support apply_chat_template().")
533
+
534
+ messages = self._inject_tool_choice(list(request.messages), request.tool_choice)
535
+ template_messages = self._messages_for_template(messages)
536
+ kwargs = self._chat_template_kwargs(request, enable_thinking=self._effective_enable_thinking(request))
537
+ return apply_chat_template, template_messages, kwargs
538
+
539
+ def _compute_generation_boundary(
540
+ self,
541
+ *,
542
+ apply_chat_template: Any,
543
+ template_messages: list[dict[str, Any]],
544
+ kwargs: dict[str, Any],
545
+ ) -> tuple[int, Any, Any, Any]:
546
+ """Compute the (end-of-user) boundary and generation prompt suffix.
547
+
548
+ Returns:
549
+ (boundary_pos, ids_no_gen, ids_with_gen, gen_prompt_ids)
550
+
551
+ Raises:
552
+ ValueError if the strict-prefix invariant is violated.
553
+ """
554
+ ids_no_gen = apply_chat_template(template_messages, add_generation_prompt=False, **kwargs)
555
+ ids_with_gen = apply_chat_template(template_messages, add_generation_prompt=True, **kwargs)
556
+
557
+ try:
558
+ boundary_pos = int(getattr(ids_no_gen, "shape")[1])
559
+ with_len = int(getattr(ids_with_gen, "shape")[1])
560
+ except Exception as exc: # pragma: no cover
561
+ raise ValueError("apply_chat_template() must return a tensor with shape (1, seq_len).") from exc
562
+
563
+ if boundary_pos < 0 or with_len < 0 or with_len < boundary_pos:
564
+ raise ValueError("Invalid chat template boundary lengths.")
565
+ if with_len == boundary_pos:
566
+ raise ValueError("Chat template boundary is not a strict prefix (no generation prompt suffix).")
567
+
568
+ prefix_ok = False
569
+ try:
570
+ import torch
571
+
572
+ if (
573
+ isinstance(ids_no_gen, torch.Tensor)
574
+ and isinstance(ids_with_gen, torch.Tensor)
575
+ and ids_no_gen.ndim == 2
576
+ and ids_with_gen.ndim == 2
577
+ and ids_no_gen.shape[0] == 1
578
+ and ids_with_gen.shape[0] == 1
579
+ ):
580
+ prefix_ok = torch.equal(ids_with_gen[:, :boundary_pos], ids_no_gen)
581
+ except Exception:
582
+ prefix_ok = False
583
+
584
+ if not prefix_ok:
585
+ # Best-effort diagnostic suffix to help debug template drift.
586
+ suffix = None
587
+ try:
588
+ tokenizer = getattr(self._adapter, "tokenizer", None)
589
+ decode = getattr(tokenizer, "decode", None)
590
+ if callable(decode):
591
+ suffix_ids = ids_with_gen[0, max(boundary_pos - 16, 0) : boundary_pos + 16].tolist()
592
+ suffix = decode(suffix_ids, skip_special_tokens=False)
593
+ except Exception:
594
+ suffix = None
595
+
596
+ if suffix:
597
+ logger.warning("apply_chat_template prefix invariant failed near boundary: %r", suffix)
598
+ raise ValueError("Chat template strict-prefix boundary invariant failed.")
599
+
600
+ gen_prompt_ids = ids_with_gen[:, boundary_pos:]
601
+ return boundary_pos, ids_no_gen, ids_with_gen, gen_prompt_ids
602
+
603
+ async def astream_chat(self, request: ChatRequest) -> AsyncIterator[StreamEvent]:
604
+ """Async iterator streaming internal events."""
605
+ loop = asyncio.get_running_loop()
606
+ queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue()
607
+ cancel = threading.Event()
608
+
609
+ repetition_cfg = self._config.repetition_detection.merged(
610
+ request.extra.get("repetition_detection") if request.extra else None
611
+ )
612
+
613
+ enable_thinking = self._effective_enable_thinking(request)
614
+ discard_thinking = self._effective_discard_thinking(request)
615
+ emit_thinking = bool(request.stream_thinking) and enable_thinking
616
+
617
+ apply_chat_template, template_messages, kwargs = self._prepare_chat_template(request)
618
+
619
+ boundary_pos: int | None = None
620
+ ids_no_gen = None
621
+ gen_prompt_ids = None
622
+
623
+ if request.session_id and enable_thinking and discard_thinking:
624
+ try:
625
+ boundary_pos, ids_no_gen, input_ids, gen_prompt_ids = self._compute_generation_boundary(
626
+ apply_chat_template=apply_chat_template,
627
+ template_messages=template_messages,
628
+ kwargs=kwargs,
629
+ )
630
+ except ValueError as exc:
631
+ # Safety-over-performance fallback: keep discard-thinking enabled but use
632
+ # checkpoint-before-user (Option A) to avoid relying on a potentially-wrong boundary.
633
+ logger.warning("Falling back to checkpoint-before-user: %s", exc)
634
+ input_ids = apply_chat_template(template_messages, add_generation_prompt=True, **kwargs)
635
+ else:
636
+ input_ids = apply_chat_template(template_messages, add_generation_prompt=True, **kwargs)
637
+
638
+ try:
639
+ prompt_tokens = int(getattr(input_ids, "shape")[1])
640
+ except Exception:
641
+ prompt_tokens = 0
642
+
643
+ if prompt_tokens and prompt_tokens > self._config.max_prompt_tokens:
644
+ yield ErrorEvent(
645
+ f"Prompt too long: {prompt_tokens} tokens (max={self._config.max_prompt_tokens})."
646
+ )
647
+ return
648
+
649
+ stop_token_ids, stop_strings = self._single_token_stop_ids(request.stop)
650
+ tool_names = self._tool_names(request.tools)
651
+ parser = _ModelOutputParser(
652
+ stop_sequences=stop_strings,
653
+ valid_tool_names=tool_names,
654
+ max_tool_calls=self._config.max_tool_calls_per_turn,
655
+ allow_tool_calls=request.tool_choice != "none",
656
+ start_in_think=enable_thinking,
657
+ emit_thinking=emit_thinking,
658
+ )
659
+
660
+ def worker() -> None:
661
+ started = time.monotonic()
662
+ first_token_at: float | None = None
663
+ completion_tokens = 0
664
+ finish_reason: str = "stop"
665
+
666
+ normalized_raw_content_for_history: str | None = None
667
+
668
+ # Per-request stream policy.
669
+ flush_n = max(int(request.stream_options.flush_every_n_tokens), 1)
670
+ flush_ms = max(int(request.stream_options.flush_every_ms), 1)
671
+ flush_s = flush_ms / 1000.0
672
+
673
+ token_buffer: list[int] = []
674
+ last_flush = time.monotonic()
675
+
676
+ assistant_text_parts: list[str] = []
677
+ assistant_raw_text_parts: list[str] = [] # Raw text including <think> blocks
678
+ assistant_tool_calls: list[ToolCall] = []
679
+
680
+ repetition_tail: deque[int] | None = None
681
+ if repetition_cfg.enabled:
682
+ repetition_tail = deque(maxlen=repetition_cfg.tail_len)
683
+
684
+ def _emit_event(event: StreamEvent) -> None:
685
+ nonlocal assistant_tool_calls
686
+ if isinstance(event, DeltaEvent):
687
+ if event.text:
688
+ assistant_text_parts.append(event.text)
689
+ elif isinstance(event, ThinkingDeltaEvent):
690
+ # Thinking deltas are streamed but never counted as assistant output content.
691
+ pass
692
+ elif isinstance(event, ToolCallEvent):
693
+ assistant_tool_calls = list(event.tool_calls)
694
+ loop.call_soon_threadsafe(queue.put_nowait, event)
695
+
696
+ try:
697
+ with self._lock:
698
+
699
+ # Session-based generation path
700
+ if request.session_id:
701
+ append_from = request.session_append_from_pos
702
+ if append_from is None:
703
+ append_from = 0
704
+ try:
705
+ append_from = int(append_from)
706
+ except Exception:
707
+ append_from = 0
708
+ if append_from < 0:
709
+ append_from = 0
710
+
711
+ discard_session_thinking = enable_thinking and discard_thinking
712
+
713
+ checkpoint = None
714
+ commit_from_pos = None
715
+ fallback_checkpoint = None
716
+ fallback_commit_from_pos = None
717
+
718
+ if discard_session_thinking:
719
+ if not hasattr(self._adapter, "checkpoint_session") or not hasattr(
720
+ self._adapter, "restore_session_checkpoint"
721
+ ):
722
+ raise RuntimeError(
723
+ "Adapter does not support checkpoint/restore required for discard_thinking."
724
+ )
725
+
726
+ # Option B (fast path): checkpoint after user boundary.
727
+ if boundary_pos is not None and ids_no_gen is not None and gen_prompt_ids is not None:
728
+ # Keep a fallback checkpoint in case boundary invariants drift mid-flight.
729
+ fallback_checkpoint = self._adapter.checkpoint_session(request.session_id)
730
+ fallback_commit_from_pos = int(append_from)
731
+
732
+ if append_from > int(boundary_pos):
733
+ raise ValueError(
734
+ f"session_append_from_pos={append_from} exceeds boundary_pos={boundary_pos}."
735
+ )
736
+
737
+ try:
738
+ delta_user_ids = ids_no_gen[:, append_from:boundary_pos]
739
+ except Exception:
740
+ delta_user_ids = ids_no_gen
741
+
742
+ if getattr(delta_user_ids, "numel", lambda: 0)() > 0:
743
+ self._adapter.append_to_session(
744
+ cache_id=request.session_id,
745
+ input_ids=delta_user_ids,
746
+ )
747
+
748
+ sess_info = self._adapter.get_session_info(request.session_id)
749
+ cur = int(sess_info.get("current_pos", -1))
750
+ if cur != int(boundary_pos):
751
+ raise ValueError(
752
+ f"Session cursor mismatch after user prefill: current_pos={cur} "
753
+ f"!= boundary_pos={boundary_pos}."
754
+ )
755
+
756
+ checkpoint = self._adapter.checkpoint_session(request.session_id)
757
+ commit_from_pos = int(boundary_pos)
758
+
759
+ if getattr(gen_prompt_ids, "numel", lambda: 0)() > 0:
760
+ self._adapter.append_to_session(
761
+ cache_id=request.session_id,
762
+ input_ids=gen_prompt_ids,
763
+ )
764
+ else:
765
+ # Option A (fallback): checkpoint before appending user.
766
+ checkpoint = self._adapter.checkpoint_session(request.session_id)
767
+ commit_from_pos = int(append_from)
768
+
769
+ # Append full delta prompt tokens (includes user + gen prompt).
770
+ try:
771
+ delta_input_ids = input_ids[:, append_from:]
772
+ except Exception:
773
+ delta_input_ids = input_ids
774
+
775
+ if getattr(delta_input_ids, "numel", lambda: 0)() > 0:
776
+ self._adapter.append_to_session(
777
+ cache_id=request.session_id,
778
+ input_ids=delta_input_ids,
779
+ )
780
+ else:
781
+ # Default: append full delta prompt tokens (server may provide full-history messages).
782
+ try:
783
+ delta_input_ids = input_ids[:, append_from:]
784
+ except Exception:
785
+ delta_input_ids = input_ids
786
+
787
+ if getattr(delta_input_ids, "numel", lambda: 0)() > 0:
788
+ self._adapter.append_to_session(
789
+ cache_id=request.session_id,
790
+ input_ids=delta_input_ids,
791
+ )
792
+
793
+ # Generate from the session
794
+ token_iter = self._adapter.stream_generate_session(
795
+ cache_id=request.session_id,
796
+ max_new_tokens=int(request.max_tokens),
797
+ temperature=float(request.temperature or 0.0),
798
+ stop_token_ids=stop_token_ids or None,
799
+ )
800
+ else:
801
+ # Stateless generation path
802
+ token_iter = self._adapter.stream_generate(
803
+ input_ids,
804
+ max_new_tokens=int(request.max_tokens),
805
+ temperature=float(request.temperature or 0.0),
806
+ stop_token_ids=stop_token_ids or None,
807
+ backend=self._config.default_backend,
808
+ reasoning_budget=request.reasoning_budget,
809
+ enable_thinking=enable_thinking,
810
+ )
811
+
812
+ # Track if we've bailed out of thinking due to repetition
813
+ thinking_bailout_done = False
814
+
815
+ try:
816
+ for token in token_iter:
817
+ if cancel.is_set():
818
+ finish_reason = "cancelled"
819
+ break
820
+
821
+ completion_tokens += 1
822
+ if first_token_at is None:
823
+ first_token_at = time.monotonic()
824
+
825
+ try:
826
+ token_id = int(token.item())
827
+ except Exception:
828
+ # Fall back to best-effort stringification.
829
+ token_id = int(token) # type: ignore[arg-type]
830
+
831
+ token_buffer.append(token_id)
832
+
833
+ if repetition_tail is not None:
834
+ repetition_tail.append(token_id)
835
+ if (
836
+ completion_tokens >= repetition_cfg.min_generated_tokens
837
+ and (completion_tokens % repetition_cfg.check_every) == 0
838
+ ):
839
+ hit = detect_repetition_kmp_tail(
840
+ list(repetition_tail),
841
+ tail_len=repetition_cfg.tail_len,
842
+ min_generated_tokens=0,
843
+ min_repeats=repetition_cfg.min_repeats,
844
+ max_period=repetition_cfg.max_period,
845
+ min_unique_tokens=repetition_cfg.min_unique_tokens,
846
+ )
847
+ if hit is not None:
848
+ # Flush buffer before checking parser state
849
+ if token_buffer:
850
+ _flush_token_buffer(token_buffer, parser, _emit_event, assistant_raw_text_parts)
851
+ token_buffer.clear()
852
+
853
+ # If we're in thinking mode and haven't bailed out yet,
854
+ # inject </think> and continue instead of stopping
855
+ if parser._in_think and not thinking_bailout_done:
856
+ logger.debug(
857
+ "Repetition in thinking - injecting </think>: period=%d repeats=%d completion_tokens=%d",
858
+ hit.period,
859
+ hit.repeats,
860
+ completion_tokens,
861
+ )
862
+ # Feed </think> to parser to exit thinking mode
863
+ for event in parser.feed("</think>"):
864
+ _emit_event(event)
865
+ # Clear repetition tail to give fresh start
866
+ repetition_tail.clear()
867
+ thinking_bailout_done = True
868
+ # Continue generating (don't break)
869
+ continue
870
+
871
+ logger.debug(
872
+ "Repetition early-stop: period=%d repeats=%d checked_tail_len=%d completion_tokens=%d",
873
+ hit.period,
874
+ hit.repeats,
875
+ hit.checked_tail_len,
876
+ completion_tokens,
877
+ )
878
+ finish_reason = "repetition"
879
+ break
880
+
881
+ now = time.monotonic()
882
+ if len(token_buffer) < flush_n and (now - last_flush) < flush_s:
883
+ continue
884
+
885
+ last_flush = now
886
+ _flush_token_buffer(token_buffer, parser, _emit_event, assistant_raw_text_parts)
887
+ token_buffer.clear()
888
+
889
+ if parser.stopped and parser.stop_reason == "tool_calls":
890
+ finish_reason = "tool_calls"
891
+ break
892
+ if parser.stopped and parser.stop_reason == "stop":
893
+ finish_reason = "stop"
894
+ break
895
+ if parser.stopped and parser.stop_reason == "error":
896
+ finish_reason = "error"
897
+ break
898
+ finally:
899
+ # Explicitly close the generator to ensure its finally block runs.
900
+ # This is critical for session-based generation where the generator's
901
+ # finally block persists the KV cache state.
902
+ if hasattr(token_iter, 'close'):
903
+ token_iter.close()
904
+
905
+ # Final flush.
906
+ if token_buffer and not parser.stopped:
907
+ _flush_token_buffer(token_buffer, parser, _emit_event, assistant_raw_text_parts)
908
+ token_buffer.clear()
909
+
910
+ # Flush parser tail.
911
+ if not parser.stopped:
912
+ for event in parser.finish():
913
+ _emit_event(event)
914
+ if isinstance(event, ErrorEvent):
915
+ finish_reason = "error"
916
+
917
+ # Discard-thinking commit: restore to a checkpoint and append tokens for persisted history.
918
+ if request.session_id and enable_thinking and discard_thinking:
919
+ if checkpoint is None or commit_from_pos is None:
920
+ raise RuntimeError("Discard-thinking flow missing checkpoint state.")
921
+
922
+ self._adapter.restore_session_checkpoint(
923
+ cache_id=request.session_id,
924
+ checkpoint=checkpoint,
925
+ )
926
+
927
+ # Persist tool calls as a structured assistant message.
928
+ assistant_msg = None
929
+ if assistant_tool_calls:
930
+ assistant_msg = ChatMessage(
931
+ role="assistant",
932
+ content=None,
933
+ tool_calls=list(assistant_tool_calls),
934
+ )
935
+ else:
936
+ assistant_msg = ChatMessage(
937
+ role="assistant",
938
+ content="".join(assistant_text_parts),
939
+ )
940
+
941
+ persisted_messages = self._inject_tool_choice(
942
+ [*list(request.messages), assistant_msg],
943
+ request.tool_choice,
944
+ )
945
+ template_persisted = self._messages_for_template(persisted_messages)
946
+
947
+ ids_persisted = apply_chat_template(
948
+ template_persisted,
949
+ add_generation_prompt=False,
950
+ **kwargs,
951
+ )
952
+
953
+ # Sanity check: persisted prompt should start with the end-of-user prefix.
954
+ if boundary_pos is not None and ids_no_gen is not None:
955
+ import torch
956
+
957
+ if (
958
+ isinstance(ids_persisted, torch.Tensor)
959
+ and isinstance(ids_no_gen, torch.Tensor)
960
+ and ids_persisted.ndim == 2
961
+ and ids_no_gen.ndim == 2
962
+ and ids_persisted.shape[0] == 1
963
+ and ids_no_gen.shape[0] == 1
964
+ and ids_persisted.shape[1] >= int(boundary_pos)
965
+ and not torch.equal(ids_persisted[:, : int(boundary_pos)], ids_no_gen)
966
+ ):
967
+ if fallback_checkpoint is not None and fallback_commit_from_pos is not None:
968
+ logger.warning(
969
+ "Persisted prompt prefix mismatch; falling back to checkpoint-before-user."
970
+ )
971
+ self._adapter.restore_session_checkpoint(
972
+ cache_id=request.session_id,
973
+ checkpoint=fallback_checkpoint,
974
+ )
975
+ commit_from_pos = int(fallback_commit_from_pos)
976
+ else:
977
+ raise ValueError(
978
+ "Persisted prompt no longer matches the end-of-user prefix."
979
+ )
980
+
981
+ try:
982
+ delta_commit_ids = ids_persisted[:, int(commit_from_pos) :]
983
+ except Exception:
984
+ delta_commit_ids = ids_persisted
985
+
986
+ if getattr(delta_commit_ids, "numel", lambda: 0)() > 0:
987
+ self._adapter.append_to_session(
988
+ cache_id=request.session_id,
989
+ input_ids=delta_commit_ids,
990
+ )
991
+
992
+ # discard_thinking=False session path: the model often stops *before* emitting <|im_end|>,
993
+ # but apply_chat_template(add_generation_prompt=False) will include it for assistant messages.
994
+ # To keep KV/history in sync, append any missing tail tokens after generation.
995
+ if request.session_id and enable_thinking and not discard_thinking and not assistant_tool_calls:
996
+ # Build normalized raw content that matches the template's generation prompt prefix.
997
+ normalized_raw_content_for_history = (
998
+ "".join(assistant_raw_text_parts) if assistant_raw_text_parts else None
999
+ )
1000
+ if normalized_raw_content_for_history:
1001
+ if (
1002
+ _ModelOutputParser._THINK_CLOSE in normalized_raw_content_for_history
1003
+ and _ModelOutputParser._THINK_OPEN not in normalized_raw_content_for_history[:64]
1004
+ ):
1005
+ normalized_raw_content_for_history = (
1006
+ _ModelOutputParser._THINK_OPEN + "\n" + normalized_raw_content_for_history
1007
+ )
1008
+
1009
+ assistant_msg = ChatMessage(
1010
+ role="assistant",
1011
+ content=normalized_raw_content_for_history,
1012
+ )
1013
+ persisted_messages = self._inject_tool_choice(
1014
+ [*list(request.messages), assistant_msg],
1015
+ request.tool_choice,
1016
+ )
1017
+ template_persisted = self._messages_for_template(persisted_messages)
1018
+ ids_persisted = apply_chat_template(
1019
+ template_persisted,
1020
+ add_generation_prompt=False,
1021
+ **kwargs,
1022
+ )
1023
+
1024
+ sess_info = self._adapter.get_session_info(request.session_id)
1025
+ cur = int(sess_info.get("current_pos", 0))
1026
+ try:
1027
+ expected_total = int(getattr(ids_persisted, "shape")[1])
1028
+ except Exception:
1029
+ expected_total = cur
1030
+
1031
+ if expected_total > cur:
1032
+ try:
1033
+ delta_tail_ids = ids_persisted[:, cur:expected_total]
1034
+ except Exception:
1035
+ delta_tail_ids = ids_persisted
1036
+ if getattr(delta_tail_ids, "numel", lambda: 0)() > 0:
1037
+ self._adapter.append_to_session(
1038
+ cache_id=request.session_id,
1039
+ input_ids=delta_tail_ids,
1040
+ )
1041
+
1042
+ # If we exhausted the token budget without an explicit stop/tool call/cancel,
1043
+ # report a length stop (best-effort; adapter doesn't expose stop reason).
1044
+ if (
1045
+ finish_reason == "stop"
1046
+ and not parser.stopped
1047
+ and completion_tokens >= int(request.max_tokens)
1048
+ ):
1049
+ finish_reason = "length"
1050
+
1051
+ except Exception as exc:
1052
+ finish_reason = "error"
1053
+ loop.call_soon_threadsafe(queue.put_nowait, ErrorEvent(f"Generation failed: {exc}"))
1054
+ finally:
1055
+ ended = time.monotonic()
1056
+ prefill_s = None if first_token_at is None else max(first_token_at - started, 0.0)
1057
+ decode_s = None
1058
+ if first_token_at is not None:
1059
+ decode_s = max(ended - first_token_at, 0.0)
1060
+
1061
+ tok_per_s = None
1062
+ if decode_s and decode_s > 0 and completion_tokens > 0:
1063
+ tok_per_s = completion_tokens / decode_s
1064
+
1065
+ # Include raw content (with thinking) when discard_thinking=False for sessions
1066
+ raw_content = None
1067
+ if request.session_id and enable_thinking and not discard_thinking:
1068
+ raw_content = normalized_raw_content_for_history
1069
+ if raw_content is None:
1070
+ raw_content = "".join(assistant_raw_text_parts) if assistant_raw_text_parts else None
1071
+ # When enable_thinking=True, the generation prompt already includes <think>.
1072
+ # The model output stream may therefore omit the opening tag and begin directly
1073
+ # with the thinking text, later emitting only </think>. For history/KV sync,
1074
+ # normalize by re-introducing the opening tag when needed.
1075
+ if raw_content:
1076
+ if (
1077
+ _ModelOutputParser._THINK_CLOSE in raw_content
1078
+ and _ModelOutputParser._THINK_OPEN not in raw_content[:64]
1079
+ ):
1080
+ raw_content = _ModelOutputParser._THINK_OPEN + "\n" + raw_content
1081
+
1082
+ final = FinalEvent(
1083
+ finish_reason=finish_reason
1084
+ if finish_reason in {"stop", "length", "tool_calls", "cancelled", "error", "repetition"}
1085
+ else "stop",
1086
+ usage=Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
1087
+ timing=Timing(
1088
+ prefill_s=prefill_s,
1089
+ decode_s=decode_s,
1090
+ total_s=max(ended - started, 0.0),
1091
+ tok_per_s=tok_per_s,
1092
+ ),
1093
+ raw_content=raw_content,
1094
+ )
1095
+ loop.call_soon_threadsafe(queue.put_nowait, final)
1096
+ loop.call_soon_threadsafe(queue.put_nowait, None)
1097
+
1098
+ def _flush_token_buffer(
1099
+ token_ids: list[int],
1100
+ parser_: _ModelOutputParser,
1101
+ emit: Any,
1102
+ raw_text_parts: list[str] | None = None,
1103
+ ) -> None:
1104
+ tokenizer = getattr(self._adapter, "tokenizer", None)
1105
+ decode = getattr(tokenizer, "decode", None)
1106
+ if not callable(decode):
1107
+ raise RuntimeError("Tokenizer does not support decode().")
1108
+
1109
+ text = decode(token_ids, skip_special_tokens=False)
1110
+ if raw_text_parts is not None:
1111
+ raw_text_parts.append(text)
1112
+ for event in parser_.feed(text):
1113
+ emit(event)
1114
+
1115
+ thread = threading.Thread(target=worker, name=f"superlinear-gen-{uuid.uuid4().hex}", daemon=True)
1116
+ thread.start()
1117
+
1118
+ try:
1119
+ while True:
1120
+ event = await queue.get()
1121
+ if event is None:
1122
+ break
1123
+ yield event
1124
+ except asyncio.CancelledError:
1125
+ cancel.set()
1126
+ raise
1127
+ finally:
1128
+ # If the consumer stops early (disconnect / generator close), cancel generation promptly.
1129
+ cancel.set()
1130
+
1131
+ async def generate_chat(self, request: ChatRequest) -> dict[str, Any]:
1132
+ """Non-streaming chat completion.
1133
+
1134
+ Returns:
1135
+ Dict containing:
1136
+ - content: str | None
1137
+ - tool_calls: list[ToolCall]
1138
+ - finish_reason: str
1139
+ - usage: Usage
1140
+ - timing: Timing
1141
+ - raw_content: str | None (if discard_thinking=False)
1142
+ """
1143
+ content_parts: list[str] = []
1144
+ tool_calls: list[ToolCall] = []
1145
+ usage: Usage | None = None
1146
+ timing: Timing | None = None
1147
+ finish_reason = "stop"
1148
+ raw_content: str | None = None
1149
+
1150
+ async for event in self.astream_chat(request):
1151
+ if isinstance(event, DeltaEvent):
1152
+ content_parts.append(event.text)
1153
+ elif isinstance(event, ToolCallEvent):
1154
+ tool_calls = event.tool_calls
1155
+ finish_reason = "tool_calls"
1156
+ elif isinstance(event, FinalEvent):
1157
+ usage = event.usage
1158
+ timing = event.timing
1159
+ raw_content = event.raw_content
1160
+ # Don't override tool_calls finish reason
1161
+ if finish_reason != "tool_calls":
1162
+ finish_reason = event.finish_reason
1163
+ elif isinstance(event, ErrorEvent):
1164
+ raise RuntimeError(event.message)
1165
+
1166
+ return {
1167
+ "content": "".join(content_parts) if content_parts else None,
1168
+ "tool_calls": tool_calls,
1169
+ "finish_reason": finish_reason,
1170
+ "usage": usage,
1171
+ "timing": timing,
1172
+ "raw_content": raw_content,
1173
+ }