ata-coder 2.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ata_coder/__init__.py +1 -0
- ata_coder/agent.py +874 -0
- ata_coder/agent_compact.py +190 -0
- ata_coder/agent_controller.py +218 -0
- ata_coder/agent_extension.py +69 -0
- ata_coder/agent_routing.py +105 -0
- ata_coder/agent_subsystems.py +72 -0
- ata_coder/agent_tools.py +318 -0
- ata_coder/agent_undo.py +63 -0
- ata_coder/anthropic_client.py +465 -0
- ata_coder/change_tracker.py +368 -0
- ata_coder/clawd_integration.py +574 -0
- ata_coder/commands/__init__.py +128 -0
- ata_coder/commands/_core.py +184 -0
- ata_coder/commands/_safety.py +95 -0
- ata_coder/commands/_settings.py +241 -0
- ata_coder/commands/_workflow.py +451 -0
- ata_coder/commands.py +974 -0
- ata_coder/config.py +257 -0
- ata_coder/core/__init__.py +35 -0
- ata_coder/core/events.py +73 -0
- ata_coder/core/queue.py +85 -0
- ata_coder/core/state.py +17 -0
- ata_coder/event_queue.py +5 -0
- ata_coder/extension.py +654 -0
- ata_coder/extensions/__init__.py +1 -0
- ata_coder/extensions/hello_skill.py +47 -0
- ata_coder/fool_proof.py +295 -0
- ata_coder/git_workflow.py +371 -0
- ata_coder/gui.py +511 -0
- ata_coder/llm_client.py +543 -0
- ata_coder/main.py +814 -0
- ata_coder/mcp_client.py +1095 -0
- ata_coder/memory.py +539 -0
- ata_coder/model_registry.py +134 -0
- ata_coder/model_router.py +105 -0
- ata_coder/permissions.py +274 -0
- ata_coder/privilege.py +464 -0
- ata_coder/project.py +273 -0
- ata_coder/prompt_template.py +423 -0
- ata_coder/prompts/auto-mode.md +7 -0
- ata_coder/prompts/coding-rules.md +40 -0
- ata_coder/prompts/execution-guardrails.md +14 -0
- ata_coder/prompts/memory-system.md +24 -0
- ata_coder/prompts/output-style.md +23 -0
- ata_coder/prompts/safety.md +17 -0
- ata_coder/prompts/slash-commands.md +24 -0
- ata_coder/prompts/sub-agents.md +38 -0
- ata_coder/prompts/system-reminders.md +17 -0
- ata_coder/prompts/system.md +105 -0
- ata_coder/prompts/tool-policy.md +46 -0
- ata_coder/repl_theme.py +99 -0
- ata_coder/repl_tracker.py +89 -0
- ata_coder/repl_ui.py +1214 -0
- ata_coder/safety_guard.py +434 -0
- ata_coder/self_correct.py +346 -0
- ata_coder/server.py +882 -0
- ata_coder/server_session.py +159 -0
- ata_coder/server_shell.py +129 -0
- ata_coder/session.py +431 -0
- ata_coder/settings.py +439 -0
- ata_coder/setup_wizard.py +136 -0
- ata_coder/skill_extension.py +92 -0
- ata_coder/skills/architect/SKILL.md +42 -0
- ata_coder/skills/code-reviewer/SKILL.md +37 -0
- ata_coder/skills/codecraft/SKILL.md +452 -0
- ata_coder/skills/debugger/SKILL.md +45 -0
- ata_coder/skills/doc-writer/SKILL.md +36 -0
- ata_coder/skills/general-coder/SKILL.md +76 -0
- ata_coder/skills/math-calculator/README.md +40 -0
- ata_coder/skills/math-calculator/SKILL.md +59 -0
- ata_coder/skills/math-calculator/handler.py +103 -0
- ata_coder/skills/math-calculator/prompts/system.md +8 -0
- ata_coder/skills/math-calculator/requirements.txt +2 -0
- ata_coder/skills/math-calculator/resources/constants.json +8 -0
- ata_coder/skills/math-calculator/tests/test_handler.py +53 -0
- ata_coder/skills/security-auditor/SKILL.md +40 -0
- ata_coder/skills/test-writer/SKILL.md +36 -0
- ata_coder/skills/weather-skill/README.md +45 -0
- ata_coder/skills/weather-skill/handler.py +76 -0
- ata_coder/skills/weather-skill/manifest.json +48 -0
- ata_coder/skills/weather-skill/prompts/system_prompt.txt +9 -0
- ata_coder/skills/weather-skill/prompts/user_prompt_template.txt +3 -0
- ata_coder/skills/weather-skill/requirements.txt +1 -0
- ata_coder/skills/weather-skill/resources/city_list.json +17 -0
- ata_coder/skills/weather-skill/resources/error_messages.json +7 -0
- ata_coder/skills/weather-skill/tests/test_handler.py +28 -0
- ata_coder/skills/weather-skill/weather_utils.py +50 -0
- ata_coder/skills.py +1014 -0
- ata_coder/sub_agent.py +273 -0
- ata_coder/sub_agent_manager.py +203 -0
- ata_coder/system_prompt_builder.py +146 -0
- ata_coder/task_planner.py +391 -0
- ata_coder/terminal.py +318 -0
- ata_coder/test_runner.py +219 -0
- ata_coder/thread_supervisor.py +195 -0
- ata_coder/tool_defs.py +335 -0
- ata_coder/tools/__init__.py +11 -0
- ata_coder/tools/definitions.py +335 -0
- ata_coder/tools/executor.py +1036 -0
- ata_coder/tools/result.py +26 -0
- ata_coder/tools/subagent.py +332 -0
- ata_coder/tools/web.py +361 -0
- ata_coder/tools.py +1576 -0
- ata_coder/types.py +92 -0
- ata_coder/utils.py +113 -0
- ata_coder/web/css/style.css +180 -0
- ata_coder/web/index.html +84 -0
- ata_coder/web/js/app.js +489 -0
- ata_coder/web/package-lock.json +25 -0
- ata_coder/web/package.json +10 -0
- ata_coder/web/tsconfig.json +13 -0
- ata_coder-2.4.2.dist-info/METADATA +799 -0
- ata_coder-2.4.2.dist-info/RECORD +118 -0
- ata_coder-2.4.2.dist-info/WHEEL +5 -0
- ata_coder-2.4.2.dist-info/entry_points.txt +2 -0
- ata_coder-2.4.2.dist-info/licenses/LICENSE +21 -0
- ata_coder-2.4.2.dist-info/top_level.txt +1 -0
ata_coder/llm_client.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenAI-compatible async LLM client with tool/function calling support.
|
|
3
|
+
Uses httpx.AsyncClient (no openai SDK dependency) for maximum compatibility.
|
|
4
|
+
Supports any provider that implements the OpenAI chat completions API format.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import random
|
|
11
|
+
from typing import Any, AsyncIterator, Callable
|
|
12
|
+
|
|
13
|
+
import httpx
|
|
14
|
+
|
|
15
|
+
from .config import LLMConfig
|
|
16
|
+
from .types import BaseLLMClient, Message, ToolDef
|
|
17
|
+
from .utils import enhance_api_error
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# ── System prompt for the coding agent ───────────────────────────────────────
|
|
23
|
+
|
|
24
|
+
_SYSTEM_PROMPT_CACHE: str | None = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _load_system_prompt() -> str:
|
|
28
|
+
"""Load fallback system prompt from skills/codecraft/SKILL.md if available.
|
|
29
|
+
|
|
30
|
+
Cached after first call — no file I/O on repeated access.
|
|
31
|
+
"""
|
|
32
|
+
global _SYSTEM_PROMPT_CACHE
|
|
33
|
+
if _SYSTEM_PROMPT_CACHE is not None:
|
|
34
|
+
return _SYSTEM_PROMPT_CACHE
|
|
35
|
+
|
|
36
|
+
import re
|
|
37
|
+
from pathlib import Path
|
|
38
|
+
prompt_file = Path(__file__).parent / "skills" / "codecraft" / "SKILL.md"
|
|
39
|
+
if prompt_file.exists():
|
|
40
|
+
try:
|
|
41
|
+
raw = prompt_file.read_text(encoding="utf-8")
|
|
42
|
+
match = re.match(r"^---\s*\n(.*?)\n---\s*\n(.*)", raw, re.DOTALL)
|
|
43
|
+
if match:
|
|
44
|
+
_SYSTEM_PROMPT_CACHE = match.group(2).strip()
|
|
45
|
+
return _SYSTEM_PROMPT_CACHE
|
|
46
|
+
_SYSTEM_PROMPT_CACHE = raw
|
|
47
|
+
return raw
|
|
48
|
+
except Exception:
|
|
49
|
+
pass
|
|
50
|
+
_SYSTEM_PROMPT_CACHE = (
|
|
51
|
+
"You are an expert software engineer. "
|
|
52
|
+
"Write correct, secure, maintainable code."
|
|
53
|
+
)
|
|
54
|
+
return _SYSTEM_PROMPT_CACHE
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# Lazy alias — defers loading until first access so that the skills directory
|
|
58
|
+
# is guaranteed to exist. Module-level ``__getattr__`` (Python 3.7+) only
|
|
59
|
+
# fires when the attribute is NOT found in the module dict.
|
|
60
|
+
_SYSTEM_PROMPT_LAZY: str | None = None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def __getattr__(name: str):
|
|
64
|
+
if name == "SYSTEM_PROMPT":
|
|
65
|
+
global _SYSTEM_PROMPT_LAZY
|
|
66
|
+
if _SYSTEM_PROMPT_LAZY is None:
|
|
67
|
+
_SYSTEM_PROMPT_LAZY = _load_system_prompt()
|
|
68
|
+
return _SYSTEM_PROMPT_LAZY
|
|
69
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class LLMClient(BaseLLMClient):
|
|
73
|
+
"""
|
|
74
|
+
OpenAI-compatible async LLM client using httpx.AsyncClient.
|
|
75
|
+
|
|
76
|
+
Supports:
|
|
77
|
+
- Any OpenAI-compatible endpoint (OpenAI, Azure, Ollama, vLLM, etc.)
|
|
78
|
+
- Function/tool calling
|
|
79
|
+
- Streaming and non-streaming modes
|
|
80
|
+
- Rate limit retry with exponential backoff
|
|
81
|
+
- Usage tracking callback
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(self, config: LLMConfig | None = None):
|
|
85
|
+
self.config = config or LLMConfig()
|
|
86
|
+
self._tools: list[ToolDef] = []
|
|
87
|
+
|
|
88
|
+
# Build the HTTP client — URL normalization via shared module
|
|
89
|
+
from .model_registry import build_api_url
|
|
90
|
+
self._api_url = build_api_url(self.config.base_url, "chat/completions")
|
|
91
|
+
|
|
92
|
+
self._headers = {
|
|
93
|
+
"Authorization": f"Bearer {self.config.api_key}",
|
|
94
|
+
"Content-Type": "application/json",
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
self._client = httpx.AsyncClient(
|
|
98
|
+
timeout=httpx.Timeout(300.0, connect=30.0),
|
|
99
|
+
headers=self._headers,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Usage tracking
|
|
103
|
+
self._usage_callback: Callable[[int, int], None] | None = None
|
|
104
|
+
self._total_prompt_tokens = 0
|
|
105
|
+
self._total_completion_tokens = 0
|
|
106
|
+
|
|
107
|
+
# Retry config
|
|
108
|
+
self._max_retries = 3
|
|
109
|
+
self._retry_base_delay = 1.0 # seconds
|
|
110
|
+
|
|
111
|
+
def on_usage(self, callback: Callable[[int, int], None]) -> None:
|
|
112
|
+
"""Register a callback for token usage: callback(prompt_tokens, completion_tokens)."""
|
|
113
|
+
self._usage_callback = callback
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def total_prompt_tokens(self) -> int:
|
|
117
|
+
return self._total_prompt_tokens
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def total_completion_tokens(self) -> int:
|
|
121
|
+
return self._total_completion_tokens
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def total_tokens(self) -> int:
|
|
125
|
+
return self._total_prompt_tokens + self._total_completion_tokens
|
|
126
|
+
|
|
127
|
+
# ── Tool registration ──────────────────────────────────────────────────
|
|
128
|
+
|
|
129
|
+
def register_tools(self, tools: list[ToolDef]) -> None:
|
|
130
|
+
"""Register tool definitions for subsequent requests."""
|
|
131
|
+
self._tools = tools
|
|
132
|
+
|
|
133
|
+
# ── Chat completion (non-streaming) ────────────────────────────────────
|
|
134
|
+
|
|
135
|
+
async def chat(
|
|
136
|
+
self,
|
|
137
|
+
messages: list[Message],
|
|
138
|
+
tools: list[ToolDef] | None = None,
|
|
139
|
+
system_prompt: str = "",
|
|
140
|
+
) -> Message:
|
|
141
|
+
"""
|
|
142
|
+
Send messages and get a completion.
|
|
143
|
+
Returns the assistant message (may include tool_calls).
|
|
144
|
+
Automatically retries on rate limit (429) errors.
|
|
145
|
+
|
|
146
|
+
*system_prompt* is prepended as a system message when the messages
|
|
147
|
+
list does not already contain one. This provides API parity with
|
|
148
|
+
AnthropicClient without requiring the caller to branch on provider.
|
|
149
|
+
"""
|
|
150
|
+
tool_defs = tools if tools is not None else self._tools
|
|
151
|
+
|
|
152
|
+
# Honour system_prompt param for API parity with AnthropicClient
|
|
153
|
+
resolved_messages = list(messages)
|
|
154
|
+
if system_prompt and not any(m.get("role") == "system" for m in resolved_messages):
|
|
155
|
+
resolved_messages.insert(0, {"role": "system", "content": system_prompt})
|
|
156
|
+
|
|
157
|
+
body: dict[str, Any] = {
|
|
158
|
+
"model": self.config.model,
|
|
159
|
+
"messages": resolved_messages,
|
|
160
|
+
"temperature": self.config.temperature,
|
|
161
|
+
"max_tokens": self.config.max_tokens,
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
if tool_defs:
|
|
165
|
+
body["tools"] = tool_defs
|
|
166
|
+
body["tool_choice"] = "auto"
|
|
167
|
+
|
|
168
|
+
# Thinking mode
|
|
169
|
+
thinking_strength = getattr(self.config, 'thinking_strength', '') or ''
|
|
170
|
+
if thinking_strength == "off":
|
|
171
|
+
body["extra_body"] = {"thinking": {"type": "disabled"}}
|
|
172
|
+
elif thinking_strength:
|
|
173
|
+
body["reasoning_effort"] = thinking_strength.lower()
|
|
174
|
+
body.pop("temperature", None)
|
|
175
|
+
|
|
176
|
+
logger.debug(
|
|
177
|
+
"Calling %s with %d messages, %d tools, thinking=%s",
|
|
178
|
+
self.config.model,
|
|
179
|
+
len(messages),
|
|
180
|
+
len(tool_defs) if tool_defs else 0,
|
|
181
|
+
thinking_strength or "off",
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Sanitize surrogates before JSON encoding (prevent UTF-8 encode crash)
|
|
185
|
+
from .utils import sanitize_surrogates
|
|
186
|
+
body = sanitize_surrogates(body)
|
|
187
|
+
|
|
188
|
+
data = await self._request_with_retry(body)
|
|
189
|
+
|
|
190
|
+
choice = data["choices"][0]
|
|
191
|
+
msg = choice["message"]
|
|
192
|
+
|
|
193
|
+
# Build a clean message dict (preserve reasoning_content for DeepSeek v4/etc)
|
|
194
|
+
result: Message = {
|
|
195
|
+
"role": "assistant",
|
|
196
|
+
"content": msg.get("content") or "",
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
# Preserve reasoning_content for thinking/reasoning models (DeepSeek R1/v4, etc.)
|
|
200
|
+
if msg.get("reasoning_content"):
|
|
201
|
+
result["reasoning_content"] = msg["reasoning_content"]
|
|
202
|
+
|
|
203
|
+
if msg.get("tool_calls"):
|
|
204
|
+
result["tool_calls"] = [
|
|
205
|
+
{
|
|
206
|
+
"id": tc["id"],
|
|
207
|
+
"type": "function",
|
|
208
|
+
"function": {
|
|
209
|
+
"name": tc["function"]["name"],
|
|
210
|
+
"arguments": tc["function"]["arguments"],
|
|
211
|
+
},
|
|
212
|
+
}
|
|
213
|
+
for tc in msg["tool_calls"]
|
|
214
|
+
]
|
|
215
|
+
|
|
216
|
+
# Track usage (with fallback estimation)
|
|
217
|
+
usage = data.get("usage")
|
|
218
|
+
if usage and usage.get("total_tokens"):
|
|
219
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
220
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
221
|
+
else:
|
|
222
|
+
# Fallback estimation
|
|
223
|
+
prompt_tokens = self.count_tokens_approx(messages)
|
|
224
|
+
completion_tokens = self.count_tokens_approx([result])
|
|
225
|
+
self._total_prompt_tokens += prompt_tokens
|
|
226
|
+
self._total_completion_tokens += completion_tokens
|
|
227
|
+
logger.info(
|
|
228
|
+
"Tokens: %d in, %d out (session: %d total)",
|
|
229
|
+
prompt_tokens, completion_tokens, self.total_tokens,
|
|
230
|
+
)
|
|
231
|
+
if self._usage_callback:
|
|
232
|
+
self._usage_callback(prompt_tokens, completion_tokens)
|
|
233
|
+
|
|
234
|
+
return result
|
|
235
|
+
|
|
236
|
+
# ── Streaming chat completion ──────────────────────────────────────────
|
|
237
|
+
|
|
238
|
+
async def chat_stream(
|
|
239
|
+
self,
|
|
240
|
+
messages: list[Message],
|
|
241
|
+
tools: list[ToolDef] | None = None,
|
|
242
|
+
system_prompt: str = "",
|
|
243
|
+
) -> AsyncIterator[tuple[str, Any]]:
|
|
244
|
+
"""
|
|
245
|
+
Stream a chat completion.
|
|
246
|
+
Yields (delta_type, content) tuples.
|
|
247
|
+
delta_type is one of: "text", "tool_call", "finish".
|
|
248
|
+
|
|
249
|
+
*system_prompt* is prepended as a system message when the messages
|
|
250
|
+
list does not already contain one — matching the behaviour of chat().
|
|
251
|
+
"""
|
|
252
|
+
tool_defs = tools if tools is not None else self._tools
|
|
253
|
+
|
|
254
|
+
# Honour system_prompt param for API parity with AnthropicClient
|
|
255
|
+
resolved_messages = list(messages)
|
|
256
|
+
if system_prompt and not any(m.get("role") == "system" for m in resolved_messages):
|
|
257
|
+
resolved_messages.insert(0, {"role": "system", "content": system_prompt})
|
|
258
|
+
|
|
259
|
+
body: dict[str, Any] = {
|
|
260
|
+
"model": self.config.model,
|
|
261
|
+
"messages": resolved_messages,
|
|
262
|
+
"temperature": self.config.temperature,
|
|
263
|
+
"max_tokens": self.config.max_tokens,
|
|
264
|
+
"stream": True,
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
if tool_defs:
|
|
268
|
+
body["tools"] = tool_defs
|
|
269
|
+
body["tool_choice"] = "auto"
|
|
270
|
+
|
|
271
|
+
# Thinking mode for streaming
|
|
272
|
+
thinking_strength = getattr(self.config, 'thinking_strength', '') or ''
|
|
273
|
+
if thinking_strength == "off":
|
|
274
|
+
body["extra_body"] = {"thinking": {"type": "disabled"}}
|
|
275
|
+
elif thinking_strength:
|
|
276
|
+
body["reasoning_effort"] = thinking_strength.lower()
|
|
277
|
+
body.pop("temperature", None)
|
|
278
|
+
|
|
279
|
+
# Retry loop for streaming (up to 2 retries for 429/5xx)
|
|
280
|
+
last_error = None
|
|
281
|
+
|
|
282
|
+
# Sanitize surrogates before JSON encoding (prevent UTF-8 encode crash)
|
|
283
|
+
from .utils import sanitize_surrogates
|
|
284
|
+
body = sanitize_surrogates(body)
|
|
285
|
+
|
|
286
|
+
for attempt in range(self._max_retries):
|
|
287
|
+
try:
|
|
288
|
+
response = await self._client.send(
|
|
289
|
+
self._client.build_request("POST", self._api_url, json=body),
|
|
290
|
+
stream=True,
|
|
291
|
+
)
|
|
292
|
+
except (httpx.ConnectError, httpx.ReadTimeout, httpx.RemoteProtocolError) as e:
|
|
293
|
+
if attempt < self._max_retries - 1:
|
|
294
|
+
delay = self._retry_base_delay * (2 ** attempt) * (0.5 + random.random())
|
|
295
|
+
logger.warning("Stream connect error, retrying in %.1fs: %s", delay, e)
|
|
296
|
+
await asyncio.sleep(delay)
|
|
297
|
+
continue
|
|
298
|
+
raise RuntimeError(f"Stream connection failed: {e}")
|
|
299
|
+
|
|
300
|
+
if response.status_code >= 400:
|
|
301
|
+
if response.status_code == 429 and attempt < self._max_retries - 1:
|
|
302
|
+
delay = self._retry_base_delay * (2 ** attempt) * (0.5 + random.random())
|
|
303
|
+
logger.warning("Stream rate limited, retrying in %.1fs", delay)
|
|
304
|
+
await asyncio.sleep(delay)
|
|
305
|
+
continue
|
|
306
|
+
if response.status_code >= 500 and attempt < self._max_retries - 1:
|
|
307
|
+
delay = self._retry_base_delay * (2 ** attempt) * (0.5 + random.random())
|
|
308
|
+
logger.warning("Stream server error (%d), retrying in %.1fs", response.status_code, delay)
|
|
309
|
+
await asyncio.sleep(delay)
|
|
310
|
+
continue
|
|
311
|
+
try:
|
|
312
|
+
error_body = (await response.aread()).decode("utf-8", errors="replace")[:500]
|
|
313
|
+
except Exception:
|
|
314
|
+
error_body = "(could not read body)"
|
|
315
|
+
logger.error("Stream request failed (%d): %s", response.status_code, error_body)
|
|
316
|
+
response.raise_for_status()
|
|
317
|
+
|
|
318
|
+
# Success — read the streaming response body
|
|
319
|
+
|
|
320
|
+
# Accumulators for streaming tool calls
|
|
321
|
+
tool_call_buf: dict[int, dict[str, Any]] = {}
|
|
322
|
+
finish_reason = None
|
|
323
|
+
|
|
324
|
+
# Track usage from streaming (with fallback estimation)
|
|
325
|
+
prompt_tokens = 0
|
|
326
|
+
completion_tokens = 0
|
|
327
|
+
total_text_chars = 0 # fallback estimation
|
|
328
|
+
|
|
329
|
+
async for line in response.aiter_lines():
|
|
330
|
+
if not line or not line.startswith("data: "):
|
|
331
|
+
continue
|
|
332
|
+
|
|
333
|
+
data_str = line[6:] # strip "data: " prefix
|
|
334
|
+
if data_str.strip() == "[DONE]":
|
|
335
|
+
break
|
|
336
|
+
|
|
337
|
+
try:
|
|
338
|
+
chunk = json.loads(data_str)
|
|
339
|
+
except json.JSONDecodeError:
|
|
340
|
+
logger.warning("Failed to parse SSE line: %s", data_str[:100])
|
|
341
|
+
continue
|
|
342
|
+
|
|
343
|
+
# Track usage if present in chunk
|
|
344
|
+
usage = chunk.get("usage")
|
|
345
|
+
if usage:
|
|
346
|
+
prompt_tokens = usage.get("prompt_tokens", prompt_tokens)
|
|
347
|
+
completion_tokens = usage.get("completion_tokens", completion_tokens)
|
|
348
|
+
|
|
349
|
+
choices = chunk.get("choices", [])
|
|
350
|
+
if not choices:
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
choice = choices[0]
|
|
354
|
+
delta = choice.get("delta", {})
|
|
355
|
+
finish_reason = choice.get("finish_reason") or finish_reason
|
|
356
|
+
|
|
357
|
+
# Text content
|
|
358
|
+
if delta.get("content"):
|
|
359
|
+
text_chunk = delta["content"]
|
|
360
|
+
total_text_chars += len(text_chunk)
|
|
361
|
+
yield ("text", text_chunk)
|
|
362
|
+
|
|
363
|
+
# Reasoning content (DeepSeek R1/v4, o1 models)
|
|
364
|
+
if delta.get("reasoning_content"):
|
|
365
|
+
yield ("reasoning", delta["reasoning_content"])
|
|
366
|
+
|
|
367
|
+
# Tool calls (streaming — incremental chunks)
|
|
368
|
+
tool_calls_delta = delta.get("tool_calls", [])
|
|
369
|
+
for tc_delta in tool_calls_delta:
|
|
370
|
+
idx = tc_delta.get("index", 0)
|
|
371
|
+
if idx not in tool_call_buf:
|
|
372
|
+
tool_call_buf[idx] = {
|
|
373
|
+
"id": "",
|
|
374
|
+
"function": {"name": "", "arguments": ""},
|
|
375
|
+
}
|
|
376
|
+
buf = tool_call_buf[idx]
|
|
377
|
+
if tc_delta.get("id"):
|
|
378
|
+
buf["id"] = tc_delta["id"]
|
|
379
|
+
fn = tc_delta.get("function", {})
|
|
380
|
+
if fn.get("name"):
|
|
381
|
+
buf["function"]["name"] += fn["name"]
|
|
382
|
+
if fn.get("arguments"):
|
|
383
|
+
buf["function"]["arguments"] += fn["arguments"]
|
|
384
|
+
|
|
385
|
+
# Yield assembled tool calls
|
|
386
|
+
for idx in sorted(tool_call_buf.keys()):
|
|
387
|
+
buf = tool_call_buf[idx]
|
|
388
|
+
yield ("tool_call", {
|
|
389
|
+
"id": buf["id"],
|
|
390
|
+
"type": "function",
|
|
391
|
+
"function": buf["function"],
|
|
392
|
+
})
|
|
393
|
+
|
|
394
|
+
# Track streaming usage (with fallback estimation)
|
|
395
|
+
if not prompt_tokens:
|
|
396
|
+
# Fallback: estimate from character count
|
|
397
|
+
# Roughly 4 chars per token for English, 1.5 for Chinese
|
|
398
|
+
prompt_tokens = self.count_tokens_approx(messages) if messages else 0
|
|
399
|
+
completion_tokens = max(1, total_text_chars // 3)
|
|
400
|
+
self._total_prompt_tokens += prompt_tokens
|
|
401
|
+
self._total_completion_tokens += completion_tokens
|
|
402
|
+
logger.info("Stream tokens: %d in, %d out (session: %d total)",
|
|
403
|
+
prompt_tokens, completion_tokens, self.total_tokens)
|
|
404
|
+
if self._usage_callback:
|
|
405
|
+
self._usage_callback(prompt_tokens, completion_tokens)
|
|
406
|
+
|
|
407
|
+
if finish_reason:
|
|
408
|
+
yield ("finish", finish_reason)
|
|
409
|
+
break # success — exit retry loop
|
|
410
|
+
|
|
411
|
+
# ── Retry logic ──────────────────────────────────────────────────────
|
|
412
|
+
|
|
413
|
+
async def _request_with_retry(self, body: dict[str, Any]) -> dict[str, Any]:
|
|
414
|
+
"""Send request with exponential backoff on rate limits."""
|
|
415
|
+
last_error: str | None = None
|
|
416
|
+
for attempt in range(self._max_retries + 1):
|
|
417
|
+
try:
|
|
418
|
+
response = await self._client.post(self._api_url, json=body)
|
|
419
|
+
except httpx.ConnectError as e:
|
|
420
|
+
raise RuntimeError(
|
|
421
|
+
f"Cannot connect to {self._api_url}\n"
|
|
422
|
+
f" Check: is the server running? Is the URL correct?\n"
|
|
423
|
+
f" Current: {self.config.base_url}\n"
|
|
424
|
+
f" Detail: {e}"
|
|
425
|
+
)
|
|
426
|
+
except httpx.ReadTimeout:
|
|
427
|
+
raise RuntimeError(
|
|
428
|
+
"Request timed out after 300s.\n"
|
|
429
|
+
" The model may be overloaded or the prompt too large.\n"
|
|
430
|
+
" Try again or reduce the task complexity."
|
|
431
|
+
)
|
|
432
|
+
except httpx.RemoteProtocolError as e:
|
|
433
|
+
last_error = str(e)
|
|
434
|
+
if attempt < self._max_retries:
|
|
435
|
+
delay = self._retry_base_delay * (2 ** attempt) * (0.5 + random.random())
|
|
436
|
+
logger.warning("Remote protocol error, retrying in %.1fs: %s", delay, e)
|
|
437
|
+
await asyncio.sleep(delay)
|
|
438
|
+
continue
|
|
439
|
+
raise RuntimeError(
|
|
440
|
+
"Server disconnected unexpectedly.\n"
|
|
441
|
+
" The model may have timed out internally.\n"
|
|
442
|
+
" Try again with a smaller task."
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
if response.status_code == 429:
|
|
446
|
+
last_error = "HTTP 429 (rate limited)"
|
|
447
|
+
# Rate limited — extract retry-after or use exponential backoff
|
|
448
|
+
retry_after = response.headers.get("retry-after", "")
|
|
449
|
+
try:
|
|
450
|
+
delay = float(retry_after) if retry_after else self._retry_base_delay * (2 ** attempt)
|
|
451
|
+
except ValueError:
|
|
452
|
+
delay = self._retry_base_delay * (2 ** attempt) * (0.5 + random.random())
|
|
453
|
+
delay = min(delay, 60.0) # cap at 60s
|
|
454
|
+
|
|
455
|
+
if attempt < self._max_retries:
|
|
456
|
+
logger.warning(
|
|
457
|
+
"Rate limited (429). Retrying in %.1fs (attempt %d/%d)...",
|
|
458
|
+
delay, attempt + 1, self._max_retries,
|
|
459
|
+
)
|
|
460
|
+
await asyncio.sleep(delay)
|
|
461
|
+
continue
|
|
462
|
+
else:
|
|
463
|
+
raise RuntimeError(
|
|
464
|
+
f"Rate limit exceeded after {self._max_retries} retries. "
|
|
465
|
+
f"Wait and try again."
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
if response.status_code >= 500:
|
|
469
|
+
last_error = f"HTTP {response.status_code} (server error)"
|
|
470
|
+
# Server error — retry
|
|
471
|
+
if attempt < self._max_retries:
|
|
472
|
+
delay = self._retry_base_delay * (2 ** attempt) * (0.5 + random.random())
|
|
473
|
+
logger.warning(
|
|
474
|
+
"Server error (%d). Retrying in %.1fs...",
|
|
475
|
+
response.status_code, delay,
|
|
476
|
+
)
|
|
477
|
+
await asyncio.sleep(delay)
|
|
478
|
+
continue
|
|
479
|
+
|
|
480
|
+
try:
|
|
481
|
+
response.raise_for_status()
|
|
482
|
+
except httpx.HTTPStatusError as e:
|
|
483
|
+
# Try to extract error message from response
|
|
484
|
+
try:
|
|
485
|
+
err_data = response.json()
|
|
486
|
+
err_msg = err_data.get("error", {}).get("message", str(e))
|
|
487
|
+
except Exception:
|
|
488
|
+
err_msg = str(e)
|
|
489
|
+
raise RuntimeError(
|
|
490
|
+
enhance_api_error(response.status_code, f"API error ({response.status_code}): {err_msg}", self.config.base_url)
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
return response.json()
|
|
494
|
+
|
|
495
|
+
raise RuntimeError(f"Request failed after {self._max_retries} retries: {last_error}")
|
|
496
|
+
|
|
497
|
+
# ── Convenience methods ────────────────────────────────────────────────
|
|
498
|
+
|
|
499
|
+
async def simple_chat(self, user_message: str, system: str | None = None) -> str:
|
|
500
|
+
"""Single-turn chat without tools. Returns text response."""
|
|
501
|
+
messages = [
|
|
502
|
+
{"role": "system", "content": system or SYSTEM_PROMPT},
|
|
503
|
+
{"role": "user", "content": user_message},
|
|
504
|
+
]
|
|
505
|
+
result = await self.chat(messages, tools=[])
|
|
506
|
+
return result.get("content", "")
|
|
507
|
+
|
|
508
|
+
def count_tokens_approx(self, messages: list[Message]) -> int:
|
|
509
|
+
"""
|
|
510
|
+
Token count estimation — CJK-aware + tiktoken if available.
|
|
511
|
+
"""
|
|
512
|
+
try:
|
|
513
|
+
import tiktoken
|
|
514
|
+
enc = tiktoken.get_encoding("cl100k_base")
|
|
515
|
+
total = 0
|
|
516
|
+
for msg in messages:
|
|
517
|
+
content = msg.get("content", "") or ""
|
|
518
|
+
total += len(enc.encode(content))
|
|
519
|
+
for tc in msg.get("tool_calls", []):
|
|
520
|
+
total += len(enc.encode(json.dumps(tc)))
|
|
521
|
+
return total
|
|
522
|
+
except ImportError:
|
|
523
|
+
pass
|
|
524
|
+
|
|
525
|
+
# CJK-aware fallback
|
|
526
|
+
import re
|
|
527
|
+
total = 0
|
|
528
|
+
for msg in messages:
|
|
529
|
+
content = msg.get("content", "") or ""
|
|
530
|
+
cjk = len(re.findall(r'[一-鿿 -〿-]', content))
|
|
531
|
+
other = len(content) - cjk
|
|
532
|
+
total += (cjk * 2 // 3) + (other // 4)
|
|
533
|
+
for tc in msg.get("tool_calls", []):
|
|
534
|
+
total += len(json.dumps(tc)) // 4
|
|
535
|
+
return max(1, total)
|
|
536
|
+
|
|
537
|
+
def set_model(self, model: str) -> None:
|
|
538
|
+
"""Change the model at runtime without recreating the client."""
|
|
539
|
+
self.config.model = model
|
|
540
|
+
|
|
541
|
+
async def close(self):
|
|
542
|
+
"""Close the HTTP client."""
|
|
543
|
+
await self._client.aclose()
|