caudate-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- api/__init__.py +5 -0
- api/anthropic_compat.py +1518 -0
- api/artifact_viewer.py +366 -0
- api/caudate_middleware.py +618 -0
- api/forge_bootstrapper_routes.py +377 -0
- api/forge_routes.py +630 -0
- api/forge_system_routes.py +294 -0
- api/openai_compat.py +1993 -0
- api/server.py +667 -0
- api/storyboard_page.py +677 -0
- caudate_cli-0.1.0.dist-info/METADATA +354 -0
- caudate_cli-0.1.0.dist-info/RECORD +153 -0
- caudate_cli-0.1.0.dist-info/WHEEL +5 -0
- caudate_cli-0.1.0.dist-info/entry_points.txt +2 -0
- caudate_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- caudate_cli-0.1.0.dist-info/top_level.txt +14 -0
- cognos_mcp/__init__.py +4 -0
- cognos_mcp/bridge.py +41 -0
- cognos_mcp/client.py +70 -0
- cognos_mcp/config.py +49 -0
- cognos_mcp/server.py +66 -0
- config.py +82 -0
- core/__init__.py +0 -0
- core/agent.py +468 -0
- core/agentic_loop.py +731 -0
- core/anthropic_auth.py +91 -0
- core/background.py +113 -0
- core/banner.py +134 -0
- core/bootstrap.py +292 -0
- core/citations.py +131 -0
- core/compaction.py +109 -0
- core/constitution.py +198 -0
- core/diff_viewer.py +87 -0
- core/export.py +85 -0
- core/file_refs.py +119 -0
- core/files.py +199 -0
- core/hooks.py +209 -0
- core/image.py +599 -0
- core/input.py +91 -0
- core/loop.py +238 -0
- core/memory_md.py +147 -0
- core/notifications.py +99 -0
- core/ownership.py +181 -0
- core/paste.py +81 -0
- core/permissions.py +210 -0
- core/plan_mode.py +215 -0
- core/sandbox_prompt.py +185 -0
- core/scheduler.py +195 -0
- core/schemas.py +202 -0
- core/session.py +90 -0
- core/settings.py +132 -0
- core/skills.py +398 -0
- core/slash_commands.py +977 -0
- core/statusline.py +61 -0
- core/subagent.py +300 -0
- core/thinking.py +50 -0
- core/updater.py +122 -0
- core/usage.py +109 -0
- core/worktree.py +93 -0
- execution/__init__.py +0 -0
- execution/executor.py +329 -0
- execution/plugins.py +108 -0
- execution/tools/__init__.py +0 -0
- execution/tools/agent_tool.py +107 -0
- execution/tools/agentic_tool.py +297 -0
- execution/tools/artifact_tool.py +191 -0
- execution/tools/ask_user_question_tool.py +137 -0
- execution/tools/base.py +81 -0
- execution/tools/calculator_tool.py +137 -0
- execution/tools/cognos_card_tool.py +124 -0
- execution/tools/cron_tool.py +215 -0
- execution/tools/datetime_tool.py +215 -0
- execution/tools/describe_image_tool.py +161 -0
- execution/tools/draw_tool.py +164 -0
- execution/tools/edit_image_tool.py +262 -0
- execution/tools/edit_tool.py +245 -0
- execution/tools/file_tool.py +90 -0
- execution/tools/find_anywhere_tool.py +255 -0
- execution/tools/forge_feature_tools.py +377 -0
- execution/tools/glob_tool.py +59 -0
- execution/tools/grep_tool.py +89 -0
- execution/tools/http_request_tool.py +224 -0
- execution/tools/load_skill_tool.py +104 -0
- execution/tools/longcat_avatar_tool.py +384 -0
- execution/tools/mcp_tool.py +100 -0
- execution/tools/notebook_tool.py +279 -0
- execution/tools/openapi_tool.py +440 -0
- execution/tools/plan_mode_tool.py +95 -0
- execution/tools/push_notification_tool.py +157 -0
- execution/tools/python_tool.py +61 -0
- execution/tools/respond_tool.py +40 -0
- execution/tools/sandbox_tool.py +378 -0
- execution/tools/search_tool.py +153 -0
- execution/tools/semantic_search_tool.py +106 -0
- execution/tools/shell_tool.py +283 -0
- execution/tools/speak_tool.py +134 -0
- execution/tools/storyboard_tool.py +727 -0
- execution/tools/system_info_tool.py +212 -0
- execution/tools/task_tool.py +323 -0
- execution/tools/think_tool.py +49 -0
- execution/tools/transcribe_audio_tool.py +86 -0
- execution/tools/update_memory_tool.py +92 -0
- execution/tools/web_fetch_tool.py +82 -0
- execution/tools/worktree_tool.py +174 -0
- llm/__init__.py +0 -0
- llm/fallback.py +116 -0
- llm/models.py +320 -0
- llm/provider.py +1356 -0
- llm/router.py +373 -0
- main.py +1889 -0
- memory/__init__.py +0 -0
- memory/episodic.py +99 -0
- memory/procedural.py +145 -0
- memory/semantic.py +71 -0
- memory/working.py +64 -0
- nn/__init__.py +43 -0
- nn/auto_evolve.py +245 -0
- nn/caudate.py +136 -0
- nn/config.py +141 -0
- nn/consolidator.py +81 -0
- nn/data.py +1635 -0
- nn/encoder.py +258 -0
- nn/forge_advisor.py +303 -0
- nn/format.py +235 -0
- nn/heads.py +432 -0
- nn/observer.py +994 -0
- nn/policy.py +214 -0
- nn/runtime.py +343 -0
- nn/scorer.py +175 -0
- nn/trainer.py +515 -0
- nn/vision.py +352 -0
- personality/__init__.py +23 -0
- personality/engine.py +129 -0
- personality/identity.py +144 -0
- personality/inner_voice.py +100 -0
- personality/mood.py +205 -0
- planning/__init__.py +0 -0
- planning/dev_server.py +221 -0
- planning/forge_models.py +718 -0
- planning/orchestrator.py +1363 -0
- planning/planner.py +451 -0
- planning/task_graph.py +61 -0
- reflection/__init__.py +0 -0
- reflection/meta_learner.py +156 -0
- reflection/reflector.py +127 -0
- ui/__init__.py +5 -0
- ui/display.py +88 -0
- voice/__init__.py +0 -0
- voice/conversation.py +125 -0
- voice/listener.py +111 -0
- voice/speaker.py +59 -0
- voice/stt.py +126 -0
- voice/tts.py +214 -0
llm/provider.py
ADDED
|
@@ -0,0 +1,1356 @@
|
|
|
1
|
+
"""Model-agnostic LLM provider using LiteLLM.
|
|
2
|
+
|
|
3
|
+
Supports:
|
|
4
|
+
- Plain chat/completion
|
|
5
|
+
- Tool calling (native + prompt-based fallback)
|
|
6
|
+
- Streaming (async generator of StreamEvents)
|
|
7
|
+
- Structured JSON output
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
import random
|
|
16
|
+
import re
|
|
17
|
+
import uuid
|
|
18
|
+
from typing import Any, AsyncIterator, Awaitable, Callable, TypeVar
|
|
19
|
+
|
|
20
|
+
import litellm
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
|
|
23
|
+
from config import LLM_MODEL, LLM_TEMPERATURE, LLM_MAX_TOKENS, PROMPT_CACHING
|
|
24
|
+
from core.schemas import StreamEvent, ToolUseBlock
|
|
25
|
+
from core.usage import get_global_tracker
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
T = TypeVar("T")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# Retry policy for LLM calls — transient failures (network blips, model
|
|
33
|
+
# warmup stalls) should auto-retry with backoff. Pathological failures
|
|
34
|
+
# (bad model id, auth) are NOT retried — we raise immediately.
|
|
35
|
+
|
|
36
|
+
DEFAULT_MAX_RETRIES = 3
|
|
37
|
+
DEFAULT_INITIAL_BACKOFF_S = 1.0
|
|
38
|
+
DEFAULT_BACKOFF_MULTIPLIER = 2.5
|
|
39
|
+
DEFAULT_JITTER_S = 0.3
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# Errors that are worth retrying. Keep this narrow: anything model-side
|
|
43
|
+
# (wrong model name, invalid request, context-length overflow) won't fix
|
|
44
|
+
# itself by retrying.
|
|
45
|
+
_RETRYABLE_SUBSTRINGS = (
|
|
46
|
+
"timeout", "timed out", "connection", "econnreset", "read error",
|
|
47
|
+
"temporarily", "service unavailable", "503", "502", "500",
|
|
48
|
+
"rate limit", "429", "overloaded", "busy",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _is_retryable(err: BaseException) -> bool:
|
|
53
|
+
"""True if the error is transient and worth retrying."""
|
|
54
|
+
if isinstance(err, (asyncio.TimeoutError, ConnectionError, TimeoutError)):
|
|
55
|
+
return True
|
|
56
|
+
message = str(err).lower()
|
|
57
|
+
return any(tok in message for tok in _RETRYABLE_SUBSTRINGS)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
async def _with_retry(
|
|
61
|
+
op: Callable[[], Awaitable[T]],
|
|
62
|
+
label: str,
|
|
63
|
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
|
64
|
+
) -> T:
|
|
65
|
+
"""Run an async op with exponential-backoff retry on transient errors."""
|
|
66
|
+
delay = DEFAULT_INITIAL_BACKOFF_S
|
|
67
|
+
last: BaseException | None = None
|
|
68
|
+
for attempt in range(max_retries + 1):
|
|
69
|
+
try:
|
|
70
|
+
return await op()
|
|
71
|
+
except Exception as e:
|
|
72
|
+
last = e
|
|
73
|
+
if not _is_retryable(e) or attempt == max_retries:
|
|
74
|
+
break
|
|
75
|
+
jitter = random.uniform(0, DEFAULT_JITTER_S)
|
|
76
|
+
logger.warning(
|
|
77
|
+
f"{label} failed (attempt {attempt + 1}/{max_retries + 1}): "
|
|
78
|
+
f"{e} — retrying in {delay + jitter:.1f}s"
|
|
79
|
+
)
|
|
80
|
+
await asyncio.sleep(delay + jitter)
|
|
81
|
+
delay *= DEFAULT_BACKOFF_MULTIPLIER
|
|
82
|
+
assert last is not None
|
|
83
|
+
raise last
|
|
84
|
+
|
|
85
|
+
# Suppress litellm noise
|
|
86
|
+
litellm.suppress_debug_info = True
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class LLMResponse(BaseModel):
|
|
90
|
+
content: str
|
|
91
|
+
raw: dict[str, Any] = {}
|
|
92
|
+
model: str = ""
|
|
93
|
+
usage: dict[str, int] = {}
|
|
94
|
+
stop_reason: str | None = None
|
|
95
|
+
tool_calls: list[ToolUseBlock] = []
|
|
96
|
+
# Thinking-model output (gemma4, kimi, deepseek, etc). Most models
|
|
97
|
+
# leave this empty; thinking models surface their internal reasoning
|
|
98
|
+
# here. Cognos passes it through to Anthropic-compat clients as a
|
|
99
|
+
# `thinking` content block so Claude Code can render it.
|
|
100
|
+
thinking: str = ""
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# Models known to lack native tool-calling — fall back to prompt-based
|
|
104
|
+
# protocol.
|
|
105
|
+
_NO_NATIVE_TOOLS = {
|
|
106
|
+
"gemma", "gemma2", "gemma3", "gemma4",
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _has_native_tool_support(model: str) -> bool:
|
|
111
|
+
"""Heuristic: does this model support native tool-calling via LiteLLM?"""
|
|
112
|
+
m = model.lower()
|
|
113
|
+
for prefix in _NO_NATIVE_TOOLS:
|
|
114
|
+
if prefix in m:
|
|
115
|
+
return False
|
|
116
|
+
return True
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
_KNOWN_NO_JSON_MODE = {"gemma", "gemma2"} # older gemma rejects response_format
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _has_json_mode_support(model: str) -> bool:
|
|
123
|
+
m = model.lower()
|
|
124
|
+
return not any(prefix in m for prefix in _KNOWN_NO_JSON_MODE)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _supports_prompt_caching(model: str) -> bool:
|
|
128
|
+
"""Only Anthropic models honor cache_control today. Other providers ignore it."""
|
|
129
|
+
return "claude" in model.lower() or model.lower().startswith("anthropic/")
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _apply_cache_control(
|
|
133
|
+
messages: list[dict[str, Any]],
|
|
134
|
+
model: str,
|
|
135
|
+
) -> list[dict[str, Any]]:
|
|
136
|
+
"""Rewrite the first system message with a cache_control breakpoint.
|
|
137
|
+
|
|
138
|
+
LiteLLM forwards cache_control through to Anthropic's SDK. Other providers
|
|
139
|
+
see the nested content-block list and ignore the extra fields, so this is
|
|
140
|
+
safe to leave on globally. We only bother if the system prompt is long
|
|
141
|
+
enough to matter (>= 1024 tokens ≈ 4000 chars) — Anthropic's caching has a
|
|
142
|
+
minimum chunk size below which it's a no-op.
|
|
143
|
+
"""
|
|
144
|
+
if not PROMPT_CACHING or not _supports_prompt_caching(model):
|
|
145
|
+
return messages
|
|
146
|
+
if not messages or messages[0].get("role") != "system":
|
|
147
|
+
return messages
|
|
148
|
+
|
|
149
|
+
raw = messages[0].get("content")
|
|
150
|
+
if not isinstance(raw, str) or len(raw) < 4000:
|
|
151
|
+
return messages
|
|
152
|
+
|
|
153
|
+
out = list(messages)
|
|
154
|
+
out[0] = {
|
|
155
|
+
"role": "system",
|
|
156
|
+
"content": [
|
|
157
|
+
{
|
|
158
|
+
"type": "text",
|
|
159
|
+
"text": raw,
|
|
160
|
+
"cache_control": {"type": "ephemeral"},
|
|
161
|
+
},
|
|
162
|
+
],
|
|
163
|
+
}
|
|
164
|
+
return out
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class LLMProvider:
|
|
168
|
+
"""Unified LLM interface supporting local (Ollama) and cloud (Claude, OpenAI) models."""
|
|
169
|
+
|
|
170
|
+
def __init__(
|
|
171
|
+
self,
|
|
172
|
+
model: str = LLM_MODEL,
|
|
173
|
+
temperature: float = LLM_TEMPERATURE,
|
|
174
|
+
max_tokens: int = LLM_MAX_TOKENS,
|
|
175
|
+
):
|
|
176
|
+
self.model = model
|
|
177
|
+
self.temperature = temperature
|
|
178
|
+
self.max_tokens = max_tokens
|
|
179
|
+
|
|
180
|
+
# ------------------------------------------------------------------
|
|
181
|
+
# Basic completion / chat
|
|
182
|
+
# ------------------------------------------------------------------
|
|
183
|
+
|
|
184
|
+
async def complete(
|
|
185
|
+
self,
|
|
186
|
+
prompt: str,
|
|
187
|
+
system: str | None = None,
|
|
188
|
+
temperature: float | None = None,
|
|
189
|
+
max_tokens: int | None = None,
|
|
190
|
+
response_format: dict | None = None,
|
|
191
|
+
caller: str | None = None,
|
|
192
|
+
) -> LLMResponse:
|
|
193
|
+
"""Send a single-turn completion request."""
|
|
194
|
+
messages: list[dict[str, Any]] = []
|
|
195
|
+
if system:
|
|
196
|
+
messages.append({"role": "system", "content": system})
|
|
197
|
+
messages.append({"role": "user", "content": prompt})
|
|
198
|
+
return await self.chat(
|
|
199
|
+
messages,
|
|
200
|
+
temperature=temperature,
|
|
201
|
+
max_tokens=max_tokens,
|
|
202
|
+
response_format=response_format,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def _route_model_for_litellm(self) -> str:
|
|
206
|
+
"""LiteLLM uses different prefixes for different Ollama endpoints:
|
|
207
|
+
|
|
208
|
+
- `ollama/<name>` → /v1/chat/completions (OpenAI-compat;
|
|
209
|
+
drops Ollama-specific fields like thinking)
|
|
210
|
+
- `ollama_chat/<name>` → /api/chat (native; preserves thinking
|
|
211
|
+
as `reasoning_content`)
|
|
212
|
+
|
|
213
|
+
We always want the native path so thinking-model output reaches
|
|
214
|
+
the agent and downstream Anthropic-compat clients.
|
|
215
|
+
"""
|
|
216
|
+
if self.model.startswith("ollama/") and not self.model.startswith("ollama_chat/"):
|
|
217
|
+
return "ollama_chat/" + self.model[len("ollama/"):]
|
|
218
|
+
return self.model
|
|
219
|
+
|
|
220
|
+
def _should_use_anthropic_subscription(self) -> bool:
|
|
221
|
+
"""True iff the request should bypass LiteLLM and call
|
|
222
|
+
api.anthropic.com directly with the Claude-Code subscription
|
|
223
|
+
OAuth token. We bypass LiteLLM because its anthropic provider
|
|
224
|
+
sends `x-api-key` (not Bearer); subscription tokens require
|
|
225
|
+
the Bearer header instead.
|
|
226
|
+
|
|
227
|
+
Activation requires *all* of:
|
|
228
|
+
- configured model is an anthropic/* id
|
|
229
|
+
- the calling code has entered `subscription_auth_scope()`
|
|
230
|
+
(only the web-UI /chat and /chat/stream endpoints do)
|
|
231
|
+
- the credentials file is readable
|
|
232
|
+
"""
|
|
233
|
+
if not self.model.startswith("anthropic/"):
|
|
234
|
+
return False
|
|
235
|
+
try:
|
|
236
|
+
from core.anthropic_auth import is_active, read_subscription_token
|
|
237
|
+
except Exception:
|
|
238
|
+
return False
|
|
239
|
+
if not is_active():
|
|
240
|
+
return False
|
|
241
|
+
return read_subscription_token() is not None
|
|
242
|
+
|
|
243
|
+
def _anthropic_subscription_headers(self) -> dict[str, str]:
|
|
244
|
+
from core.anthropic_auth import read_subscription_token
|
|
245
|
+
token = read_subscription_token() or ""
|
|
246
|
+
return {
|
|
247
|
+
"Authorization": f"Bearer {token}",
|
|
248
|
+
"anthropic-version": "2023-06-01",
|
|
249
|
+
"anthropic-beta": "claude-code-20250219,oauth-2025-04-20",
|
|
250
|
+
"content-type": "application/json",
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
def _build_anthropic_body(
|
|
254
|
+
self,
|
|
255
|
+
messages: list[dict[str, Any]],
|
|
256
|
+
tools: list[dict] | None,
|
|
257
|
+
max_tokens: int | None,
|
|
258
|
+
temperature: float | None,
|
|
259
|
+
stream: bool,
|
|
260
|
+
) -> dict[str, Any]:
|
|
261
|
+
"""Translate Cognos's internal OpenAI-shape messages into
|
|
262
|
+
Anthropic's /v1/messages body."""
|
|
263
|
+
# Strip the "anthropic/" prefix to get the bare model id.
|
|
264
|
+
model = self.model.split("/", 1)[1]
|
|
265
|
+
body: dict[str, Any] = {
|
|
266
|
+
"model": model,
|
|
267
|
+
"max_tokens": max_tokens if max_tokens is not None else self.max_tokens,
|
|
268
|
+
"stream": stream,
|
|
269
|
+
}
|
|
270
|
+
if temperature is not None:
|
|
271
|
+
body["temperature"] = temperature
|
|
272
|
+
|
|
273
|
+
system_parts: list[str] = []
|
|
274
|
+
out_msgs: list[dict[str, Any]] = []
|
|
275
|
+
|
|
276
|
+
for m in messages:
|
|
277
|
+
role = m.get("role")
|
|
278
|
+
content = m.get("content")
|
|
279
|
+
if role == "system":
|
|
280
|
+
if isinstance(content, str):
|
|
281
|
+
system_parts.append(content)
|
|
282
|
+
elif isinstance(content, list):
|
|
283
|
+
for b in content:
|
|
284
|
+
if isinstance(b, dict) and b.get("type") == "text":
|
|
285
|
+
system_parts.append(b.get("text", ""))
|
|
286
|
+
continue
|
|
287
|
+
if role == "tool":
|
|
288
|
+
# OpenAI tool-result → Anthropic user msg w/ tool_result block.
|
|
289
|
+
out_msgs.append({
|
|
290
|
+
"role": "user",
|
|
291
|
+
"content": [{
|
|
292
|
+
"type": "tool_result",
|
|
293
|
+
"tool_use_id": m.get("tool_call_id", ""),
|
|
294
|
+
"content": str(content or ""),
|
|
295
|
+
}],
|
|
296
|
+
})
|
|
297
|
+
continue
|
|
298
|
+
if role == "assistant":
|
|
299
|
+
blocks: list[dict[str, Any]] = []
|
|
300
|
+
if isinstance(content, str) and content:
|
|
301
|
+
blocks.append({"type": "text", "text": content})
|
|
302
|
+
elif isinstance(content, list):
|
|
303
|
+
for b in content:
|
|
304
|
+
if isinstance(b, dict) and b.get("type") == "text":
|
|
305
|
+
blocks.append({"type": "text", "text": b.get("text", "")})
|
|
306
|
+
for tc in m.get("tool_calls") or []:
|
|
307
|
+
fn = tc.get("function") or {}
|
|
308
|
+
raw_args = fn.get("arguments")
|
|
309
|
+
# Ollama/Kimi sometimes send arguments as a dict
|
|
310
|
+
# already (their tool-call format diverges from
|
|
311
|
+
# OpenAI's "arguments must be a JSON string" rule).
|
|
312
|
+
if isinstance(raw_args, dict):
|
|
313
|
+
args = raw_args
|
|
314
|
+
else:
|
|
315
|
+
try:
|
|
316
|
+
args = json.loads(raw_args or "{}")
|
|
317
|
+
except Exception:
|
|
318
|
+
args = {}
|
|
319
|
+
blocks.append({
|
|
320
|
+
"type": "tool_use",
|
|
321
|
+
"id": tc.get("id") or f"toolu_{uuid.uuid4().hex[:12]}",
|
|
322
|
+
"name": fn.get("name", ""),
|
|
323
|
+
"input": args,
|
|
324
|
+
})
|
|
325
|
+
if blocks:
|
|
326
|
+
out_msgs.append({"role": "assistant", "content": blocks})
|
|
327
|
+
continue
|
|
328
|
+
# role == "user"
|
|
329
|
+
if isinstance(content, str):
|
|
330
|
+
out_msgs.append({"role": "user", "content": content})
|
|
331
|
+
elif isinstance(content, list):
|
|
332
|
+
blocks = []
|
|
333
|
+
for b in content:
|
|
334
|
+
if not isinstance(b, dict):
|
|
335
|
+
continue
|
|
336
|
+
if b.get("type") == "text":
|
|
337
|
+
blocks.append({"type": "text", "text": b.get("text", "")})
|
|
338
|
+
elif b.get("type") == "image_url":
|
|
339
|
+
url = (b.get("image_url") or {}).get("url", "")
|
|
340
|
+
if url.startswith("data:"):
|
|
341
|
+
head, _, b64 = url.partition(",")
|
|
342
|
+
media = head.split(";")[0].split(":")[-1] or "image/png"
|
|
343
|
+
blocks.append({
|
|
344
|
+
"type": "image",
|
|
345
|
+
"source": {"type": "base64",
|
|
346
|
+
"media_type": media, "data": b64},
|
|
347
|
+
})
|
|
348
|
+
out_msgs.append({
|
|
349
|
+
"role": "user",
|
|
350
|
+
"content": blocks or [{"type": "text", "text": ""}],
|
|
351
|
+
})
|
|
352
|
+
|
|
353
|
+
if system_parts:
|
|
354
|
+
body["system"] = "\n\n".join(p for p in system_parts if p)
|
|
355
|
+
body["messages"] = out_msgs
|
|
356
|
+
|
|
357
|
+
if tools:
|
|
358
|
+
body["tools"] = []
|
|
359
|
+
for t in tools:
|
|
360
|
+
fn = t.get("function") or t
|
|
361
|
+
body["tools"].append({
|
|
362
|
+
"name": fn.get("name", ""),
|
|
363
|
+
"description": fn.get("description", ""),
|
|
364
|
+
"input_schema": fn.get("parameters") or {
|
|
365
|
+
"type": "object", "properties": {},
|
|
366
|
+
},
|
|
367
|
+
})
|
|
368
|
+
return body
|
|
369
|
+
|
|
370
|
+
async def _call_anthropic_subscription_chat(
|
|
371
|
+
self,
|
|
372
|
+
messages: list[dict[str, Any]],
|
|
373
|
+
tools: list[dict] | None,
|
|
374
|
+
max_tokens: int | None,
|
|
375
|
+
temperature: float | None,
|
|
376
|
+
) -> LLMResponse:
|
|
377
|
+
"""Non-streaming direct call to api.anthropic.com using the OAuth
|
|
378
|
+
subscription Bearer."""
|
|
379
|
+
import httpx
|
|
380
|
+
body = self._build_anthropic_body(messages, tools, max_tokens, temperature, stream=False)
|
|
381
|
+
headers = self._anthropic_subscription_headers()
|
|
382
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=15.0)) as client:
|
|
383
|
+
resp = await client.post(
|
|
384
|
+
"https://api.anthropic.com/v1/messages",
|
|
385
|
+
headers=headers, json=body,
|
|
386
|
+
)
|
|
387
|
+
if resp.status_code >= 400:
|
|
388
|
+
raise RuntimeError(
|
|
389
|
+
f"Anthropic subscription call failed: {resp.status_code} {resp.text}"
|
|
390
|
+
)
|
|
391
|
+
data = resp.json()
|
|
392
|
+
text = ""
|
|
393
|
+
thinking = ""
|
|
394
|
+
tool_calls: list[ToolUseBlock] = []
|
|
395
|
+
for block in data.get("content") or []:
|
|
396
|
+
btype = block.get("type")
|
|
397
|
+
if btype == "text":
|
|
398
|
+
text += block.get("text", "")
|
|
399
|
+
elif btype == "thinking":
|
|
400
|
+
thinking += block.get("thinking", "")
|
|
401
|
+
elif btype == "tool_use":
|
|
402
|
+
tool_calls.append(ToolUseBlock(
|
|
403
|
+
id=block.get("id") or str(uuid.uuid4()),
|
|
404
|
+
name=block.get("name", ""),
|
|
405
|
+
input=block.get("input") or {},
|
|
406
|
+
))
|
|
407
|
+
usage = data.get("usage") or {}
|
|
408
|
+
in_tok = int(usage.get("input_tokens", 0))
|
|
409
|
+
out_tok = int(usage.get("output_tokens", 0))
|
|
410
|
+
usage_dict = {
|
|
411
|
+
"prompt_tokens": in_tok,
|
|
412
|
+
"completion_tokens": out_tok,
|
|
413
|
+
"total_tokens": in_tok + out_tok,
|
|
414
|
+
}
|
|
415
|
+
get_global_tracker().record(self.model, usage_dict)
|
|
416
|
+
return LLMResponse(
|
|
417
|
+
content=text,
|
|
418
|
+
raw=data,
|
|
419
|
+
model=data.get("model") or self.model,
|
|
420
|
+
usage=usage_dict,
|
|
421
|
+
stop_reason=data.get("stop_reason"),
|
|
422
|
+
tool_calls=tool_calls,
|
|
423
|
+
thinking=thinking,
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
async def _call_anthropic_subscription_stream(
|
|
427
|
+
self,
|
|
428
|
+
messages: list[dict[str, Any]],
|
|
429
|
+
tools: list[dict] | None,
|
|
430
|
+
max_tokens: int | None,
|
|
431
|
+
temperature: float | None,
|
|
432
|
+
) -> AsyncIterator[StreamEvent]:
|
|
433
|
+
"""Streaming direct call to api.anthropic.com — yields
|
|
434
|
+
Cognos-internal StreamEvents derived from Anthropic SSE."""
|
|
435
|
+
import httpx
|
|
436
|
+
body = self._build_anthropic_body(messages, tools, max_tokens, temperature, stream=True)
|
|
437
|
+
headers = self._anthropic_subscription_headers()
|
|
438
|
+
|
|
439
|
+
yield StreamEvent(type="message_start")
|
|
440
|
+
|
|
441
|
+
block_types: dict[int, str] = {}
|
|
442
|
+
block_tool_names: dict[int, str] = {}
|
|
443
|
+
block_tool_inputs: dict[int, str] = {}
|
|
444
|
+
stop_reason: str | None = None
|
|
445
|
+
pending = ""
|
|
446
|
+
|
|
447
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=15.0)) as client:
|
|
448
|
+
async with client.stream(
|
|
449
|
+
"POST", "https://api.anthropic.com/v1/messages",
|
|
450
|
+
headers=headers, json=body,
|
|
451
|
+
) as resp:
|
|
452
|
+
if resp.status_code >= 400:
|
|
453
|
+
body_bytes = await resp.aread()
|
|
454
|
+
raise RuntimeError(
|
|
455
|
+
f"Anthropic subscription stream failed: "
|
|
456
|
+
f"{resp.status_code} {body_bytes.decode('utf-8', 'replace')}"
|
|
457
|
+
)
|
|
458
|
+
async for chunk in resp.aiter_bytes():
|
|
459
|
+
if not chunk:
|
|
460
|
+
continue
|
|
461
|
+
pending += chunk.decode("utf-8", errors="replace")
|
|
462
|
+
while "\n\n" in pending:
|
|
463
|
+
raw_event, pending = pending.split("\n\n", 1)
|
|
464
|
+
if not raw_event.strip():
|
|
465
|
+
continue
|
|
466
|
+
evt_type = None
|
|
467
|
+
data_lines: list[str] = []
|
|
468
|
+
for line in raw_event.splitlines():
|
|
469
|
+
if line.startswith("event:"):
|
|
470
|
+
evt_type = line[6:].strip()
|
|
471
|
+
elif line.startswith("data:"):
|
|
472
|
+
data_lines.append(line[5:].strip())
|
|
473
|
+
if not data_lines:
|
|
474
|
+
continue
|
|
475
|
+
try:
|
|
476
|
+
data = json.loads("\n".join(data_lines))
|
|
477
|
+
except Exception:
|
|
478
|
+
continue
|
|
479
|
+
if evt_type == "content_block_start":
|
|
480
|
+
idx = int(data.get("index", -1))
|
|
481
|
+
cb = data.get("content_block") or {}
|
|
482
|
+
block_types[idx] = cb.get("type", "")
|
|
483
|
+
if cb.get("type") == "tool_use":
|
|
484
|
+
block_tool_names[idx] = cb.get("name", "")
|
|
485
|
+
block_tool_inputs[idx] = ""
|
|
486
|
+
elif evt_type == "content_block_delta":
|
|
487
|
+
delta = data.get("delta") or {}
|
|
488
|
+
dtype = delta.get("type")
|
|
489
|
+
if dtype == "text_delta":
|
|
490
|
+
yield StreamEvent(type="text_delta", delta=delta.get("text", ""))
|
|
491
|
+
elif dtype == "thinking_delta":
|
|
492
|
+
yield StreamEvent(type="thinking_delta", delta=delta.get("thinking", ""))
|
|
493
|
+
elif dtype == "input_json_delta":
|
|
494
|
+
idx = int(data.get("index", -1))
|
|
495
|
+
block_tool_inputs[idx] = (
|
|
496
|
+
block_tool_inputs.get(idx, "")
|
|
497
|
+
+ (delta.get("partial_json") or "")
|
|
498
|
+
)
|
|
499
|
+
elif evt_type == "content_block_stop":
|
|
500
|
+
idx = int(data.get("index", -1))
|
|
501
|
+
if block_types.get(idx) == "tool_use":
|
|
502
|
+
raw = block_tool_inputs.get(idx, "")
|
|
503
|
+
try:
|
|
504
|
+
args = json.loads(raw) if raw else {}
|
|
505
|
+
except Exception:
|
|
506
|
+
args = {"_raw": raw}
|
|
507
|
+
yield StreamEvent(
|
|
508
|
+
type="tool_use_end",
|
|
509
|
+
tool_use_id=f"toolu_{uuid.uuid4().hex[:12]}",
|
|
510
|
+
tool_name=block_tool_names.get(idx, ""),
|
|
511
|
+
tool_input=args,
|
|
512
|
+
block_index=idx,
|
|
513
|
+
)
|
|
514
|
+
elif evt_type == "message_delta":
|
|
515
|
+
stop_reason = (data.get("delta") or {}).get("stop_reason") or stop_reason
|
|
516
|
+
yield StreamEvent(type="message_stop", stop_reason=stop_reason)
|
|
517
|
+
|
|
518
|
+
async def chat(
|
|
519
|
+
self,
|
|
520
|
+
messages: list[dict[str, Any]],
|
|
521
|
+
temperature: float | None = None,
|
|
522
|
+
max_tokens: int | None = None,
|
|
523
|
+
response_format: dict | None = None,
|
|
524
|
+
tools: list[dict] | None = None,
|
|
525
|
+
tool_choice: str | None = None,
|
|
526
|
+
caller: str | None = None,
|
|
527
|
+
) -> LLMResponse:
|
|
528
|
+
"""Send a chat completion request.
|
|
529
|
+
|
|
530
|
+
If `tools` is passed and the model supports native tool-calling,
|
|
531
|
+
pass them through. Otherwise, fall back to a prompt-based protocol.
|
|
532
|
+
"""
|
|
533
|
+
# Subscription-OAuth path: web UI calls Anthropic via direct
|
|
534
|
+
# httpx (LiteLLM can't send Bearer-only auth cleanly).
|
|
535
|
+
if self._should_use_anthropic_subscription():
|
|
536
|
+
return await _with_retry(
|
|
537
|
+
lambda: self._call_anthropic_subscription_chat(
|
|
538
|
+
messages, tools, max_tokens, temperature,
|
|
539
|
+
),
|
|
540
|
+
label=f"LLM chat ({self.model}, subscription)",
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
use_native_tools = tools and _has_native_tool_support(self.model)
|
|
544
|
+
use_prompt_tools = tools and not use_native_tools
|
|
545
|
+
|
|
546
|
+
if use_prompt_tools:
|
|
547
|
+
messages = self._inject_tool_prompt(messages, tools)
|
|
548
|
+
|
|
549
|
+
messages = _apply_cache_control(messages, self.model)
|
|
550
|
+
|
|
551
|
+
kwargs: dict[str, Any] = {
|
|
552
|
+
"model": self._route_model_for_litellm(),
|
|
553
|
+
"messages": messages,
|
|
554
|
+
"temperature": temperature if temperature is not None else self.temperature,
|
|
555
|
+
"max_tokens": max_tokens if max_tokens is not None else self.max_tokens,
|
|
556
|
+
"timeout": 300,
|
|
557
|
+
}
|
|
558
|
+
if response_format:
|
|
559
|
+
kwargs["response_format"] = response_format
|
|
560
|
+
if use_native_tools:
|
|
561
|
+
kwargs["tools"] = tools
|
|
562
|
+
if tool_choice:
|
|
563
|
+
kwargs["tool_choice"] = tool_choice
|
|
564
|
+
|
|
565
|
+
try:
|
|
566
|
+
response = await _with_retry(
|
|
567
|
+
lambda: litellm.acompletion(**kwargs),
|
|
568
|
+
label=f"LLM chat ({self.model})",
|
|
569
|
+
)
|
|
570
|
+
except Exception as e:
|
|
571
|
+
logger.error(f"LLM call failed: {e}")
|
|
572
|
+
raise
|
|
573
|
+
|
|
574
|
+
message = response.choices[0].message
|
|
575
|
+
content = message.content or ""
|
|
576
|
+
stop_reason = response.choices[0].finish_reason
|
|
577
|
+
|
|
578
|
+
tool_calls: list[ToolUseBlock] = []
|
|
579
|
+
|
|
580
|
+
if use_native_tools and getattr(message, "tool_calls", None):
|
|
581
|
+
for tc in message.tool_calls:
|
|
582
|
+
raw_args = tc.function.arguments
|
|
583
|
+
# Ollama/Kimi may send `arguments` as a dict instead
|
|
584
|
+
# of a JSON string (OpenAI's spec says string, but
|
|
585
|
+
# many local backends diverge). Accept both.
|
|
586
|
+
if isinstance(raw_args, dict):
|
|
587
|
+
args = raw_args
|
|
588
|
+
elif raw_args:
|
|
589
|
+
try:
|
|
590
|
+
args = json.loads(raw_args)
|
|
591
|
+
except json.JSONDecodeError:
|
|
592
|
+
args = {"_raw": raw_args}
|
|
593
|
+
else:
|
|
594
|
+
args = {}
|
|
595
|
+
tool_calls.append(ToolUseBlock(
|
|
596
|
+
id=tc.id or str(uuid.uuid4()),
|
|
597
|
+
name=tc.function.name,
|
|
598
|
+
input=args,
|
|
599
|
+
))
|
|
600
|
+
elif use_prompt_tools:
|
|
601
|
+
# Parse tool calls from the text content
|
|
602
|
+
parsed_calls, stripped = self._parse_prompt_tool_calls(content)
|
|
603
|
+
if parsed_calls:
|
|
604
|
+
tool_calls = parsed_calls
|
|
605
|
+
content = stripped
|
|
606
|
+
|
|
607
|
+
# Salvage path: if we asked for native tools but got an empty
|
|
608
|
+
# tool_calls field AND content looks like a JSON code block,
|
|
609
|
+
# try the prompt parser. Some models (GLM-5.1, some Llamas)
|
|
610
|
+
# emit tool calls as text even when given the function-calling
|
|
611
|
+
# API. This makes them work without a per-model allowlist.
|
|
612
|
+
if use_native_tools and not tool_calls and content:
|
|
613
|
+
parsed_calls, stripped = self._parse_prompt_tool_calls(content)
|
|
614
|
+
if parsed_calls:
|
|
615
|
+
tool_calls = parsed_calls
|
|
616
|
+
content = stripped
|
|
617
|
+
|
|
618
|
+
# Some models (gemma4 in particular) leak their tokenizer's
|
|
619
|
+
# special tokens into output text — `<tool_call|>`, `<thought`,
|
|
620
|
+
# `<channel|>`, `<|im_start|>` etc. Strip them.
|
|
621
|
+
if content:
|
|
622
|
+
content = _strip_template_leaks(content).strip()
|
|
623
|
+
|
|
624
|
+
usage_dict = {
|
|
625
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
626
|
+
"completion_tokens": response.usage.completion_tokens,
|
|
627
|
+
"total_tokens": response.usage.total_tokens,
|
|
628
|
+
}
|
|
629
|
+
get_global_tracker().record(self.model, usage_dict)
|
|
630
|
+
# Thinking models (gemma4, kimi, etc.) put their reasoning in a
|
|
631
|
+
# separate `thinking` field on the message. Pull it out so we can
|
|
632
|
+
# forward it in the Anthropic-compat layer.
|
|
633
|
+
thinking_text = ""
|
|
634
|
+
try:
|
|
635
|
+
thinking_text = (
|
|
636
|
+
getattr(message, "thinking", None)
|
|
637
|
+
or getattr(message, "reasoning_content", None)
|
|
638
|
+
or ""
|
|
639
|
+
)
|
|
640
|
+
except Exception:
|
|
641
|
+
pass
|
|
642
|
+
return LLMResponse(
|
|
643
|
+
content=content,
|
|
644
|
+
raw=response.model_dump(),
|
|
645
|
+
model=response.model or self.model,
|
|
646
|
+
usage=usage_dict,
|
|
647
|
+
stop_reason=stop_reason,
|
|
648
|
+
tool_calls=tool_calls,
|
|
649
|
+
thinking=str(thinking_text or ""),
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# ------------------------------------------------------------------
|
|
653
|
+
# Streaming
|
|
654
|
+
# ------------------------------------------------------------------
|
|
655
|
+
|
|
656
|
+
async def stream(
|
|
657
|
+
self,
|
|
658
|
+
messages: list[dict[str, Any]],
|
|
659
|
+
temperature: float | None = None,
|
|
660
|
+
max_tokens: int | None = None,
|
|
661
|
+
tools: list[dict] | None = None,
|
|
662
|
+
tool_choice: str | None = None,
|
|
663
|
+
caller: str | None = None,
|
|
664
|
+
) -> AsyncIterator[StreamEvent]:
|
|
665
|
+
"""Stream a chat completion, yielding StreamEvent deltas."""
|
|
666
|
+
use_native_tools = tools and _has_native_tool_support(self.model)
|
|
667
|
+
use_prompt_tools = tools and not use_native_tools
|
|
668
|
+
|
|
669
|
+
if use_prompt_tools:
|
|
670
|
+
messages = self._inject_tool_prompt(messages, tools)
|
|
671
|
+
|
|
672
|
+
messages = _apply_cache_control(messages, self.model)
|
|
673
|
+
|
|
674
|
+
# Subscription-OAuth streaming path: web UI calls Anthropic via
|
|
675
|
+
# direct httpx (LiteLLM auth limitation, see chat()).
|
|
676
|
+
if self._should_use_anthropic_subscription():
|
|
677
|
+
async for event in self._call_anthropic_subscription_stream(
|
|
678
|
+
messages, tools, max_tokens, temperature,
|
|
679
|
+
):
|
|
680
|
+
yield event
|
|
681
|
+
return
|
|
682
|
+
|
|
683
|
+
kwargs: dict[str, Any] = {
|
|
684
|
+
"model": self._route_model_for_litellm(),
|
|
685
|
+
"messages": messages,
|
|
686
|
+
"temperature": temperature if temperature is not None else self.temperature,
|
|
687
|
+
"max_tokens": max_tokens if max_tokens is not None else self.max_tokens,
|
|
688
|
+
"timeout": 300,
|
|
689
|
+
"stream": True,
|
|
690
|
+
}
|
|
691
|
+
if use_native_tools:
|
|
692
|
+
kwargs["tools"] = tools
|
|
693
|
+
if tool_choice:
|
|
694
|
+
kwargs["tool_choice"] = tool_choice
|
|
695
|
+
|
|
696
|
+
yield StreamEvent(type="message_start")
|
|
697
|
+
|
|
698
|
+
# Track partial tool_calls being streamed (native path)
|
|
699
|
+
partial_tools: dict[int, dict[str, Any]] = {}
|
|
700
|
+
buffered_text = "" # for prompt-based tool parsing
|
|
701
|
+
stop_reason: str | None = None
|
|
702
|
+
|
|
703
|
+
try:
|
|
704
|
+
stream = await _with_retry(
|
|
705
|
+
lambda: litellm.acompletion(**kwargs),
|
|
706
|
+
label=f"LLM stream ({self.model})",
|
|
707
|
+
)
|
|
708
|
+
async for chunk in stream:
|
|
709
|
+
choice = chunk.choices[0]
|
|
710
|
+
delta = choice.delta
|
|
711
|
+
|
|
712
|
+
# Text delta — strip tokenizer-leak tool-call delimiters
|
|
713
|
+
# (gemma4 in particular leaks `<tool_call|>` etc. into text)
|
|
714
|
+
text = getattr(delta, "content", None)
|
|
715
|
+
if text:
|
|
716
|
+
text = _strip_template_leaks(text)
|
|
717
|
+
if use_prompt_tools:
|
|
718
|
+
buffered_text += text
|
|
719
|
+
elif text:
|
|
720
|
+
yield StreamEvent(type="text_delta", delta=text)
|
|
721
|
+
|
|
722
|
+
# Thinking delta — gemma4/kimi/deepseek-style models emit
|
|
723
|
+
# reasoning incrementally in a separate `thinking` (or
|
|
724
|
+
# `reasoning_content`) field on each chunk.
|
|
725
|
+
thinking_delta = (
|
|
726
|
+
getattr(delta, "thinking", None)
|
|
727
|
+
or getattr(delta, "reasoning_content", None)
|
|
728
|
+
)
|
|
729
|
+
if thinking_delta:
|
|
730
|
+
yield StreamEvent(type="thinking_delta", delta=thinking_delta)
|
|
731
|
+
|
|
732
|
+
# Native tool_calls streaming
|
|
733
|
+
tool_deltas = getattr(delta, "tool_calls", None)
|
|
734
|
+
if tool_deltas:
|
|
735
|
+
for td in tool_deltas:
|
|
736
|
+
idx = getattr(td, "index", 0) or 0
|
|
737
|
+
slot = partial_tools.setdefault(idx, {"id": None, "name": None, "args": ""})
|
|
738
|
+
if getattr(td, "id", None):
|
|
739
|
+
slot["id"] = td.id
|
|
740
|
+
fn = getattr(td, "function", None)
|
|
741
|
+
if fn:
|
|
742
|
+
if getattr(fn, "name", None):
|
|
743
|
+
slot["name"] = fn.name
|
|
744
|
+
if getattr(fn, "arguments", None):
|
|
745
|
+
fn_args = fn.arguments
|
|
746
|
+
# Ollama/Kimi sometimes deliver the
|
|
747
|
+
# arguments as a complete dict in one
|
|
748
|
+
# shot rather than streaming string
|
|
749
|
+
# deltas. Accept both shapes.
|
|
750
|
+
if isinstance(fn_args, dict):
|
|
751
|
+
slot["args"] = fn_args
|
|
752
|
+
else:
|
|
753
|
+
slot["args"] = (slot["args"] or "") + fn_args
|
|
754
|
+
|
|
755
|
+
if choice.finish_reason:
|
|
756
|
+
stop_reason = choice.finish_reason
|
|
757
|
+
except Exception as e:
|
|
758
|
+
logger.error(f"Stream failed: {e}")
|
|
759
|
+
raise
|
|
760
|
+
|
|
761
|
+
# Prompt-based tool parsing from buffered text
|
|
762
|
+
if use_prompt_tools and buffered_text:
|
|
763
|
+
parsed_calls, stripped = self._parse_prompt_tool_calls(buffered_text)
|
|
764
|
+
if stripped:
|
|
765
|
+
yield StreamEvent(type="text_delta", delta=stripped)
|
|
766
|
+
for call in parsed_calls:
|
|
767
|
+
yield StreamEvent(
|
|
768
|
+
type="tool_use_end",
|
|
769
|
+
tool_use_id=call.id,
|
|
770
|
+
tool_name=call.name,
|
|
771
|
+
tool_input=call.input,
|
|
772
|
+
)
|
|
773
|
+
|
|
774
|
+
# Emit native tool calls
|
|
775
|
+
for idx, slot in partial_tools.items():
|
|
776
|
+
raw_args = slot["args"]
|
|
777
|
+
if isinstance(raw_args, dict):
|
|
778
|
+
args = raw_args
|
|
779
|
+
elif raw_args:
|
|
780
|
+
try:
|
|
781
|
+
args = json.loads(raw_args)
|
|
782
|
+
except json.JSONDecodeError:
|
|
783
|
+
args = {"_raw": raw_args}
|
|
784
|
+
else:
|
|
785
|
+
args = {}
|
|
786
|
+
yield StreamEvent(
|
|
787
|
+
type="tool_use_end",
|
|
788
|
+
tool_use_id=slot["id"] or str(uuid.uuid4()),
|
|
789
|
+
tool_name=slot["name"] or "",
|
|
790
|
+
tool_input=args,
|
|
791
|
+
block_index=idx,
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
yield StreamEvent(type="message_stop", stop_reason=stop_reason)
|
|
795
|
+
|
|
796
|
+
# ------------------------------------------------------------------
|
|
797
|
+
# Structured output (Pydantic)
|
|
798
|
+
# ------------------------------------------------------------------
|
|
799
|
+
|
|
800
|
+
async def structured_output(
|
|
801
|
+
self,
|
|
802
|
+
prompt: str,
|
|
803
|
+
system: str | None = None,
|
|
804
|
+
schema_hint: str = "",
|
|
805
|
+
response_model: type[BaseModel] | None = None,
|
|
806
|
+
caller: str | None = None,
|
|
807
|
+
max_tokens: int | None = None,
|
|
808
|
+
) -> Any:
|
|
809
|
+
"""Get JSON-structured output from the LLM.
|
|
810
|
+
|
|
811
|
+
If ``response_model`` is provided, returns a validated instance.
|
|
812
|
+
Otherwise returns a plain dict (legacy behavior).
|
|
813
|
+
|
|
814
|
+
``max_tokens`` overrides the provider's default — long-form
|
|
815
|
+
structured generation (e.g. a 15-feature backlog) needs more
|
|
816
|
+
than the 4 k default or the response gets truncated mid-string
|
|
817
|
+
and the validator fails. Pass ``8192`` or higher for those.
|
|
818
|
+
"""
|
|
819
|
+
if response_model is not None:
|
|
820
|
+
schema = response_model.model_json_schema()
|
|
821
|
+
json_prompt = (
|
|
822
|
+
f"{prompt}\n\n"
|
|
823
|
+
f"Respond with valid JSON matching this schema:\n"
|
|
824
|
+
f"{json.dumps(schema, indent=2)}\n\n"
|
|
825
|
+
f"Respond ONLY with the JSON object, no markdown or explanation."
|
|
826
|
+
)
|
|
827
|
+
response = await self.complete(
|
|
828
|
+
prompt=json_prompt, system=system, max_tokens=max_tokens,
|
|
829
|
+
)
|
|
830
|
+
text = _strip_code_fence(response.content.strip())
|
|
831
|
+
try:
|
|
832
|
+
return response_model.model_validate_json(text)
|
|
833
|
+
except Exception as first_err:
|
|
834
|
+
# Salvage path. Tries three repairs in order:
|
|
835
|
+
# 1. Cheap salvage (prose framing, trailing commas,
|
|
836
|
+
# smart quotes) — handles most well-formed-but-noisy
|
|
837
|
+
# responses.
|
|
838
|
+
# 2. Truncation repair — if the response was cut off
|
|
839
|
+
# mid-string (max_tokens), close the open string,
|
|
840
|
+
# drop the partial element, close any open arrays
|
|
841
|
+
# and objects so the front of the structure parses.
|
|
842
|
+
# 3. Both combined.
|
|
843
|
+
# Each retry that yields a parseable validator hit short-
|
|
844
|
+
# circuits the chain; only if every repair fails do we
|
|
845
|
+
# propagate the original error.
|
|
846
|
+
attempts: list[tuple[str, str]] = []
|
|
847
|
+
cheap = _salvage_json(text)
|
|
848
|
+
if cheap:
|
|
849
|
+
attempts.append(("cheap", cheap))
|
|
850
|
+
truncated = _repair_truncated_json(text)
|
|
851
|
+
if truncated:
|
|
852
|
+
attempts.append(("truncated", truncated))
|
|
853
|
+
# also try cheap on top of truncation repair
|
|
854
|
+
cheap_on_trunc = _salvage_json(truncated)
|
|
855
|
+
if cheap_on_trunc and cheap_on_trunc != truncated:
|
|
856
|
+
attempts.append(("trunc+cheap", cheap_on_trunc))
|
|
857
|
+
|
|
858
|
+
last_err = first_err
|
|
859
|
+
for label, candidate in attempts:
|
|
860
|
+
try:
|
|
861
|
+
result = response_model.model_validate_json(candidate)
|
|
862
|
+
logger.info(
|
|
863
|
+
f"structured_output: salvaged via {label} "
|
|
864
|
+
f"({len(text)} → {len(candidate)} chars)"
|
|
865
|
+
)
|
|
866
|
+
return result
|
|
867
|
+
except Exception as e:
|
|
868
|
+
last_err = e
|
|
869
|
+
logger.debug(
|
|
870
|
+
f"structured_output: {label} salvage failed: {e}"
|
|
871
|
+
)
|
|
872
|
+
logger.warning(
|
|
873
|
+
f"structured_output: every salvage attempt failed. "
|
|
874
|
+
f"first_err={first_err}; last_err={last_err}; "
|
|
875
|
+
f"raw[:200]={text[:200]!r}; raw[-200:]={text[-200:]!r}"
|
|
876
|
+
)
|
|
877
|
+
raise last_err
|
|
878
|
+
|
|
879
|
+
# Legacy dict-based path
|
|
880
|
+
json_prompt = prompt
|
|
881
|
+
if schema_hint:
|
|
882
|
+
json_prompt += f"\n\nRespond with valid JSON matching this structure:\n{schema_hint}"
|
|
883
|
+
json_prompt += "\n\nRespond ONLY with valid JSON, no markdown or explanation."
|
|
884
|
+
|
|
885
|
+
response = await self.complete(prompt=json_prompt, system=system)
|
|
886
|
+
text = _strip_code_fence(response.content.strip())
|
|
887
|
+
return json.loads(text)
|
|
888
|
+
|
|
889
|
+
def switch_model(self, model: str) -> None:
|
|
890
|
+
"""Switch to a different model at runtime."""
|
|
891
|
+
logger.info(f"Switching model: {self.model} -> {model}")
|
|
892
|
+
self.model = model
|
|
893
|
+
|
|
894
|
+
# ------------------------------------------------------------------
|
|
895
|
+
# Fill-in-the-middle (FIM) — code completion for editor-style gap fills.
|
|
896
|
+
# ------------------------------------------------------------------
|
|
897
|
+
|
|
898
|
+
# FIM token templates per model family. Match against the (case-folded)
|
|
899
|
+
# model id; first hit wins. Add new families inline as needed.
|
|
900
|
+
_FIM_TEMPLATES: tuple[tuple[str, dict[str, Any]], ...] = (
|
|
901
|
+
("qwen", { # qwen2.5-coder, qwen3-coder, qwen3-coder-next
|
|
902
|
+
"prefix": "<|fim_prefix|>",
|
|
903
|
+
"suffix": "<|fim_suffix|>",
|
|
904
|
+
"middle": "<|fim_middle|>",
|
|
905
|
+
"stop": ["<|endoftext|>", "<|fim_pad|>", "<|im_end|>",
|
|
906
|
+
"<|repo_name|>", "<|file_sep|>"],
|
|
907
|
+
}),
|
|
908
|
+
("deepseek", {
|
|
909
|
+
"prefix": "<|fim▁begin|>",
|
|
910
|
+
"suffix": "<|fim▁hole|>",
|
|
911
|
+
"middle": "<|fim▁end|>",
|
|
912
|
+
"stop": ["<|end▁of▁sentence|>"],
|
|
913
|
+
}),
|
|
914
|
+
("codellama", {
|
|
915
|
+
"prefix": "<PRE> ",
|
|
916
|
+
"suffix": " <SUF>",
|
|
917
|
+
"middle": " <MID>",
|
|
918
|
+
"stop": ["<EOT>"],
|
|
919
|
+
}),
|
|
920
|
+
("starcoder", {
|
|
921
|
+
"prefix": "<fim_prefix>",
|
|
922
|
+
"suffix": "<fim_suffix>",
|
|
923
|
+
"middle": "<fim_middle>",
|
|
924
|
+
"stop": ["<|endoftext|>"],
|
|
925
|
+
}),
|
|
926
|
+
("codegemma", {
|
|
927
|
+
"prefix": "<|fim_prefix|>",
|
|
928
|
+
"suffix": "<|fim_suffix|>",
|
|
929
|
+
"middle": "<|fim_middle|>",
|
|
930
|
+
"stop": ["<|file_separator|>", "<|endoftext|>"],
|
|
931
|
+
}),
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
@classmethod
|
|
935
|
+
def fim_template_for(cls, model: str) -> dict[str, Any] | None:
|
|
936
|
+
"""Return the FIM token template for a model, or None if unknown."""
|
|
937
|
+
m = model.lower()
|
|
938
|
+
for key, tpl in cls._FIM_TEMPLATES:
|
|
939
|
+
if key in m:
|
|
940
|
+
return tpl
|
|
941
|
+
return None
|
|
942
|
+
|
|
943
|
+
async def fim_complete(
|
|
944
|
+
self,
|
|
945
|
+
prefix: str,
|
|
946
|
+
suffix: str = "",
|
|
947
|
+
model: str | None = None,
|
|
948
|
+
temperature: float | None = None,
|
|
949
|
+
max_tokens: int | None = None,
|
|
950
|
+
stop: list[str] | None = None,
|
|
951
|
+
ollama_host: str = "http://localhost:11434",
|
|
952
|
+
) -> str:
|
|
953
|
+
"""Fill-in-the-middle completion.
|
|
954
|
+
|
|
955
|
+
Generates the text that should appear between `prefix` and
|
|
956
|
+
`suffix`. Requires a FIM-trained code model on Ollama
|
|
957
|
+
(qwen-coder, deepseek-coder, codellama, starcoder, codegemma).
|
|
958
|
+
Anthropic/OpenAI chat models cannot do FIM natively — raises
|
|
959
|
+
ValueError if asked.
|
|
960
|
+
"""
|
|
961
|
+
use_model = model or self.model
|
|
962
|
+
bare = use_model.split("/", 1)[-1] if "/" in use_model else use_model
|
|
963
|
+
tpl = self.fim_template_for(bare)
|
|
964
|
+
if tpl is None:
|
|
965
|
+
raise ValueError(
|
|
966
|
+
f"Model '{use_model}' is not a known FIM-capable model. "
|
|
967
|
+
"Use a qwen-coder / deepseek-coder / codellama / "
|
|
968
|
+
"starcoder / codegemma variant on Ollama."
|
|
969
|
+
)
|
|
970
|
+
if not (use_model.startswith("ollama/")
|
|
971
|
+
or use_model.startswith("ollama_chat/")
|
|
972
|
+
or "/" not in use_model):
|
|
973
|
+
raise ValueError(
|
|
974
|
+
f"FIM only supported via Ollama backend (got '{use_model}')."
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
prompt = f"{tpl['prefix']}{prefix}{tpl['suffix']}{suffix}{tpl['middle']}"
|
|
978
|
+
body: dict[str, Any] = {
|
|
979
|
+
"model": bare,
|
|
980
|
+
"prompt": prompt,
|
|
981
|
+
"raw": True, # do not wrap in chat template
|
|
982
|
+
"stream": False,
|
|
983
|
+
"options": {
|
|
984
|
+
"temperature": temperature if temperature is not None else 0.1,
|
|
985
|
+
"num_predict": max_tokens if max_tokens is not None else 128,
|
|
986
|
+
"stop": list(stop) if stop else list(tpl["stop"]),
|
|
987
|
+
},
|
|
988
|
+
}
|
|
989
|
+
|
|
990
|
+
import httpx
|
|
991
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0, connect=10.0)) as client:
|
|
992
|
+
resp = await client.post(f"{ollama_host}/api/generate", json=body)
|
|
993
|
+
resp.raise_for_status()
|
|
994
|
+
data = resp.json()
|
|
995
|
+
completion = data.get("response", "")
|
|
996
|
+
# Trim any leaked stop tokens (Ollama usually strips them, but be safe).
|
|
997
|
+
for s in body["options"]["stop"]:
|
|
998
|
+
idx = completion.find(s)
|
|
999
|
+
if idx >= 0:
|
|
1000
|
+
completion = completion[:idx]
|
|
1001
|
+
break
|
|
1002
|
+
return completion
|
|
1003
|
+
|
|
1004
|
+
# ------------------------------------------------------------------
|
|
1005
|
+
# Prompt-based tool-calling fallback (for models w/o native support)
|
|
1006
|
+
# ------------------------------------------------------------------
|
|
1007
|
+
|
|
1008
|
+
# Tool calls in raw text show up under several tag names depending
|
|
1009
|
+
# on which model you talk to:
|
|
1010
|
+
# <tool_call>...</tool_call> — Cognos's prompt-fallback
|
|
1011
|
+
# <function_call>...</function_call> — GLM-5.1 with tools=
|
|
1012
|
+
# <function>...</function> — older Mistral / Llama
|
|
1013
|
+
# <action>...</action> — some experimental models
|
|
1014
|
+
# We accept any of these for robustness.
|
|
1015
|
+
_TOOL_CALL_RE = re.compile(
|
|
1016
|
+
r"<(?:tool_call|function_call|function|action)>\s*(\{.*?\})\s*"
|
|
1017
|
+
r"</(?:tool_call|function_call|function|action)>",
|
|
1018
|
+
re.DOTALL,
|
|
1019
|
+
)
|
|
1020
|
+
_JSON_FENCE_RE = re.compile(
|
|
1021
|
+
r"```(?:json)?\s*(\{[^`]*?\})\s*```",
|
|
1022
|
+
re.DOTALL,
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
def _inject_tool_prompt(
|
|
1026
|
+
self,
|
|
1027
|
+
messages: list[dict[str, Any]],
|
|
1028
|
+
tools: list[dict],
|
|
1029
|
+
) -> list[dict[str, Any]]:
|
|
1030
|
+
"""Inject a system prompt describing available tools for models
|
|
1031
|
+
without native tool-calling support."""
|
|
1032
|
+
tool_descriptions = []
|
|
1033
|
+
for t in tools:
|
|
1034
|
+
fn = t.get("function", {})
|
|
1035
|
+
tool_descriptions.append(
|
|
1036
|
+
f"- {fn.get('name')}: {fn.get('description')}\n"
|
|
1037
|
+
f" input_schema: {json.dumps(fn.get('parameters', {}))}"
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
instruction = (
|
|
1041
|
+
"You have access to these tools:\n"
|
|
1042
|
+
+ "\n".join(tool_descriptions)
|
|
1043
|
+
+ "\n\nTo call a tool, emit a block like this EXACTLY:\n"
|
|
1044
|
+
'<tool_call>{"name": "ToolName", "input": {"arg": "value"}}</tool_call>\n'
|
|
1045
|
+
"You may emit multiple tool_call blocks. Any text outside tool_call "
|
|
1046
|
+
"blocks is shown to the user as your response. If no tools are needed, "
|
|
1047
|
+
"just respond normally."
|
|
1048
|
+
)
|
|
1049
|
+
|
|
1050
|
+
# Prepend to existing system message, or insert a new one
|
|
1051
|
+
if messages and messages[0].get("role") == "system":
|
|
1052
|
+
messages = list(messages)
|
|
1053
|
+
messages[0] = {
|
|
1054
|
+
**messages[0],
|
|
1055
|
+
"content": f"{messages[0].get('content', '')}\n\n{instruction}",
|
|
1056
|
+
}
|
|
1057
|
+
else:
|
|
1058
|
+
messages = [{"role": "system", "content": instruction}, *messages]
|
|
1059
|
+
return messages
|
|
1060
|
+
|
|
1061
|
+
def _parse_prompt_tool_calls(
|
|
1062
|
+
self,
|
|
1063
|
+
text: str,
|
|
1064
|
+
) -> tuple[list[ToolUseBlock], str]:
|
|
1065
|
+
"""Parse tool calls from raw text. Handles two formats:
|
|
1066
|
+
|
|
1067
|
+
- Explicit `<tool_call>{...}</tool_call>` (the format we ask
|
|
1068
|
+
the model to use).
|
|
1069
|
+
- Bare `\`\`\`json {...}\`\`\`` code blocks (the format some
|
|
1070
|
+
models — GLM especially — naturally emit).
|
|
1071
|
+
|
|
1072
|
+
Returns (tool_calls, text_with_blocks_removed).
|
|
1073
|
+
"""
|
|
1074
|
+
calls: list[ToolUseBlock] = []
|
|
1075
|
+
stripped = text
|
|
1076
|
+
|
|
1077
|
+
# 1. Explicit <tool_call> blocks (preferred format).
|
|
1078
|
+
for match in self._TOOL_CALL_RE.finditer(text):
|
|
1079
|
+
try:
|
|
1080
|
+
data = json.loads(match.group(1))
|
|
1081
|
+
calls.append(ToolUseBlock(
|
|
1082
|
+
name=data.get("name", ""),
|
|
1083
|
+
input=data.get("input", {}),
|
|
1084
|
+
))
|
|
1085
|
+
except (json.JSONDecodeError, KeyError) as e:
|
|
1086
|
+
logger.warning(f"Failed to parse tool_call block: {e}")
|
|
1087
|
+
stripped = self._TOOL_CALL_RE.sub("", stripped)
|
|
1088
|
+
|
|
1089
|
+
# 2. ```json {...}``` blocks — only consume them if they
|
|
1090
|
+
# actually look like a tool call (have a 'name' key plus
|
|
1091
|
+
# 'arguments' or 'input').
|
|
1092
|
+
for match in self._JSON_FENCE_RE.finditer(stripped):
|
|
1093
|
+
try:
|
|
1094
|
+
data = json.loads(match.group(1))
|
|
1095
|
+
except json.JSONDecodeError:
|
|
1096
|
+
continue
|
|
1097
|
+
name = data.get("name")
|
|
1098
|
+
if not name:
|
|
1099
|
+
continue
|
|
1100
|
+
args = data.get("input")
|
|
1101
|
+
if args is None:
|
|
1102
|
+
args = data.get("arguments", {})
|
|
1103
|
+
calls.append(ToolUseBlock(name=name, input=args or {}))
|
|
1104
|
+
# Only remove fenced blocks that produced calls — leave normal
|
|
1105
|
+
# code-block content alone. Re-iterate and replace just those.
|
|
1106
|
+
if calls:
|
|
1107
|
+
def _rep(m: re.Match) -> str:
|
|
1108
|
+
try:
|
|
1109
|
+
d = json.loads(m.group(1))
|
|
1110
|
+
if isinstance(d, dict) and d.get("name"):
|
|
1111
|
+
return ""
|
|
1112
|
+
except Exception:
|
|
1113
|
+
pass
|
|
1114
|
+
return m.group(0)
|
|
1115
|
+
stripped = self._JSON_FENCE_RE.sub(_rep, stripped)
|
|
1116
|
+
|
|
1117
|
+
return calls, stripped.strip()
|
|
1118
|
+
|
|
1119
|
+
|
|
1120
|
+
_TEMPLATE_LEAK_RE = re.compile(
|
|
1121
|
+
# Two flavors of leak observed from gemma4 / qwen / llama chat templates:
|
|
1122
|
+
# 1. Generic `<|name|>` special tokens — `<|tool_call|>`, `<|channel|>`,
|
|
1123
|
+
# `<|im_start|>`, `<|/tool_call|>`, etc.
|
|
1124
|
+
# 2. Bare tag names that occasionally appear with or without closing —
|
|
1125
|
+
# `<tool_call>`, `</tool_call>`, `<tool_call|>`, `<thought`,
|
|
1126
|
+
# `<channel|>`, `<thinking>`, `<action>`. The `\b` after the tag
|
|
1127
|
+
# name avoids eating real words like "<thoughts" or "<channels".
|
|
1128
|
+
r"<\|[a-z_/]+\|>"
|
|
1129
|
+
r"|"
|
|
1130
|
+
r"</?"
|
|
1131
|
+
r"(?:tool_call|function_call|function|action|thought|thinking|channel|"
|
|
1132
|
+
r"im_start|im_end|user_token|assistant_token|system_token)"
|
|
1133
|
+
r"\b\|?>?",
|
|
1134
|
+
re.IGNORECASE,
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
def _strip_template_leaks(text: str) -> str:
|
|
1139
|
+
"""Remove tokenizer-template artifacts that leak into model output."""
|
|
1140
|
+
if not text:
|
|
1141
|
+
return text
|
|
1142
|
+
return _TEMPLATE_LEAK_RE.sub("", text)
|
|
1143
|
+
|
|
1144
|
+
|
|
1145
|
+
def _strip_code_fence(text: str) -> str:
|
|
1146
|
+
"""Strip ```json ... ``` fences if present."""
|
|
1147
|
+
if text.startswith("```"):
|
|
1148
|
+
lines = text.split("\n")
|
|
1149
|
+
text = "\n".join(lines[1:-1]) if lines[-1].strip() == "```" else "\n".join(lines[1:])
|
|
1150
|
+
return text.strip()
|
|
1151
|
+
|
|
1152
|
+
|
|
1153
|
+
def _salvage_json(text: str) -> str | None:
|
|
1154
|
+
"""Best-effort extraction of a JSON object from a noisy LLM response.
|
|
1155
|
+
|
|
1156
|
+
Handles three common failure modes:
|
|
1157
|
+
1. Prose framing — find the outermost {...} pair and slice it out.
|
|
1158
|
+
2. Trailing commas before } or ] (some LLMs love these).
|
|
1159
|
+
3. Smart quotes / curly apostrophes that break json.
|
|
1160
|
+
|
|
1161
|
+
Returns the cleaned string, or None if no balanced object was found.
|
|
1162
|
+
Used by `structured_output` as a salvage path on first parse failure.
|
|
1163
|
+
"""
|
|
1164
|
+
if not text:
|
|
1165
|
+
return None
|
|
1166
|
+
|
|
1167
|
+
# 1. Slice to the outermost {...}
|
|
1168
|
+
start = text.find("{")
|
|
1169
|
+
end = text.rfind("}")
|
|
1170
|
+
if start == -1 or end == -1 or end < start:
|
|
1171
|
+
return None
|
|
1172
|
+
candidate = text[start : end + 1]
|
|
1173
|
+
|
|
1174
|
+
# 2. Strip trailing commas: ",}" -> "}", ",]" -> "]"
|
|
1175
|
+
import re
|
|
1176
|
+
candidate = re.sub(r",(\s*[}\]])", r"\1", candidate)
|
|
1177
|
+
|
|
1178
|
+
# 3. Replace smart quotes
|
|
1179
|
+
candidate = (candidate
|
|
1180
|
+
.replace("“", '"').replace("”", '"')
|
|
1181
|
+
.replace("‘", "'").replace("’", "'"))
|
|
1182
|
+
|
|
1183
|
+
return candidate
|
|
1184
|
+
|
|
1185
|
+
|
|
1186
|
+
def _repair_truncated_json(text: str) -> str | None:
|
|
1187
|
+
"""Repair a response that was cut off mid-generation.
|
|
1188
|
+
|
|
1189
|
+
Common when ``max_tokens`` runs out partway through a long array of
|
|
1190
|
+
objects (e.g. a 15-feature backlog). The tail of ``text`` is:
|
|
1191
|
+
|
|
1192
|
+
- an unterminated string (parser reports "EOF while parsing a
|
|
1193
|
+
string"), OR
|
|
1194
|
+
- a key with no value (``"foo": ``), OR
|
|
1195
|
+
- a half-written object (``{"title": "X", "descr``)
|
|
1196
|
+
|
|
1197
|
+
Strategy: walk ``text`` with a stack-based state machine that tracks
|
|
1198
|
+
string-vs-not and brace/bracket depth. Remember the **last index
|
|
1199
|
+
where every nesting level was complete and the next position was
|
|
1200
|
+
inside a top-level array or object**. Truncate to that index, drop
|
|
1201
|
+
any trailing partial element, close the open structures with the
|
|
1202
|
+
right sequence of ``]`` and ``}``.
|
|
1203
|
+
|
|
1204
|
+
Returns the repaired JSON string, or None if no recovery point
|
|
1205
|
+
existed (e.g. the truncation hit before the first balanced element).
|
|
1206
|
+
"""
|
|
1207
|
+
if not text:
|
|
1208
|
+
return None
|
|
1209
|
+
|
|
1210
|
+
# Find the outermost {
|
|
1211
|
+
obj_start = text.find("{")
|
|
1212
|
+
if obj_start == -1:
|
|
1213
|
+
return None
|
|
1214
|
+
body = text[obj_start:]
|
|
1215
|
+
|
|
1216
|
+
# Walk the body tracking state. `stack` holds the OPEN container
|
|
1217
|
+
# chars in order (e.g. ['{', '"key"', '[', '{', '...']). When we
|
|
1218
|
+
# close one, we pop. We also record the index of the last comma
|
|
1219
|
+
# encountered at each depth — that's our safe re-truncation point
|
|
1220
|
+
# if the trailing element is incomplete.
|
|
1221
|
+
in_string = False
|
|
1222
|
+
escape = False
|
|
1223
|
+
stack: list[str] = [] # holds '{' and '[' chars
|
|
1224
|
+
# last_complete_idx[d] = last index inside container at depth d
|
|
1225
|
+
# where the structure was clean (just after a balanced child + ',')
|
|
1226
|
+
last_clean_at_depth: dict[int, int] = {}
|
|
1227
|
+
|
|
1228
|
+
for i, ch in enumerate(body):
|
|
1229
|
+
if in_string:
|
|
1230
|
+
if escape:
|
|
1231
|
+
escape = False
|
|
1232
|
+
elif ch == "\\":
|
|
1233
|
+
escape = True
|
|
1234
|
+
elif ch == '"':
|
|
1235
|
+
in_string = False
|
|
1236
|
+
continue
|
|
1237
|
+
if ch == '"':
|
|
1238
|
+
in_string = True
|
|
1239
|
+
continue
|
|
1240
|
+
if ch == "{" or ch == "[":
|
|
1241
|
+
stack.append(ch)
|
|
1242
|
+
continue
|
|
1243
|
+
if ch == "}" or ch == "]":
|
|
1244
|
+
if not stack:
|
|
1245
|
+
# garbage tail before the body started — give up
|
|
1246
|
+
return None
|
|
1247
|
+
stack.pop()
|
|
1248
|
+
# After popping, the parent's depth gained a complete child.
|
|
1249
|
+
# Record the position so we can return here if needed.
|
|
1250
|
+
last_clean_at_depth[len(stack)] = i + 1
|
|
1251
|
+
continue
|
|
1252
|
+
if ch == "," and not stack:
|
|
1253
|
+
# Comma outside any container — malformed; bail.
|
|
1254
|
+
return None
|
|
1255
|
+
if ch == "," and stack:
|
|
1256
|
+
last_clean_at_depth[len(stack) - 1] = i + 1
|
|
1257
|
+
continue
|
|
1258
|
+
|
|
1259
|
+
if not stack:
|
|
1260
|
+
# Walk succeeded without truncation — nothing to repair here.
|
|
1261
|
+
return None
|
|
1262
|
+
|
|
1263
|
+
# Truncate to the deepest clean point we know about, then close the
|
|
1264
|
+
# remaining open containers in reverse order.
|
|
1265
|
+
# Pick the **shallowest** clean point that was after the outermost
|
|
1266
|
+
# `{`, because closing back to that level discards the least content
|
|
1267
|
+
# but keeps everything above it intact.
|
|
1268
|
+
if not last_clean_at_depth:
|
|
1269
|
+
# No clean child landed before the truncation — nothing to keep.
|
|
1270
|
+
return None
|
|
1271
|
+
# Use the deepest depth's clean index (the most recent valid boundary)
|
|
1272
|
+
deepest = max(last_clean_at_depth.keys())
|
|
1273
|
+
cut = last_clean_at_depth[deepest]
|
|
1274
|
+
head = body[:cut]
|
|
1275
|
+
|
|
1276
|
+
# Trim a trailing comma before closing so we don't emit ",]" or ",}"
|
|
1277
|
+
head = head.rstrip()
|
|
1278
|
+
if head.endswith(","):
|
|
1279
|
+
head = head[:-1]
|
|
1280
|
+
|
|
1281
|
+
# Compute the closer string from the surviving stack. The stack
|
|
1282
|
+
# holds the OPEN chars up to `cut` — but `cut` was placed AFTER
|
|
1283
|
+
# closing a child, so the remaining stack at that point is
|
|
1284
|
+
# `stack[: len(stack) - (original_depth - deepest)]`. Easier: walk
|
|
1285
|
+
# the surviving head again to compute the remaining stack.
|
|
1286
|
+
survivor: list[str] = []
|
|
1287
|
+
in_s = False
|
|
1288
|
+
esc = False
|
|
1289
|
+
for ch in head:
|
|
1290
|
+
if in_s:
|
|
1291
|
+
if esc: esc = False
|
|
1292
|
+
elif ch == "\\": esc = True
|
|
1293
|
+
elif ch == '"': in_s = False
|
|
1294
|
+
continue
|
|
1295
|
+
if ch == '"': in_s = True
|
|
1296
|
+
elif ch in "{[": survivor.append(ch)
|
|
1297
|
+
elif ch in "}]":
|
|
1298
|
+
if survivor:
|
|
1299
|
+
survivor.pop()
|
|
1300
|
+
closers = "".join("}" if c == "{" else "]" for c in reversed(survivor))
|
|
1301
|
+
|
|
1302
|
+
repaired = text[:obj_start] + head + closers
|
|
1303
|
+
# Cheap salvage on top (trailing commas inside repaired text).
|
|
1304
|
+
import re
|
|
1305
|
+
repaired = re.sub(r",(\s*[}\]])", r"\1", repaired)
|
|
1306
|
+
return repaired
|
|
1307
|
+
|
|
1308
|
+
|
|
1309
|
+
# ---------------------------------------------------------------------------
|
|
1310
|
+
# Module-level FIM convenience function
|
|
1311
|
+
#
|
|
1312
|
+
# Lets any caller — Caudate's dispatch hook, an editor route, the Edit
|
|
1313
|
+
# tool, an external script — invoke FIM with one import and no provider
|
|
1314
|
+
# instance. Spawns a transient LLMProvider per call (cheap: just config,
|
|
1315
|
+
# the httpx client is per-call inside `fim_complete`).
|
|
1316
|
+
#
|
|
1317
|
+
# Default model is qwen2.5-coder:1.5b because it's the smallest FIM
|
|
1318
|
+
# model on the local Ollama and gives ~50-100ms latency suitable for
|
|
1319
|
+
# in-editor autocomplete. Override with `model=` for heavier gap-fills
|
|
1320
|
+
# (qwen3-coder-next, deepseek-coder, etc.).
|
|
1321
|
+
# ---------------------------------------------------------------------------
|
|
1322
|
+
|
|
1323
|
+
|
|
1324
|
+
DEFAULT_FIM_MODEL = "ollama/qwen2.5-coder:1.5b"
|
|
1325
|
+
|
|
1326
|
+
|
|
1327
|
+
async def fim_complete(
|
|
1328
|
+
prefix: str,
|
|
1329
|
+
suffix: str = "",
|
|
1330
|
+
*,
|
|
1331
|
+
model: str = DEFAULT_FIM_MODEL,
|
|
1332
|
+
temperature: float | None = None,
|
|
1333
|
+
max_tokens: int | None = None,
|
|
1334
|
+
stop: list[str] | None = None,
|
|
1335
|
+
ollama_host: str = "http://localhost:11434",
|
|
1336
|
+
) -> str:
|
|
1337
|
+
"""Fill-in-the-middle: generate the text between `prefix` and `suffix`.
|
|
1338
|
+
|
|
1339
|
+
Module-level entry point — does not require an existing LLMProvider.
|
|
1340
|
+
Routes through LLMProvider.fim_complete(); see that method for full
|
|
1341
|
+
semantics, supported model families, and error cases.
|
|
1342
|
+
|
|
1343
|
+
Caudate can call this directly via `from llm.provider import
|
|
1344
|
+
fim_complete`. It's a side-channel: it does not go through the
|
|
1345
|
+
System-1/System-2 chat router and does not update tracker state.
|
|
1346
|
+
"""
|
|
1347
|
+
provider = LLMProvider(model=model)
|
|
1348
|
+
return await provider.fim_complete(
|
|
1349
|
+
prefix=prefix,
|
|
1350
|
+
suffix=suffix,
|
|
1351
|
+
model=model,
|
|
1352
|
+
temperature=temperature,
|
|
1353
|
+
max_tokens=max_tokens,
|
|
1354
|
+
stop=stop,
|
|
1355
|
+
ollama_host=ollama_host,
|
|
1356
|
+
)
|