ripperdoc 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripperdoc/__init__.py +3 -0
- ripperdoc/__main__.py +20 -0
- ripperdoc/cli/__init__.py +1 -0
- ripperdoc/cli/cli.py +405 -0
- ripperdoc/cli/commands/__init__.py +82 -0
- ripperdoc/cli/commands/agents_cmd.py +263 -0
- ripperdoc/cli/commands/base.py +19 -0
- ripperdoc/cli/commands/clear_cmd.py +18 -0
- ripperdoc/cli/commands/compact_cmd.py +23 -0
- ripperdoc/cli/commands/config_cmd.py +31 -0
- ripperdoc/cli/commands/context_cmd.py +144 -0
- ripperdoc/cli/commands/cost_cmd.py +82 -0
- ripperdoc/cli/commands/doctor_cmd.py +221 -0
- ripperdoc/cli/commands/exit_cmd.py +19 -0
- ripperdoc/cli/commands/help_cmd.py +20 -0
- ripperdoc/cli/commands/mcp_cmd.py +70 -0
- ripperdoc/cli/commands/memory_cmd.py +202 -0
- ripperdoc/cli/commands/models_cmd.py +413 -0
- ripperdoc/cli/commands/permissions_cmd.py +302 -0
- ripperdoc/cli/commands/resume_cmd.py +98 -0
- ripperdoc/cli/commands/status_cmd.py +167 -0
- ripperdoc/cli/commands/tasks_cmd.py +278 -0
- ripperdoc/cli/commands/todos_cmd.py +69 -0
- ripperdoc/cli/commands/tools_cmd.py +19 -0
- ripperdoc/cli/ui/__init__.py +1 -0
- ripperdoc/cli/ui/context_display.py +298 -0
- ripperdoc/cli/ui/helpers.py +22 -0
- ripperdoc/cli/ui/rich_ui.py +1557 -0
- ripperdoc/cli/ui/spinner.py +49 -0
- ripperdoc/cli/ui/thinking_spinner.py +128 -0
- ripperdoc/cli/ui/tool_renderers.py +298 -0
- ripperdoc/core/__init__.py +1 -0
- ripperdoc/core/agents.py +486 -0
- ripperdoc/core/commands.py +33 -0
- ripperdoc/core/config.py +559 -0
- ripperdoc/core/default_tools.py +88 -0
- ripperdoc/core/permissions.py +252 -0
- ripperdoc/core/providers/__init__.py +47 -0
- ripperdoc/core/providers/anthropic.py +250 -0
- ripperdoc/core/providers/base.py +265 -0
- ripperdoc/core/providers/gemini.py +615 -0
- ripperdoc/core/providers/openai.py +487 -0
- ripperdoc/core/query.py +1058 -0
- ripperdoc/core/query_utils.py +622 -0
- ripperdoc/core/skills.py +295 -0
- ripperdoc/core/system_prompt.py +431 -0
- ripperdoc/core/tool.py +240 -0
- ripperdoc/sdk/__init__.py +9 -0
- ripperdoc/sdk/client.py +333 -0
- ripperdoc/tools/__init__.py +1 -0
- ripperdoc/tools/ask_user_question_tool.py +431 -0
- ripperdoc/tools/background_shell.py +389 -0
- ripperdoc/tools/bash_output_tool.py +98 -0
- ripperdoc/tools/bash_tool.py +1016 -0
- ripperdoc/tools/dynamic_mcp_tool.py +428 -0
- ripperdoc/tools/enter_plan_mode_tool.py +226 -0
- ripperdoc/tools/exit_plan_mode_tool.py +153 -0
- ripperdoc/tools/file_edit_tool.py +346 -0
- ripperdoc/tools/file_read_tool.py +203 -0
- ripperdoc/tools/file_write_tool.py +205 -0
- ripperdoc/tools/glob_tool.py +179 -0
- ripperdoc/tools/grep_tool.py +370 -0
- ripperdoc/tools/kill_bash_tool.py +136 -0
- ripperdoc/tools/ls_tool.py +471 -0
- ripperdoc/tools/mcp_tools.py +591 -0
- ripperdoc/tools/multi_edit_tool.py +456 -0
- ripperdoc/tools/notebook_edit_tool.py +386 -0
- ripperdoc/tools/skill_tool.py +205 -0
- ripperdoc/tools/task_tool.py +379 -0
- ripperdoc/tools/todo_tool.py +494 -0
- ripperdoc/tools/tool_search_tool.py +380 -0
- ripperdoc/utils/__init__.py +1 -0
- ripperdoc/utils/bash_constants.py +51 -0
- ripperdoc/utils/bash_output_utils.py +43 -0
- ripperdoc/utils/coerce.py +34 -0
- ripperdoc/utils/context_length_errors.py +252 -0
- ripperdoc/utils/exit_code_handlers.py +241 -0
- ripperdoc/utils/file_watch.py +135 -0
- ripperdoc/utils/git_utils.py +274 -0
- ripperdoc/utils/json_utils.py +27 -0
- ripperdoc/utils/log.py +176 -0
- ripperdoc/utils/mcp.py +560 -0
- ripperdoc/utils/memory.py +253 -0
- ripperdoc/utils/message_compaction.py +676 -0
- ripperdoc/utils/messages.py +519 -0
- ripperdoc/utils/output_utils.py +258 -0
- ripperdoc/utils/path_ignore.py +677 -0
- ripperdoc/utils/path_utils.py +46 -0
- ripperdoc/utils/permissions/__init__.py +27 -0
- ripperdoc/utils/permissions/path_validation_utils.py +174 -0
- ripperdoc/utils/permissions/shell_command_validation.py +552 -0
- ripperdoc/utils/permissions/tool_permission_utils.py +279 -0
- ripperdoc/utils/prompt.py +17 -0
- ripperdoc/utils/safe_get_cwd.py +31 -0
- ripperdoc/utils/sandbox_utils.py +38 -0
- ripperdoc/utils/session_history.py +260 -0
- ripperdoc/utils/session_usage.py +117 -0
- ripperdoc/utils/shell_token_utils.py +95 -0
- ripperdoc/utils/shell_utils.py +159 -0
- ripperdoc/utils/todo.py +203 -0
- ripperdoc/utils/token_estimation.py +34 -0
- ripperdoc-0.2.6.dist-info/METADATA +193 -0
- ripperdoc-0.2.6.dist-info/RECORD +107 -0
- ripperdoc-0.2.6.dist-info/WHEEL +5 -0
- ripperdoc-0.2.6.dist-info/entry_points.txt +3 -0
- ripperdoc-0.2.6.dist-info/licenses/LICENSE +53 -0
- ripperdoc-0.2.6.dist-info/top_level.txt +1 -0
ripperdoc/core/query.py
ADDED
|
@@ -0,0 +1,1058 @@
|
|
|
1
|
+
"""AI query system for Ripperdoc.
|
|
2
|
+
|
|
3
|
+
This module handles communication with AI models and manages
|
|
4
|
+
the query-response loop including tool execution.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import inspect
|
|
9
|
+
import os
|
|
10
|
+
import time
|
|
11
|
+
from asyncio import CancelledError
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from typing import (
|
|
14
|
+
Any,
|
|
15
|
+
AsyncGenerator,
|
|
16
|
+
Awaitable,
|
|
17
|
+
Callable,
|
|
18
|
+
Dict,
|
|
19
|
+
Iterable,
|
|
20
|
+
List,
|
|
21
|
+
Optional,
|
|
22
|
+
Tuple,
|
|
23
|
+
Union,
|
|
24
|
+
cast,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
from pydantic import ValidationError
|
|
28
|
+
|
|
29
|
+
from ripperdoc.core.config import provider_protocol
|
|
30
|
+
from ripperdoc.core.providers import ProviderClient, get_provider_client
|
|
31
|
+
from ripperdoc.core.permissions import PermissionResult
|
|
32
|
+
from ripperdoc.core.query_utils import (
|
|
33
|
+
build_full_system_prompt,
|
|
34
|
+
determine_tool_mode,
|
|
35
|
+
extract_tool_use_blocks,
|
|
36
|
+
format_pydantic_errors,
|
|
37
|
+
log_openai_messages,
|
|
38
|
+
resolve_model_profile,
|
|
39
|
+
text_mode_history,
|
|
40
|
+
tool_result_message,
|
|
41
|
+
)
|
|
42
|
+
from ripperdoc.core.tool import Tool, ToolProgress, ToolResult, ToolUseContext
|
|
43
|
+
from ripperdoc.utils.coerce import parse_optional_int
|
|
44
|
+
from ripperdoc.utils.context_length_errors import detect_context_length_error
|
|
45
|
+
from ripperdoc.utils.file_watch import ChangedFileNotice, FileSnapshot, detect_changed_files
|
|
46
|
+
from ripperdoc.utils.log import get_logger
|
|
47
|
+
from ripperdoc.utils.messages import (
|
|
48
|
+
AssistantMessage,
|
|
49
|
+
MessageContent,
|
|
50
|
+
ProgressMessage,
|
|
51
|
+
UserMessage,
|
|
52
|
+
create_assistant_message,
|
|
53
|
+
create_user_message,
|
|
54
|
+
create_progress_message,
|
|
55
|
+
normalize_messages_for_api,
|
|
56
|
+
INTERRUPT_MESSAGE,
|
|
57
|
+
INTERRUPT_MESSAGE_FOR_TOOL_USE,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
logger = get_logger()
|
|
62
|
+
|
|
63
|
+
DEFAULT_REQUEST_TIMEOUT_SEC = float(os.getenv("RIPPERDOC_API_TIMEOUT", "120"))
|
|
64
|
+
MAX_LLM_RETRIES = int(os.getenv("RIPPERDOC_MAX_RETRIES", "10"))
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _resolve_tool(
|
|
68
|
+
tool_registry: "ToolRegistry", tool_name: str, tool_use_id: str
|
|
69
|
+
) -> tuple[Optional[Tool[Any, Any]], Optional[UserMessage]]:
|
|
70
|
+
"""Find a tool by name and return an error message if missing."""
|
|
71
|
+
tool = tool_registry.get(tool_name)
|
|
72
|
+
if tool:
|
|
73
|
+
tool_registry.activate_tools([tool_name])
|
|
74
|
+
return tool, None
|
|
75
|
+
return None, tool_result_message(
|
|
76
|
+
tool_use_id, f"Error: Tool '{tool_name}' not found", is_error=True
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
ToolPermissionCallable = Callable[
|
|
81
|
+
[Tool[Any, Any], Any],
|
|
82
|
+
Union[
|
|
83
|
+
PermissionResult,
|
|
84
|
+
Dict[str, Any],
|
|
85
|
+
Tuple[bool, Optional[str]],
|
|
86
|
+
bool,
|
|
87
|
+
Awaitable[Union[PermissionResult, Dict[str, Any], Tuple[bool, Optional[str]], bool]],
|
|
88
|
+
],
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
async def _check_tool_permissions(
|
|
93
|
+
tool: Tool[Any, Any],
|
|
94
|
+
parsed_input: Any,
|
|
95
|
+
query_context: "QueryContext",
|
|
96
|
+
can_use_tool_fn: Optional[ToolPermissionCallable],
|
|
97
|
+
) -> tuple[bool, Optional[str]]:
|
|
98
|
+
"""Evaluate whether a tool call is allowed."""
|
|
99
|
+
try:
|
|
100
|
+
if can_use_tool_fn is not None:
|
|
101
|
+
decision = can_use_tool_fn(tool, parsed_input)
|
|
102
|
+
if inspect.isawaitable(decision):
|
|
103
|
+
decision = await decision
|
|
104
|
+
if isinstance(decision, PermissionResult):
|
|
105
|
+
return decision.result, decision.message
|
|
106
|
+
if isinstance(decision, dict) and "result" in decision:
|
|
107
|
+
return bool(decision.get("result")), decision.get("message")
|
|
108
|
+
if isinstance(decision, tuple) and len(decision) == 2:
|
|
109
|
+
return bool(decision[0]), decision[1]
|
|
110
|
+
return bool(decision), None
|
|
111
|
+
|
|
112
|
+
if query_context.safe_mode and tool.needs_permissions(parsed_input):
|
|
113
|
+
loop = asyncio.get_running_loop()
|
|
114
|
+
input_preview = (
|
|
115
|
+
parsed_input.model_dump()
|
|
116
|
+
if hasattr(parsed_input, "model_dump")
|
|
117
|
+
else str(parsed_input)
|
|
118
|
+
)
|
|
119
|
+
prompt = f"Allow tool '{tool.name}' with input {input_preview}? [y/N]: "
|
|
120
|
+
response = await loop.run_in_executor(None, lambda: input(prompt))
|
|
121
|
+
return response.strip().lower() in ("y", "yes"), None
|
|
122
|
+
|
|
123
|
+
return True, None
|
|
124
|
+
except (TypeError, AttributeError, ValueError) as exc:
|
|
125
|
+
logger.warning(
|
|
126
|
+
f"Error checking permissions for tool '{tool.name}': {type(exc).__name__}: {exc}",
|
|
127
|
+
extra={"tool": getattr(tool, "name", None), "error_type": type(exc).__name__},
|
|
128
|
+
)
|
|
129
|
+
return False, None
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _format_changed_file_notice(notices: List[ChangedFileNotice]) -> str:
|
|
133
|
+
"""Render a system notice about files that changed on disk."""
|
|
134
|
+
lines: List[str] = [
|
|
135
|
+
"System notice: Files you previously read have changed on disk.",
|
|
136
|
+
"Please re-read the affected files before making further edits.",
|
|
137
|
+
"",
|
|
138
|
+
]
|
|
139
|
+
for notice in notices:
|
|
140
|
+
lines.append(f"- {notice.file_path}")
|
|
141
|
+
summary = (notice.summary or "").rstrip()
|
|
142
|
+
if summary:
|
|
143
|
+
indented = "\n".join(f" {line}" for line in summary.splitlines())
|
|
144
|
+
lines.append(indented)
|
|
145
|
+
return "\n".join(lines)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
async def _run_tool_use_generator(
|
|
149
|
+
tool: Tool[Any, Any],
|
|
150
|
+
tool_use_id: str,
|
|
151
|
+
tool_name: str,
|
|
152
|
+
parsed_input: Any,
|
|
153
|
+
sibling_ids: set[str],
|
|
154
|
+
tool_context: ToolUseContext,
|
|
155
|
+
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
156
|
+
"""Execute a single tool_use and yield progress/results."""
|
|
157
|
+
try:
|
|
158
|
+
async for output in tool.call(parsed_input, tool_context):
|
|
159
|
+
if isinstance(output, ToolProgress):
|
|
160
|
+
yield create_progress_message(
|
|
161
|
+
tool_use_id=tool_use_id,
|
|
162
|
+
sibling_tool_use_ids=sibling_ids,
|
|
163
|
+
content=output.content,
|
|
164
|
+
)
|
|
165
|
+
logger.debug(f"[query] Progress from tool_use_id={tool_use_id}: {output.content}")
|
|
166
|
+
elif isinstance(output, ToolResult):
|
|
167
|
+
result_content = output.result_for_assistant or str(output.data)
|
|
168
|
+
result_msg = tool_result_message(
|
|
169
|
+
tool_use_id, result_content, tool_use_result=output.data
|
|
170
|
+
)
|
|
171
|
+
yield result_msg
|
|
172
|
+
logger.debug(
|
|
173
|
+
f"[query] Tool completed tool_use_id={tool_use_id} name={tool_name} "
|
|
174
|
+
f"result_len={len(result_content)}"
|
|
175
|
+
)
|
|
176
|
+
except CancelledError:
|
|
177
|
+
raise # Don't suppress task cancellation
|
|
178
|
+
except (RuntimeError, ValueError, TypeError, OSError, IOError, AttributeError, KeyError) as exc:
|
|
179
|
+
logger.warning(
|
|
180
|
+
"Error executing tool '%s': %s: %s",
|
|
181
|
+
tool_name, type(exc).__name__, exc,
|
|
182
|
+
extra={"tool": tool_name, "tool_use_id": tool_use_id},
|
|
183
|
+
)
|
|
184
|
+
yield tool_result_message(tool_use_id, f"Error executing tool: {str(exc)}", is_error=True)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _group_tool_calls_by_concurrency(prepared_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
188
|
+
"""Group consecutive tool calls by their concurrency safety."""
|
|
189
|
+
groups: List[Dict[str, Any]] = []
|
|
190
|
+
for call in prepared_calls:
|
|
191
|
+
is_safe = bool(call.get("is_concurrency_safe"))
|
|
192
|
+
if groups and groups[-1]["is_concurrency_safe"] == is_safe:
|
|
193
|
+
groups[-1]["items"].append(call)
|
|
194
|
+
else:
|
|
195
|
+
groups.append({"is_concurrency_safe": is_safe, "items": [call]})
|
|
196
|
+
return groups
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
async def _execute_tools_sequentially(
|
|
200
|
+
items: List[Dict[str, Any]], tool_results: List[UserMessage]
|
|
201
|
+
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
202
|
+
"""Run tool generators one by one."""
|
|
203
|
+
for item in items:
|
|
204
|
+
gen = item.get("generator")
|
|
205
|
+
if not gen:
|
|
206
|
+
continue
|
|
207
|
+
async for message in gen:
|
|
208
|
+
if isinstance(message, UserMessage):
|
|
209
|
+
tool_results.append(message)
|
|
210
|
+
yield message
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
async def _execute_tools_in_parallel(
|
|
214
|
+
items: List[Dict[str, Any]], tool_results: List[UserMessage]
|
|
215
|
+
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
216
|
+
"""Run tool generators concurrently."""
|
|
217
|
+
generators = [call["generator"] for call in items if call.get("generator")]
|
|
218
|
+
async for message in _run_concurrent_tool_uses(generators, tool_results):
|
|
219
|
+
yield message
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
async def _run_tools_concurrently(
|
|
223
|
+
prepared_calls: List[Dict[str, Any]], tool_results: List[UserMessage]
|
|
224
|
+
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
225
|
+
"""Run tools grouped by concurrency safety (parallel for safe groups)."""
|
|
226
|
+
for group in _group_tool_calls_by_concurrency(prepared_calls):
|
|
227
|
+
if group["is_concurrency_safe"]:
|
|
228
|
+
logger.debug(
|
|
229
|
+
f"[query] Executing {len(group['items'])} concurrency-safe tool(s) in parallel"
|
|
230
|
+
)
|
|
231
|
+
async for message in _execute_tools_in_parallel(group["items"], tool_results):
|
|
232
|
+
yield message
|
|
233
|
+
else:
|
|
234
|
+
logger.debug(
|
|
235
|
+
f"[query] Executing {len(group['items'])} tool(s) sequentially (not concurrency safe)"
|
|
236
|
+
)
|
|
237
|
+
async for message in _run_tools_serially(group["items"], tool_results):
|
|
238
|
+
yield message
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
async def _run_tools_serially(
|
|
242
|
+
prepared_calls: List[Dict[str, Any]], tool_results: List[UserMessage]
|
|
243
|
+
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
244
|
+
"""Run all tools sequentially (helper for clarity)."""
|
|
245
|
+
async for message in _execute_tools_sequentially(prepared_calls, tool_results):
|
|
246
|
+
yield message
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
async def _run_concurrent_tool_uses(
|
|
250
|
+
generators: List[AsyncGenerator[Union[UserMessage, ProgressMessage], None]],
|
|
251
|
+
tool_results: List[UserMessage],
|
|
252
|
+
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
253
|
+
"""Drain multiple tool generators concurrently and stream outputs."""
|
|
254
|
+
if not generators:
|
|
255
|
+
return
|
|
256
|
+
|
|
257
|
+
queue: asyncio.Queue[Optional[Union[UserMessage, ProgressMessage]]] = asyncio.Queue()
|
|
258
|
+
|
|
259
|
+
async def _consume(gen: AsyncGenerator[Union[UserMessage, ProgressMessage], None]) -> None:
|
|
260
|
+
try:
|
|
261
|
+
async for message in gen:
|
|
262
|
+
await queue.put(message)
|
|
263
|
+
except asyncio.CancelledError:
|
|
264
|
+
raise # Don't suppress cancellation
|
|
265
|
+
except (StopAsyncIteration, GeneratorExit):
|
|
266
|
+
pass # Normal generator termination
|
|
267
|
+
except (RuntimeError, ValueError, TypeError) as exc:
|
|
268
|
+
logger.warning(
|
|
269
|
+
"[query] Error while consuming tool generator: %s: %s",
|
|
270
|
+
type(exc).__name__, exc,
|
|
271
|
+
)
|
|
272
|
+
finally:
|
|
273
|
+
await queue.put(None)
|
|
274
|
+
|
|
275
|
+
tasks = [asyncio.create_task(_consume(gen)) for gen in generators]
|
|
276
|
+
active = len(tasks)
|
|
277
|
+
|
|
278
|
+
try:
|
|
279
|
+
while active:
|
|
280
|
+
message = await queue.get()
|
|
281
|
+
if message is None:
|
|
282
|
+
active -= 1
|
|
283
|
+
continue
|
|
284
|
+
if isinstance(message, UserMessage):
|
|
285
|
+
tool_results.append(message)
|
|
286
|
+
yield message
|
|
287
|
+
finally:
|
|
288
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class ToolRegistry:
|
|
292
|
+
"""Track available tools, including deferred ones, and expose search/activation helpers."""
|
|
293
|
+
|
|
294
|
+
def __init__(self, tools: List[Tool[Any, Any]]) -> None:
|
|
295
|
+
self._tool_map: Dict[str, Tool[Any, Any]] = {}
|
|
296
|
+
self._order: List[str] = []
|
|
297
|
+
self._deferred: set[str] = set()
|
|
298
|
+
self._active: List[str] = []
|
|
299
|
+
self._active_set: set[str] = set()
|
|
300
|
+
self.replace_tools(tools)
|
|
301
|
+
|
|
302
|
+
def replace_tools(self, tools: List[Tool[Any, Any]]) -> None:
|
|
303
|
+
"""Replace all known tools and rebuild active/deferred lists."""
|
|
304
|
+
seen = set()
|
|
305
|
+
self._tool_map.clear()
|
|
306
|
+
self._order.clear()
|
|
307
|
+
self._deferred.clear()
|
|
308
|
+
self._active.clear()
|
|
309
|
+
self._active_set.clear()
|
|
310
|
+
|
|
311
|
+
for tool in tools:
|
|
312
|
+
name = getattr(tool, "name", None)
|
|
313
|
+
if not name or name in seen:
|
|
314
|
+
continue
|
|
315
|
+
seen.add(name)
|
|
316
|
+
self._tool_map[name] = tool
|
|
317
|
+
self._order.append(name)
|
|
318
|
+
try:
|
|
319
|
+
deferred = tool.defer_loading()
|
|
320
|
+
except (TypeError, AttributeError) as exc:
|
|
321
|
+
logger.warning(
|
|
322
|
+
"[tool_registry] Tool.defer_loading failed: %s: %s",
|
|
323
|
+
type(exc).__name__, exc,
|
|
324
|
+
extra={"tool": getattr(tool, "name", None)},
|
|
325
|
+
)
|
|
326
|
+
deferred = False
|
|
327
|
+
if deferred:
|
|
328
|
+
self._deferred.add(name)
|
|
329
|
+
else:
|
|
330
|
+
self._active.append(name)
|
|
331
|
+
self._active_set.add(name)
|
|
332
|
+
|
|
333
|
+
@property
|
|
334
|
+
def active_tools(self) -> List[Tool[Any, Any]]:
|
|
335
|
+
"""Return active (non-deferred) tools in original order."""
|
|
336
|
+
return [self._tool_map[name] for name in self._order if name in self._active_set]
|
|
337
|
+
|
|
338
|
+
@property
|
|
339
|
+
def all_tools(self) -> List[Tool[Any, Any]]:
|
|
340
|
+
"""Return all known tools in registration order."""
|
|
341
|
+
return [self._tool_map[name] for name in self._order]
|
|
342
|
+
|
|
343
|
+
@property
|
|
344
|
+
def deferred_names(self) -> set[str]:
|
|
345
|
+
"""Return the set of deferred tool names."""
|
|
346
|
+
return set(self._deferred)
|
|
347
|
+
|
|
348
|
+
def get(self, name: str) -> Optional[Tool[Any, Any]]:
|
|
349
|
+
"""Lookup a tool by name."""
|
|
350
|
+
return self._tool_map.get(name)
|
|
351
|
+
|
|
352
|
+
def is_active(self, name: str) -> bool:
|
|
353
|
+
"""Check if a tool is currently active."""
|
|
354
|
+
return name in self._active_set
|
|
355
|
+
|
|
356
|
+
def activate_tools(self, names: Iterable[str]) -> Tuple[List[str], List[str]]:
|
|
357
|
+
"""Activate deferred tools by name."""
|
|
358
|
+
activated: List[str] = []
|
|
359
|
+
missing: List[str] = []
|
|
360
|
+
for raw_name in names:
|
|
361
|
+
name = (raw_name or "").strip()
|
|
362
|
+
if not name:
|
|
363
|
+
continue
|
|
364
|
+
if name in self._active_set:
|
|
365
|
+
continue
|
|
366
|
+
tool = self._tool_map.get(name)
|
|
367
|
+
if tool:
|
|
368
|
+
self._active.append(name)
|
|
369
|
+
self._active_set.add(name)
|
|
370
|
+
self._deferred.discard(name)
|
|
371
|
+
activated.append(name)
|
|
372
|
+
else:
|
|
373
|
+
missing.append(name)
|
|
374
|
+
return activated, missing
|
|
375
|
+
|
|
376
|
+
def iter_named_tools(self) -> Iterable[tuple[str, Tool[Any, Any]]]:
|
|
377
|
+
"""Yield (name, tool) for all known tools in registration order."""
|
|
378
|
+
for name in self._order:
|
|
379
|
+
tool = self._tool_map.get(name)
|
|
380
|
+
if tool:
|
|
381
|
+
yield name, tool
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _apply_skill_context_updates(
|
|
385
|
+
tool_results: List[UserMessage], query_context: "QueryContext"
|
|
386
|
+
) -> None:
|
|
387
|
+
"""Update query context based on Skill tool outputs."""
|
|
388
|
+
for message in tool_results:
|
|
389
|
+
data = getattr(message, "tool_use_result", None)
|
|
390
|
+
if not isinstance(data, dict):
|
|
391
|
+
continue
|
|
392
|
+
skill_name = (
|
|
393
|
+
data.get("skill")
|
|
394
|
+
or data.get("command_name")
|
|
395
|
+
or data.get("commandName")
|
|
396
|
+
or data.get("command")
|
|
397
|
+
)
|
|
398
|
+
if not skill_name:
|
|
399
|
+
continue
|
|
400
|
+
|
|
401
|
+
allowed_tools = data.get("allowed_tools") or data.get("allowedTools") or []
|
|
402
|
+
if allowed_tools and getattr(query_context, "tool_registry", None):
|
|
403
|
+
try:
|
|
404
|
+
query_context.tool_registry.activate_tools(
|
|
405
|
+
[tool for tool in allowed_tools if isinstance(tool, str) and tool.strip()]
|
|
406
|
+
)
|
|
407
|
+
except (KeyError, ValueError, TypeError) as exc:
|
|
408
|
+
logger.warning(
|
|
409
|
+
"[query] Failed to activate tools listed in skill output: %s: %s",
|
|
410
|
+
type(exc).__name__, exc,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
model_hint = data.get("model")
|
|
414
|
+
if isinstance(model_hint, str) and model_hint.strip():
|
|
415
|
+
logger.debug(
|
|
416
|
+
"[query] Applying model hint from skill",
|
|
417
|
+
extra={"skill": skill_name, "model": model_hint},
|
|
418
|
+
)
|
|
419
|
+
query_context.model = model_hint.strip()
|
|
420
|
+
|
|
421
|
+
max_tokens = data.get("max_thinking_tokens")
|
|
422
|
+
if max_tokens is None:
|
|
423
|
+
max_tokens = data.get("maxThinkingTokens")
|
|
424
|
+
parsed_max = parse_optional_int(max_tokens)
|
|
425
|
+
if parsed_max is not None:
|
|
426
|
+
logger.debug(
|
|
427
|
+
"[query] Applying max thinking tokens from skill",
|
|
428
|
+
extra={"skill": skill_name, "max_thinking_tokens": parsed_max},
|
|
429
|
+
)
|
|
430
|
+
query_context.max_thinking_tokens = parsed_max
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
class QueryContext:
|
|
434
|
+
"""Context for a query session."""
|
|
435
|
+
|
|
436
|
+
def __init__(
|
|
437
|
+
self,
|
|
438
|
+
tools: List[Tool[Any, Any]],
|
|
439
|
+
max_thinking_tokens: int = 0,
|
|
440
|
+
safe_mode: bool = False,
|
|
441
|
+
model: str = "main",
|
|
442
|
+
verbose: bool = False,
|
|
443
|
+
pause_ui: Optional[Callable[[], None]] = None,
|
|
444
|
+
resume_ui: Optional[Callable[[], None]] = None,
|
|
445
|
+
) -> None:
|
|
446
|
+
self.tool_registry = ToolRegistry(tools)
|
|
447
|
+
self.max_thinking_tokens = max_thinking_tokens
|
|
448
|
+
self.safe_mode = safe_mode
|
|
449
|
+
self.model = model
|
|
450
|
+
self.verbose = verbose
|
|
451
|
+
self.abort_controller = asyncio.Event()
|
|
452
|
+
self.file_state_cache: Dict[str, FileSnapshot] = {}
|
|
453
|
+
self.pause_ui = pause_ui
|
|
454
|
+
self.resume_ui = resume_ui
|
|
455
|
+
|
|
456
|
+
@property
|
|
457
|
+
def tools(self) -> List[Tool[Any, Any]]:
|
|
458
|
+
"""Active tools available for the current request."""
|
|
459
|
+
return self.tool_registry.active_tools
|
|
460
|
+
|
|
461
|
+
@tools.setter
|
|
462
|
+
def tools(self, tools: List[Tool[Any, Any]]) -> None:
|
|
463
|
+
"""Replace tool inventory and recompute active/deferred sets."""
|
|
464
|
+
self.tool_registry.replace_tools(tools)
|
|
465
|
+
|
|
466
|
+
def activate_tools(self, names: Iterable[str]) -> Tuple[List[str], List[str]]:
|
|
467
|
+
"""Activate deferred tools by name."""
|
|
468
|
+
return self.tool_registry.activate_tools(names)
|
|
469
|
+
|
|
470
|
+
def all_tools(self) -> List[Tool[Any, Any]]:
|
|
471
|
+
"""Return all known tools (active + deferred)."""
|
|
472
|
+
return self.tool_registry.all_tools
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
async def query_llm(
|
|
476
|
+
messages: List[Union[UserMessage, AssistantMessage, ProgressMessage]],
|
|
477
|
+
system_prompt: str,
|
|
478
|
+
tools: List[Tool[Any, Any]],
|
|
479
|
+
max_thinking_tokens: int = 0,
|
|
480
|
+
model: str = "main",
|
|
481
|
+
_abort_signal: Optional[asyncio.Event] = None,
|
|
482
|
+
*,
|
|
483
|
+
progress_callback: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
484
|
+
request_timeout: Optional[float] = None,
|
|
485
|
+
max_retries: int = MAX_LLM_RETRIES,
|
|
486
|
+
stream: bool = True,
|
|
487
|
+
) -> AssistantMessage:
|
|
488
|
+
"""Query the AI model and return the response.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
messages: Conversation history
|
|
492
|
+
system_prompt: System prompt for the model
|
|
493
|
+
tools: Available tools
|
|
494
|
+
max_thinking_tokens: Maximum tokens for thinking (0 = disabled)
|
|
495
|
+
model: Model pointer to use
|
|
496
|
+
_abort_signal: Event to signal abortion (currently unused, reserved for future)
|
|
497
|
+
progress_callback: Optional async callback invoked with streamed text chunks
|
|
498
|
+
request_timeout: Max seconds to wait for a provider response before retrying
|
|
499
|
+
max_retries: Number of retries on timeout/errors (total attempts = retries + 1)
|
|
500
|
+
stream: Enable streaming for providers that support it (text-only mode)
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
AssistantMessage with the model's response
|
|
504
|
+
"""
|
|
505
|
+
request_timeout = request_timeout or DEFAULT_REQUEST_TIMEOUT_SEC
|
|
506
|
+
request_timeout = request_timeout or DEFAULT_REQUEST_TIMEOUT_SEC
|
|
507
|
+
model_profile = resolve_model_profile(model)
|
|
508
|
+
|
|
509
|
+
# Normalize messages based on protocol family (Anthropic allows tool blocks; OpenAI-style prefers text-only)
|
|
510
|
+
protocol = provider_protocol(model_profile.provider)
|
|
511
|
+
tool_mode = determine_tool_mode(model_profile)
|
|
512
|
+
messages_for_model: List[Union[UserMessage, AssistantMessage, ProgressMessage]]
|
|
513
|
+
if tool_mode == "text":
|
|
514
|
+
messages_for_model = cast(
|
|
515
|
+
List[Union[UserMessage, AssistantMessage, ProgressMessage]],
|
|
516
|
+
text_mode_history(messages),
|
|
517
|
+
)
|
|
518
|
+
else:
|
|
519
|
+
messages_for_model = messages
|
|
520
|
+
|
|
521
|
+
normalized_messages: List[Dict[str, Any]] = normalize_messages_for_api(
|
|
522
|
+
messages_for_model, protocol=protocol, tool_mode=tool_mode
|
|
523
|
+
)
|
|
524
|
+
logger.info(
|
|
525
|
+
"[query_llm] Preparing model request",
|
|
526
|
+
extra={
|
|
527
|
+
"model_pointer": model,
|
|
528
|
+
"provider": getattr(model_profile.provider, "value", str(model_profile.provider)),
|
|
529
|
+
"model": model_profile.model,
|
|
530
|
+
"normalized_messages": len(normalized_messages),
|
|
531
|
+
"tool_count": len(tools),
|
|
532
|
+
"max_thinking_tokens": max_thinking_tokens,
|
|
533
|
+
"tool_mode": tool_mode,
|
|
534
|
+
},
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
if protocol == "openai":
|
|
538
|
+
log_openai_messages(normalized_messages)
|
|
539
|
+
|
|
540
|
+
logger.debug(
|
|
541
|
+
f"[query_llm] Sending {len(normalized_messages)} messages to model pointer "
|
|
542
|
+
f"'{model}' with {len(tools)} tool schemas; "
|
|
543
|
+
f"max_thinking_tokens={max_thinking_tokens} protocol={protocol}"
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Make the API call
|
|
547
|
+
start_time = time.time()
|
|
548
|
+
|
|
549
|
+
try:
|
|
550
|
+
client: Optional[ProviderClient] = get_provider_client(model_profile.provider)
|
|
551
|
+
if client is None:
|
|
552
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
553
|
+
error_msg = create_assistant_message(
|
|
554
|
+
content=(
|
|
555
|
+
"Gemini protocol is not supported yet in Ripperdoc. "
|
|
556
|
+
"Please configure an Anthropic or OpenAI-compatible model."
|
|
557
|
+
),
|
|
558
|
+
duration_ms=duration_ms,
|
|
559
|
+
)
|
|
560
|
+
error_msg.is_api_error_message = True
|
|
561
|
+
return error_msg
|
|
562
|
+
|
|
563
|
+
provider_response = await client.call(
|
|
564
|
+
model_profile=model_profile,
|
|
565
|
+
system_prompt=system_prompt,
|
|
566
|
+
normalized_messages=normalized_messages,
|
|
567
|
+
tools=tools,
|
|
568
|
+
tool_mode=tool_mode,
|
|
569
|
+
stream=stream,
|
|
570
|
+
progress_callback=progress_callback,
|
|
571
|
+
request_timeout=request_timeout,
|
|
572
|
+
max_retries=max_retries,
|
|
573
|
+
max_thinking_tokens=max_thinking_tokens,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
# Check if provider returned an error response
|
|
577
|
+
if provider_response.is_error:
|
|
578
|
+
logger.warning(
|
|
579
|
+
"[query_llm] Provider returned error response",
|
|
580
|
+
extra={
|
|
581
|
+
"model": model_profile.model,
|
|
582
|
+
"error_code": provider_response.error_code,
|
|
583
|
+
"error_message": provider_response.error_message,
|
|
584
|
+
},
|
|
585
|
+
)
|
|
586
|
+
metadata: Dict[str, Any] = {
|
|
587
|
+
"api_error": True,
|
|
588
|
+
"error_code": provider_response.error_code,
|
|
589
|
+
"error_message": provider_response.error_message,
|
|
590
|
+
}
|
|
591
|
+
# Add context length info if applicable
|
|
592
|
+
if provider_response.error_code == "context_length_exceeded":
|
|
593
|
+
metadata["context_length_exceeded"] = True
|
|
594
|
+
|
|
595
|
+
error_msg = create_assistant_message(
|
|
596
|
+
content=provider_response.content_blocks,
|
|
597
|
+
duration_ms=provider_response.duration_ms,
|
|
598
|
+
metadata=metadata,
|
|
599
|
+
)
|
|
600
|
+
error_msg.is_api_error_message = True
|
|
601
|
+
return error_msg
|
|
602
|
+
|
|
603
|
+
return create_assistant_message(
|
|
604
|
+
content=provider_response.content_blocks,
|
|
605
|
+
cost_usd=provider_response.cost_usd,
|
|
606
|
+
duration_ms=provider_response.duration_ms,
|
|
607
|
+
metadata=provider_response.metadata,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
except CancelledError:
|
|
611
|
+
raise # Don't suppress task cancellation
|
|
612
|
+
except (RuntimeError, ValueError, TypeError, OSError, ConnectionError, TimeoutError) as e:
|
|
613
|
+
# Return error message
|
|
614
|
+
logger.warning(
|
|
615
|
+
"Error querying AI model: %s: %s",
|
|
616
|
+
type(e).__name__, e,
|
|
617
|
+
extra={
|
|
618
|
+
"model": getattr(model_profile, "model", None),
|
|
619
|
+
"model_pointer": model,
|
|
620
|
+
"provider": (
|
|
621
|
+
getattr(model_profile.provider, "value", None) if model_profile else None
|
|
622
|
+
),
|
|
623
|
+
},
|
|
624
|
+
)
|
|
625
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
626
|
+
context_error = detect_context_length_error(e)
|
|
627
|
+
metadata = None
|
|
628
|
+
content = f"Error querying AI model: {str(e)}"
|
|
629
|
+
|
|
630
|
+
if context_error:
|
|
631
|
+
content = f"The request exceeded the model's context window. {context_error.message}"
|
|
632
|
+
metadata = {
|
|
633
|
+
"context_length_exceeded": True,
|
|
634
|
+
"context_length_provider": context_error.provider,
|
|
635
|
+
"context_length_error_code": context_error.error_code,
|
|
636
|
+
"context_length_status_code": context_error.status_code,
|
|
637
|
+
}
|
|
638
|
+
logger.info(
|
|
639
|
+
"[query_llm] Detected context-length error; consider compacting history",
|
|
640
|
+
extra={
|
|
641
|
+
"provider": context_error.provider,
|
|
642
|
+
"error_code": context_error.error_code,
|
|
643
|
+
"status_code": context_error.status_code,
|
|
644
|
+
},
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
error_msg = create_assistant_message(
|
|
648
|
+
content=content, duration_ms=duration_ms, metadata=metadata
|
|
649
|
+
)
|
|
650
|
+
error_msg.is_api_error_message = True
|
|
651
|
+
return error_msg
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
MAX_QUERY_ITERATIONS = int(os.getenv("RIPPERDOC_MAX_QUERY_ITERATIONS", "1024"))
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
@dataclass
|
|
658
|
+
class IterationResult:
|
|
659
|
+
"""Result of a single query iteration.
|
|
660
|
+
|
|
661
|
+
This is used as an "out parameter" to communicate results from
|
|
662
|
+
_run_query_iteration back to the main query loop.
|
|
663
|
+
"""
|
|
664
|
+
|
|
665
|
+
assistant_message: Optional[AssistantMessage] = None
|
|
666
|
+
tool_results: List[UserMessage] = field(default_factory=list)
|
|
667
|
+
should_stop: bool = False # True means exit the query loop entirely
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
async def _run_query_iteration(
|
|
671
|
+
messages: List[Union[UserMessage, AssistantMessage, ProgressMessage]],
|
|
672
|
+
system_prompt: str,
|
|
673
|
+
context: Dict[str, str],
|
|
674
|
+
query_context: QueryContext,
|
|
675
|
+
can_use_tool_fn: Optional[ToolPermissionCallable],
|
|
676
|
+
iteration: int,
|
|
677
|
+
result: IterationResult,
|
|
678
|
+
) -> AsyncGenerator[Union[UserMessage, AssistantMessage, ProgressMessage], None]:
|
|
679
|
+
"""Run a single iteration of the query loop.
|
|
680
|
+
|
|
681
|
+
This function handles one round of:
|
|
682
|
+
1. Calling the LLM
|
|
683
|
+
2. Streaming progress
|
|
684
|
+
3. Processing tool calls (if any)
|
|
685
|
+
|
|
686
|
+
Args:
|
|
687
|
+
messages: Current conversation history
|
|
688
|
+
system_prompt: Base system prompt
|
|
689
|
+
context: Additional context dictionary
|
|
690
|
+
query_context: Query configuration
|
|
691
|
+
can_use_tool_fn: Optional function to check tool permissions
|
|
692
|
+
iteration: Current iteration number (for logging)
|
|
693
|
+
result: IterationResult object to store results
|
|
694
|
+
|
|
695
|
+
Yields:
|
|
696
|
+
Messages (progress, assistant, tool results) as they are generated
|
|
697
|
+
"""
|
|
698
|
+
logger.debug(f"[query] Iteration {iteration}/{MAX_QUERY_ITERATIONS}")
|
|
699
|
+
|
|
700
|
+
# Check for file changes at the start of each iteration
|
|
701
|
+
change_notices = detect_changed_files(query_context.file_state_cache)
|
|
702
|
+
if change_notices:
|
|
703
|
+
messages.append(create_user_message(_format_changed_file_notice(change_notices)))
|
|
704
|
+
|
|
705
|
+
model_profile = resolve_model_profile(query_context.model)
|
|
706
|
+
tool_mode = determine_tool_mode(model_profile)
|
|
707
|
+
tools_for_model: List[Tool[Any, Any]] = (
|
|
708
|
+
[] if tool_mode == "text" else query_context.all_tools()
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
full_system_prompt = build_full_system_prompt(
|
|
712
|
+
system_prompt, context, tool_mode, query_context.all_tools()
|
|
713
|
+
)
|
|
714
|
+
logger.debug(
|
|
715
|
+
"[query] Built system prompt",
|
|
716
|
+
extra={
|
|
717
|
+
"prompt_chars": len(full_system_prompt),
|
|
718
|
+
"context_entries": len(context),
|
|
719
|
+
"tool_count": len(tools_for_model),
|
|
720
|
+
},
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# Stream LLM response
|
|
724
|
+
progress_queue: asyncio.Queue[Optional[ProgressMessage]] = asyncio.Queue()
|
|
725
|
+
|
|
726
|
+
async def _stream_progress(chunk: str) -> None:
|
|
727
|
+
if not chunk:
|
|
728
|
+
return
|
|
729
|
+
try:
|
|
730
|
+
await progress_queue.put(
|
|
731
|
+
create_progress_message(
|
|
732
|
+
tool_use_id="stream",
|
|
733
|
+
sibling_tool_use_ids=set(),
|
|
734
|
+
content=chunk,
|
|
735
|
+
)
|
|
736
|
+
)
|
|
737
|
+
except asyncio.QueueFull:
|
|
738
|
+
logger.warning("[query] Progress queue full, dropping chunk")
|
|
739
|
+
except (RuntimeError, ValueError) as exc:
|
|
740
|
+
logger.warning("[query] Failed to enqueue stream progress chunk: %s", exc)
|
|
741
|
+
|
|
742
|
+
assistant_task = asyncio.create_task(
|
|
743
|
+
query_llm(
|
|
744
|
+
messages,
|
|
745
|
+
full_system_prompt,
|
|
746
|
+
tools_for_model,
|
|
747
|
+
query_context.max_thinking_tokens,
|
|
748
|
+
query_context.model,
|
|
749
|
+
query_context.abort_controller,
|
|
750
|
+
progress_callback=_stream_progress,
|
|
751
|
+
request_timeout=DEFAULT_REQUEST_TIMEOUT_SEC,
|
|
752
|
+
max_retries=MAX_LLM_RETRIES,
|
|
753
|
+
stream=True,
|
|
754
|
+
)
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
assistant_message: Optional[AssistantMessage] = None
|
|
758
|
+
|
|
759
|
+
# Wait for LLM response while yielding progress
|
|
760
|
+
while True:
|
|
761
|
+
if query_context.abort_controller.is_set():
|
|
762
|
+
assistant_task.cancel()
|
|
763
|
+
try:
|
|
764
|
+
await assistant_task
|
|
765
|
+
except CancelledError:
|
|
766
|
+
pass
|
|
767
|
+
yield create_assistant_message(INTERRUPT_MESSAGE)
|
|
768
|
+
result.should_stop = True
|
|
769
|
+
return
|
|
770
|
+
if assistant_task.done():
|
|
771
|
+
assistant_message = await assistant_task
|
|
772
|
+
break
|
|
773
|
+
try:
|
|
774
|
+
progress = progress_queue.get_nowait()
|
|
775
|
+
except asyncio.QueueEmpty:
|
|
776
|
+
waiter = asyncio.create_task(progress_queue.get())
|
|
777
|
+
# Use timeout to periodically check abort_controller during LLM request
|
|
778
|
+
done, pending = await asyncio.wait(
|
|
779
|
+
{assistant_task, waiter},
|
|
780
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
781
|
+
timeout=0.1 # Check abort_controller every 100ms
|
|
782
|
+
)
|
|
783
|
+
if not done:
|
|
784
|
+
# Timeout - cancel waiter and continue loop to check abort_controller
|
|
785
|
+
waiter.cancel()
|
|
786
|
+
try:
|
|
787
|
+
await waiter
|
|
788
|
+
except asyncio.CancelledError:
|
|
789
|
+
pass
|
|
790
|
+
continue
|
|
791
|
+
if assistant_task in done:
|
|
792
|
+
for task in pending:
|
|
793
|
+
task.cancel()
|
|
794
|
+
assistant_message = await assistant_task
|
|
795
|
+
break
|
|
796
|
+
progress = waiter.result()
|
|
797
|
+
if progress:
|
|
798
|
+
yield progress
|
|
799
|
+
|
|
800
|
+
# Drain remaining progress messages
|
|
801
|
+
while not progress_queue.empty():
|
|
802
|
+
residual = progress_queue.get_nowait()
|
|
803
|
+
if residual:
|
|
804
|
+
yield residual
|
|
805
|
+
|
|
806
|
+
assert assistant_message is not None
|
|
807
|
+
result.assistant_message = assistant_message
|
|
808
|
+
|
|
809
|
+
# Check for abort
|
|
810
|
+
if query_context.abort_controller.is_set():
|
|
811
|
+
yield create_assistant_message(INTERRUPT_MESSAGE)
|
|
812
|
+
result.should_stop = True
|
|
813
|
+
return
|
|
814
|
+
|
|
815
|
+
yield assistant_message
|
|
816
|
+
|
|
817
|
+
# Extract and process tool calls
|
|
818
|
+
tool_use_blocks: List[MessageContent] = extract_tool_use_blocks(assistant_message)
|
|
819
|
+
text_blocks = (
|
|
820
|
+
len(assistant_message.message.content)
|
|
821
|
+
if isinstance(assistant_message.message.content, list)
|
|
822
|
+
else 1
|
|
823
|
+
)
|
|
824
|
+
logger.debug(
|
|
825
|
+
f"[query] Assistant message received: text_blocks={text_blocks}, "
|
|
826
|
+
f"tool_use_blocks={len(tool_use_blocks)}"
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
if not tool_use_blocks:
|
|
830
|
+
logger.debug("[query] No tool_use blocks; returning response to user.")
|
|
831
|
+
result.should_stop = True
|
|
832
|
+
return
|
|
833
|
+
|
|
834
|
+
# Process tool calls
|
|
835
|
+
logger.debug(f"[query] Executing {len(tool_use_blocks)} tool_use block(s).")
|
|
836
|
+
tool_results: List[UserMessage] = []
|
|
837
|
+
permission_denied = False
|
|
838
|
+
sibling_ids = set(
|
|
839
|
+
getattr(t, "tool_use_id", None) or getattr(t, "id", None) or ""
|
|
840
|
+
for t in tool_use_blocks
|
|
841
|
+
)
|
|
842
|
+
prepared_calls: List[Dict[str, Any]] = []
|
|
843
|
+
|
|
844
|
+
for tool_use in tool_use_blocks:
|
|
845
|
+
tool_name = tool_use.name
|
|
846
|
+
if not tool_name:
|
|
847
|
+
continue
|
|
848
|
+
tool_use_id = (
|
|
849
|
+
getattr(tool_use, "tool_use_id", None) or getattr(tool_use, "id", None) or ""
|
|
850
|
+
)
|
|
851
|
+
tool_input = getattr(tool_use, "input", {}) or {}
|
|
852
|
+
|
|
853
|
+
tool, missing_msg = _resolve_tool(
|
|
854
|
+
query_context.tool_registry, tool_name, tool_use_id
|
|
855
|
+
)
|
|
856
|
+
if missing_msg:
|
|
857
|
+
logger.warning(
|
|
858
|
+
f"[query] Tool '{tool_name}' not found for tool_use_id={tool_use_id}"
|
|
859
|
+
)
|
|
860
|
+
tool_results.append(missing_msg)
|
|
861
|
+
yield missing_msg
|
|
862
|
+
continue
|
|
863
|
+
assert tool is not None
|
|
864
|
+
|
|
865
|
+
try:
|
|
866
|
+
parsed_input = tool.input_schema(**tool_input)
|
|
867
|
+
logger.debug(
|
|
868
|
+
f"[query] tool_use_id={tool_use_id} name={tool_name} parsed_input="
|
|
869
|
+
f"{str(parsed_input)[:500]}"
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
tool_context = ToolUseContext(
|
|
873
|
+
safe_mode=query_context.safe_mode,
|
|
874
|
+
verbose=query_context.verbose,
|
|
875
|
+
permission_checker=can_use_tool_fn,
|
|
876
|
+
tool_registry=query_context.tool_registry,
|
|
877
|
+
file_state_cache=query_context.file_state_cache,
|
|
878
|
+
abort_signal=query_context.abort_controller,
|
|
879
|
+
pause_ui=query_context.pause_ui,
|
|
880
|
+
resume_ui=query_context.resume_ui,
|
|
881
|
+
)
|
|
882
|
+
|
|
883
|
+
validation = await tool.validate_input(parsed_input, tool_context)
|
|
884
|
+
if not validation.result:
|
|
885
|
+
logger.debug(
|
|
886
|
+
f"[query] Validation failed for tool_use_id={tool_use_id}: "
|
|
887
|
+
f"{validation.message}"
|
|
888
|
+
)
|
|
889
|
+
result_msg = tool_result_message(
|
|
890
|
+
tool_use_id,
|
|
891
|
+
validation.message or "Tool input validation failed.",
|
|
892
|
+
is_error=True,
|
|
893
|
+
)
|
|
894
|
+
tool_results.append(result_msg)
|
|
895
|
+
yield result_msg
|
|
896
|
+
continue
|
|
897
|
+
|
|
898
|
+
if query_context.safe_mode or can_use_tool_fn is not None:
|
|
899
|
+
allowed, denial_message = await _check_tool_permissions(
|
|
900
|
+
tool, parsed_input, query_context, can_use_tool_fn
|
|
901
|
+
)
|
|
902
|
+
if not allowed:
|
|
903
|
+
logger.debug(
|
|
904
|
+
f"[query] Permission denied for tool_use_id={tool_use_id}: "
|
|
905
|
+
f"{denial_message}"
|
|
906
|
+
)
|
|
907
|
+
denial_text = (
|
|
908
|
+
denial_message or f"User aborted the tool invocation: {tool_name}"
|
|
909
|
+
)
|
|
910
|
+
denial_msg = tool_result_message(tool_use_id, denial_text, is_error=True)
|
|
911
|
+
tool_results.append(denial_msg)
|
|
912
|
+
yield denial_msg
|
|
913
|
+
permission_denied = True
|
|
914
|
+
break
|
|
915
|
+
|
|
916
|
+
prepared_calls.append(
|
|
917
|
+
{
|
|
918
|
+
"is_concurrency_safe": tool.is_concurrency_safe(),
|
|
919
|
+
"generator": _run_tool_use_generator(
|
|
920
|
+
tool,
|
|
921
|
+
tool_use_id,
|
|
922
|
+
tool_name,
|
|
923
|
+
parsed_input,
|
|
924
|
+
sibling_ids,
|
|
925
|
+
tool_context,
|
|
926
|
+
),
|
|
927
|
+
}
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
except ValidationError as ve:
|
|
931
|
+
detail_text = format_pydantic_errors(ve)
|
|
932
|
+
error_msg = tool_result_message(
|
|
933
|
+
tool_use_id,
|
|
934
|
+
f"Invalid input for tool '{tool_name}': {detail_text}",
|
|
935
|
+
is_error=True,
|
|
936
|
+
)
|
|
937
|
+
tool_results.append(error_msg)
|
|
938
|
+
yield error_msg
|
|
939
|
+
continue
|
|
940
|
+
except CancelledError:
|
|
941
|
+
raise # Don't suppress task cancellation
|
|
942
|
+
except (
|
|
943
|
+
RuntimeError,
|
|
944
|
+
ValueError,
|
|
945
|
+
TypeError,
|
|
946
|
+
OSError,
|
|
947
|
+
IOError,
|
|
948
|
+
AttributeError,
|
|
949
|
+
KeyError,
|
|
950
|
+
) as e:
|
|
951
|
+
logger.warning(
|
|
952
|
+
"Error executing tool '%s': %s: %s",
|
|
953
|
+
tool_name,
|
|
954
|
+
type(e).__name__,
|
|
955
|
+
e,
|
|
956
|
+
extra={"tool": tool_name, "tool_use_id": tool_use_id},
|
|
957
|
+
)
|
|
958
|
+
error_msg = tool_result_message(
|
|
959
|
+
tool_use_id, f"Error executing tool: {str(e)}", is_error=True
|
|
960
|
+
)
|
|
961
|
+
tool_results.append(error_msg)
|
|
962
|
+
yield error_msg
|
|
963
|
+
|
|
964
|
+
if permission_denied:
|
|
965
|
+
break
|
|
966
|
+
|
|
967
|
+
if permission_denied:
|
|
968
|
+
result.tool_results = tool_results
|
|
969
|
+
result.should_stop = True
|
|
970
|
+
return
|
|
971
|
+
|
|
972
|
+
if prepared_calls:
|
|
973
|
+
async for message in _run_tools_concurrently(prepared_calls, tool_results):
|
|
974
|
+
yield message
|
|
975
|
+
|
|
976
|
+
_apply_skill_context_updates(tool_results, query_context)
|
|
977
|
+
|
|
978
|
+
# Check for abort after tools
|
|
979
|
+
if query_context.abort_controller.is_set():
|
|
980
|
+
yield create_assistant_message(INTERRUPT_MESSAGE_FOR_TOOL_USE)
|
|
981
|
+
result.tool_results = tool_results
|
|
982
|
+
result.should_stop = True
|
|
983
|
+
return
|
|
984
|
+
|
|
985
|
+
result.tool_results = tool_results
|
|
986
|
+
# should_stop remains False, indicating the loop should continue
|
|
987
|
+
|
|
988
|
+
|
|
989
|
+
async def query(
|
|
990
|
+
messages: List[Union[UserMessage, AssistantMessage, ProgressMessage]],
|
|
991
|
+
system_prompt: str,
|
|
992
|
+
context: Dict[str, str],
|
|
993
|
+
query_context: QueryContext,
|
|
994
|
+
can_use_tool_fn: Optional[ToolPermissionCallable] = None,
|
|
995
|
+
) -> AsyncGenerator[Union[UserMessage, AssistantMessage, ProgressMessage], None]:
|
|
996
|
+
"""Execute a query with tool support.
|
|
997
|
+
|
|
998
|
+
This is the main query loop that:
|
|
999
|
+
1. Sends messages to the AI
|
|
1000
|
+
2. Handles tool use responses
|
|
1001
|
+
3. Executes tools
|
|
1002
|
+
4. Continues the conversation in a loop until no more tool calls
|
|
1003
|
+
|
|
1004
|
+
Args:
|
|
1005
|
+
messages: Conversation history
|
|
1006
|
+
system_prompt: Base system prompt
|
|
1007
|
+
context: Additional context dictionary
|
|
1008
|
+
query_context: Query configuration
|
|
1009
|
+
can_use_tool_fn: Optional function to check tool permissions
|
|
1010
|
+
|
|
1011
|
+
Yields:
|
|
1012
|
+
Messages (user, assistant, progress) as they are generated
|
|
1013
|
+
"""
|
|
1014
|
+
logger.info(
|
|
1015
|
+
"[query] Starting query loop",
|
|
1016
|
+
extra={
|
|
1017
|
+
"message_count": len(messages),
|
|
1018
|
+
"tool_count": len(query_context.tools),
|
|
1019
|
+
"safe_mode": query_context.safe_mode,
|
|
1020
|
+
"model_pointer": query_context.model,
|
|
1021
|
+
},
|
|
1022
|
+
)
|
|
1023
|
+
# Work on a copy so external mutations (e.g., UI appending messages while consuming)
|
|
1024
|
+
# do not interfere with the loop or normalization.
|
|
1025
|
+
messages = list(messages)
|
|
1026
|
+
|
|
1027
|
+
for iteration in range(1, MAX_QUERY_ITERATIONS + 1):
|
|
1028
|
+
result = IterationResult()
|
|
1029
|
+
|
|
1030
|
+
async for msg in _run_query_iteration(
|
|
1031
|
+
messages,
|
|
1032
|
+
system_prompt,
|
|
1033
|
+
context,
|
|
1034
|
+
query_context,
|
|
1035
|
+
can_use_tool_fn,
|
|
1036
|
+
iteration,
|
|
1037
|
+
result,
|
|
1038
|
+
):
|
|
1039
|
+
yield msg
|
|
1040
|
+
|
|
1041
|
+
if result.should_stop:
|
|
1042
|
+
return
|
|
1043
|
+
|
|
1044
|
+
# Update messages for next iteration
|
|
1045
|
+
messages = messages + [result.assistant_message] + result.tool_results
|
|
1046
|
+
logger.debug(
|
|
1047
|
+
f"[query] Continuing loop with {len(messages)} messages after tools; "
|
|
1048
|
+
f"tool_results_count={len(result.tool_results)}"
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
# Reached max iterations
|
|
1052
|
+
logger.warning(
|
|
1053
|
+
f"[query] Reached maximum iterations ({MAX_QUERY_ITERATIONS}), stopping query loop"
|
|
1054
|
+
)
|
|
1055
|
+
yield create_assistant_message(
|
|
1056
|
+
f"Reached maximum query iterations ({MAX_QUERY_ITERATIONS}). "
|
|
1057
|
+
"Please continue the conversation to proceed."
|
|
1058
|
+
)
|