ripperdoc 0.2.0__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripperdoc/__init__.py +1 -1
- ripperdoc/cli/cli.py +74 -9
- ripperdoc/cli/commands/__init__.py +4 -0
- ripperdoc/cli/commands/agents_cmd.py +30 -4
- ripperdoc/cli/commands/context_cmd.py +11 -1
- ripperdoc/cli/commands/cost_cmd.py +5 -0
- ripperdoc/cli/commands/doctor_cmd.py +208 -0
- ripperdoc/cli/commands/memory_cmd.py +202 -0
- ripperdoc/cli/commands/models_cmd.py +61 -6
- ripperdoc/cli/commands/resume_cmd.py +4 -2
- ripperdoc/cli/commands/status_cmd.py +1 -1
- ripperdoc/cli/commands/tasks_cmd.py +27 -0
- ripperdoc/cli/ui/rich_ui.py +258 -11
- ripperdoc/cli/ui/thinking_spinner.py +128 -0
- ripperdoc/core/agents.py +14 -4
- ripperdoc/core/config.py +56 -3
- ripperdoc/core/default_tools.py +16 -2
- ripperdoc/core/permissions.py +19 -0
- ripperdoc/core/providers/__init__.py +31 -0
- ripperdoc/core/providers/anthropic.py +136 -0
- ripperdoc/core/providers/base.py +187 -0
- ripperdoc/core/providers/gemini.py +172 -0
- ripperdoc/core/providers/openai.py +142 -0
- ripperdoc/core/query.py +510 -386
- ripperdoc/core/query_utils.py +578 -0
- ripperdoc/core/system_prompt.py +2 -1
- ripperdoc/core/tool.py +16 -1
- ripperdoc/sdk/client.py +12 -1
- ripperdoc/tools/background_shell.py +63 -21
- ripperdoc/tools/bash_tool.py +48 -13
- ripperdoc/tools/file_edit_tool.py +20 -0
- ripperdoc/tools/file_read_tool.py +23 -0
- ripperdoc/tools/file_write_tool.py +20 -0
- ripperdoc/tools/glob_tool.py +59 -15
- ripperdoc/tools/grep_tool.py +7 -0
- ripperdoc/tools/ls_tool.py +246 -73
- ripperdoc/tools/mcp_tools.py +32 -10
- ripperdoc/tools/multi_edit_tool.py +23 -0
- ripperdoc/tools/notebook_edit_tool.py +18 -3
- ripperdoc/tools/task_tool.py +7 -0
- ripperdoc/tools/todo_tool.py +157 -25
- ripperdoc/tools/tool_search_tool.py +17 -4
- ripperdoc/utils/file_watch.py +134 -0
- ripperdoc/utils/git_utils.py +274 -0
- ripperdoc/utils/json_utils.py +27 -0
- ripperdoc/utils/log.py +129 -29
- ripperdoc/utils/mcp.py +71 -6
- ripperdoc/utils/memory.py +12 -1
- ripperdoc/utils/message_compaction.py +22 -5
- ripperdoc/utils/messages.py +72 -17
- ripperdoc/utils/output_utils.py +34 -9
- ripperdoc/utils/permissions/path_validation_utils.py +6 -0
- ripperdoc/utils/prompt.py +17 -0
- ripperdoc/utils/safe_get_cwd.py +4 -0
- ripperdoc/utils/session_history.py +27 -9
- ripperdoc/utils/session_usage.py +7 -0
- ripperdoc/utils/shell_utils.py +159 -0
- ripperdoc/utils/todo.py +2 -2
- {ripperdoc-0.2.0.dist-info → ripperdoc-0.2.3.dist-info}/METADATA +4 -2
- ripperdoc-0.2.3.dist-info/RECORD +95 -0
- ripperdoc-0.2.0.dist-info/RECORD +0 -81
- {ripperdoc-0.2.0.dist-info → ripperdoc-0.2.3.dist-info}/WHEEL +0 -0
- {ripperdoc-0.2.0.dist-info → ripperdoc-0.2.3.dist-info}/entry_points.txt +0 -0
- {ripperdoc-0.2.0.dist-info → ripperdoc-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {ripperdoc-0.2.0.dist-info → ripperdoc-0.2.3.dist-info}/top_level.txt +0 -0
ripperdoc/core/query.py
CHANGED
|
@@ -6,86 +6,273 @@ the query-response loop including tool execution.
|
|
|
6
6
|
|
|
7
7
|
import asyncio
|
|
8
8
|
import inspect
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
from
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
9
|
+
import os
|
|
10
|
+
import time
|
|
11
|
+
from asyncio import CancelledError
|
|
12
|
+
from typing import (
|
|
13
|
+
Any,
|
|
14
|
+
AsyncGenerator,
|
|
15
|
+
Awaitable,
|
|
16
|
+
Callable,
|
|
17
|
+
Dict,
|
|
18
|
+
Iterable,
|
|
19
|
+
List,
|
|
20
|
+
Optional,
|
|
21
|
+
Tuple,
|
|
22
|
+
Union,
|
|
23
|
+
cast,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from pydantic import ValidationError
|
|
27
|
+
|
|
28
|
+
from ripperdoc.core.config import provider_protocol
|
|
29
|
+
from ripperdoc.core.providers import ProviderClient, get_provider_client
|
|
30
|
+
from ripperdoc.core.permissions import PermissionResult
|
|
31
|
+
from ripperdoc.core.query_utils import (
|
|
32
|
+
build_full_system_prompt,
|
|
33
|
+
determine_tool_mode,
|
|
34
|
+
extract_tool_use_blocks,
|
|
35
|
+
format_pydantic_errors,
|
|
36
|
+
log_openai_messages,
|
|
37
|
+
resolve_model_profile,
|
|
38
|
+
text_mode_history,
|
|
39
|
+
tool_result_message,
|
|
20
40
|
)
|
|
41
|
+
from ripperdoc.core.tool import Tool, ToolProgress, ToolResult, ToolUseContext
|
|
42
|
+
from ripperdoc.utils.file_watch import ChangedFileNotice, FileSnapshot, detect_changed_files
|
|
21
43
|
from ripperdoc.utils.log import get_logger
|
|
22
44
|
from ripperdoc.utils.messages import (
|
|
23
|
-
MessageContent,
|
|
24
|
-
UserMessage,
|
|
25
45
|
AssistantMessage,
|
|
46
|
+
MessageContent,
|
|
26
47
|
ProgressMessage,
|
|
27
|
-
|
|
48
|
+
UserMessage,
|
|
28
49
|
create_assistant_message,
|
|
50
|
+
create_user_message,
|
|
29
51
|
create_progress_message,
|
|
30
52
|
normalize_messages_for_api,
|
|
31
53
|
INTERRUPT_MESSAGE,
|
|
32
54
|
INTERRUPT_MESSAGE_FOR_TOOL_USE,
|
|
33
55
|
)
|
|
34
|
-
from ripperdoc.core.permissions import PermissionResult
|
|
35
|
-
from ripperdoc.core.config import get_global_config, ProviderType, provider_protocol
|
|
36
|
-
from ripperdoc.utils.session_usage import record_usage
|
|
37
|
-
|
|
38
|
-
import time
|
|
39
56
|
|
|
40
57
|
|
|
41
58
|
logger = get_logger()
|
|
42
59
|
|
|
60
|
+
DEFAULT_REQUEST_TIMEOUT_SEC = float(os.getenv("RIPPERDOC_API_TIMEOUT", "120"))
|
|
61
|
+
MAX_LLM_RETRIES = 1
|
|
62
|
+
|
|
43
63
|
|
|
44
|
-
def
|
|
45
|
-
""
|
|
64
|
+
def _resolve_tool(
|
|
65
|
+
tool_registry: "ToolRegistry", tool_name: str, tool_use_id: str
|
|
66
|
+
) -> tuple[Optional[Tool[Any, Any]], Optional[UserMessage]]:
|
|
67
|
+
"""Find a tool by name and return an error message if missing."""
|
|
68
|
+
tool = tool_registry.get(tool_name)
|
|
69
|
+
if tool:
|
|
70
|
+
tool_registry.activate_tools([tool_name])
|
|
71
|
+
return tool, None
|
|
72
|
+
return None, tool_result_message(
|
|
73
|
+
tool_use_id, f"Error: Tool '{tool_name}' not found", is_error=True
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
ToolPermissionCallable = Callable[
|
|
78
|
+
[Tool[Any, Any], Any],
|
|
79
|
+
Union[
|
|
80
|
+
PermissionResult,
|
|
81
|
+
Dict[str, Any],
|
|
82
|
+
Tuple[bool, Optional[str]],
|
|
83
|
+
bool,
|
|
84
|
+
Awaitable[Union[PermissionResult, Dict[str, Any], Tuple[bool, Optional[str]], bool]],
|
|
85
|
+
],
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
async def _check_tool_permissions(
|
|
90
|
+
tool: Tool[Any, Any],
|
|
91
|
+
parsed_input: Any,
|
|
92
|
+
query_context: "QueryContext",
|
|
93
|
+
can_use_tool_fn: Optional[ToolPermissionCallable],
|
|
94
|
+
) -> tuple[bool, Optional[str]]:
|
|
95
|
+
"""Evaluate whether a tool call is allowed."""
|
|
46
96
|
try:
|
|
47
|
-
if
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
97
|
+
if can_use_tool_fn is not None:
|
|
98
|
+
decision = can_use_tool_fn(tool, parsed_input)
|
|
99
|
+
if inspect.isawaitable(decision):
|
|
100
|
+
decision = await decision
|
|
101
|
+
if isinstance(decision, PermissionResult):
|
|
102
|
+
return decision.result, decision.message
|
|
103
|
+
if isinstance(decision, dict) and "result" in decision:
|
|
104
|
+
return bool(decision.get("result")), decision.get("message")
|
|
105
|
+
if isinstance(decision, tuple) and len(decision) == 2:
|
|
106
|
+
return bool(decision[0]), decision[1]
|
|
107
|
+
return bool(decision), None
|
|
108
|
+
|
|
109
|
+
if query_context.safe_mode and tool.needs_permissions(parsed_input):
|
|
110
|
+
loop = asyncio.get_running_loop()
|
|
111
|
+
input_preview = (
|
|
112
|
+
parsed_input.model_dump()
|
|
113
|
+
if hasattr(parsed_input, "model_dump")
|
|
114
|
+
else str(parsed_input)
|
|
115
|
+
)
|
|
116
|
+
prompt = f"Allow tool '{tool.name}' with input {input_preview}? [y/N]: "
|
|
117
|
+
response = await loop.run_in_executor(None, lambda: input(prompt))
|
|
118
|
+
return response.strip().lower() in ("y", "yes"), None
|
|
119
|
+
|
|
120
|
+
return True, None
|
|
121
|
+
except Exception:
|
|
122
|
+
logger.exception(
|
|
123
|
+
f"Error checking permissions for tool '{tool.name}'",
|
|
124
|
+
extra={"tool": getattr(tool, "name", None)},
|
|
125
|
+
)
|
|
126
|
+
return False, None
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _format_changed_file_notice(notices: List[ChangedFileNotice]) -> str:
|
|
130
|
+
"""Render a system notice about files that changed on disk."""
|
|
131
|
+
lines: List[str] = [
|
|
132
|
+
"System notice: Files you previously read have changed on disk.",
|
|
133
|
+
"Please re-read the affected files before making further edits.",
|
|
134
|
+
"",
|
|
135
|
+
]
|
|
136
|
+
for notice in notices:
|
|
137
|
+
lines.append(f"- {notice.file_path}")
|
|
138
|
+
summary = (notice.summary or "").rstrip()
|
|
139
|
+
if summary:
|
|
140
|
+
indented = "\n".join(f" {line}" for line in summary.splitlines())
|
|
141
|
+
lines.append(indented)
|
|
142
|
+
return "\n".join(lines)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
async def _run_tool_use_generator(
|
|
146
|
+
tool: Tool[Any, Any],
|
|
147
|
+
tool_use_id: str,
|
|
148
|
+
tool_name: str,
|
|
149
|
+
parsed_input: Any,
|
|
150
|
+
sibling_ids: set[str],
|
|
151
|
+
tool_context: ToolUseContext,
|
|
152
|
+
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
153
|
+
"""Execute a single tool_use and yield progress/results."""
|
|
154
|
+
try:
|
|
155
|
+
async for output in tool.call(parsed_input, tool_context):
|
|
156
|
+
if isinstance(output, ToolProgress):
|
|
157
|
+
yield create_progress_message(
|
|
158
|
+
tool_use_id=tool_use_id,
|
|
159
|
+
sibling_tool_use_ids=sibling_ids,
|
|
160
|
+
content=output.content,
|
|
161
|
+
)
|
|
162
|
+
logger.debug(f"[query] Progress from tool_use_id={tool_use_id}: {output.content}")
|
|
163
|
+
elif isinstance(output, ToolResult):
|
|
164
|
+
result_content = output.result_for_assistant or str(output.data)
|
|
165
|
+
result_msg = tool_result_message(
|
|
166
|
+
tool_use_id, result_content, tool_use_result=output.data
|
|
167
|
+
)
|
|
168
|
+
yield result_msg
|
|
169
|
+
logger.debug(
|
|
170
|
+
f"[query] Tool completed tool_use_id={tool_use_id} name={tool_name} "
|
|
171
|
+
f"result_len={len(result_content)}"
|
|
172
|
+
)
|
|
173
|
+
except Exception as exc:
|
|
174
|
+
logger.exception(
|
|
175
|
+
f"Error executing tool '{tool_name}'",
|
|
176
|
+
extra={"tool": tool_name, "tool_use_id": tool_use_id},
|
|
177
|
+
)
|
|
178
|
+
yield tool_result_message(tool_use_id, f"Error executing tool: {str(exc)}", is_error=True)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _group_tool_calls_by_concurrency(prepared_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
182
|
+
"""Group consecutive tool calls by their concurrency safety."""
|
|
183
|
+
groups: List[Dict[str, Any]] = []
|
|
184
|
+
for call in prepared_calls:
|
|
185
|
+
is_safe = bool(call.get("is_concurrency_safe"))
|
|
186
|
+
if groups and groups[-1]["is_concurrency_safe"] == is_safe:
|
|
187
|
+
groups[-1]["items"].append(call)
|
|
188
|
+
else:
|
|
189
|
+
groups.append({"is_concurrency_safe": is_safe, "items": [call]})
|
|
190
|
+
return groups
|
|
80
191
|
|
|
81
|
-
cache_read_tokens = _get_usage_field(prompt_details, "cached_tokens") if prompt_details else 0
|
|
82
192
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
193
|
+
async def _execute_tools_sequentially(
|
|
194
|
+
items: List[Dict[str, Any]], tool_results: List[UserMessage]
|
|
195
|
+
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
196
|
+
"""Run tool generators one by one."""
|
|
197
|
+
for item in items:
|
|
198
|
+
gen = item.get("generator")
|
|
199
|
+
if not gen:
|
|
200
|
+
continue
|
|
201
|
+
async for message in gen:
|
|
202
|
+
if isinstance(message, UserMessage):
|
|
203
|
+
tool_results.append(message)
|
|
204
|
+
yield message
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
async def _execute_tools_in_parallel(
|
|
208
|
+
items: List[Dict[str, Any]], tool_results: List[UserMessage]
|
|
209
|
+
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
210
|
+
"""Run tool generators concurrently."""
|
|
211
|
+
generators = [call["generator"] for call in items if call.get("generator")]
|
|
212
|
+
async for message in _run_concurrent_tool_uses(generators, tool_results):
|
|
213
|
+
yield message
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
async def _run_tools_concurrently(
|
|
217
|
+
prepared_calls: List[Dict[str, Any]], tool_results: List[UserMessage]
|
|
218
|
+
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
219
|
+
"""Run tools grouped by concurrency safety (parallel for safe groups)."""
|
|
220
|
+
for group in _group_tool_calls_by_concurrency(prepared_calls):
|
|
221
|
+
if group["is_concurrency_safe"]:
|
|
222
|
+
logger.debug(
|
|
223
|
+
f"[query] Executing {len(group['items'])} concurrency-safe tool(s) in parallel"
|
|
224
|
+
)
|
|
225
|
+
async for message in _execute_tools_in_parallel(group["items"], tool_results):
|
|
226
|
+
yield message
|
|
227
|
+
else:
|
|
228
|
+
logger.debug(
|
|
229
|
+
f"[query] Executing {len(group['items'])} tool(s) sequentially (not concurrency safe)"
|
|
230
|
+
)
|
|
231
|
+
async for message in _run_tools_serially(group["items"], tool_results):
|
|
232
|
+
yield message
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
async def _run_tools_serially(
|
|
236
|
+
prepared_calls: List[Dict[str, Any]], tool_results: List[UserMessage]
|
|
237
|
+
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
238
|
+
"""Run all tools sequentially (helper for clarity)."""
|
|
239
|
+
async for message in _execute_tools_sequentially(prepared_calls, tool_results):
|
|
240
|
+
yield message
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
async def _run_concurrent_tool_uses(
|
|
244
|
+
generators: List[AsyncGenerator[Union[UserMessage, ProgressMessage], None]],
|
|
245
|
+
tool_results: List[UserMessage],
|
|
246
|
+
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
247
|
+
"""Drain multiple tool generators concurrently and stream outputs."""
|
|
248
|
+
if not generators:
|
|
249
|
+
return
|
|
250
|
+
|
|
251
|
+
queue: asyncio.Queue[Optional[Union[UserMessage, ProgressMessage]]] = asyncio.Queue()
|
|
252
|
+
|
|
253
|
+
async def _consume(gen: AsyncGenerator[Union[UserMessage, ProgressMessage], None]) -> None:
|
|
254
|
+
try:
|
|
255
|
+
async for message in gen:
|
|
256
|
+
await queue.put(message)
|
|
257
|
+
except Exception:
|
|
258
|
+
logger.exception("[query] Unexpected error while consuming tool generator")
|
|
259
|
+
finally:
|
|
260
|
+
await queue.put(None)
|
|
261
|
+
|
|
262
|
+
tasks = [asyncio.create_task(_consume(gen)) for gen in generators]
|
|
263
|
+
active = len(tasks)
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
while active:
|
|
267
|
+
message = await queue.get()
|
|
268
|
+
if message is None:
|
|
269
|
+
active -= 1
|
|
270
|
+
continue
|
|
271
|
+
if isinstance(message, UserMessage):
|
|
272
|
+
tool_results.append(message)
|
|
273
|
+
yield message
|
|
274
|
+
finally:
|
|
275
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
89
276
|
|
|
90
277
|
|
|
91
278
|
class ToolRegistry:
|
|
@@ -118,6 +305,10 @@ class ToolRegistry:
|
|
|
118
305
|
try:
|
|
119
306
|
deferred = tool.defer_loading()
|
|
120
307
|
except Exception:
|
|
308
|
+
logger.exception(
|
|
309
|
+
"[tool_registry] Tool.defer_loading failed",
|
|
310
|
+
extra={"tool": getattr(tool, "name", None)},
|
|
311
|
+
)
|
|
121
312
|
deferred = False
|
|
122
313
|
if deferred:
|
|
123
314
|
self._deferred.add(name)
|
|
@@ -193,6 +384,7 @@ class QueryContext:
|
|
|
193
384
|
self.model = model
|
|
194
385
|
self.verbose = verbose
|
|
195
386
|
self.abort_controller = asyncio.Event()
|
|
387
|
+
self.file_state_cache: Dict[str, FileSnapshot] = {}
|
|
196
388
|
|
|
197
389
|
@property
|
|
198
390
|
def tools(self) -> List[Tool[Any, Any]]:
|
|
@@ -220,6 +412,11 @@ async def query_llm(
|
|
|
220
412
|
max_thinking_tokens: int = 0,
|
|
221
413
|
model: str = "main",
|
|
222
414
|
abort_signal: Optional[asyncio.Event] = None,
|
|
415
|
+
*,
|
|
416
|
+
progress_callback: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
417
|
+
request_timeout: Optional[float] = None,
|
|
418
|
+
max_retries: int = MAX_LLM_RETRIES,
|
|
419
|
+
stream: bool = True,
|
|
223
420
|
) -> AssistantMessage:
|
|
224
421
|
"""Query the AI model and return the response.
|
|
225
422
|
|
|
@@ -230,47 +427,48 @@ async def query_llm(
|
|
|
230
427
|
max_thinking_tokens: Maximum tokens for thinking (0 = disabled)
|
|
231
428
|
model: Model pointer to use
|
|
232
429
|
abort_signal: Event to signal abortion
|
|
430
|
+
progress_callback: Optional async callback invoked with streamed text chunks
|
|
431
|
+
request_timeout: Max seconds to wait for a provider response before retrying
|
|
432
|
+
max_retries: Number of retries on timeout/errors (total attempts = retries + 1)
|
|
433
|
+
stream: Enable streaming for providers that support it (text-only mode)
|
|
233
434
|
|
|
234
435
|
Returns:
|
|
235
436
|
AssistantMessage with the model's response
|
|
236
437
|
"""
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
profile_name = getattr(config.model_pointers, model, None)
|
|
241
|
-
if profile_name is None:
|
|
242
|
-
profile_name = model
|
|
243
|
-
|
|
244
|
-
model_profile = config.model_profiles.get(profile_name)
|
|
245
|
-
if model_profile is None:
|
|
246
|
-
fallback_profile = getattr(config.model_pointers, "main", "default")
|
|
247
|
-
model_profile = config.model_profiles.get(fallback_profile)
|
|
248
|
-
|
|
249
|
-
if not model_profile:
|
|
250
|
-
raise ValueError(f"No model profile found for pointer: {model}")
|
|
438
|
+
request_timeout = request_timeout or DEFAULT_REQUEST_TIMEOUT_SEC
|
|
439
|
+
request_timeout = request_timeout or DEFAULT_REQUEST_TIMEOUT_SEC
|
|
440
|
+
model_profile = resolve_model_profile(model)
|
|
251
441
|
|
|
252
442
|
# Normalize messages based on protocol family (Anthropic allows tool blocks; OpenAI-style prefers text-only)
|
|
253
443
|
protocol = provider_protocol(model_profile.provider)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
444
|
+
tool_mode = determine_tool_mode(model_profile)
|
|
445
|
+
messages_for_model: List[Union[UserMessage, AssistantMessage, ProgressMessage]]
|
|
446
|
+
if tool_mode == "text":
|
|
447
|
+
messages_for_model = cast(
|
|
448
|
+
List[Union[UserMessage, AssistantMessage, ProgressMessage]],
|
|
449
|
+
text_mode_history(messages),
|
|
450
|
+
)
|
|
451
|
+
else:
|
|
452
|
+
messages_for_model = messages
|
|
453
|
+
|
|
454
|
+
normalized_messages: List[Dict[str, Any]] = normalize_messages_for_api(
|
|
455
|
+
messages_for_model, protocol=protocol, tool_mode=tool_mode
|
|
456
|
+
)
|
|
457
|
+
logger.info(
|
|
458
|
+
"[query_llm] Preparing model request",
|
|
459
|
+
extra={
|
|
460
|
+
"model_pointer": model,
|
|
461
|
+
"provider": getattr(model_profile.provider, "value", str(model_profile.provider)),
|
|
462
|
+
"model": model_profile.model,
|
|
463
|
+
"normalized_messages": len(normalized_messages),
|
|
464
|
+
"tool_count": len(tools),
|
|
465
|
+
"max_thinking_tokens": max_thinking_tokens,
|
|
466
|
+
"tool_mode": tool_mode,
|
|
467
|
+
},
|
|
257
468
|
)
|
|
258
469
|
|
|
259
470
|
if protocol == "openai":
|
|
260
|
-
|
|
261
|
-
for idx, m in enumerate(normalized_messages):
|
|
262
|
-
role = m.get("role")
|
|
263
|
-
tool_calls = m.get("tool_calls")
|
|
264
|
-
tc_ids = []
|
|
265
|
-
if tool_calls:
|
|
266
|
-
tc_ids = [tc.get("id") for tc in tool_calls]
|
|
267
|
-
tool_call_id = m.get("tool_call_id")
|
|
268
|
-
summary_parts.append(
|
|
269
|
-
f"{idx}:{role}"
|
|
270
|
-
+ (f" tool_calls={tc_ids}" if tc_ids else "")
|
|
271
|
-
+ (f" tool_call_id={tool_call_id}" if tool_call_id else "")
|
|
272
|
-
)
|
|
273
|
-
logger.debug(f"[query_llm] OpenAI normalized messages: {' | '.join(summary_parts)}")
|
|
471
|
+
log_openai_messages(normalized_messages)
|
|
274
472
|
|
|
275
473
|
logger.debug(
|
|
276
474
|
f"[query_llm] Sending {len(normalized_messages)} messages to model pointer "
|
|
@@ -282,136 +480,49 @@ async def query_llm(
|
|
|
282
480
|
start_time = time.time()
|
|
283
481
|
|
|
284
482
|
try:
|
|
285
|
-
|
|
286
|
-
if
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
system=system_prompt,
|
|
310
|
-
messages=normalized_messages, # type: ignore[arg-type]
|
|
311
|
-
tools=tool_schemas if tool_schemas else None, # type: ignore
|
|
312
|
-
temperature=model_profile.temperature,
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
duration_ms = (time.time() - start_time) * 1000
|
|
316
|
-
|
|
317
|
-
usage_tokens = _anthropic_usage_tokens(getattr(response, "usage", None))
|
|
318
|
-
record_usage(model_profile.model, duration_ms=duration_ms, **usage_tokens)
|
|
319
|
-
|
|
320
|
-
# Calculate cost (simplified, should use actual pricing)
|
|
321
|
-
cost_usd = 0.0 # TODO: Implement cost calculation
|
|
322
|
-
|
|
323
|
-
# Convert response to our format
|
|
324
|
-
content_blocks = []
|
|
325
|
-
for block in response.content:
|
|
326
|
-
if block.type == "text":
|
|
327
|
-
content_blocks.append({"type": "text", "text": block.text})
|
|
328
|
-
elif block.type == "tool_use":
|
|
329
|
-
content_blocks.append(
|
|
330
|
-
{
|
|
331
|
-
"type": "tool_use",
|
|
332
|
-
"tool_use_id": block.id,
|
|
333
|
-
"name": block.name,
|
|
334
|
-
"input": block.input, # type: ignore[dict-item]
|
|
335
|
-
}
|
|
336
|
-
)
|
|
337
|
-
|
|
338
|
-
return create_assistant_message(
|
|
339
|
-
content=content_blocks, cost_usd=cost_usd, duration_ms=duration_ms
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
elif model_profile.provider == ProviderType.OPENAI_COMPATIBLE:
|
|
343
|
-
# OpenAI-compatible APIs (OpenAI, DeepSeek, Mistral, etc.)
|
|
344
|
-
async with AsyncOpenAI(
|
|
345
|
-
api_key=model_profile.api_key, base_url=model_profile.api_base
|
|
346
|
-
) as client:
|
|
347
|
-
# Build tool schemas for OpenAI format
|
|
348
|
-
openai_tools = []
|
|
349
|
-
for tool in tools:
|
|
350
|
-
description = await build_tool_description(
|
|
351
|
-
tool, include_examples=True, max_examples=2
|
|
352
|
-
)
|
|
353
|
-
openai_tools.append(
|
|
354
|
-
{
|
|
355
|
-
"type": "function",
|
|
356
|
-
"function": {
|
|
357
|
-
"name": tool.name,
|
|
358
|
-
"description": description,
|
|
359
|
-
"parameters": tool.input_schema.model_json_schema(),
|
|
360
|
-
},
|
|
361
|
-
}
|
|
362
|
-
)
|
|
363
|
-
|
|
364
|
-
# Prepare messages for OpenAI format
|
|
365
|
-
openai_messages = [
|
|
366
|
-
{"role": "system", "content": system_prompt}
|
|
367
|
-
] + normalized_messages
|
|
368
|
-
|
|
369
|
-
# Make the API call
|
|
370
|
-
openai_response: Any = await client.chat.completions.create(
|
|
371
|
-
model=model_profile.model,
|
|
372
|
-
messages=openai_messages,
|
|
373
|
-
tools=openai_tools if openai_tools else None, # type: ignore[arg-type]
|
|
374
|
-
temperature=model_profile.temperature,
|
|
375
|
-
max_tokens=model_profile.max_tokens,
|
|
376
|
-
)
|
|
377
|
-
|
|
378
|
-
duration_ms = (time.time() - start_time) * 1000
|
|
379
|
-
usage_tokens = _openai_usage_tokens(getattr(openai_response, "usage", None))
|
|
380
|
-
record_usage(model_profile.model, duration_ms=duration_ms, **usage_tokens)
|
|
381
|
-
cost_usd = 0.0 # TODO: Implement cost calculation
|
|
382
|
-
|
|
383
|
-
# Convert OpenAI response to our format
|
|
384
|
-
content_blocks = []
|
|
385
|
-
choice = openai_response.choices[0]
|
|
386
|
-
|
|
387
|
-
if choice.message.content:
|
|
388
|
-
content_blocks.append({"type": "text", "text": choice.message.content})
|
|
389
|
-
|
|
390
|
-
if choice.message.tool_calls:
|
|
391
|
-
for tool_call in choice.message.tool_calls:
|
|
392
|
-
import json
|
|
393
|
-
|
|
394
|
-
content_blocks.append(
|
|
395
|
-
{
|
|
396
|
-
"type": "tool_use",
|
|
397
|
-
"tool_use_id": tool_call.id,
|
|
398
|
-
"name": tool_call.function.name,
|
|
399
|
-
"input": json.loads(tool_call.function.arguments),
|
|
400
|
-
}
|
|
401
|
-
)
|
|
402
|
-
|
|
403
|
-
return create_assistant_message(
|
|
404
|
-
content=content_blocks, cost_usd=cost_usd, duration_ms=duration_ms
|
|
405
|
-
)
|
|
483
|
+
client: Optional[ProviderClient] = get_provider_client(model_profile.provider)
|
|
484
|
+
if client is None:
|
|
485
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
486
|
+
error_msg = create_assistant_message(
|
|
487
|
+
content=(
|
|
488
|
+
"Gemini protocol is not supported yet in Ripperdoc. "
|
|
489
|
+
"Please configure an Anthropic or OpenAI-compatible model."
|
|
490
|
+
),
|
|
491
|
+
duration_ms=duration_ms,
|
|
492
|
+
)
|
|
493
|
+
error_msg.is_api_error_message = True
|
|
494
|
+
return error_msg
|
|
495
|
+
|
|
496
|
+
provider_response = await client.call(
|
|
497
|
+
model_profile=model_profile,
|
|
498
|
+
system_prompt=system_prompt,
|
|
499
|
+
normalized_messages=normalized_messages,
|
|
500
|
+
tools=tools,
|
|
501
|
+
tool_mode=tool_mode,
|
|
502
|
+
stream=stream,
|
|
503
|
+
progress_callback=progress_callback,
|
|
504
|
+
request_timeout=request_timeout,
|
|
505
|
+
max_retries=max_retries,
|
|
506
|
+
)
|
|
406
507
|
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
508
|
+
return create_assistant_message(
|
|
509
|
+
content=provider_response.content_blocks,
|
|
510
|
+
cost_usd=provider_response.cost_usd,
|
|
511
|
+
duration_ms=provider_response.duration_ms,
|
|
512
|
+
)
|
|
411
513
|
|
|
412
514
|
except Exception as e:
|
|
413
515
|
# Return error message
|
|
414
|
-
logger.
|
|
516
|
+
logger.exception(
|
|
517
|
+
"Error querying AI model",
|
|
518
|
+
extra={
|
|
519
|
+
"model": getattr(model_profile, "model", None),
|
|
520
|
+
"model_pointer": model,
|
|
521
|
+
"provider": (
|
|
522
|
+
getattr(model_profile.provider, "value", None) if model_profile else None
|
|
523
|
+
),
|
|
524
|
+
},
|
|
525
|
+
)
|
|
415
526
|
duration_ms = (time.time() - start_time) * 1000
|
|
416
527
|
error_msg = create_assistant_message(
|
|
417
528
|
content=f"Error querying AI model: {str(e)}", duration_ms=duration_ms
|
|
@@ -425,7 +536,7 @@ async def query(
|
|
|
425
536
|
system_prompt: str,
|
|
426
537
|
context: Dict[str, str],
|
|
427
538
|
query_context: QueryContext,
|
|
428
|
-
can_use_tool_fn: Optional[
|
|
539
|
+
can_use_tool_fn: Optional[ToolPermissionCallable] = None,
|
|
429
540
|
) -> AsyncGenerator[Union[UserMessage, AssistantMessage, ProgressMessage], None]:
|
|
430
541
|
"""Execute a query with tool support.
|
|
431
542
|
|
|
@@ -445,59 +556,105 @@ async def query(
|
|
|
445
556
|
Yields:
|
|
446
557
|
Messages (user, assistant, progress) as they are generated
|
|
447
558
|
"""
|
|
559
|
+
logger.info(
|
|
560
|
+
"[query] Starting query loop",
|
|
561
|
+
extra={
|
|
562
|
+
"message_count": len(messages),
|
|
563
|
+
"tool_count": len(query_context.tools),
|
|
564
|
+
"safe_mode": query_context.safe_mode,
|
|
565
|
+
"model_pointer": query_context.model,
|
|
566
|
+
},
|
|
567
|
+
)
|
|
448
568
|
# Work on a copy so external mutations (e.g., UI appending messages while consuming)
|
|
449
569
|
# do not interfere with recursion or normalization.
|
|
450
570
|
messages = list(messages)
|
|
571
|
+
change_notices = detect_changed_files(query_context.file_state_cache)
|
|
572
|
+
if change_notices:
|
|
573
|
+
messages.append(create_user_message(_format_changed_file_notice(change_notices)))
|
|
574
|
+
model_profile = resolve_model_profile(query_context.model)
|
|
575
|
+
tool_mode = determine_tool_mode(model_profile)
|
|
576
|
+
tools_for_model: List[Tool[Any, Any]] = [] if tool_mode == "text" else query_context.all_tools()
|
|
577
|
+
|
|
578
|
+
full_system_prompt = build_full_system_prompt(
|
|
579
|
+
system_prompt, context, tool_mode, query_context.all_tools()
|
|
580
|
+
)
|
|
581
|
+
logger.debug(
|
|
582
|
+
"[query] Built system prompt",
|
|
583
|
+
extra={
|
|
584
|
+
"prompt_chars": len(full_system_prompt),
|
|
585
|
+
"context_entries": len(context),
|
|
586
|
+
"tool_count": len(tools_for_model),
|
|
587
|
+
},
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
progress_queue: asyncio.Queue[Optional[ProgressMessage]] = asyncio.Queue()
|
|
451
591
|
|
|
452
|
-
async def
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
"""Check permissions for tool execution."""
|
|
592
|
+
async def _stream_progress(chunk: str) -> None:
|
|
593
|
+
if not chunk:
|
|
594
|
+
return
|
|
456
595
|
try:
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
return decision.result, decision.message
|
|
463
|
-
if isinstance(decision, dict) and "result" in decision:
|
|
464
|
-
return bool(decision.get("result")), decision.get("message")
|
|
465
|
-
if isinstance(decision, tuple) and len(decision) == 2:
|
|
466
|
-
return bool(decision[0]), decision[1]
|
|
467
|
-
return bool(decision), None
|
|
468
|
-
|
|
469
|
-
if query_context.safe_mode and tool.needs_permissions(parsed_input):
|
|
470
|
-
loop = asyncio.get_running_loop()
|
|
471
|
-
input_preview = (
|
|
472
|
-
parsed_input.model_dump()
|
|
473
|
-
if hasattr(parsed_input, "model_dump")
|
|
474
|
-
else str(parsed_input)
|
|
596
|
+
await progress_queue.put(
|
|
597
|
+
create_progress_message(
|
|
598
|
+
tool_use_id="stream",
|
|
599
|
+
sibling_tool_use_ids=set(),
|
|
600
|
+
content=chunk,
|
|
475
601
|
)
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
messages,
|
|
494
|
-
full_system_prompt,
|
|
495
|
-
query_context.all_tools(),
|
|
496
|
-
query_context.max_thinking_tokens,
|
|
497
|
-
query_context.model,
|
|
498
|
-
query_context.abort_controller,
|
|
602
|
+
)
|
|
603
|
+
except Exception:
|
|
604
|
+
logger.exception("[query] Failed to enqueue stream progress chunk")
|
|
605
|
+
|
|
606
|
+
assistant_task = asyncio.create_task(
|
|
607
|
+
query_llm(
|
|
608
|
+
messages,
|
|
609
|
+
full_system_prompt,
|
|
610
|
+
tools_for_model,
|
|
611
|
+
query_context.max_thinking_tokens,
|
|
612
|
+
query_context.model,
|
|
613
|
+
query_context.abort_controller,
|
|
614
|
+
progress_callback=_stream_progress,
|
|
615
|
+
request_timeout=DEFAULT_REQUEST_TIMEOUT_SEC,
|
|
616
|
+
max_retries=MAX_LLM_RETRIES,
|
|
617
|
+
stream=True,
|
|
618
|
+
)
|
|
499
619
|
)
|
|
500
620
|
|
|
621
|
+
assistant_message: Optional[AssistantMessage] = None
|
|
622
|
+
|
|
623
|
+
while True:
|
|
624
|
+
if query_context.abort_controller.is_set():
|
|
625
|
+
assistant_task.cancel()
|
|
626
|
+
try:
|
|
627
|
+
await assistant_task
|
|
628
|
+
except CancelledError:
|
|
629
|
+
pass
|
|
630
|
+
yield create_assistant_message(INTERRUPT_MESSAGE)
|
|
631
|
+
return
|
|
632
|
+
if assistant_task.done():
|
|
633
|
+
assistant_message = await assistant_task
|
|
634
|
+
break
|
|
635
|
+
try:
|
|
636
|
+
progress = progress_queue.get_nowait()
|
|
637
|
+
except asyncio.QueueEmpty:
|
|
638
|
+
waiter = asyncio.create_task(progress_queue.get())
|
|
639
|
+
done, pending = await asyncio.wait(
|
|
640
|
+
{assistant_task, waiter}, return_when=asyncio.FIRST_COMPLETED
|
|
641
|
+
)
|
|
642
|
+
if assistant_task in done:
|
|
643
|
+
for task in pending:
|
|
644
|
+
task.cancel()
|
|
645
|
+
assistant_message = await assistant_task
|
|
646
|
+
break
|
|
647
|
+
progress = waiter.result()
|
|
648
|
+
if progress:
|
|
649
|
+
yield progress
|
|
650
|
+
|
|
651
|
+
while not progress_queue.empty():
|
|
652
|
+
residual = progress_queue.get_nowait()
|
|
653
|
+
if residual:
|
|
654
|
+
yield residual
|
|
655
|
+
|
|
656
|
+
assert assistant_message is not None
|
|
657
|
+
|
|
501
658
|
# Check for abort
|
|
502
659
|
if query_context.abort_controller.is_set():
|
|
503
660
|
yield create_assistant_message(INTERRUPT_MESSAGE)
|
|
@@ -505,175 +662,142 @@ async def query(
|
|
|
505
662
|
|
|
506
663
|
yield assistant_message
|
|
507
664
|
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
)
|
|
665
|
+
tool_use_blocks: List[MessageContent] = extract_tool_use_blocks(assistant_message)
|
|
666
|
+
text_blocks = (
|
|
667
|
+
len(assistant_message.message.content)
|
|
668
|
+
if isinstance(assistant_message.message.content, list)
|
|
669
|
+
else 1
|
|
670
|
+
)
|
|
515
671
|
logger.debug(
|
|
516
|
-
f"[query] Assistant message received: "
|
|
517
|
-
f"
|
|
518
|
-
f"tool_use_blocks={tool_block_count}"
|
|
672
|
+
f"[query] Assistant message received: text_blocks={text_blocks}, "
|
|
673
|
+
f"tool_use_blocks={len(tool_use_blocks)}"
|
|
519
674
|
)
|
|
520
675
|
|
|
521
|
-
# Check for tool use
|
|
522
|
-
tool_use_blocks = []
|
|
523
|
-
if isinstance(assistant_message.message.content, list):
|
|
524
|
-
for block in assistant_message.message.content:
|
|
525
|
-
normalized_block = MessageContent(**block) if isinstance(block, dict) else block
|
|
526
|
-
if hasattr(normalized_block, "type") and normalized_block.type == "tool_use":
|
|
527
|
-
tool_use_blocks.append(normalized_block)
|
|
528
|
-
|
|
529
|
-
# If no tool use, we're done
|
|
530
676
|
if not tool_use_blocks:
|
|
531
677
|
logger.debug("[query] No tool_use blocks; returning response to user.")
|
|
532
678
|
return
|
|
533
679
|
|
|
534
|
-
# Execute tools
|
|
535
|
-
tool_results: List[UserMessage] = []
|
|
536
|
-
|
|
537
680
|
logger.debug(f"[query] Executing {len(tool_use_blocks)} tool_use block(s).")
|
|
681
|
+
tool_results: List[UserMessage] = []
|
|
682
|
+
permission_denied = False
|
|
683
|
+
sibling_ids = set(
|
|
684
|
+
getattr(t, "tool_use_id", None) or getattr(t, "id", None) or "" for t in tool_use_blocks
|
|
685
|
+
)
|
|
686
|
+
prepared_calls: List[Dict[str, Any]] = []
|
|
538
687
|
|
|
539
688
|
for tool_use in tool_use_blocks:
|
|
540
689
|
tool_name = tool_use.name
|
|
541
690
|
if not tool_name:
|
|
542
691
|
continue
|
|
543
|
-
|
|
692
|
+
tool_use_id = getattr(tool_use, "tool_use_id", None) or getattr(tool_use, "id", None) or ""
|
|
544
693
|
tool_input = getattr(tool_use, "input", {}) or {}
|
|
545
694
|
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
if not tool:
|
|
553
|
-
# Tool not found
|
|
554
|
-
logger.warning(f"[query] Tool '{tool_name}' not found for tool_use_id={tool_id}")
|
|
555
|
-
result_msg = create_user_message(
|
|
556
|
-
[
|
|
557
|
-
{
|
|
558
|
-
"type": "tool_result",
|
|
559
|
-
"tool_use_id": tool_id,
|
|
560
|
-
"text": f"Error: Tool '{tool_name}' not found",
|
|
561
|
-
"is_error": True,
|
|
562
|
-
}
|
|
563
|
-
]
|
|
564
|
-
)
|
|
565
|
-
tool_results.append(result_msg)
|
|
566
|
-
yield result_msg
|
|
695
|
+
tool, missing_msg = _resolve_tool(query_context.tool_registry, tool_name, tool_use_id)
|
|
696
|
+
if missing_msg:
|
|
697
|
+
logger.warning(f"[query] Tool '{tool_name}' not found for tool_use_id={tool_use_id}")
|
|
698
|
+
tool_results.append(missing_msg)
|
|
699
|
+
yield missing_msg
|
|
567
700
|
continue
|
|
568
|
-
|
|
569
|
-
# Execute the tool
|
|
570
|
-
tool_context = ToolUseContext(
|
|
571
|
-
safe_mode=query_context.safe_mode,
|
|
572
|
-
verbose=query_context.verbose,
|
|
573
|
-
permission_checker=can_use_tool_fn,
|
|
574
|
-
tool_registry=query_context.tool_registry,
|
|
575
|
-
)
|
|
701
|
+
assert tool is not None
|
|
576
702
|
|
|
577
703
|
try:
|
|
578
|
-
# Parse input using tool's schema
|
|
579
704
|
parsed_input = tool.input_schema(**tool_input)
|
|
580
705
|
logger.debug(
|
|
581
|
-
f"[query] tool_use_id={
|
|
706
|
+
f"[query] tool_use_id={tool_use_id} name={tool_name} parsed_input="
|
|
582
707
|
f"{str(parsed_input)[:500]}"
|
|
583
708
|
)
|
|
584
709
|
|
|
585
|
-
|
|
710
|
+
tool_context = ToolUseContext(
|
|
711
|
+
safe_mode=query_context.safe_mode,
|
|
712
|
+
verbose=query_context.verbose,
|
|
713
|
+
permission_checker=can_use_tool_fn,
|
|
714
|
+
tool_registry=query_context.tool_registry,
|
|
715
|
+
file_state_cache=query_context.file_state_cache,
|
|
716
|
+
abort_signal=query_context.abort_controller,
|
|
717
|
+
)
|
|
718
|
+
|
|
586
719
|
validation = await tool.validate_input(parsed_input, tool_context)
|
|
587
720
|
if not validation.result:
|
|
588
721
|
logger.debug(
|
|
589
|
-
f"[query] Validation failed for tool_use_id={
|
|
722
|
+
f"[query] Validation failed for tool_use_id={tool_use_id}: {validation.message}"
|
|
590
723
|
)
|
|
591
|
-
result_msg =
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
"tool_use_id": tool_id,
|
|
596
|
-
"text": validation.message or "Tool input validation failed.",
|
|
597
|
-
"is_error": True,
|
|
598
|
-
}
|
|
599
|
-
]
|
|
724
|
+
result_msg = tool_result_message(
|
|
725
|
+
tool_use_id,
|
|
726
|
+
validation.message or "Tool input validation failed.",
|
|
727
|
+
is_error=True,
|
|
600
728
|
)
|
|
601
729
|
tool_results.append(result_msg)
|
|
602
730
|
yield result_msg
|
|
603
731
|
continue
|
|
604
732
|
|
|
605
|
-
# Permission check (safe mode or custom checker)
|
|
606
733
|
if query_context.safe_mode or can_use_tool_fn is not None:
|
|
607
|
-
allowed, denial_message = await
|
|
734
|
+
allowed, denial_message = await _check_tool_permissions(
|
|
735
|
+
tool, parsed_input, query_context, can_use_tool_fn
|
|
736
|
+
)
|
|
608
737
|
if not allowed:
|
|
609
738
|
logger.debug(
|
|
610
|
-
f"[query] Permission denied for tool_use_id={
|
|
611
|
-
)
|
|
612
|
-
denial_text = denial_message or f"Permission denied for tool '{tool_name}'."
|
|
613
|
-
result_msg = create_user_message(
|
|
614
|
-
[
|
|
615
|
-
{
|
|
616
|
-
"type": "tool_result",
|
|
617
|
-
"tool_use_id": tool_id,
|
|
618
|
-
"text": denial_text,
|
|
619
|
-
"is_error": True,
|
|
620
|
-
}
|
|
621
|
-
]
|
|
622
|
-
)
|
|
623
|
-
tool_results.append(result_msg)
|
|
624
|
-
yield result_msg
|
|
625
|
-
continue
|
|
626
|
-
|
|
627
|
-
# Execute tool
|
|
628
|
-
async for output in tool.call(parsed_input, tool_context):
|
|
629
|
-
if isinstance(output, ToolProgress):
|
|
630
|
-
# Yield progress
|
|
631
|
-
progress = create_progress_message(
|
|
632
|
-
tool_use_id=tool_id,
|
|
633
|
-
sibling_tool_use_ids=set(
|
|
634
|
-
getattr(t, "tool_use_id", None) or getattr(t, "id", None) or ""
|
|
635
|
-
for t in tool_use_blocks
|
|
636
|
-
),
|
|
637
|
-
content=output.content,
|
|
638
|
-
)
|
|
639
|
-
yield progress
|
|
640
|
-
logger.debug(f"[query] Progress from tool_use_id={tool_id}: {output.content}")
|
|
641
|
-
elif isinstance(output, ToolResult):
|
|
642
|
-
# Tool completed
|
|
643
|
-
result_content = output.result_for_assistant or str(output.data)
|
|
644
|
-
result_msg = create_user_message(
|
|
645
|
-
[{"type": "tool_result", "tool_use_id": tool_id, "text": result_content}],
|
|
646
|
-
tool_use_result=output.data,
|
|
647
|
-
)
|
|
648
|
-
tool_results.append(result_msg)
|
|
649
|
-
yield result_msg
|
|
650
|
-
logger.debug(
|
|
651
|
-
f"[query] Tool completed tool_use_id={tool_id} name={tool_name} "
|
|
652
|
-
f"result_len={len(result_content)}"
|
|
739
|
+
f"[query] Permission denied for tool_use_id={tool_use_id}: {denial_message}"
|
|
653
740
|
)
|
|
741
|
+
denial_text = denial_message or f"User aborted the tool invocation: {tool_name}"
|
|
742
|
+
denial_msg = tool_result_message(tool_use_id, denial_text, is_error=True)
|
|
743
|
+
tool_results.append(denial_msg)
|
|
744
|
+
yield denial_msg
|
|
745
|
+
permission_denied = True
|
|
746
|
+
break
|
|
747
|
+
|
|
748
|
+
prepared_calls.append(
|
|
749
|
+
{
|
|
750
|
+
"is_concurrency_safe": tool.is_concurrency_safe(),
|
|
751
|
+
"generator": _run_tool_use_generator(
|
|
752
|
+
tool,
|
|
753
|
+
tool_use_id,
|
|
754
|
+
tool_name,
|
|
755
|
+
parsed_input,
|
|
756
|
+
sibling_ids,
|
|
757
|
+
tool_context,
|
|
758
|
+
),
|
|
759
|
+
}
|
|
760
|
+
)
|
|
654
761
|
|
|
762
|
+
except ValidationError as ve:
|
|
763
|
+
detail_text = format_pydantic_errors(ve)
|
|
764
|
+
error_msg = tool_result_message(
|
|
765
|
+
tool_use_id,
|
|
766
|
+
f"Invalid input for tool '{tool_name}': {detail_text}",
|
|
767
|
+
is_error=True,
|
|
768
|
+
)
|
|
769
|
+
tool_results.append(error_msg)
|
|
770
|
+
yield error_msg
|
|
771
|
+
continue
|
|
655
772
|
except Exception as e:
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
"tool_use_id": tool_id,
|
|
663
|
-
"text": f"Error executing tool: {str(e)}",
|
|
664
|
-
"is_error": True,
|
|
665
|
-
}
|
|
666
|
-
]
|
|
773
|
+
logger.exception(
|
|
774
|
+
f"Error executing tool '{tool_name}'",
|
|
775
|
+
extra={"tool": tool_name, "tool_use_id": tool_use_id},
|
|
776
|
+
)
|
|
777
|
+
error_msg = tool_result_message(
|
|
778
|
+
tool_use_id, f"Error executing tool: {str(e)}", is_error=True
|
|
667
779
|
)
|
|
668
780
|
tool_results.append(error_msg)
|
|
669
781
|
yield error_msg
|
|
670
782
|
|
|
783
|
+
if permission_denied:
|
|
784
|
+
break
|
|
785
|
+
|
|
786
|
+
if permission_denied:
|
|
787
|
+
return
|
|
788
|
+
|
|
789
|
+
if prepared_calls:
|
|
790
|
+
async for message in _run_tools_concurrently(prepared_calls, tool_results):
|
|
791
|
+
yield message
|
|
792
|
+
|
|
671
793
|
# Check for abort after tools
|
|
672
794
|
if query_context.abort_controller.is_set():
|
|
673
795
|
yield create_assistant_message(INTERRUPT_MESSAGE_FOR_TOOL_USE)
|
|
674
796
|
return
|
|
675
797
|
|
|
676
|
-
|
|
798
|
+
if permission_denied:
|
|
799
|
+
return
|
|
800
|
+
|
|
677
801
|
new_messages = messages + [assistant_message] + tool_results
|
|
678
802
|
logger.debug(
|
|
679
803
|
f"[query] Recursing with {len(new_messages)} messages after tools; "
|