krnl-code 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- krnl_agent/__init__.py +9 -0
- krnl_agent/__main__.py +7 -0
- krnl_agent/agent_registry.py +95 -0
- krnl_agent/agent_selector.py +69 -0
- krnl_agent/audit_log.py +155 -0
- krnl_agent/background.py +94 -0
- krnl_agent/checkpoints.py +67 -0
- krnl_agent/ci.py +73 -0
- krnl_agent/cli.py +1458 -0
- krnl_agent/commands.py +42 -0
- krnl_agent/config.py +425 -0
- krnl_agent/context.py +352 -0
- krnl_agent/depaudit.py +63 -0
- krnl_agent/deploy.py +245 -0
- krnl_agent/doctor.py +106 -0
- krnl_agent/events.py +141 -0
- krnl_agent/gitignore.py +47 -0
- krnl_agent/graph.py +928 -0
- krnl_agent/guardrails.py +70 -0
- krnl_agent/headless.py +60 -0
- krnl_agent/history.py +49 -0
- krnl_agent/hooks.py +72 -0
- krnl_agent/ingest.py +129 -0
- krnl_agent/llm.py +456 -0
- krnl_agent/loop.py +779 -0
- krnl_agent/mcp_client.py +128 -0
- krnl_agent/memory.py +61 -0
- krnl_agent/modelrouter.py +151 -0
- krnl_agent/monitor.py +112 -0
- krnl_agent/notify.py +119 -0
- krnl_agent/parallel_executor.py +139 -0
- krnl_agent/permissions.py +128 -0
- krnl_agent/plugins.py +105 -0
- krnl_agent/pricing.py +85 -0
- krnl_agent/prompts.py +60 -0
- krnl_agent/repomap.py +133 -0
- krnl_agent/sandbox.py +69 -0
- krnl_agent/scaffold.py +167 -0
- krnl_agent/schedules.py +137 -0
- krnl_agent/secrets.py +100 -0
- krnl_agent/selfheal.py +87 -0
- krnl_agent/server.py +302 -0
- krnl_agent/sessions.py +258 -0
- krnl_agent/settings.py +59 -0
- krnl_agent/skills.py +73 -0
- krnl_agent/teams.py +38 -0
- krnl_agent/tool_schemas.py +431 -0
- krnl_agent/tools.py +694 -0
- krnl_agent/webtools.py +139 -0
- krnl_code-1.0.4.dist-info/METADATA +214 -0
- krnl_code-1.0.4.dist-info/RECORD +56 -0
- krnl_code-1.0.4.dist-info/WHEEL +5 -0
- krnl_code-1.0.4.dist-info/entry_points.txt +2 -0
- krnl_code-1.0.4.dist-info/licenses/LICENSE +147 -0
- krnl_code-1.0.4.dist-info/licenses/NOTICE +4 -0
- krnl_code-1.0.4.dist-info/top_level.txt +1 -0
krnl_agent/llm.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
"""Provider-agnostic LLM client.
|
|
2
|
+
|
|
3
|
+
The agent loop keeps its conversation in **OpenAI message format** regardless of
|
|
4
|
+
provider. Each client converts that to whatever the underlying API needs and
|
|
5
|
+
returns a normalized `LLMResponse` (assistant text + tool calls + token usage).
|
|
6
|
+
|
|
7
|
+
Adds resilience (retry with backoff on transient errors) and best-effort token
|
|
8
|
+
usage accounting.
|
|
9
|
+
|
|
10
|
+
Two client types:
|
|
11
|
+
* OpenAICompatClient — any OpenAI Chat Completions API (Krnl, OpenAI,
|
|
12
|
+
OpenRouter, Ollama, vLLM, LM Studio, …).
|
|
13
|
+
* AnthropicClient — native Anthropic Messages API.
|
|
14
|
+
"""
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import time
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
from typing import Any, Callable, Optional
|
|
21
|
+
|
|
22
|
+
from .config import AgentConfig, ProviderConfig
|
|
23
|
+
|
|
24
|
+
StreamCallback = Optional[Callable[..., None]]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class ToolCall:
|
|
29
|
+
id: str
|
|
30
|
+
name: str
|
|
31
|
+
arguments: dict
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class LLMResponse:
|
|
36
|
+
content: str = ""
|
|
37
|
+
tool_calls: list[ToolCall] = field(default_factory=list)
|
|
38
|
+
prompt_tokens: int = 0
|
|
39
|
+
completion_tokens: int = 0
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class LLMError(RuntimeError):
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
import re
|
|
47
|
+
|
|
48
|
+
_REASONING_MODEL = re.compile(r"(^|[-/_])(o1|o3|o4|gpt-5|reason)", re.IGNORECASE)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _make_client(provider: ProviderConfig, agent: Optional[AgentConfig] = None):
|
|
52
|
+
attempts = agent.retry_attempts if agent else 3
|
|
53
|
+
backoff = agent.retry_backoff if agent else 1.5
|
|
54
|
+
thinking = agent.thinking if agent else False
|
|
55
|
+
effort = agent.reasoning_effort if agent else None
|
|
56
|
+
if provider.type == "anthropic":
|
|
57
|
+
return AnthropicClient(provider, attempts, backoff, thinking=thinking)
|
|
58
|
+
return OpenAICompatClient(provider, attempts, backoff, reasoning_effort=effort)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def build_client(provider: ProviderConfig, agent: Optional[AgentConfig] = None):
|
|
62
|
+
primary = _make_client(provider, agent)
|
|
63
|
+
fallbacks = (agent.fallback_models if agent else []) or []
|
|
64
|
+
if not fallbacks:
|
|
65
|
+
return primary
|
|
66
|
+
from dataclasses import replace
|
|
67
|
+
|
|
68
|
+
clients = [primary] + [_make_client(replace(provider, model=m), agent) for m in fallbacks]
|
|
69
|
+
return FallbackClient(clients)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class FallbackClient:
|
|
73
|
+
"""Tries each underlying client in order; falls through on LLMError."""
|
|
74
|
+
|
|
75
|
+
def __init__(self, clients: list):
|
|
76
|
+
self.clients = clients
|
|
77
|
+
self.provider = clients[0].provider
|
|
78
|
+
|
|
79
|
+
def chat(self, messages, tools=None, stream_cb=None, stream=True) -> "LLMResponse":
|
|
80
|
+
last: Optional[Exception] = None
|
|
81
|
+
for client in self.clients:
|
|
82
|
+
try:
|
|
83
|
+
return client.chat(messages, tools, stream_cb, stream)
|
|
84
|
+
except LLMError as e:
|
|
85
|
+
last = e
|
|
86
|
+
continue
|
|
87
|
+
raise last if last else LLMError("no clients available")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _safe_json_loads(s: str) -> dict:
|
|
91
|
+
if not s:
|
|
92
|
+
return {}
|
|
93
|
+
s_clean = s.strip()
|
|
94
|
+
if s_clean.startswith("```json"):
|
|
95
|
+
s_clean = s_clean[7:]
|
|
96
|
+
elif s_clean.startswith("```"):
|
|
97
|
+
s_clean = s_clean[3:]
|
|
98
|
+
if s_clean.endswith("```"):
|
|
99
|
+
s_clean = s_clean[:-3]
|
|
100
|
+
s_clean = s_clean.strip()
|
|
101
|
+
try:
|
|
102
|
+
return json.loads(s_clean)
|
|
103
|
+
except Exception:
|
|
104
|
+
try:
|
|
105
|
+
return json.loads(s_clean[: s_clean.rfind("}") + 1])
|
|
106
|
+
except Exception:
|
|
107
|
+
return {"__raw__": s}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _retryable(e: Exception) -> bool:
|
|
112
|
+
"""Retry connection/timeout/5xx/429; never retry auth/bad-request (4xx)."""
|
|
113
|
+
code = getattr(e, "status_code", None)
|
|
114
|
+
if code is None:
|
|
115
|
+
code = getattr(getattr(e, "response", None), "status_code", None)
|
|
116
|
+
if code in (400, 401, 403, 404, 422):
|
|
117
|
+
return False
|
|
118
|
+
return True
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _detect_param_fix(errmsg: str, kwargs: dict) -> Optional[str]:
|
|
122
|
+
"""Inspect a 400 'unsupported parameter' error and decide how to adapt the
|
|
123
|
+
request so it works with newer models (gpt-5.x, o-series, etc.) - generically,
|
|
124
|
+
with no per-model configuration."""
|
|
125
|
+
m = errmsg.lower()
|
|
126
|
+
if "max_completion_tokens" in m and "max_tokens" in m and "max_tokens" in kwargs:
|
|
127
|
+
return "max_completion_tokens"
|
|
128
|
+
if "temperature" in m and "max_tokens" not in m and any(
|
|
129
|
+
s in m for s in ("unsupported", "not supported", "does not support",
|
|
130
|
+
"only the default", "only supports", "is not supported")
|
|
131
|
+
):
|
|
132
|
+
return "drop_temperature"
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _estimate(text: str) -> int:
|
|
137
|
+
return max(0, len(text) // 4)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class _RetryMixin:
|
|
141
|
+
retry_attempts: int
|
|
142
|
+
retry_backoff: float
|
|
143
|
+
|
|
144
|
+
def _with_retry(self, fn, stream_cb: StreamCallback):
|
|
145
|
+
last: Optional[Exception] = None
|
|
146
|
+
for attempt in range(1, self.retry_attempts + 1):
|
|
147
|
+
streamed = {"n": 0}
|
|
148
|
+
|
|
149
|
+
def cb(t: str):
|
|
150
|
+
streamed["n"] += 1
|
|
151
|
+
if stream_cb:
|
|
152
|
+
stream_cb(t)
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
return fn(cb if stream_cb else None)
|
|
156
|
+
except Exception as e: # noqa: BLE001
|
|
157
|
+
last = e
|
|
158
|
+
# can't safely retry once tokens have been emitted to the user
|
|
159
|
+
if (
|
|
160
|
+
attempt >= self.retry_attempts
|
|
161
|
+
or streamed["n"] > 0
|
|
162
|
+
or not _retryable(e)
|
|
163
|
+
):
|
|
164
|
+
break
|
|
165
|
+
time.sleep(self.retry_backoff ** attempt)
|
|
166
|
+
raise LLMError(f"{type(last).__name__}: {last}")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# --------------------------------------------------------------------------- #
|
|
170
|
+
# OpenAI-compatible
|
|
171
|
+
# --------------------------------------------------------------------------- #
|
|
172
|
+
class OpenAICompatClient(_RetryMixin):
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
provider: ProviderConfig,
|
|
176
|
+
attempts: int = 3,
|
|
177
|
+
backoff: float = 1.5,
|
|
178
|
+
reasoning_effort: Optional[str] = None,
|
|
179
|
+
):
|
|
180
|
+
from openai import OpenAI
|
|
181
|
+
|
|
182
|
+
self.provider = provider
|
|
183
|
+
self.retry_attempts = attempts
|
|
184
|
+
self.retry_backoff = backoff
|
|
185
|
+
self.reasoning_effort = reasoning_effort
|
|
186
|
+
self._compat: set = set() # learned param fixes (e.g. max_completion_tokens)
|
|
187
|
+
self.client = OpenAI(
|
|
188
|
+
api_key=provider.api_key or "sk-no-key-required",
|
|
189
|
+
base_url=provider.base_url,
|
|
190
|
+
default_headers=provider.extra_headers or None,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def chat(self, messages, tools=None, stream_cb=None, stream=True) -> LLMResponse:
|
|
194
|
+
kwargs: dict[str, Any] = {
|
|
195
|
+
"model": self.provider.model,
|
|
196
|
+
"messages": messages,
|
|
197
|
+
"temperature": self.provider.temperature,
|
|
198
|
+
"max_tokens": self.provider.max_tokens,
|
|
199
|
+
}
|
|
200
|
+
# Reasoning models take reasoning_effort and reject custom temperature.
|
|
201
|
+
if self.reasoning_effort and _REASONING_MODEL.search(self.provider.model):
|
|
202
|
+
kwargs["reasoning_effort"] = self.reasoning_effort
|
|
203
|
+
kwargs.pop("temperature", None)
|
|
204
|
+
if tools:
|
|
205
|
+
kwargs["tools"] = tools
|
|
206
|
+
kwargs["tool_choice"] = "auto"
|
|
207
|
+
|
|
208
|
+
def run(cb):
|
|
209
|
+
return self._chat_stream(kwargs, cb) if stream else self._chat_once(kwargs)
|
|
210
|
+
|
|
211
|
+
resp = self._with_retry(run, stream_cb)
|
|
212
|
+
if not resp.prompt_tokens: # estimate when the server didn't report usage
|
|
213
|
+
resp.prompt_tokens = sum(
|
|
214
|
+
_estimate(str(m.get("content") or "")) for m in messages
|
|
215
|
+
)
|
|
216
|
+
if not resp.completion_tokens:
|
|
217
|
+
resp.completion_tokens = _estimate(resp.content)
|
|
218
|
+
return resp
|
|
219
|
+
|
|
220
|
+
def _apply_compat(self, kwargs: dict) -> dict:
|
|
221
|
+
k = dict(kwargs)
|
|
222
|
+
if "max_completion_tokens" in self._compat and "max_tokens" in k:
|
|
223
|
+
k["max_completion_tokens"] = k.pop("max_tokens")
|
|
224
|
+
if "drop_temperature" in self._compat:
|
|
225
|
+
k.pop("temperature", None)
|
|
226
|
+
return k
|
|
227
|
+
|
|
228
|
+
def _completions_create(self, kwargs: dict):
|
|
229
|
+
"""Call the API, auto-adapting unsupported parameters (and remembering the
|
|
230
|
+
fix so later calls in this session skip the failed attempt)."""
|
|
231
|
+
last: Optional[Exception] = None
|
|
232
|
+
for _ in range(4):
|
|
233
|
+
attempt = self._apply_compat(kwargs)
|
|
234
|
+
try:
|
|
235
|
+
return self.client.chat.completions.create(**attempt)
|
|
236
|
+
except Exception as e: # noqa: BLE001
|
|
237
|
+
fix = _detect_param_fix(str(e), attempt)
|
|
238
|
+
if not fix or fix in self._compat:
|
|
239
|
+
raise
|
|
240
|
+
self._compat.add(fix)
|
|
241
|
+
last = e
|
|
242
|
+
if last:
|
|
243
|
+
raise last
|
|
244
|
+
|
|
245
|
+
def _chat_once(self, kwargs: dict) -> LLMResponse:
|
|
246
|
+
resp = self._completions_create(kwargs)
|
|
247
|
+
msg = resp.choices[0].message
|
|
248
|
+
calls = [
|
|
249
|
+
ToolCall(tc.id, tc.function.name, _safe_json_loads(tc.function.arguments or "{}"))
|
|
250
|
+
for tc in (msg.tool_calls or [])
|
|
251
|
+
]
|
|
252
|
+
u = getattr(resp, "usage", None)
|
|
253
|
+
return LLMResponse(
|
|
254
|
+
content=msg.content or "",
|
|
255
|
+
tool_calls=calls,
|
|
256
|
+
prompt_tokens=getattr(u, "prompt_tokens", 0) or 0,
|
|
257
|
+
completion_tokens=getattr(u, "completion_tokens", 0) or 0,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def _chat_stream(self, kwargs: dict, stream_cb: StreamCallback) -> LLMResponse:
|
|
261
|
+
content_parts: list[str] = []
|
|
262
|
+
partial: dict[int, dict] = {}
|
|
263
|
+
for chunk in self._completions_create({**kwargs, "stream": True}):
|
|
264
|
+
if not chunk.choices:
|
|
265
|
+
continue
|
|
266
|
+
delta = chunk.choices[0].delta
|
|
267
|
+
|
|
268
|
+
# Extract reasoning/thinking tokens if present (for reasoning models)
|
|
269
|
+
reasoning = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None)
|
|
270
|
+
if reasoning:
|
|
271
|
+
if stream_cb:
|
|
272
|
+
try:
|
|
273
|
+
stream_cb(reasoning, is_thinking=True)
|
|
274
|
+
except TypeError:
|
|
275
|
+
stream_cb(reasoning)
|
|
276
|
+
|
|
277
|
+
if getattr(delta, "content", None):
|
|
278
|
+
content_parts.append(delta.content)
|
|
279
|
+
if stream_cb:
|
|
280
|
+
stream_cb(delta.content)
|
|
281
|
+
for tc in getattr(delta, "tool_calls", None) or []:
|
|
282
|
+
slot = partial.setdefault(tc.index, {"id": None, "name": "", "args": ""})
|
|
283
|
+
if tc.id:
|
|
284
|
+
slot["id"] = tc.id
|
|
285
|
+
if tc.function and tc.function.name:
|
|
286
|
+
slot["name"] += tc.function.name
|
|
287
|
+
if tc.function and tc.function.arguments:
|
|
288
|
+
slot["args"] += tc.function.arguments
|
|
289
|
+
calls = [
|
|
290
|
+
ToolCall(slot["id"] or f"call_{idx}", slot["name"], _safe_json_loads(slot["args"]))
|
|
291
|
+
for idx, slot in sorted(partial.items())
|
|
292
|
+
if slot["name"]
|
|
293
|
+
]
|
|
294
|
+
return LLMResponse(content="".join(content_parts), tool_calls=calls)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
# --------------------------------------------------------------------------- #
|
|
298
|
+
# Anthropic native
|
|
299
|
+
# --------------------------------------------------------------------------- #
|
|
300
|
+
class AnthropicClient(_RetryMixin):
|
|
301
|
+
def __init__(
|
|
302
|
+
self,
|
|
303
|
+
provider: ProviderConfig,
|
|
304
|
+
attempts: int = 3,
|
|
305
|
+
backoff: float = 1.5,
|
|
306
|
+
thinking: bool = False,
|
|
307
|
+
):
|
|
308
|
+
import anthropic
|
|
309
|
+
|
|
310
|
+
self.provider = provider
|
|
311
|
+
self.retry_attempts = attempts
|
|
312
|
+
self.retry_backoff = backoff
|
|
313
|
+
self.thinking = thinking
|
|
314
|
+
self.client = anthropic.Anthropic(
|
|
315
|
+
api_key=provider.api_key, base_url=provider.base_url or None
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
@staticmethod
|
|
319
|
+
def _to_anthropic_tools(tools):
|
|
320
|
+
return [
|
|
321
|
+
{
|
|
322
|
+
"name": t["function"]["name"],
|
|
323
|
+
"description": t["function"].get("description", ""),
|
|
324
|
+
"input_schema": t["function"].get("parameters", {"type": "object"}),
|
|
325
|
+
}
|
|
326
|
+
for t in (tools or [])
|
|
327
|
+
]
|
|
328
|
+
|
|
329
|
+
@staticmethod
|
|
330
|
+
def _to_anthropic_messages(messages):
|
|
331
|
+
system_parts: list[str] = []
|
|
332
|
+
out: list[dict] = []
|
|
333
|
+
|
|
334
|
+
def push(role: str, block: dict):
|
|
335
|
+
if out and out[-1]["role"] == role:
|
|
336
|
+
out[-1]["content"].append(block)
|
|
337
|
+
else:
|
|
338
|
+
out.append({"role": role, "content": [block]})
|
|
339
|
+
|
|
340
|
+
for m in messages:
|
|
341
|
+
role = m["role"]
|
|
342
|
+
if role == "system":
|
|
343
|
+
system_parts.append(m.get("content") or "")
|
|
344
|
+
elif role == "user":
|
|
345
|
+
content = m.get("content")
|
|
346
|
+
if isinstance(content, list):
|
|
347
|
+
for block in content:
|
|
348
|
+
if block.get("type") == "text":
|
|
349
|
+
push("user", {"type": "text", "text": block.get("text", "")})
|
|
350
|
+
elif block.get("type") == "image_url":
|
|
351
|
+
url = block.get("image_url", {}).get("url", "")
|
|
352
|
+
if url.startswith("data:") and "," in url:
|
|
353
|
+
header, data = url.split(",", 1)
|
|
354
|
+
media = header.split(";")[0].split(":")[-1]
|
|
355
|
+
push("user", {
|
|
356
|
+
"type": "image",
|
|
357
|
+
"source": {"type": "base64", "media_type": media, "data": data},
|
|
358
|
+
})
|
|
359
|
+
else:
|
|
360
|
+
push("user", {"type": "text", "text": content or ""})
|
|
361
|
+
elif role == "assistant":
|
|
362
|
+
if m.get("content"):
|
|
363
|
+
push("assistant", {"type": "text", "text": m["content"]})
|
|
364
|
+
for tc in m.get("tool_calls") or []:
|
|
365
|
+
fn = tc["function"]
|
|
366
|
+
push("assistant", {
|
|
367
|
+
"type": "tool_use",
|
|
368
|
+
"id": tc["id"],
|
|
369
|
+
"name": fn["name"],
|
|
370
|
+
"input": _safe_json_loads(fn.get("arguments") or "{}"),
|
|
371
|
+
})
|
|
372
|
+
elif role == "tool":
|
|
373
|
+
push("user", {
|
|
374
|
+
"type": "tool_result",
|
|
375
|
+
"tool_use_id": m.get("tool_call_id"),
|
|
376
|
+
"content": m.get("content") or "",
|
|
377
|
+
})
|
|
378
|
+
return "\n".join(p for p in system_parts if p), out
|
|
379
|
+
|
|
380
|
+
def chat(self, messages, tools=None, stream_cb=None, stream=True) -> LLMResponse:
|
|
381
|
+
system, conv = self._to_anthropic_messages(messages)
|
|
382
|
+
# Prompt caching: the system prompt (repo map, AGENTS.md, tool docs) is large
|
|
383
|
+
# and static across a turn — mark it ephemeral so Anthropic reuses it from
|
|
384
|
+
# cache instead of re-billing the full prefix every step.
|
|
385
|
+
system_field: Any = system
|
|
386
|
+
if system and len(system) > 2000:
|
|
387
|
+
system_field = [{
|
|
388
|
+
"type": "text", "text": system,
|
|
389
|
+
"cache_control": {"type": "ephemeral"},
|
|
390
|
+
}]
|
|
391
|
+
# Set cache_control breakpoint on the last message of the conversation to cache history
|
|
392
|
+
if conv and len(messages) >= 4:
|
|
393
|
+
last_msg = conv[-1]
|
|
394
|
+
if last_msg.get("content") and isinstance(last_msg["content"], list):
|
|
395
|
+
last_block = last_msg["content"][-1]
|
|
396
|
+
if isinstance(last_block, dict) and last_block.get("type") in ("text", "tool_result"):
|
|
397
|
+
last_block["cache_control"] = {"type": "ephemeral"}
|
|
398
|
+
kwargs: dict[str, Any] = {
|
|
399
|
+
"model": self.provider.model,
|
|
400
|
+
"system": system_field,
|
|
401
|
+
"messages": conv,
|
|
402
|
+
"max_tokens": self.provider.max_tokens,
|
|
403
|
+
"temperature": self.provider.temperature,
|
|
404
|
+
}
|
|
405
|
+
if self.thinking:
|
|
406
|
+
budget = min(8000, max(1024, self.provider.max_tokens // 2))
|
|
407
|
+
kwargs["max_tokens"] = max(self.provider.max_tokens, budget + 1024)
|
|
408
|
+
kwargs["temperature"] = 1 # required when thinking is enabled
|
|
409
|
+
kwargs["thinking"] = {"type": "enabled", "budget_tokens": budget}
|
|
410
|
+
if tools:
|
|
411
|
+
kwargs["tools"] = self._to_anthropic_tools(tools)
|
|
412
|
+
|
|
413
|
+
def run(cb):
|
|
414
|
+
return self._chat_stream(kwargs, cb) if stream else self._chat_once(kwargs)
|
|
415
|
+
|
|
416
|
+
return self._with_retry(run, stream_cb)
|
|
417
|
+
|
|
418
|
+
def _collect(self, message) -> LLMResponse:
|
|
419
|
+
content, calls = "", []
|
|
420
|
+
for block in message.content:
|
|
421
|
+
if block.type == "text":
|
|
422
|
+
content += block.text
|
|
423
|
+
elif block.type == "tool_use":
|
|
424
|
+
calls.append(ToolCall(block.id, block.name, block.input or {}))
|
|
425
|
+
u = getattr(message, "usage", None)
|
|
426
|
+
return LLMResponse(
|
|
427
|
+
content=content,
|
|
428
|
+
tool_calls=calls,
|
|
429
|
+
prompt_tokens=getattr(u, "input_tokens", 0) or 0,
|
|
430
|
+
completion_tokens=getattr(u, "output_tokens", 0) or 0,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
def _chat_once(self, kwargs: dict) -> LLMResponse:
|
|
434
|
+
return self._collect(self.client.messages.create(**kwargs))
|
|
435
|
+
|
|
436
|
+
def _chat_stream(self, kwargs: dict, stream_cb: StreamCallback) -> LLMResponse:
|
|
437
|
+
with self.client.messages.stream(**kwargs) as s:
|
|
438
|
+
if stream_cb:
|
|
439
|
+
try:
|
|
440
|
+
for event in s:
|
|
441
|
+
if event.type == "content_block_delta":
|
|
442
|
+
d = event.delta
|
|
443
|
+
if getattr(d, "type", None) == "thinking_delta" and getattr(d, "thinking", None):
|
|
444
|
+
try:
|
|
445
|
+
stream_cb(d.thinking, is_thinking=True)
|
|
446
|
+
except TypeError:
|
|
447
|
+
stream_cb(d.thinking)
|
|
448
|
+
elif getattr(d, "type", None) == "text_delta" and getattr(d, "text", None):
|
|
449
|
+
stream_cb(d.text)
|
|
450
|
+
except Exception:
|
|
451
|
+
for text in s.text_stream:
|
|
452
|
+
stream_cb(text)
|
|
453
|
+
else:
|
|
454
|
+
for _ in s.text_stream:
|
|
455
|
+
pass
|
|
456
|
+
return self._collect(s.get_final_message())
|