vtx-coding-agent 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vtx/__init__.py +63 -0
- vtx/async_utils.py +40 -0
- vtx/builtin_skills/github/SKILL.md +139 -0
- vtx/builtin_skills/init/SKILL.md +74 -0
- vtx/builtin_skills/review/SKILL.md +73 -0
- vtx/builtin_skills/skill-builder/SKILL.md +133 -0
- vtx/cli.py +90 -0
- vtx/config.py +741 -0
- vtx/context/__init__.py +15 -0
- vtx/context/_xml.py +8 -0
- vtx/context/agent_mds.py +128 -0
- vtx/context/git.py +64 -0
- vtx/context/loader.py +41 -0
- vtx/context/skills.py +423 -0
- vtx/core/__init__.py +47 -0
- vtx/core/compaction.py +89 -0
- vtx/core/errors.py +17 -0
- vtx/core/handoff.py +51 -0
- vtx/core/scratchpad.py +54 -0
- vtx/core/types.py +197 -0
- vtx/defaults/__init__.py +0 -0
- vtx/defaults/config.yml +53 -0
- vtx/diff_display.py +12 -0
- vtx/events.py +224 -0
- vtx/gh_cli.py +82 -0
- vtx/git_branch.py +90 -0
- vtx/headless.py +127 -0
- vtx/llm/__init__.py +93 -0
- vtx/llm/base.py +217 -0
- vtx/llm/context_length.py +150 -0
- vtx/llm/dynamic_models.py +735 -0
- vtx/llm/model_fetcher.py +279 -0
- vtx/llm/models.py +78 -0
- vtx/llm/oauth/__init__.py +59 -0
- vtx/llm/oauth/copilot.py +358 -0
- vtx/llm/oauth/dynamic.py +236 -0
- vtx/llm/oauth/openai.py +400 -0
- vtx/llm/phase_parser.py +270 -0
- vtx/llm/provider.yaml +280 -0
- vtx/llm/provider_catalog.py +230 -0
- vtx/llm/providers/__init__.py +45 -0
- vtx/llm/providers/anthropic_sdk.py +256 -0
- vtx/llm/providers/mock.py +249 -0
- vtx/llm/providers/openai_sdk.py +246 -0
- vtx/llm/providers/sanitize.py +14 -0
- vtx/llm/sdk/__init__.py +13 -0
- vtx/llm/sdk/anthropic.py +382 -0
- vtx/llm/sdk/base.py +82 -0
- vtx/llm/sdk/openai.py +344 -0
- vtx/llm/tool_parser.py +161 -0
- vtx/loop.py +272 -0
- vtx/notify.py +109 -0
- vtx/permissions.py +114 -0
- vtx/prompts/__init__.py +45 -0
- vtx/prompts/builder.py +86 -0
- vtx/prompts/env.py +58 -0
- vtx/prompts/identity.py +166 -0
- vtx/prompts/tooling.py +36 -0
- vtx/py.typed +0 -0
- vtx/runtime.py +580 -0
- vtx/session.py +868 -0
- vtx/sounds/completion.wav +0 -0
- vtx/sounds/error.wav +0 -0
- vtx/sounds/permission.wav +0 -0
- vtx/themes.py +1104 -0
- vtx/tools/__init__.py +68 -0
- vtx/tools/_read_image.py +106 -0
- vtx/tools/_tool_utils.py +90 -0
- vtx/tools/base.py +36 -0
- vtx/tools/bash.py +371 -0
- vtx/tools/edit.py +261 -0
- vtx/tools/find.py +132 -0
- vtx/tools/read.py +238 -0
- vtx/tools/skill.py +278 -0
- vtx/tools/web.py +238 -0
- vtx/tools/write.py +88 -0
- vtx/tools_manager.py +216 -0
- vtx/turn.py +789 -0
- vtx/ui/__init__.py +0 -0
- vtx/ui/agent_runner.py +417 -0
- vtx/ui/app.py +665 -0
- vtx/ui/app_protocol.py +29 -0
- vtx/ui/autocomplete.py +440 -0
- vtx/ui/blocks.py +735 -0
- vtx/ui/chat.py +613 -0
- vtx/ui/clipboard.py +59 -0
- vtx/ui/commands/__init__.py +100 -0
- vtx/ui/commands/auth.py +306 -0
- vtx/ui/commands/base.py +122 -0
- vtx/ui/commands/models.py +144 -0
- vtx/ui/commands/sessions.py +388 -0
- vtx/ui/commands/settings.py +286 -0
- vtx/ui/completion_ui.py +313 -0
- vtx/ui/export.py +703 -0
- vtx/ui/floating_list.py +370 -0
- vtx/ui/formatting.py +287 -0
- vtx/ui/input.py +760 -0
- vtx/ui/latex.py +349 -0
- vtx/ui/launch.py +108 -0
- vtx/ui/path_complete.py +228 -0
- vtx/ui/prompt_history.py +102 -0
- vtx/ui/queue_ui.py +141 -0
- vtx/ui/selection_mode.py +18 -0
- vtx/ui/session_ui.py +235 -0
- vtx/ui/startup.py +124 -0
- vtx/ui/styles.py +327 -0
- vtx/ui/tool_output.py +34 -0
- vtx/ui/tree.py +437 -0
- vtx/ui/welcome.py +51 -0
- vtx/ui/widgets.py +558 -0
- vtx/update_check.py +49 -0
- vtx/version.py +22 -0
- vtx_coding_agent-0.1.1.dist-info/METADATA +259 -0
- vtx_coding_agent-0.1.1.dist-info/RECORD +117 -0
- vtx_coding_agent-0.1.1.dist-info/WHEEL +4 -0
- vtx_coding_agent-0.1.1.dist-info/entry_points.txt +2 -0
- vtx_coding_agent-0.1.1.dist-info/licenses/LICENSE +201 -0
vtx/llm/sdk/base.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Base SDK class for LLM providers."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class Message:
|
|
11
|
+
role: str
|
|
12
|
+
content: str
|
|
13
|
+
metadata: dict[str, Any] | None = None
|
|
14
|
+
image_parts: list[str] | None = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class GenerationConfig:
|
|
19
|
+
model: str
|
|
20
|
+
temperature: float = 0.7
|
|
21
|
+
max_tokens: int | None = None
|
|
22
|
+
top_p: float | None = None
|
|
23
|
+
frequency_penalty: float | None = None
|
|
24
|
+
presence_penalty: float | None = None
|
|
25
|
+
stop_sequences: list[str] | None = None
|
|
26
|
+
tool_choice: str | dict | bool | None = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class ToolCall:
|
|
31
|
+
id: str
|
|
32
|
+
name: str
|
|
33
|
+
arguments: str
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class GenerationResponse:
|
|
38
|
+
content: str
|
|
39
|
+
model: str
|
|
40
|
+
finish_reason: str | None = None
|
|
41
|
+
tool_calls: list[ToolCall] | None = None
|
|
42
|
+
usage: dict[str, int] | None = None
|
|
43
|
+
reasoning_content: str = ""
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class BaseLLMSDK(ABC):
|
|
47
|
+
def __init__(self, api_key: str, base_url: str | None = None):
|
|
48
|
+
self.api_key = api_key
|
|
49
|
+
self.base_url = base_url
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def client(self): ...
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
async def generate(
|
|
57
|
+
self, messages: list[Message], config: GenerationConfig, stream: bool = False
|
|
58
|
+
) -> GenerationResponse | AsyncGenerator: ...
|
|
59
|
+
|
|
60
|
+
@abstractmethod
|
|
61
|
+
async def generate_with_tools(
|
|
62
|
+
self,
|
|
63
|
+
messages: list[Message],
|
|
64
|
+
tools: list[dict],
|
|
65
|
+
config: GenerationConfig,
|
|
66
|
+
stream: bool = False,
|
|
67
|
+
) -> GenerationResponse | AsyncGenerator: ...
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def get_available_models(self) -> list[str]: ...
|
|
71
|
+
|
|
72
|
+
def convert_messages_to_dict(self, messages: list[Message]) -> list[dict]:
|
|
73
|
+
result = []
|
|
74
|
+
for msg in messages:
|
|
75
|
+
if msg.image_parts:
|
|
76
|
+
content: list[dict[str, Any]] = [{"type": "text", "text": msg.content}]
|
|
77
|
+
for image_url in msg.image_parts:
|
|
78
|
+
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
79
|
+
result.append({"role": msg.role, "content": content, **(msg.metadata or {})})
|
|
80
|
+
else:
|
|
81
|
+
result.append({"role": msg.role, "content": msg.content, **(msg.metadata or {})})
|
|
82
|
+
return result
|
vtx/llm/sdk/openai.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
"""OpenAI GPT SDK using the official openai package."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from collections.abc import AsyncGenerator, AsyncIterator
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from openai import AsyncOpenAI
|
|
10
|
+
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
|
11
|
+
|
|
12
|
+
from .base import BaseLLMSDK, GenerationConfig, GenerationResponse, Message, ToolCall
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
_DEFAULT_MODEL = "gpt-4o"
|
|
17
|
+
_MAX_RETRIES = 3
|
|
18
|
+
_RETRY_BASE_DELAY = 1.0
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _is_transient_error(e: Exception) -> bool:
|
|
22
|
+
msg = str(e).lower()
|
|
23
|
+
return any(
|
|
24
|
+
s in msg
|
|
25
|
+
for s in [
|
|
26
|
+
"connection",
|
|
27
|
+
"connect",
|
|
28
|
+
"timeout",
|
|
29
|
+
"timed out",
|
|
30
|
+
"reset",
|
|
31
|
+
"broken pipe",
|
|
32
|
+
"eof",
|
|
33
|
+
"network",
|
|
34
|
+
"unavailable",
|
|
35
|
+
"bad gateway",
|
|
36
|
+
"gateway timeout",
|
|
37
|
+
"service unavailable",
|
|
38
|
+
]
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
async def _retry_on_transient(coro_factory, max_retries: int = _MAX_RETRIES):
|
|
43
|
+
last_error = None
|
|
44
|
+
for attempt in range(max_retries):
|
|
45
|
+
try:
|
|
46
|
+
return await coro_factory()
|
|
47
|
+
except Exception as e:
|
|
48
|
+
last_error = e
|
|
49
|
+
if not _is_transient_error(e) or attempt == max_retries - 1:
|
|
50
|
+
raise
|
|
51
|
+
delay = _RETRY_BASE_DELAY * (2**attempt)
|
|
52
|
+
logger.warning(
|
|
53
|
+
"Transient error (attempt %d/%d), retrying in %.1fs: %s",
|
|
54
|
+
attempt + 1,
|
|
55
|
+
max_retries,
|
|
56
|
+
delay,
|
|
57
|
+
str(e)[:200],
|
|
58
|
+
)
|
|
59
|
+
await asyncio.sleep(delay)
|
|
60
|
+
if last_error is not None:
|
|
61
|
+
raise last_error
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
async def _openai_stream_chunks(
|
|
65
|
+
stream: AsyncIterator[ChatCompletionChunk],
|
|
66
|
+
) -> AsyncGenerator[dict[str, Any], None]:
|
|
67
|
+
from ..phase_parser import (
|
|
68
|
+
INLINE_THINK_SIGNATURE,
|
|
69
|
+
ResponseDelta,
|
|
70
|
+
ResponseEnd,
|
|
71
|
+
ResponseStart,
|
|
72
|
+
ThinkDelta,
|
|
73
|
+
ThinkEnd,
|
|
74
|
+
ThinkingPhaseParser,
|
|
75
|
+
ThinkStart,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
tool_calls_acc: dict[int, dict[str, Any]] = {}
|
|
79
|
+
phase_parser = ThinkingPhaseParser()
|
|
80
|
+
think_emitted_len = 0
|
|
81
|
+
try:
|
|
82
|
+
async for chunk in stream:
|
|
83
|
+
if chunk.usage:
|
|
84
|
+
yield {"type": "usage", "usage": chunk.usage.model_dump()}
|
|
85
|
+
if not chunk.choices:
|
|
86
|
+
continue
|
|
87
|
+
delta = chunk.choices[0].delta
|
|
88
|
+
reasoning_delta = getattr(delta, "reasoning_content", None) or getattr(
|
|
89
|
+
delta, "reasoning", None
|
|
90
|
+
)
|
|
91
|
+
if reasoning_delta:
|
|
92
|
+
yield {
|
|
93
|
+
"type": "reasoning",
|
|
94
|
+
"content": reasoning_delta,
|
|
95
|
+
"signature": "reasoning_content",
|
|
96
|
+
}
|
|
97
|
+
elif delta.content:
|
|
98
|
+
for phase_event in phase_parser.feed(delta.content):
|
|
99
|
+
if isinstance(phase_event, ThinkStart):
|
|
100
|
+
pass
|
|
101
|
+
elif isinstance(phase_event, ThinkDelta):
|
|
102
|
+
think_emitted_len += len(phase_event.text)
|
|
103
|
+
yield {
|
|
104
|
+
"type": "reasoning",
|
|
105
|
+
"content": phase_event.text,
|
|
106
|
+
"signature": INLINE_THINK_SIGNATURE,
|
|
107
|
+
}
|
|
108
|
+
elif isinstance(phase_event, ThinkEnd):
|
|
109
|
+
remaining = phase_event.full_thinking[think_emitted_len:]
|
|
110
|
+
if remaining:
|
|
111
|
+
yield {
|
|
112
|
+
"type": "reasoning",
|
|
113
|
+
"content": remaining,
|
|
114
|
+
"signature": INLINE_THINK_SIGNATURE,
|
|
115
|
+
}
|
|
116
|
+
think_emitted_len = 0
|
|
117
|
+
elif isinstance(phase_event, ResponseStart):
|
|
118
|
+
pass
|
|
119
|
+
elif isinstance(phase_event, ResponseDelta):
|
|
120
|
+
yield {"type": "text", "content": phase_event.text}
|
|
121
|
+
elif isinstance(phase_event, ResponseEnd):
|
|
122
|
+
pass
|
|
123
|
+
if delta.tool_calls:
|
|
124
|
+
for tc_delta in delta.tool_calls:
|
|
125
|
+
idx = tc_delta.index
|
|
126
|
+
if idx not in tool_calls_acc:
|
|
127
|
+
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
|
|
128
|
+
if tc_delta.id:
|
|
129
|
+
tool_calls_acc[idx]["id"] = tc_delta.id
|
|
130
|
+
if tc_delta.function:
|
|
131
|
+
if tc_delta.function.name:
|
|
132
|
+
tool_calls_acc[idx]["name"] = tc_delta.function.name
|
|
133
|
+
if tc_delta.function.arguments:
|
|
134
|
+
tool_calls_acc[idx]["arguments"] += tc_delta.function.arguments
|
|
135
|
+
for phase_event in phase_parser.flush():
|
|
136
|
+
if isinstance(phase_event, ThinkDelta):
|
|
137
|
+
think_emitted_len += len(phase_event.text)
|
|
138
|
+
yield {
|
|
139
|
+
"type": "reasoning",
|
|
140
|
+
"content": phase_event.text,
|
|
141
|
+
"signature": INLINE_THINK_SIGNATURE,
|
|
142
|
+
}
|
|
143
|
+
elif isinstance(phase_event, ThinkEnd):
|
|
144
|
+
remaining = phase_event.full_thinking[think_emitted_len:]
|
|
145
|
+
if remaining:
|
|
146
|
+
yield {
|
|
147
|
+
"type": "reasoning",
|
|
148
|
+
"content": remaining,
|
|
149
|
+
"signature": INLINE_THINK_SIGNATURE,
|
|
150
|
+
}
|
|
151
|
+
think_emitted_len = 0
|
|
152
|
+
elif isinstance(phase_event, ResponseDelta):
|
|
153
|
+
yield {"type": "text", "content": phase_event.text}
|
|
154
|
+
if tool_calls_acc:
|
|
155
|
+
yield {
|
|
156
|
+
"type": "tool_calls",
|
|
157
|
+
"tool_calls": [
|
|
158
|
+
ToolCall(id=v["id"], name=v["name"], arguments=v["arguments"])
|
|
159
|
+
for v in sorted(tool_calls_acc.values(), key=lambda x: x["id"])
|
|
160
|
+
],
|
|
161
|
+
}
|
|
162
|
+
finally:
|
|
163
|
+
if hasattr(stream, "close"):
|
|
164
|
+
try:
|
|
165
|
+
from typing import cast as typing_cast
|
|
166
|
+
|
|
167
|
+
await typing_cast(Any, stream).close()
|
|
168
|
+
except Exception:
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class OpenAISDK(BaseLLMSDK):
|
|
173
|
+
def __init__(self, api_key: str, base_url: str | None = None, rate_limit_hook=None):
|
|
174
|
+
resolved_url = base_url or "https://api.openai.com/v1"
|
|
175
|
+
if resolved_url.startswith("http://"):
|
|
176
|
+
resolved_url = "https://" + resolved_url[7:]
|
|
177
|
+
super().__init__(api_key, resolved_url)
|
|
178
|
+
self._async_client: AsyncOpenAI | None = None
|
|
179
|
+
self._rate_limit_hook = rate_limit_hook
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def client(self) -> AsyncOpenAI:
|
|
183
|
+
if self._async_client is None:
|
|
184
|
+
self._async_client = AsyncOpenAI(
|
|
185
|
+
api_key=self.api_key, base_url=self.base_url, timeout=None, max_retries=3
|
|
186
|
+
)
|
|
187
|
+
return self._async_client
|
|
188
|
+
|
|
189
|
+
def _build_kwargs(
|
|
190
|
+
self, messages: list[Message], config: GenerationConfig, tools: list[dict] | None = None
|
|
191
|
+
) -> dict[str, Any]:
|
|
192
|
+
openai_messages = self.convert_messages_to_dict(messages)
|
|
193
|
+
model = (
|
|
194
|
+
config.model.strip()
|
|
195
|
+
if config.model and config.model.strip()
|
|
196
|
+
else os.getenv("VTX_MODEL", "").strip() or _DEFAULT_MODEL
|
|
197
|
+
)
|
|
198
|
+
kwargs: dict[str, Any] = {"model": model, "messages": openai_messages}
|
|
199
|
+
if config.temperature is not None and config.temperature != 0.7:
|
|
200
|
+
kwargs["temperature"] = config.temperature
|
|
201
|
+
if config.max_tokens is not None:
|
|
202
|
+
kwargs["max_tokens"] = config.max_tokens
|
|
203
|
+
if config.top_p is not None:
|
|
204
|
+
kwargs["top_p"] = config.top_p
|
|
205
|
+
if config.frequency_penalty is not None and config.frequency_penalty != 0.0:
|
|
206
|
+
kwargs["frequency_penalty"] = config.frequency_penalty
|
|
207
|
+
if config.presence_penalty is not None and config.presence_penalty != 0.0:
|
|
208
|
+
kwargs["presence_penalty"] = config.presence_penalty
|
|
209
|
+
if config.stop_sequences:
|
|
210
|
+
kwargs["stop"] = config.stop_sequences
|
|
211
|
+
if tools:
|
|
212
|
+
kwargs["tools"] = [
|
|
213
|
+
ChatCompletionToolParam(
|
|
214
|
+
type="function",
|
|
215
|
+
function={
|
|
216
|
+
"name": t["function"]["name"],
|
|
217
|
+
"description": t["function"].get("description", ""),
|
|
218
|
+
"parameters": t["function"]["parameters"],
|
|
219
|
+
},
|
|
220
|
+
)
|
|
221
|
+
for t in tools
|
|
222
|
+
]
|
|
223
|
+
if config.tool_choice is not None:
|
|
224
|
+
kwargs["tool_choice"] = config.tool_choice
|
|
225
|
+
return kwargs
|
|
226
|
+
|
|
227
|
+
async def generate(
|
|
228
|
+
self, messages: list[Message], config: GenerationConfig, stream: bool = False
|
|
229
|
+
) -> GenerationResponse | AsyncGenerator:
|
|
230
|
+
try:
|
|
231
|
+
kwargs = self._build_kwargs(messages, config)
|
|
232
|
+
kwargs["stream"] = stream
|
|
233
|
+
if stream:
|
|
234
|
+
raw_stream = await _retry_on_transient(
|
|
235
|
+
lambda: self.client.chat.completions.create(**kwargs)
|
|
236
|
+
)
|
|
237
|
+
return _openai_stream_chunks(raw_stream)
|
|
238
|
+
else:
|
|
239
|
+
|
|
240
|
+
async def _do_generate():
|
|
241
|
+
return await self.client.chat.completions.create(**kwargs)
|
|
242
|
+
|
|
243
|
+
completion = await _retry_on_transient(_do_generate)
|
|
244
|
+
choice = completion.choices[0]
|
|
245
|
+
msg = choice.message
|
|
246
|
+
content = msg.content or ""
|
|
247
|
+
reasoning = (
|
|
248
|
+
getattr(msg, "reasoning_content", None) or getattr(msg, "reasoning", "") or ""
|
|
249
|
+
)
|
|
250
|
+
usage = completion.usage
|
|
251
|
+
return GenerationResponse(
|
|
252
|
+
content=content,
|
|
253
|
+
model=completion.model,
|
|
254
|
+
finish_reason=choice.finish_reason,
|
|
255
|
+
usage=(
|
|
256
|
+
{
|
|
257
|
+
"input_tokens": usage.prompt_tokens if usage else 0,
|
|
258
|
+
"output_tokens": usage.completion_tokens if usage else 0,
|
|
259
|
+
"total_tokens": usage.total_tokens if usage else 0,
|
|
260
|
+
}
|
|
261
|
+
if usage
|
|
262
|
+
else None
|
|
263
|
+
),
|
|
264
|
+
reasoning_content=reasoning,
|
|
265
|
+
)
|
|
266
|
+
except Exception as e:
|
|
267
|
+
error_msg = str(e).lower()
|
|
268
|
+
if "rate limit" in error_msg or "too many requests" in error_msg or "429" in error_msg:
|
|
269
|
+
raise RuntimeError(f"Rate limit exceeded: {e!s}") from e
|
|
270
|
+
raise RuntimeError(f"OpenAI generation failed: {e!s}") from e
|
|
271
|
+
|
|
272
|
+
async def generate_with_tools(
|
|
273
|
+
self,
|
|
274
|
+
messages: list[Message],
|
|
275
|
+
tools: list[dict],
|
|
276
|
+
config: GenerationConfig,
|
|
277
|
+
stream: bool = False,
|
|
278
|
+
) -> GenerationResponse | AsyncGenerator:
|
|
279
|
+
try:
|
|
280
|
+
kwargs = self._build_kwargs(messages, config, tools)
|
|
281
|
+
kwargs["stream"] = stream
|
|
282
|
+
if stream:
|
|
283
|
+
kwargs["stream_options"] = {"include_usage": True}
|
|
284
|
+
raw_stream = await _retry_on_transient(
|
|
285
|
+
lambda: self.client.chat.completions.create(**kwargs)
|
|
286
|
+
)
|
|
287
|
+
return _openai_stream_chunks(raw_stream)
|
|
288
|
+
else:
|
|
289
|
+
|
|
290
|
+
async def _do_generate():
|
|
291
|
+
return await self.client.chat.completions.create(**kwargs)
|
|
292
|
+
|
|
293
|
+
completion = await _retry_on_transient(_do_generate)
|
|
294
|
+
choice = completion.choices[0]
|
|
295
|
+
msg = choice.message
|
|
296
|
+
content = msg.content or ""
|
|
297
|
+
reasoning = (
|
|
298
|
+
getattr(msg, "reasoning_content", None) or getattr(msg, "reasoning", "") or ""
|
|
299
|
+
)
|
|
300
|
+
tool_calls = []
|
|
301
|
+
if msg.tool_calls:
|
|
302
|
+
for tc in msg.tool_calls:
|
|
303
|
+
tool_calls.append(
|
|
304
|
+
ToolCall(
|
|
305
|
+
id=tc.id, name=tc.function.name, arguments=tc.function.arguments
|
|
306
|
+
)
|
|
307
|
+
)
|
|
308
|
+
usage = completion.usage
|
|
309
|
+
return GenerationResponse(
|
|
310
|
+
content=content,
|
|
311
|
+
model=completion.model,
|
|
312
|
+
finish_reason=choice.finish_reason,
|
|
313
|
+
tool_calls=tool_calls or None,
|
|
314
|
+
usage=(
|
|
315
|
+
{
|
|
316
|
+
"input_tokens": usage.prompt_tokens if usage else 0,
|
|
317
|
+
"output_tokens": usage.completion_tokens if usage else 0,
|
|
318
|
+
"total_tokens": usage.total_tokens if usage else 0,
|
|
319
|
+
}
|
|
320
|
+
if usage
|
|
321
|
+
else None
|
|
322
|
+
),
|
|
323
|
+
reasoning_content=reasoning,
|
|
324
|
+
)
|
|
325
|
+
except Exception as e:
|
|
326
|
+
error_msg = str(e).lower()
|
|
327
|
+
if "rate limit" in error_msg or "too many requests" in error_msg or "429" in error_msg:
|
|
328
|
+
raise RuntimeError(f"Rate limit exceeded: {e!s}") from e
|
|
329
|
+
raise RuntimeError(f"OpenAI tool generation failed: {e!s}") from e
|
|
330
|
+
|
|
331
|
+
def get_available_models(self) -> list[str]:
|
|
332
|
+
return ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]
|
|
333
|
+
|
|
334
|
+
def convert_messages_to_dict(self, messages: list[Message]) -> list[dict]:
|
|
335
|
+
result = []
|
|
336
|
+
for msg in messages:
|
|
337
|
+
if msg.image_parts:
|
|
338
|
+
content: list[dict[str, Any]] = [{"type": "text", "text": msg.content}]
|
|
339
|
+
for image_url in msg.image_parts:
|
|
340
|
+
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
341
|
+
result.append({"role": msg.role, "content": content, **(msg.metadata or {})})
|
|
342
|
+
else:
|
|
343
|
+
result.append({"role": msg.role, "content": msg.content, **(msg.metadata or {})})
|
|
344
|
+
return result
|
vtx/llm/tool_parser.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""Tool call parser for handling text-embedded tool calls.
|
|
2
|
+
|
|
3
|
+
This module parses tool calls embedded in text content, supporting formats like:
|
|
4
|
+
- <function=name>...</function>
|
|
5
|
+
- <function name="name">...</function>
|
|
6
|
+
- With nested <parameter name="x">value</parameter> tags
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import re
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def extract_tool_calls_from_text(content: str) -> list[dict[str, Any]]:
|
|
17
|
+
"""Extract tool calls embedded in text content.
|
|
18
|
+
|
|
19
|
+
Supports formats:
|
|
20
|
+
- <function=name>...</function>
|
|
21
|
+
- <function name="name">...</function>
|
|
22
|
+
- With nested <parameter name="x">value</parameter> tags
|
|
23
|
+
- Self-closing: <function=name/> or <function name="name"/>
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
content: Text content that may contain embedded tool calls
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
List of dicts: [{"name": "tool_name", "arguments": {...}}]
|
|
30
|
+
"""
|
|
31
|
+
tool_calls = []
|
|
32
|
+
|
|
33
|
+
# Match both self-closing and open/close tag pairs
|
|
34
|
+
# Pattern for: <function=name ...> or <function name="name" ...>
|
|
35
|
+
# Then either /> for self-closing or >...</function> for open/close
|
|
36
|
+
|
|
37
|
+
# First, find all function tags (both self-closing and with content)
|
|
38
|
+
# Match: <function=name ...> or <function name="name" ...>
|
|
39
|
+
function_start_pattern = (
|
|
40
|
+
r"<function(?:\s+name=[\"\']?([^\"\'\s/>]+)[\"\']?|=[\"\']?([^\"\'\s/>]+)[\"\']?)([^>]*)"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
for match in re.finditer(function_start_pattern, content):
|
|
44
|
+
name1 = match.group(1) # function name="xxx" format
|
|
45
|
+
name2 = match.group(2) # function=xxx format
|
|
46
|
+
attrs = match.group(3) # Additional attributes
|
|
47
|
+
|
|
48
|
+
tool_name = name1 or name2
|
|
49
|
+
if not tool_name:
|
|
50
|
+
continue
|
|
51
|
+
|
|
52
|
+
# Find the full function tag (self-closing or with content)
|
|
53
|
+
start_pos = match.start()
|
|
54
|
+
|
|
55
|
+
# Check if self-closing
|
|
56
|
+
if "/>" in content[start_pos : start_pos + 200]: # Look ahead for />
|
|
57
|
+
# Self-closing tag
|
|
58
|
+
arguments = _parse_function_attributes(attrs)
|
|
59
|
+
tool_calls.append({"name": tool_name, "arguments": arguments})
|
|
60
|
+
else:
|
|
61
|
+
# Find matching </function>
|
|
62
|
+
end_tag = "</function>"
|
|
63
|
+
end_pos = content.find(end_tag, start_pos)
|
|
64
|
+
if end_pos == -1:
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
# Extract body (content between > and </function>)
|
|
68
|
+
tag_end = content.find(">", start_pos)
|
|
69
|
+
if tag_end == -1:
|
|
70
|
+
continue
|
|
71
|
+
|
|
72
|
+
body = content[tag_end + 1 : end_pos]
|
|
73
|
+
|
|
74
|
+
arguments = _parse_function_attributes(attrs)
|
|
75
|
+
|
|
76
|
+
# Parse parameters from body
|
|
77
|
+
if body:
|
|
78
|
+
param_pattern = (
|
|
79
|
+
r"<parameter\s+name=[\"\']?([^\"\'\s/>]+)[\"\']?[^>]*>(.*?)</parameter>"
|
|
80
|
+
)
|
|
81
|
+
for param_match in re.finditer(param_pattern, body, re.DOTALL):
|
|
82
|
+
param_name = param_match.group(1)
|
|
83
|
+
param_value = param_match.group(2)
|
|
84
|
+
arguments[param_name] = param_value
|
|
85
|
+
|
|
86
|
+
tool_calls.append({"name": tool_name, "arguments": arguments})
|
|
87
|
+
|
|
88
|
+
return tool_calls
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _parse_function_attributes(attrs: str) -> dict[str, Any]:
|
|
92
|
+
"""Parse attributes from function tag."""
|
|
93
|
+
arguments = {}
|
|
94
|
+
if not attrs:
|
|
95
|
+
return arguments
|
|
96
|
+
|
|
97
|
+
attr_pattern = r'(\w+)=(?:"([^"]*)"|\'([^\']*)\'|([^\s/>]+))'
|
|
98
|
+
for attr_match in re.finditer(attr_pattern, attrs):
|
|
99
|
+
key = attr_match.group(1)
|
|
100
|
+
value = attr_match.group(2) or attr_match.group(3) or attr_match.group(4) or ""
|
|
101
|
+
if key not in ("name", "function"):
|
|
102
|
+
arguments[key] = value
|
|
103
|
+
|
|
104
|
+
return arguments
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def has_text_tool_calls(content: str) -> bool:
|
|
108
|
+
"""Check if content contains text-embedded tool calls.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
content: Text content to check
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
True if content contains <function...> tags
|
|
115
|
+
"""
|
|
116
|
+
if not content:
|
|
117
|
+
return False
|
|
118
|
+
return bool(re.search(r"<function[\s=]", content))
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def normalize_tool_calls(tool_calls: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
122
|
+
"""Normalize tool calls to format compatible with ToolRegistry.
|
|
123
|
+
|
|
124
|
+
Converts arguments dict to JSON string for compatibility with
|
|
125
|
+
the existing tool execution pipeline.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
tool_calls: List of tool call dicts with "name" and "arguments" keys
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Normalized list with arguments as JSON strings
|
|
132
|
+
"""
|
|
133
|
+
normalized = []
|
|
134
|
+
for tc in tool_calls:
|
|
135
|
+
normalized.append(
|
|
136
|
+
{"name": tc.get("name", ""), "arguments": json.dumps(tc.get("arguments", {}))}
|
|
137
|
+
)
|
|
138
|
+
return normalized
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def extract_text_and_tool_calls(content: str) -> tuple[str, list[dict[str, Any]]]:
|
|
142
|
+
"""Extract both text content and tool calls from a mixed response.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
content: Text that may contain embedded tool calls
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Tuple of (cleaned_text, tool_calls)
|
|
149
|
+
- cleaned_text: Text with tool call tags removed
|
|
150
|
+
- tool_calls: List of extracted tool calls
|
|
151
|
+
"""
|
|
152
|
+
if not has_text_tool_calls(content):
|
|
153
|
+
return content, []
|
|
154
|
+
|
|
155
|
+
tool_calls = extract_tool_calls_from_text(content)
|
|
156
|
+
|
|
157
|
+
# Remove function tags from text (both self-closing and with content)
|
|
158
|
+
cleaned = re.sub(r"<function[^>]*(?:/>|>.*?</function>)", "", content, flags=re.DOTALL)
|
|
159
|
+
cleaned = re.sub(r"\n\s*\n", "\n\n", cleaned).strip() # Clean up extra newlines
|
|
160
|
+
|
|
161
|
+
return cleaned, tool_calls
|