ripperdoc 0.2.9__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripperdoc/__init__.py +1 -1
- ripperdoc/cli/cli.py +379 -51
- ripperdoc/cli/commands/__init__.py +6 -0
- ripperdoc/cli/commands/agents_cmd.py +128 -5
- ripperdoc/cli/commands/clear_cmd.py +8 -0
- ripperdoc/cli/commands/doctor_cmd.py +29 -0
- ripperdoc/cli/commands/exit_cmd.py +1 -0
- ripperdoc/cli/commands/memory_cmd.py +2 -1
- ripperdoc/cli/commands/models_cmd.py +63 -7
- ripperdoc/cli/commands/resume_cmd.py +5 -0
- ripperdoc/cli/commands/skills_cmd.py +103 -0
- ripperdoc/cli/commands/stats_cmd.py +244 -0
- ripperdoc/cli/commands/status_cmd.py +10 -0
- ripperdoc/cli/commands/tasks_cmd.py +6 -3
- ripperdoc/cli/commands/themes_cmd.py +139 -0
- ripperdoc/cli/ui/file_mention_completer.py +63 -13
- ripperdoc/cli/ui/helpers.py +6 -3
- ripperdoc/cli/ui/interrupt_handler.py +34 -0
- ripperdoc/cli/ui/panels.py +14 -8
- ripperdoc/cli/ui/rich_ui.py +737 -47
- ripperdoc/cli/ui/spinner.py +93 -18
- ripperdoc/cli/ui/thinking_spinner.py +1 -2
- ripperdoc/cli/ui/tool_renderers.py +10 -9
- ripperdoc/cli/ui/wizard.py +24 -19
- ripperdoc/core/agents.py +14 -3
- ripperdoc/core/config.py +238 -6
- ripperdoc/core/default_tools.py +91 -10
- ripperdoc/core/hooks/events.py +4 -0
- ripperdoc/core/hooks/llm_callback.py +58 -0
- ripperdoc/core/hooks/manager.py +6 -0
- ripperdoc/core/permissions.py +160 -9
- ripperdoc/core/providers/openai.py +84 -28
- ripperdoc/core/query.py +489 -87
- ripperdoc/core/query_utils.py +17 -14
- ripperdoc/core/skills.py +1 -0
- ripperdoc/core/theme.py +298 -0
- ripperdoc/core/tool.py +15 -5
- ripperdoc/protocol/__init__.py +14 -0
- ripperdoc/protocol/models.py +300 -0
- ripperdoc/protocol/stdio.py +1453 -0
- ripperdoc/tools/background_shell.py +354 -139
- ripperdoc/tools/bash_tool.py +117 -22
- ripperdoc/tools/file_edit_tool.py +228 -50
- ripperdoc/tools/file_read_tool.py +154 -3
- ripperdoc/tools/file_write_tool.py +53 -11
- ripperdoc/tools/grep_tool.py +98 -8
- ripperdoc/tools/lsp_tool.py +609 -0
- ripperdoc/tools/multi_edit_tool.py +26 -3
- ripperdoc/tools/skill_tool.py +52 -1
- ripperdoc/tools/task_tool.py +539 -65
- ripperdoc/utils/conversation_compaction.py +1 -1
- ripperdoc/utils/file_watch.py +216 -7
- ripperdoc/utils/image_utils.py +125 -0
- ripperdoc/utils/log.py +30 -3
- ripperdoc/utils/lsp.py +812 -0
- ripperdoc/utils/mcp.py +80 -18
- ripperdoc/utils/message_formatting.py +7 -4
- ripperdoc/utils/messages.py +198 -33
- ripperdoc/utils/pending_messages.py +50 -0
- ripperdoc/utils/permissions/shell_command_validation.py +3 -3
- ripperdoc/utils/permissions/tool_permission_utils.py +180 -15
- ripperdoc/utils/platform.py +198 -0
- ripperdoc/utils/session_heatmap.py +242 -0
- ripperdoc/utils/session_history.py +2 -2
- ripperdoc/utils/session_stats.py +294 -0
- ripperdoc/utils/shell_utils.py +8 -5
- ripperdoc/utils/todo.py +0 -6
- {ripperdoc-0.2.9.dist-info → ripperdoc-0.3.0.dist-info}/METADATA +55 -17
- ripperdoc-0.3.0.dist-info/RECORD +136 -0
- {ripperdoc-0.2.9.dist-info → ripperdoc-0.3.0.dist-info}/WHEEL +1 -1
- ripperdoc/sdk/__init__.py +0 -9
- ripperdoc/sdk/client.py +0 -333
- ripperdoc-0.2.9.dist-info/RECORD +0 -123
- {ripperdoc-0.2.9.dist-info → ripperdoc-0.3.0.dist-info}/entry_points.txt +0 -0
- {ripperdoc-0.2.9.dist-info → ripperdoc-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {ripperdoc-0.2.9.dist-info → ripperdoc-0.3.0.dist-info}/top_level.txt +0 -0
ripperdoc/core/query.py
CHANGED
|
@@ -43,7 +43,12 @@ from ripperdoc.core.query_utils import (
|
|
|
43
43
|
from ripperdoc.core.tool import Tool, ToolProgress, ToolResult, ToolUseContext
|
|
44
44
|
from ripperdoc.utils.coerce import parse_optional_int
|
|
45
45
|
from ripperdoc.utils.context_length_errors import detect_context_length_error
|
|
46
|
-
from ripperdoc.utils.file_watch import
|
|
46
|
+
from ripperdoc.utils.file_watch import (
|
|
47
|
+
BoundedFileCache,
|
|
48
|
+
ChangedFileNotice,
|
|
49
|
+
detect_changed_files,
|
|
50
|
+
)
|
|
51
|
+
from ripperdoc.utils.pending_messages import PendingMessageQueue
|
|
47
52
|
from ripperdoc.utils.log import get_logger
|
|
48
53
|
from ripperdoc.utils.messages import (
|
|
49
54
|
AssistantMessage,
|
|
@@ -63,6 +68,10 @@ logger = get_logger()
|
|
|
63
68
|
|
|
64
69
|
DEFAULT_REQUEST_TIMEOUT_SEC = float(os.getenv("RIPPERDOC_API_TIMEOUT", "120"))
|
|
65
70
|
MAX_LLM_RETRIES = int(os.getenv("RIPPERDOC_MAX_RETRIES", "10"))
|
|
71
|
+
# Timeout for individual tool execution (can be overridden per tool if needed)
|
|
72
|
+
DEFAULT_TOOL_TIMEOUT_SEC = float(os.getenv("RIPPERDOC_TOOL_TIMEOUT", "300")) # 5 minutes
|
|
73
|
+
# Timeout for concurrent tool execution (total for all tools)
|
|
74
|
+
DEFAULT_CONCURRENT_TOOL_TIMEOUT_SEC = float(os.getenv("RIPPERDOC_CONCURRENT_TOOL_TIMEOUT", "600")) # 10 minutes
|
|
66
75
|
|
|
67
76
|
|
|
68
77
|
def infer_thinking_mode(model_profile: ModelProfile) -> Optional[str]:
|
|
@@ -81,6 +90,9 @@ def infer_thinking_mode(model_profile: ModelProfile) -> Optional[str]:
|
|
|
81
90
|
# Use explicit config if set
|
|
82
91
|
explicit_mode = model_profile.thinking_mode
|
|
83
92
|
if explicit_mode:
|
|
93
|
+
# "none", "disabled", "off" means thinking is explicitly disabled
|
|
94
|
+
if explicit_mode.lower() in ("disabled", "off"):
|
|
95
|
+
return None
|
|
84
96
|
return explicit_mode
|
|
85
97
|
|
|
86
98
|
# Auto-detect based on API base and model name
|
|
@@ -131,7 +143,7 @@ async def _check_tool_permissions(
|
|
|
131
143
|
parsed_input: Any,
|
|
132
144
|
query_context: "QueryContext",
|
|
133
145
|
can_use_tool_fn: Optional[ToolPermissionCallable],
|
|
134
|
-
) -> tuple[bool, Optional[str]]:
|
|
146
|
+
) -> tuple[bool, Optional[str], Optional[Any]]:
|
|
135
147
|
"""Evaluate whether a tool call is allowed."""
|
|
136
148
|
try:
|
|
137
149
|
if can_use_tool_fn is not None:
|
|
@@ -139,12 +151,16 @@ async def _check_tool_permissions(
|
|
|
139
151
|
if inspect.isawaitable(decision):
|
|
140
152
|
decision = await decision
|
|
141
153
|
if isinstance(decision, PermissionResult):
|
|
142
|
-
return decision.result, decision.message
|
|
154
|
+
return decision.result, decision.message, decision.updated_input
|
|
143
155
|
if isinstance(decision, dict) and "result" in decision:
|
|
144
|
-
return
|
|
156
|
+
return (
|
|
157
|
+
bool(decision.get("result")),
|
|
158
|
+
decision.get("message"),
|
|
159
|
+
decision.get("updated_input"),
|
|
160
|
+
)
|
|
145
161
|
if isinstance(decision, tuple) and len(decision) == 2:
|
|
146
|
-
return bool(decision[0]), decision[1]
|
|
147
|
-
return bool(decision), None
|
|
162
|
+
return bool(decision[0]), decision[1], None
|
|
163
|
+
return bool(decision), None, None
|
|
148
164
|
|
|
149
165
|
if not query_context.yolo_mode and tool.needs_permissions(parsed_input):
|
|
150
166
|
loop = asyncio.get_running_loop()
|
|
@@ -155,15 +171,15 @@ async def _check_tool_permissions(
|
|
|
155
171
|
)
|
|
156
172
|
prompt = f"Allow tool '{tool.name}' with input {input_preview}? [y/N]: "
|
|
157
173
|
response = await loop.run_in_executor(None, lambda: input(prompt))
|
|
158
|
-
return response.strip().lower() in ("y", "yes"), None
|
|
174
|
+
return response.strip().lower() in ("y", "yes"), None, None
|
|
159
175
|
|
|
160
|
-
return True, None
|
|
176
|
+
return True, None, None
|
|
161
177
|
except (TypeError, AttributeError, ValueError) as exc:
|
|
162
178
|
logger.warning(
|
|
163
179
|
f"Error checking permissions for tool '{tool.name}': {type(exc).__name__}: {exc}",
|
|
164
180
|
extra={"tool": getattr(tool, "name", None), "error_type": type(exc).__name__},
|
|
165
181
|
)
|
|
166
|
-
return False, None
|
|
182
|
+
return False, None, None
|
|
167
183
|
|
|
168
184
|
|
|
169
185
|
def _format_changed_file_notice(notices: List[ChangedFileNotice]) -> str:
|
|
@@ -182,6 +198,18 @@ def _format_changed_file_notice(notices: List[ChangedFileNotice]) -> str:
|
|
|
182
198
|
return "\n".join(lines)
|
|
183
199
|
|
|
184
200
|
|
|
201
|
+
def _append_hook_context(context: Dict[str, str], label: str, payload: Optional[str]) -> None:
|
|
202
|
+
"""Append hook-supplied context to the shared context dict."""
|
|
203
|
+
if not payload:
|
|
204
|
+
return
|
|
205
|
+
key = f"Hook:{label}"
|
|
206
|
+
existing = context.get(key)
|
|
207
|
+
if existing:
|
|
208
|
+
context[key] = f"{existing}\n{payload}"
|
|
209
|
+
else:
|
|
210
|
+
context[key] = payload
|
|
211
|
+
|
|
212
|
+
|
|
185
213
|
async def _run_tool_use_generator(
|
|
186
214
|
tool: Tool[Any, Any],
|
|
187
215
|
tool_use_id: str,
|
|
@@ -189,8 +217,14 @@ async def _run_tool_use_generator(
|
|
|
189
217
|
parsed_input: Any,
|
|
190
218
|
sibling_ids: set[str],
|
|
191
219
|
tool_context: ToolUseContext,
|
|
220
|
+
context: Dict[str, str],
|
|
192
221
|
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
193
222
|
"""Execute a single tool_use and yield progress/results."""
|
|
223
|
+
logger.debug(
|
|
224
|
+
"[query] _run_tool_use_generator ENTER: tool='%s' tool_use_id=%s",
|
|
225
|
+
tool_name,
|
|
226
|
+
tool_use_id,
|
|
227
|
+
)
|
|
194
228
|
# Get tool input as dict for hooks
|
|
195
229
|
tool_input_dict = (
|
|
196
230
|
parsed_input.model_dump()
|
|
@@ -221,8 +255,14 @@ async def _run_tool_use_generator(
|
|
|
221
255
|
)
|
|
222
256
|
# Re-parse the input with the updated values
|
|
223
257
|
try:
|
|
224
|
-
|
|
225
|
-
|
|
258
|
+
# Ensure updated_input is a dict, not a Pydantic model
|
|
259
|
+
updated_input = pre_result.updated_input
|
|
260
|
+
if hasattr(updated_input, "model_dump"):
|
|
261
|
+
updated_input = updated_input.model_dump()
|
|
262
|
+
elif not isinstance(updated_input, dict):
|
|
263
|
+
updated_input = {"value": str(updated_input)}
|
|
264
|
+
parsed_input = tool.input_schema(**updated_input)
|
|
265
|
+
tool_input_dict = updated_input
|
|
226
266
|
except (ValueError, TypeError) as exc:
|
|
227
267
|
logger.warning(
|
|
228
268
|
f"[query] Failed to apply updated input from hook: {exc}",
|
|
@@ -235,30 +275,58 @@ async def _run_tool_use_generator(
|
|
|
235
275
|
f"[query] PreToolUse hook added context for {tool_name}",
|
|
236
276
|
extra={"context": pre_result.additional_context[:100]},
|
|
237
277
|
)
|
|
278
|
+
_append_hook_context(context, f"PreToolUse:{tool_name}", pre_result.additional_context)
|
|
279
|
+
if pre_result.system_message:
|
|
280
|
+
_append_hook_context(context, f"PreToolUse:{tool_name}:system", pre_result.system_message)
|
|
238
281
|
|
|
239
282
|
tool_output = None
|
|
240
283
|
|
|
241
284
|
try:
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
285
|
+
logger.debug("[query] _run_tool_use_generator: BEFORE tool.call() for '%s'", tool_name)
|
|
286
|
+
# Wrap tool execution with timeout to prevent hangs
|
|
287
|
+
try:
|
|
288
|
+
async with asyncio.timeout(DEFAULT_TOOL_TIMEOUT_SEC):
|
|
289
|
+
async for output in tool.call(parsed_input, tool_context):
|
|
290
|
+
logger.debug(
|
|
291
|
+
"[query] _run_tool_use_generator: tool='%s' yielded output type=%s",
|
|
292
|
+
tool_name,
|
|
293
|
+
type(output).__name__,
|
|
294
|
+
)
|
|
295
|
+
if isinstance(output, ToolProgress):
|
|
296
|
+
yield create_progress_message(
|
|
297
|
+
tool_use_id=tool_use_id,
|
|
298
|
+
sibling_tool_use_ids=sibling_ids,
|
|
299
|
+
content=output.content,
|
|
300
|
+
is_subagent_message=getattr(output, 'is_subagent_message', False),
|
|
301
|
+
)
|
|
302
|
+
logger.debug(
|
|
303
|
+
f"[query] Progress from tool_use_id={tool_use_id}: {output.content}"
|
|
304
|
+
)
|
|
305
|
+
elif isinstance(output, ToolResult):
|
|
306
|
+
tool_output = output.data
|
|
307
|
+
result_content = output.result_for_assistant or str(output.data)
|
|
308
|
+
result_msg = tool_result_message(
|
|
309
|
+
tool_use_id, result_content, tool_use_result=output.data
|
|
310
|
+
)
|
|
311
|
+
yield result_msg
|
|
312
|
+
logger.debug(
|
|
313
|
+
f"[query] Tool completed tool_use_id={tool_use_id} name={tool_name} "
|
|
314
|
+
f"result_len={len(result_content)}"
|
|
315
|
+
)
|
|
316
|
+
except asyncio.TimeoutError:
|
|
317
|
+
logger.error(
|
|
318
|
+
f"[query] Tool '{tool_name}' timed out after {DEFAULT_TOOL_TIMEOUT_SEC}s",
|
|
319
|
+
extra={"tool": tool_name, "tool_use_id": tool_use_id},
|
|
320
|
+
)
|
|
321
|
+
yield tool_result_message(
|
|
322
|
+
tool_use_id,
|
|
323
|
+
f"Tool '{tool_name}' timed out after {DEFAULT_TOOL_TIMEOUT_SEC:.0f} seconds",
|
|
324
|
+
is_error=True,
|
|
325
|
+
)
|
|
326
|
+
return # Exit early on timeout
|
|
327
|
+
logger.debug("[query] _run_tool_use_generator: AFTER tool.call() loop for '%s'", tool_name)
|
|
261
328
|
except CancelledError:
|
|
329
|
+
logger.debug("[query] _run_tool_use_generator: tool='%s' CANCELLED", tool_name)
|
|
262
330
|
raise # Don't suppress task cancellation
|
|
263
331
|
except (RuntimeError, ValueError, TypeError, OSError, IOError, AttributeError, KeyError) as exc:
|
|
264
332
|
logger.warning(
|
|
@@ -271,9 +339,20 @@ async def _run_tool_use_generator(
|
|
|
271
339
|
yield tool_result_message(tool_use_id, f"Error executing tool: {str(exc)}", is_error=True)
|
|
272
340
|
|
|
273
341
|
# Run PostToolUse hooks
|
|
274
|
-
await hook_manager.run_post_tool_use_async(
|
|
342
|
+
post_result = await hook_manager.run_post_tool_use_async(
|
|
275
343
|
tool_name, tool_input_dict, tool_response=tool_output, tool_use_id=tool_use_id
|
|
276
344
|
)
|
|
345
|
+
if post_result.additional_context:
|
|
346
|
+
_append_hook_context(context, f"PostToolUse:{tool_name}", post_result.additional_context)
|
|
347
|
+
if post_result.system_message:
|
|
348
|
+
_append_hook_context(context, f"PostToolUse:{tool_name}:system", post_result.system_message)
|
|
349
|
+
if post_result.should_block:
|
|
350
|
+
reason = post_result.block_reason or post_result.stop_reason or "Blocked by hook."
|
|
351
|
+
yield create_user_message(f"PostToolUse hook blocked: {reason}")
|
|
352
|
+
|
|
353
|
+
logger.debug(
|
|
354
|
+
"[query] _run_tool_use_generator DONE: tool='%s' tool_use_id=%s", tool_name, tool_use_id
|
|
355
|
+
)
|
|
277
356
|
|
|
278
357
|
|
|
279
358
|
def _group_tool_calls_by_concurrency(prepared_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
@@ -306,9 +385,18 @@ async def _execute_tools_in_parallel(
|
|
|
306
385
|
items: List[Dict[str, Any]], tool_results: List[UserMessage]
|
|
307
386
|
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
308
387
|
"""Run tool generators concurrently."""
|
|
309
|
-
|
|
310
|
-
|
|
388
|
+
logger.debug("[query] _execute_tools_in_parallel ENTER: %d items", len(items))
|
|
389
|
+
valid_items = [call for call in items if call.get("generator")]
|
|
390
|
+
generators = [call["generator"] for call in valid_items]
|
|
391
|
+
tool_names = [call.get("tool_name", "unknown") for call in valid_items]
|
|
392
|
+
logger.debug(
|
|
393
|
+
"[query] _execute_tools_in_parallel: %d valid generators, tools=%s",
|
|
394
|
+
len(generators),
|
|
395
|
+
tool_names,
|
|
396
|
+
)
|
|
397
|
+
async for message in _run_concurrent_tool_uses(generators, tool_names, tool_results):
|
|
311
398
|
yield message
|
|
399
|
+
logger.debug("[query] _execute_tools_in_parallel DONE")
|
|
312
400
|
|
|
313
401
|
|
|
314
402
|
async def _run_tools_concurrently(
|
|
@@ -340,45 +428,164 @@ async def _run_tools_serially(
|
|
|
340
428
|
|
|
341
429
|
async def _run_concurrent_tool_uses(
|
|
342
430
|
generators: List[AsyncGenerator[Union[UserMessage, ProgressMessage], None]],
|
|
431
|
+
tool_names: List[str],
|
|
343
432
|
tool_results: List[UserMessage],
|
|
344
433
|
) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
|
|
345
|
-
"""Drain multiple tool generators concurrently and stream outputs."""
|
|
434
|
+
"""Drain multiple tool generators concurrently and stream outputs with overall timeout."""
|
|
435
|
+
logger.debug(
|
|
436
|
+
"[query] _run_concurrent_tool_uses ENTER: %d generators, tools=%s, timeout=%s",
|
|
437
|
+
len(generators),
|
|
438
|
+
tool_names,
|
|
439
|
+
DEFAULT_CONCURRENT_TOOL_TIMEOUT_SEC,
|
|
440
|
+
)
|
|
346
441
|
if not generators:
|
|
442
|
+
logger.debug("[query] _run_concurrent_tool_uses: no generators, returning")
|
|
347
443
|
return
|
|
444
|
+
yield # Make this a proper async generator that yields nothing (unreachable but required)
|
|
348
445
|
|
|
349
446
|
queue: asyncio.Queue[Optional[Union[UserMessage, ProgressMessage]]] = asyncio.Queue()
|
|
350
447
|
|
|
351
|
-
async def _consume(
|
|
448
|
+
async def _consume(
|
|
449
|
+
gen: AsyncGenerator[Union[UserMessage, ProgressMessage], None],
|
|
450
|
+
gen_index: int,
|
|
451
|
+
tool_name: str,
|
|
452
|
+
) -> Optional[Exception]:
|
|
453
|
+
"""Consume a tool generator and return any exception that occurred."""
|
|
454
|
+
logger.debug(
|
|
455
|
+
"[query] _consume START: tool='%s' index=%d gen=%s",
|
|
456
|
+
tool_name,
|
|
457
|
+
gen_index,
|
|
458
|
+
type(gen).__name__,
|
|
459
|
+
)
|
|
460
|
+
captured_exception: Optional[Exception] = None
|
|
461
|
+
message_count = 0
|
|
352
462
|
try:
|
|
463
|
+
logger.debug("[query] _consume: entering async for loop for '%s'", tool_name)
|
|
353
464
|
async for message in gen:
|
|
465
|
+
message_count += 1
|
|
466
|
+
msg_type = type(message).__name__
|
|
467
|
+
logger.debug(
|
|
468
|
+
"[query] _consume: tool='%s' received message #%d type=%s",
|
|
469
|
+
tool_name,
|
|
470
|
+
message_count,
|
|
471
|
+
msg_type,
|
|
472
|
+
)
|
|
354
473
|
await queue.put(message)
|
|
474
|
+
logger.debug("[query] _consume: tool='%s' put message to queue", tool_name)
|
|
475
|
+
logger.debug(
|
|
476
|
+
"[query] _consume: tool='%s' async for loop finished, total messages=%d",
|
|
477
|
+
tool_name,
|
|
478
|
+
message_count,
|
|
479
|
+
)
|
|
355
480
|
except asyncio.CancelledError:
|
|
481
|
+
logger.debug("[query] _consume: tool='%s' was CANCELLED", tool_name)
|
|
356
482
|
raise # Don't suppress cancellation
|
|
357
483
|
except (StopAsyncIteration, GeneratorExit):
|
|
484
|
+
logger.debug("[query] _consume: tool='%s' StopAsyncIteration/GeneratorExit", tool_name)
|
|
358
485
|
pass # Normal generator termination
|
|
359
|
-
except
|
|
486
|
+
except Exception as exc:
|
|
487
|
+
# Capture exception for reporting to caller
|
|
488
|
+
captured_exception = exc
|
|
360
489
|
logger.warning(
|
|
361
|
-
"[query] Error while consuming tool
|
|
490
|
+
"[query] Error while consuming tool '%s' (task %d): %s: %s",
|
|
491
|
+
tool_name,
|
|
492
|
+
gen_index,
|
|
362
493
|
type(exc).__name__,
|
|
363
494
|
exc,
|
|
364
495
|
)
|
|
365
496
|
finally:
|
|
497
|
+
logger.debug("[query] _consume FINALLY: tool='%s' putting None to queue", tool_name)
|
|
366
498
|
await queue.put(None)
|
|
499
|
+
logger.debug("[query] _consume DONE: tool='%s' messages=%d", tool_name, message_count)
|
|
500
|
+
return captured_exception
|
|
367
501
|
|
|
368
|
-
|
|
502
|
+
logger.debug("[query] _run_concurrent_tool_uses: creating %d tasks", len(generators))
|
|
503
|
+
tasks = [
|
|
504
|
+
asyncio.create_task(_consume(gen, i, tool_names[i])) for i, gen in enumerate(generators)
|
|
505
|
+
]
|
|
369
506
|
active = len(tasks)
|
|
507
|
+
logger.debug("[query] _run_concurrent_tool_uses: %d tasks created, entering while loop", active)
|
|
370
508
|
|
|
371
509
|
try:
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
510
|
+
# Add overall timeout for entire concurrent execution
|
|
511
|
+
async with asyncio.timeout(DEFAULT_CONCURRENT_TOOL_TIMEOUT_SEC):
|
|
512
|
+
while active:
|
|
513
|
+
logger.debug(
|
|
514
|
+
"[query] _run_concurrent_tool_uses: waiting for queue.get(), active=%d", active
|
|
515
|
+
)
|
|
516
|
+
try:
|
|
517
|
+
message = await asyncio.wait_for(
|
|
518
|
+
queue.get(), timeout=DEFAULT_CONCURRENT_TOOL_TIMEOUT_SEC
|
|
519
|
+
)
|
|
520
|
+
except asyncio.TimeoutError:
|
|
521
|
+
logger.error(
|
|
522
|
+
"[query] Concurrent tool execution timed out waiting for messages"
|
|
523
|
+
)
|
|
524
|
+
# Cancel all remaining tasks
|
|
525
|
+
for task in tasks:
|
|
526
|
+
if not task.done():
|
|
527
|
+
task.cancel()
|
|
528
|
+
raise
|
|
529
|
+
|
|
530
|
+
logger.debug(
|
|
531
|
+
"[query] _run_concurrent_tool_uses: got message type=%s, active=%d",
|
|
532
|
+
type(message).__name__ if message else "None",
|
|
533
|
+
active,
|
|
534
|
+
)
|
|
535
|
+
if message is None:
|
|
536
|
+
active -= 1
|
|
537
|
+
logger.debug(
|
|
538
|
+
"[query] _run_concurrent_tool_uses: None received, active now=%d", active
|
|
539
|
+
)
|
|
540
|
+
continue
|
|
541
|
+
if isinstance(message, UserMessage):
|
|
542
|
+
tool_results.append(message)
|
|
543
|
+
yield message
|
|
544
|
+
logger.debug("[query] _run_concurrent_tool_uses: while loop finished, all tools done")
|
|
545
|
+
except asyncio.TimeoutError:
|
|
546
|
+
logger.error(
|
|
547
|
+
f"[query] Concurrent tool execution timed out after {DEFAULT_CONCURRENT_TOOL_TIMEOUT_SEC}s",
|
|
548
|
+
extra={"tool_names": tool_names},
|
|
549
|
+
)
|
|
550
|
+
# Ensure all tasks are cancelled
|
|
551
|
+
for task in tasks:
|
|
552
|
+
if not task.done():
|
|
553
|
+
task.cancel()
|
|
554
|
+
raise
|
|
380
555
|
finally:
|
|
381
|
-
|
|
556
|
+
# Wait for all tasks and collect any exceptions
|
|
557
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
558
|
+
exceptions_found: List[tuple[int, str, BaseException]] = []
|
|
559
|
+
for i, result in enumerate(results):
|
|
560
|
+
if isinstance(result, asyncio.CancelledError):
|
|
561
|
+
continue
|
|
562
|
+
elif isinstance(result, Exception):
|
|
563
|
+
# Exception from gather itself (shouldn't happen with return_exceptions=True)
|
|
564
|
+
exceptions_found.append((i, tool_names[i], result))
|
|
565
|
+
elif result is not None:
|
|
566
|
+
# Exception returned by _consume
|
|
567
|
+
exceptions_found.append((i, tool_names[i], result))
|
|
568
|
+
|
|
569
|
+
# Log all exceptions for debugging
|
|
570
|
+
for i, name, exc in exceptions_found:
|
|
571
|
+
logger.warning(
|
|
572
|
+
"[query] Concurrent tool '%s' (task %d) failed: %s: %s",
|
|
573
|
+
name,
|
|
574
|
+
i,
|
|
575
|
+
type(exc).__name__,
|
|
576
|
+
exc,
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
# Re-raise first exception if any occurred, so caller knows something failed
|
|
580
|
+
if exceptions_found:
|
|
581
|
+
first_name = exceptions_found[0][1]
|
|
582
|
+
first_exc = exceptions_found[0][2]
|
|
583
|
+
logger.error(
|
|
584
|
+
"[query] %d tool(s) failed during concurrent execution, first error in '%s': %s",
|
|
585
|
+
len(exceptions_found),
|
|
586
|
+
first_name,
|
|
587
|
+
first_exc,
|
|
588
|
+
)
|
|
382
589
|
|
|
383
590
|
|
|
384
591
|
class ToolRegistry:
|
|
@@ -451,6 +658,9 @@ class ToolRegistry:
|
|
|
451
658
|
"""Activate deferred tools by name."""
|
|
452
659
|
activated: List[str] = []
|
|
453
660
|
missing: List[str] = []
|
|
661
|
+
|
|
662
|
+
# First pass: collect tools to activate (no mutations)
|
|
663
|
+
to_activate: List[str] = []
|
|
454
664
|
for raw_name in names:
|
|
455
665
|
name = (raw_name or "").strip()
|
|
456
666
|
if not name:
|
|
@@ -459,12 +669,17 @@ class ToolRegistry:
|
|
|
459
669
|
continue
|
|
460
670
|
tool = self._tool_map.get(name)
|
|
461
671
|
if tool:
|
|
462
|
-
|
|
463
|
-
self._active_set.add(name)
|
|
464
|
-
self._deferred.discard(name)
|
|
465
|
-
activated.append(name)
|
|
672
|
+
to_activate.append(name)
|
|
466
673
|
else:
|
|
467
674
|
missing.append(name)
|
|
675
|
+
|
|
676
|
+
# Second pass: atomically update all data structures
|
|
677
|
+
if to_activate:
|
|
678
|
+
self._active.extend(to_activate)
|
|
679
|
+
self._active_set.update(to_activate)
|
|
680
|
+
self._deferred.difference_update(to_activate)
|
|
681
|
+
activated.extend(to_activate)
|
|
682
|
+
|
|
468
683
|
return activated, missing
|
|
469
684
|
|
|
470
685
|
def iter_named_tools(self) -> Iterable[tuple[str, Tool[Any, Any]]]:
|
|
@@ -537,6 +752,12 @@ class QueryContext:
|
|
|
537
752
|
verbose: bool = False,
|
|
538
753
|
pause_ui: Optional[Callable[[], None]] = None,
|
|
539
754
|
resume_ui: Optional[Callable[[], None]] = None,
|
|
755
|
+
stop_hook: str = "stop",
|
|
756
|
+
file_cache_max_entries: int = 500,
|
|
757
|
+
file_cache_max_memory_mb: float = 50.0,
|
|
758
|
+
pending_message_queue: Optional[PendingMessageQueue] = None,
|
|
759
|
+
max_turns: Optional[int] = None,
|
|
760
|
+
permission_mode: str = "default",
|
|
540
761
|
) -> None:
|
|
541
762
|
self.tool_registry = ToolRegistry(tools)
|
|
542
763
|
self.max_thinking_tokens = max_thinking_tokens
|
|
@@ -544,9 +765,20 @@ class QueryContext:
|
|
|
544
765
|
self.model = model
|
|
545
766
|
self.verbose = verbose
|
|
546
767
|
self.abort_controller = asyncio.Event()
|
|
547
|
-
self.
|
|
768
|
+
self.pending_message_queue: PendingMessageQueue = (
|
|
769
|
+
pending_message_queue if pending_message_queue is not None else PendingMessageQueue()
|
|
770
|
+
)
|
|
771
|
+
# Use BoundedFileCache instead of plain Dict to prevent unbounded growth
|
|
772
|
+
self.file_state_cache: BoundedFileCache = BoundedFileCache(
|
|
773
|
+
max_entries=file_cache_max_entries,
|
|
774
|
+
max_memory_mb=file_cache_max_memory_mb,
|
|
775
|
+
)
|
|
548
776
|
self.pause_ui = pause_ui
|
|
549
777
|
self.resume_ui = resume_ui
|
|
778
|
+
self.stop_hook = stop_hook
|
|
779
|
+
self.stop_hook_active = False
|
|
780
|
+
self.max_turns = max_turns
|
|
781
|
+
self.permission_mode = permission_mode
|
|
550
782
|
|
|
551
783
|
@property
|
|
552
784
|
def tools(self) -> List[Tool[Any, Any]]:
|
|
@@ -566,6 +798,22 @@ class QueryContext:
|
|
|
566
798
|
"""Return all known tools (active + deferred)."""
|
|
567
799
|
return self.tool_registry.all_tools
|
|
568
800
|
|
|
801
|
+
def get_memory_stats(self) -> Dict[str, Any]:
|
|
802
|
+
"""Return memory usage statistics for monitoring."""
|
|
803
|
+
return {
|
|
804
|
+
"file_cache": self.file_state_cache.stats(),
|
|
805
|
+
"tool_count": len(self.tool_registry.all_tools),
|
|
806
|
+
"active_tool_count": len(self.tool_registry.active_tools),
|
|
807
|
+
}
|
|
808
|
+
|
|
809
|
+
def drain_pending_messages(self) -> List[UserMessage]:
|
|
810
|
+
"""Drain queued messages waiting to be injected into the conversation."""
|
|
811
|
+
return self.pending_message_queue.drain()
|
|
812
|
+
|
|
813
|
+
def enqueue_user_message(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> None:
|
|
814
|
+
"""Queue a user-style message to inject once the current loop finishes."""
|
|
815
|
+
self.pending_message_queue.enqueue_text(text, metadata=metadata)
|
|
816
|
+
|
|
569
817
|
|
|
570
818
|
async def query_llm(
|
|
571
819
|
messages: List[Union[UserMessage, AssistantMessage, ProgressMessage]],
|
|
@@ -598,7 +846,6 @@ async def query_llm(
|
|
|
598
846
|
AssistantMessage with the model's response
|
|
599
847
|
"""
|
|
600
848
|
request_timeout = request_timeout or DEFAULT_REQUEST_TIMEOUT_SEC
|
|
601
|
-
request_timeout = request_timeout or DEFAULT_REQUEST_TIMEOUT_SEC
|
|
602
849
|
model_profile = resolve_model_profile(model)
|
|
603
850
|
|
|
604
851
|
# Normalize messages based on protocol family (Anthropic allows tool blocks; OpenAI-style prefers text-only)
|
|
@@ -657,15 +904,29 @@ async def query_llm(
|
|
|
657
904
|
start_time = time.time()
|
|
658
905
|
|
|
659
906
|
try:
|
|
660
|
-
|
|
907
|
+
try:
|
|
908
|
+
client: Optional[ProviderClient] = get_provider_client(model_profile.provider)
|
|
909
|
+
except RuntimeError as exc:
|
|
910
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
911
|
+
error_msg = create_assistant_message(
|
|
912
|
+
content=str(exc),
|
|
913
|
+
duration_ms=duration_ms,
|
|
914
|
+
model=model_profile.model,
|
|
915
|
+
)
|
|
916
|
+
error_msg.is_api_error_message = True
|
|
917
|
+
return error_msg
|
|
661
918
|
if client is None:
|
|
662
919
|
duration_ms = (time.time() - start_time) * 1000
|
|
920
|
+
provider_label = getattr(model_profile.provider, "value", None) or str(
|
|
921
|
+
model_profile.provider
|
|
922
|
+
)
|
|
663
923
|
error_msg = create_assistant_message(
|
|
664
924
|
content=(
|
|
665
|
-
"
|
|
666
|
-
"
|
|
925
|
+
f"No provider client available for '{provider_label}'. "
|
|
926
|
+
"Check your model configuration and provider dependencies."
|
|
667
927
|
),
|
|
668
928
|
duration_ms=duration_ms,
|
|
929
|
+
model=model_profile.model,
|
|
669
930
|
)
|
|
670
931
|
error_msg.is_api_error_message = True
|
|
671
932
|
return error_msg
|
|
@@ -706,6 +967,7 @@ async def query_llm(
|
|
|
706
967
|
content=provider_response.content_blocks,
|
|
707
968
|
duration_ms=provider_response.duration_ms,
|
|
708
969
|
metadata=metadata,
|
|
970
|
+
model=model_profile.model,
|
|
709
971
|
)
|
|
710
972
|
error_msg.is_api_error_message = True
|
|
711
973
|
return error_msg
|
|
@@ -715,6 +977,13 @@ async def query_llm(
|
|
|
715
977
|
cost_usd=provider_response.cost_usd,
|
|
716
978
|
duration_ms=provider_response.duration_ms,
|
|
717
979
|
metadata=provider_response.metadata,
|
|
980
|
+
model=model_profile.model,
|
|
981
|
+
input_tokens=provider_response.usage_tokens.get("input_tokens", 0),
|
|
982
|
+
output_tokens=provider_response.usage_tokens.get("output_tokens", 0),
|
|
983
|
+
cache_read_tokens=provider_response.usage_tokens.get("cache_read_input_tokens", 0),
|
|
984
|
+
cache_creation_tokens=provider_response.usage_tokens.get(
|
|
985
|
+
"cache_creation_input_tokens", 0
|
|
986
|
+
),
|
|
718
987
|
)
|
|
719
988
|
|
|
720
989
|
except CancelledError:
|
|
@@ -756,7 +1025,10 @@ async def query_llm(
|
|
|
756
1025
|
)
|
|
757
1026
|
|
|
758
1027
|
error_msg = create_assistant_message(
|
|
759
|
-
content=content,
|
|
1028
|
+
content=content,
|
|
1029
|
+
duration_ms=duration_ms,
|
|
1030
|
+
metadata=error_metadata,
|
|
1031
|
+
model=model_profile.model,
|
|
760
1032
|
)
|
|
761
1033
|
error_msg.is_api_error_message = True
|
|
762
1034
|
return error_msg
|
|
@@ -806,7 +1078,7 @@ async def _run_query_iteration(
|
|
|
806
1078
|
Yields:
|
|
807
1079
|
Messages (progress, assistant, tool results) as they are generated
|
|
808
1080
|
"""
|
|
809
|
-
logger.
|
|
1081
|
+
logger.info(f"[query] Starting iteration {iteration}/{MAX_QUERY_ITERATIONS}")
|
|
810
1082
|
|
|
811
1083
|
# Check for file changes at the start of each iteration
|
|
812
1084
|
change_notices = detect_changed_files(query_context.file_state_cache)
|
|
@@ -830,21 +1102,25 @@ async def _run_query_iteration(
|
|
|
830
1102
|
)
|
|
831
1103
|
|
|
832
1104
|
# Stream LLM response
|
|
833
|
-
progress_queue: asyncio.Queue[Optional[ProgressMessage]] = asyncio.Queue()
|
|
1105
|
+
progress_queue: asyncio.Queue[Optional[ProgressMessage]] = asyncio.Queue(maxsize=1000)
|
|
834
1106
|
|
|
835
1107
|
async def _stream_progress(chunk: str) -> None:
|
|
836
1108
|
if not chunk:
|
|
837
1109
|
return
|
|
838
1110
|
try:
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
content=chunk,
|
|
844
|
-
)
|
|
1111
|
+
msg = create_progress_message(
|
|
1112
|
+
tool_use_id="stream",
|
|
1113
|
+
sibling_tool_use_ids=set(),
|
|
1114
|
+
content=chunk,
|
|
845
1115
|
)
|
|
846
|
-
|
|
847
|
-
|
|
1116
|
+
try:
|
|
1117
|
+
progress_queue.put_nowait(msg)
|
|
1118
|
+
except asyncio.QueueFull:
|
|
1119
|
+
# Queue full - wait with timeout instead of dropping immediately
|
|
1120
|
+
try:
|
|
1121
|
+
await asyncio.wait_for(progress_queue.put(msg), timeout=0.5)
|
|
1122
|
+
except asyncio.TimeoutError:
|
|
1123
|
+
logger.warning("[query] Progress queue full after timeout, dropping chunk")
|
|
848
1124
|
except (RuntimeError, ValueError) as exc:
|
|
849
1125
|
logger.warning("[query] Failed to enqueue stream progress chunk: %s", exc)
|
|
850
1126
|
|
|
@@ -863,6 +1139,8 @@ async def _run_query_iteration(
|
|
|
863
1139
|
)
|
|
864
1140
|
)
|
|
865
1141
|
|
|
1142
|
+
logger.debug("[query] Created query_llm task, waiting for response...")
|
|
1143
|
+
|
|
866
1144
|
assistant_message: Optional[AssistantMessage] = None
|
|
867
1145
|
|
|
868
1146
|
# Wait for LLM response while yielding progress
|
|
@@ -873,7 +1151,7 @@ async def _run_query_iteration(
|
|
|
873
1151
|
await assistant_task
|
|
874
1152
|
except CancelledError:
|
|
875
1153
|
pass
|
|
876
|
-
yield create_assistant_message(INTERRUPT_MESSAGE)
|
|
1154
|
+
yield create_assistant_message(INTERRUPT_MESSAGE, model=model_profile.model)
|
|
877
1155
|
result.should_stop = True
|
|
878
1156
|
return
|
|
879
1157
|
if assistant_task.done():
|
|
@@ -883,23 +1161,23 @@ async def _run_query_iteration(
|
|
|
883
1161
|
progress = progress_queue.get_nowait()
|
|
884
1162
|
except asyncio.QueueEmpty:
|
|
885
1163
|
waiter = asyncio.create_task(progress_queue.get())
|
|
886
|
-
|
|
1164
|
+
abort_waiter = asyncio.create_task(query_context.abort_controller.wait())
|
|
887
1165
|
done, pending = await asyncio.wait(
|
|
888
|
-
{assistant_task, waiter},
|
|
1166
|
+
{assistant_task, waiter, abort_waiter},
|
|
889
1167
|
return_when=asyncio.FIRST_COMPLETED,
|
|
890
|
-
timeout=0.1, # Check abort_controller every 100ms
|
|
891
1168
|
)
|
|
892
|
-
|
|
893
|
-
#
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
1169
|
+
for task in pending:
|
|
1170
|
+
# Don't cancel assistant_task here - it should only be cancelled
|
|
1171
|
+
# through abort_controller in the main loop
|
|
1172
|
+
if task is not assistant_task:
|
|
1173
|
+
task.cancel()
|
|
1174
|
+
try:
|
|
1175
|
+
await task
|
|
1176
|
+
except asyncio.CancelledError:
|
|
1177
|
+
pass
|
|
1178
|
+
if abort_waiter in done:
|
|
899
1179
|
continue
|
|
900
1180
|
if assistant_task in done:
|
|
901
|
-
for task in pending:
|
|
902
|
-
task.cancel()
|
|
903
1181
|
assistant_message = await assistant_task
|
|
904
1182
|
break
|
|
905
1183
|
progress = waiter.result()
|
|
@@ -912,12 +1190,13 @@ async def _run_query_iteration(
|
|
|
912
1190
|
if residual:
|
|
913
1191
|
yield residual
|
|
914
1192
|
|
|
915
|
-
|
|
1193
|
+
if assistant_message is None:
|
|
1194
|
+
raise RuntimeError("assistant_message was unexpectedly None after LLM query")
|
|
916
1195
|
result.assistant_message = assistant_message
|
|
917
1196
|
|
|
918
1197
|
# Check for abort
|
|
919
1198
|
if query_context.abort_controller.is_set():
|
|
920
|
-
yield create_assistant_message(INTERRUPT_MESSAGE)
|
|
1199
|
+
yield create_assistant_message(INTERRUPT_MESSAGE, model=model_profile.model)
|
|
921
1200
|
result.should_stop = True
|
|
922
1201
|
return
|
|
923
1202
|
|
|
@@ -936,7 +1215,39 @@ async def _run_query_iteration(
|
|
|
936
1215
|
)
|
|
937
1216
|
|
|
938
1217
|
if not tool_use_blocks:
|
|
939
|
-
logger.debug(
|
|
1218
|
+
logger.debug(
|
|
1219
|
+
"[query] No tool_use blocks; running stop hook and returning response to user."
|
|
1220
|
+
)
|
|
1221
|
+
stop_hook = query_context.stop_hook
|
|
1222
|
+
logger.debug(
|
|
1223
|
+
f"[query] stop_hook={stop_hook}, stop_hook_active={query_context.stop_hook_active}"
|
|
1224
|
+
)
|
|
1225
|
+
logger.debug("[query] BEFORE calling hook_manager.run_stop_async")
|
|
1226
|
+
stop_result = (
|
|
1227
|
+
await hook_manager.run_subagent_stop_async(
|
|
1228
|
+
stop_hook_active=query_context.stop_hook_active
|
|
1229
|
+
)
|
|
1230
|
+
if stop_hook == "subagent"
|
|
1231
|
+
else await hook_manager.run_stop_async(stop_hook_active=query_context.stop_hook_active)
|
|
1232
|
+
)
|
|
1233
|
+
logger.debug("[query] AFTER calling hook_manager.run_stop_async")
|
|
1234
|
+
logger.debug("[query] Checking additional_context")
|
|
1235
|
+
if stop_result.additional_context:
|
|
1236
|
+
_append_hook_context(context, f"{stop_hook}:context", stop_result.additional_context)
|
|
1237
|
+
logger.debug("[query] Checking system_message")
|
|
1238
|
+
if stop_result.system_message:
|
|
1239
|
+
_append_hook_context(context, f"{stop_hook}:system", stop_result.system_message)
|
|
1240
|
+
logger.debug("[query] Checking should_block")
|
|
1241
|
+
if stop_result.should_block:
|
|
1242
|
+
reason = stop_result.block_reason or stop_result.stop_reason or "Blocked by hook."
|
|
1243
|
+
result.tool_results = [create_user_message(f"{stop_hook} hook blocked: {reason}")]
|
|
1244
|
+
for msg in result.tool_results:
|
|
1245
|
+
yield msg
|
|
1246
|
+
query_context.stop_hook_active = True
|
|
1247
|
+
result.should_stop = False
|
|
1248
|
+
return
|
|
1249
|
+
logger.debug("[query] Setting should_stop=True and returning")
|
|
1250
|
+
query_context.stop_hook_active = False
|
|
940
1251
|
result.should_stop = True
|
|
941
1252
|
return
|
|
942
1253
|
|
|
@@ -956,13 +1267,25 @@ async def _run_query_iteration(
|
|
|
956
1267
|
tool_use_id = getattr(tool_use, "tool_use_id", None) or getattr(tool_use, "id", None) or ""
|
|
957
1268
|
tool_input = getattr(tool_use, "input", {}) or {}
|
|
958
1269
|
|
|
1270
|
+
# Handle case where input is a Pydantic model instead of a dict
|
|
1271
|
+
# This can happen when the API response contains structured tool input objects
|
|
1272
|
+
# Always try to convert if it has model_dump or dict methods
|
|
1273
|
+
if tool_input and hasattr(tool_input, "model_dump"):
|
|
1274
|
+
tool_input = tool_input.model_dump()
|
|
1275
|
+
elif tool_input and hasattr(tool_input, "dict") and callable(getattr(tool_input, "dict")):
|
|
1276
|
+
tool_input = tool_input.dict()
|
|
1277
|
+
elif tool_input and not isinstance(tool_input, dict):
|
|
1278
|
+
# Last resort: convert unknown type to string representation
|
|
1279
|
+
tool_input = {"value": str(tool_input)}
|
|
1280
|
+
|
|
959
1281
|
tool, missing_msg = _resolve_tool(query_context.tool_registry, tool_name, tool_use_id)
|
|
960
1282
|
if missing_msg:
|
|
961
1283
|
logger.warning(f"[query] Tool '{tool_name}' not found for tool_use_id={tool_use_id}")
|
|
962
1284
|
tool_results.append(missing_msg)
|
|
963
1285
|
yield missing_msg
|
|
964
1286
|
continue
|
|
965
|
-
|
|
1287
|
+
if tool is None:
|
|
1288
|
+
raise RuntimeError(f"Tool '{tool_name}' resolved to None unexpectedly")
|
|
966
1289
|
|
|
967
1290
|
try:
|
|
968
1291
|
parsed_input = tool.input_schema(**tool_input)
|
|
@@ -972,14 +1295,17 @@ async def _run_query_iteration(
|
|
|
972
1295
|
)
|
|
973
1296
|
|
|
974
1297
|
tool_context = ToolUseContext(
|
|
1298
|
+
message_id=tool_use_id, # Set message_id for parent_tool_use_id tracking
|
|
975
1299
|
yolo_mode=query_context.yolo_mode,
|
|
976
1300
|
verbose=query_context.verbose,
|
|
977
1301
|
permission_checker=can_use_tool_fn,
|
|
978
1302
|
tool_registry=query_context.tool_registry,
|
|
979
1303
|
file_state_cache=query_context.file_state_cache,
|
|
1304
|
+
conversation_messages=messages,
|
|
980
1305
|
abort_signal=query_context.abort_controller,
|
|
981
1306
|
pause_ui=query_context.pause_ui,
|
|
982
1307
|
resume_ui=query_context.resume_ui,
|
|
1308
|
+
pending_message_queue=query_context.pending_message_queue,
|
|
983
1309
|
)
|
|
984
1310
|
|
|
985
1311
|
validation = await tool.validate_input(parsed_input, tool_context)
|
|
@@ -997,7 +1323,7 @@ async def _run_query_iteration(
|
|
|
997
1323
|
continue
|
|
998
1324
|
|
|
999
1325
|
if not query_context.yolo_mode or can_use_tool_fn is not None:
|
|
1000
|
-
allowed, denial_message = await _check_tool_permissions(
|
|
1326
|
+
allowed, denial_message, updated_input = await _check_tool_permissions(
|
|
1001
1327
|
tool, parsed_input, query_context, can_use_tool_fn
|
|
1002
1328
|
)
|
|
1003
1329
|
if not allowed:
|
|
@@ -1010,9 +1336,39 @@ async def _run_query_iteration(
|
|
|
1010
1336
|
yield denial_msg
|
|
1011
1337
|
permission_denied = True
|
|
1012
1338
|
break
|
|
1339
|
+
if updated_input:
|
|
1340
|
+
try:
|
|
1341
|
+
# Ensure updated_input is a dict, not a Pydantic model
|
|
1342
|
+
normalized_input = updated_input
|
|
1343
|
+
if hasattr(normalized_input, "model_dump"):
|
|
1344
|
+
normalized_input = normalized_input.model_dump()
|
|
1345
|
+
elif not isinstance(normalized_input, dict):
|
|
1346
|
+
normalized_input = {"value": str(normalized_input)}
|
|
1347
|
+
parsed_input = tool.input_schema(**normalized_input)
|
|
1348
|
+
except ValidationError as ve:
|
|
1349
|
+
detail_text = format_pydantic_errors(ve)
|
|
1350
|
+
error_msg = tool_result_message(
|
|
1351
|
+
tool_use_id,
|
|
1352
|
+
f"Invalid permission-updated input for tool '{tool_name}': {detail_text}",
|
|
1353
|
+
is_error=True,
|
|
1354
|
+
)
|
|
1355
|
+
tool_results.append(error_msg)
|
|
1356
|
+
yield error_msg
|
|
1357
|
+
continue
|
|
1358
|
+
validation = await tool.validate_input(parsed_input, tool_context)
|
|
1359
|
+
if not validation.result:
|
|
1360
|
+
error_msg = tool_result_message(
|
|
1361
|
+
tool_use_id,
|
|
1362
|
+
validation.message or "Tool input validation failed.",
|
|
1363
|
+
is_error=True,
|
|
1364
|
+
)
|
|
1365
|
+
tool_results.append(error_msg)
|
|
1366
|
+
yield error_msg
|
|
1367
|
+
continue
|
|
1013
1368
|
|
|
1014
1369
|
prepared_calls.append(
|
|
1015
1370
|
{
|
|
1371
|
+
"tool_name": tool_name,
|
|
1016
1372
|
"is_concurrency_safe": tool.is_concurrency_safe(),
|
|
1017
1373
|
"generator": _run_tool_use_generator(
|
|
1018
1374
|
tool,
|
|
@@ -1021,6 +1377,7 @@ async def _run_query_iteration(
|
|
|
1021
1377
|
parsed_input,
|
|
1022
1378
|
sibling_ids,
|
|
1023
1379
|
tool_context,
|
|
1380
|
+
context,
|
|
1024
1381
|
),
|
|
1025
1382
|
}
|
|
1026
1383
|
)
|
|
@@ -1075,7 +1432,7 @@ async def _run_query_iteration(
|
|
|
1075
1432
|
|
|
1076
1433
|
# Check for abort after tools
|
|
1077
1434
|
if query_context.abort_controller.is_set():
|
|
1078
|
-
yield create_assistant_message(INTERRUPT_MESSAGE_FOR_TOOL_USE)
|
|
1435
|
+
yield create_assistant_message(INTERRUPT_MESSAGE_FOR_TOOL_USE, model=model_profile.model)
|
|
1079
1436
|
result.tool_results = tool_results
|
|
1080
1437
|
result.should_stop = True
|
|
1081
1438
|
return
|
|
@@ -1099,6 +1456,26 @@ async def query(
|
|
|
1099
1456
|
3. Executes tools
|
|
1100
1457
|
4. Continues the conversation in a loop until no more tool calls
|
|
1101
1458
|
|
|
1459
|
+
Args:
|
|
1460
|
+
messages: Conversation history
|
|
1461
|
+
system_prompt: Base system prompt
|
|
1462
|
+
context: Additional context dictionary
|
|
1463
|
+
query_context: Query configuration
|
|
1464
|
+
can_use_tool_fn: Optional function to check tool permissions
|
|
1465
|
+
|
|
1466
|
+
Yields:
|
|
1467
|
+
Messages (user, assistant, progress) as they are generated
|
|
1468
|
+
"""
|
|
1469
|
+
# Resolve model once for use in messages (e.g., max iterations, errors)
|
|
1470
|
+
model_profile = resolve_model_profile(query_context.model)
|
|
1471
|
+
"""Execute a query with tool support.
|
|
1472
|
+
|
|
1473
|
+
This is the main query loop that:
|
|
1474
|
+
1. Sends messages to the AI
|
|
1475
|
+
2. Handles tool use responses
|
|
1476
|
+
3. Executes tools
|
|
1477
|
+
4. Continues the conversation in a loop until no more tool calls
|
|
1478
|
+
|
|
1102
1479
|
Args:
|
|
1103
1480
|
messages: Conversation history
|
|
1104
1481
|
system_prompt: Base system prompt
|
|
@@ -1116,6 +1493,8 @@ async def query(
|
|
|
1116
1493
|
"tool_count": len(query_context.tools),
|
|
1117
1494
|
"yolo_mode": query_context.yolo_mode,
|
|
1118
1495
|
"model_pointer": query_context.model,
|
|
1496
|
+
"max_turns": query_context.max_turns,
|
|
1497
|
+
"permission_mode": query_context.permission_mode,
|
|
1119
1498
|
},
|
|
1120
1499
|
)
|
|
1121
1500
|
# Work on a copy so external mutations (e.g., UI appending messages while consuming)
|
|
@@ -1123,6 +1502,13 @@ async def query(
|
|
|
1123
1502
|
messages = list(messages)
|
|
1124
1503
|
|
|
1125
1504
|
for iteration in range(1, MAX_QUERY_ITERATIONS + 1):
|
|
1505
|
+
# Inject any pending messages queued by background events or user interjections
|
|
1506
|
+
pending_messages = query_context.drain_pending_messages()
|
|
1507
|
+
if pending_messages:
|
|
1508
|
+
messages.extend(pending_messages)
|
|
1509
|
+
for pending in pending_messages:
|
|
1510
|
+
yield pending
|
|
1511
|
+
|
|
1126
1512
|
result = IterationResult()
|
|
1127
1513
|
|
|
1128
1514
|
async for msg in _run_query_iteration(
|
|
@@ -1137,6 +1523,20 @@ async def query(
|
|
|
1137
1523
|
yield msg
|
|
1138
1524
|
|
|
1139
1525
|
if result.should_stop:
|
|
1526
|
+
# Before stopping, check if new pending messages arrived during this iteration.
|
|
1527
|
+
trailing_pending = query_context.drain_pending_messages()
|
|
1528
|
+
if trailing_pending:
|
|
1529
|
+
# type: ignore[operator,list-item]
|
|
1530
|
+
next_messages = (
|
|
1531
|
+
messages + [result.assistant_message] + result.tool_results
|
|
1532
|
+
if result.assistant_message is not None
|
|
1533
|
+
else messages + result.tool_results # type: ignore[operator]
|
|
1534
|
+
) # type: ignore[operator]
|
|
1535
|
+
next_messages = next_messages + trailing_pending # type: ignore[operator,list-item]
|
|
1536
|
+
for pending in trailing_pending:
|
|
1537
|
+
yield pending
|
|
1538
|
+
messages = next_messages
|
|
1539
|
+
continue
|
|
1140
1540
|
return
|
|
1141
1541
|
|
|
1142
1542
|
# Update messages for next iteration
|
|
@@ -1144,6 +1544,7 @@ async def query(
|
|
|
1144
1544
|
messages = messages + [result.assistant_message] + result.tool_results # type: ignore[operator]
|
|
1145
1545
|
else:
|
|
1146
1546
|
messages = messages + result.tool_results # type: ignore[operator]
|
|
1547
|
+
|
|
1147
1548
|
logger.debug(
|
|
1148
1549
|
f"[query] Continuing loop with {len(messages)} messages after tools; "
|
|
1149
1550
|
f"tool_results_count={len(result.tool_results)}"
|
|
@@ -1155,5 +1556,6 @@ async def query(
|
|
|
1155
1556
|
)
|
|
1156
1557
|
yield create_assistant_message(
|
|
1157
1558
|
f"Reached maximum query iterations ({MAX_QUERY_ITERATIONS}). "
|
|
1158
|
-
"Please continue the conversation to proceed."
|
|
1559
|
+
"Please continue the conversation to proceed.",
|
|
1560
|
+
model=model_profile.model,
|
|
1159
1561
|
)
|