ripperdoc 0.1.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. ripperdoc/__init__.py +1 -1
  2. ripperdoc/cli/cli.py +75 -15
  3. ripperdoc/cli/commands/__init__.py +4 -0
  4. ripperdoc/cli/commands/agents_cmd.py +23 -1
  5. ripperdoc/cli/commands/context_cmd.py +13 -3
  6. ripperdoc/cli/commands/cost_cmd.py +1 -1
  7. ripperdoc/cli/commands/doctor_cmd.py +200 -0
  8. ripperdoc/cli/commands/memory_cmd.py +209 -0
  9. ripperdoc/cli/commands/models_cmd.py +25 -0
  10. ripperdoc/cli/commands/resume_cmd.py +3 -3
  11. ripperdoc/cli/commands/status_cmd.py +5 -5
  12. ripperdoc/cli/commands/tasks_cmd.py +32 -5
  13. ripperdoc/cli/ui/context_display.py +4 -3
  14. ripperdoc/cli/ui/rich_ui.py +205 -43
  15. ripperdoc/cli/ui/spinner.py +3 -4
  16. ripperdoc/core/agents.py +10 -6
  17. ripperdoc/core/config.py +48 -3
  18. ripperdoc/core/default_tools.py +26 -6
  19. ripperdoc/core/permissions.py +19 -0
  20. ripperdoc/core/query.py +238 -302
  21. ripperdoc/core/query_utils.py +537 -0
  22. ripperdoc/core/system_prompt.py +2 -1
  23. ripperdoc/core/tool.py +14 -1
  24. ripperdoc/sdk/client.py +1 -1
  25. ripperdoc/tools/background_shell.py +9 -3
  26. ripperdoc/tools/bash_tool.py +19 -4
  27. ripperdoc/tools/file_edit_tool.py +9 -2
  28. ripperdoc/tools/file_read_tool.py +9 -2
  29. ripperdoc/tools/file_write_tool.py +15 -2
  30. ripperdoc/tools/glob_tool.py +57 -17
  31. ripperdoc/tools/grep_tool.py +9 -2
  32. ripperdoc/tools/ls_tool.py +244 -75
  33. ripperdoc/tools/mcp_tools.py +47 -19
  34. ripperdoc/tools/multi_edit_tool.py +13 -2
  35. ripperdoc/tools/notebook_edit_tool.py +9 -6
  36. ripperdoc/tools/task_tool.py +20 -5
  37. ripperdoc/tools/todo_tool.py +163 -29
  38. ripperdoc/tools/tool_search_tool.py +15 -4
  39. ripperdoc/utils/git_utils.py +276 -0
  40. ripperdoc/utils/json_utils.py +28 -0
  41. ripperdoc/utils/log.py +130 -29
  42. ripperdoc/utils/mcp.py +83 -10
  43. ripperdoc/utils/memory.py +14 -1
  44. ripperdoc/utils/message_compaction.py +51 -14
  45. ripperdoc/utils/messages.py +63 -4
  46. ripperdoc/utils/output_utils.py +36 -9
  47. ripperdoc/utils/permissions/path_validation_utils.py +6 -0
  48. ripperdoc/utils/safe_get_cwd.py +4 -0
  49. ripperdoc/utils/session_history.py +27 -9
  50. ripperdoc/utils/todo.py +2 -2
  51. {ripperdoc-0.1.0.dist-info → ripperdoc-0.2.2.dist-info}/METADATA +4 -2
  52. ripperdoc-0.2.2.dist-info/RECORD +86 -0
  53. ripperdoc-0.1.0.dist-info/RECORD +0 -81
  54. {ripperdoc-0.1.0.dist-info → ripperdoc-0.2.2.dist-info}/WHEEL +0 -0
  55. {ripperdoc-0.1.0.dist-info → ripperdoc-0.2.2.dist-info}/entry_points.txt +0 -0
  56. {ripperdoc-0.1.0.dist-info → ripperdoc-0.2.2.dist-info}/licenses/LICENSE +0 -0
  57. {ripperdoc-0.1.0.dist-info → ripperdoc-0.2.2.dist-info}/top_level.txt +0 -0
ripperdoc/core/query.py CHANGED
@@ -6,86 +6,100 @@ the query-response loop including tool execution.
6
6
 
7
7
  import asyncio
8
8
  import inspect
9
- from typing import AsyncGenerator, List, Optional, Dict, Any, Union, Iterable, Tuple
9
+ import time
10
+ from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Tuple, Union, cast
11
+
10
12
  from anthropic import AsyncAnthropic
11
13
  from openai import AsyncOpenAI
14
+ from pydantic import ValidationError
12
15
 
13
- from ripperdoc.core.tool import (
14
- Tool,
15
- ToolUseContext,
16
- ToolResult,
17
- ToolProgress,
18
- build_tool_description,
19
- tool_input_examples,
16
+ from ripperdoc.core.config import ProviderType, provider_protocol
17
+ from ripperdoc.core.permissions import PermissionResult
18
+ from ripperdoc.core.query_utils import (
19
+ anthropic_usage_tokens,
20
+ build_anthropic_tool_schemas,
21
+ build_full_system_prompt,
22
+ build_openai_tool_schemas,
23
+ content_blocks_from_anthropic_response,
24
+ content_blocks_from_openai_choice,
25
+ determine_tool_mode,
26
+ extract_tool_use_blocks,
27
+ format_pydantic_errors,
28
+ log_openai_messages,
29
+ openai_usage_tokens,
30
+ resolve_model_profile,
31
+ text_mode_history,
32
+ tool_result_message,
20
33
  )
34
+ from ripperdoc.core.tool import Tool, ToolProgress, ToolResult, ToolUseContext
21
35
  from ripperdoc.utils.log import get_logger
22
36
  from ripperdoc.utils.messages import (
23
- MessageContent,
24
- UserMessage,
25
37
  AssistantMessage,
26
38
  ProgressMessage,
27
- create_user_message,
39
+ UserMessage,
28
40
  create_assistant_message,
29
41
  create_progress_message,
30
42
  normalize_messages_for_api,
31
43
  INTERRUPT_MESSAGE,
32
44
  INTERRUPT_MESSAGE_FOR_TOOL_USE,
33
45
  )
34
- from ripperdoc.core.permissions import PermissionResult
35
- from ripperdoc.core.config import get_global_config, ProviderType, provider_protocol
36
46
  from ripperdoc.utils.session_usage import record_usage
37
47
 
38
- import time
39
-
40
48
 
41
49
  logger = get_logger()
42
50
 
43
51
 
44
- def _safe_int(value: Any) -> int:
45
- """Best-effort int conversion for usage counters."""
46
- try:
47
- if value is None:
48
- return 0
49
- return int(value)
50
- except (TypeError, ValueError):
51
- return 0
52
-
53
-
54
- def _get_usage_field(usage: Any, field: str) -> int:
55
- """Fetch a usage field from either a dict or object."""
56
- if usage is None:
57
- return 0
58
- if isinstance(usage, dict):
59
- return _safe_int(usage.get(field))
60
- return _safe_int(getattr(usage, field, 0))
61
-
62
-
63
- def _anthropic_usage_tokens(usage: Any) -> Dict[str, int]:
64
- """Extract token counts from an Anthropic response usage payload."""
65
- return {
66
- "input_tokens": _get_usage_field(usage, "input_tokens"),
67
- "output_tokens": _get_usage_field(usage, "output_tokens"),
68
- "cache_read_input_tokens": _get_usage_field(usage, "cache_read_input_tokens"),
69
- "cache_creation_input_tokens": _get_usage_field(usage, "cache_creation_input_tokens"),
70
- }
71
-
72
-
73
- def _openai_usage_tokens(usage: Any) -> Dict[str, int]:
74
- """Extract token counts from an OpenAI-compatible response usage payload."""
75
- prompt_details = None
76
- if isinstance(usage, dict):
77
- prompt_details = usage.get("prompt_tokens_details")
78
- else:
79
- prompt_details = getattr(usage, "prompt_tokens_details", None)
52
+ def _resolve_tool(
53
+ tool_registry: "ToolRegistry", tool_name: str, tool_use_id: str
54
+ ) -> tuple[Optional[Tool[Any, Any]], Optional[UserMessage]]:
55
+ """Find a tool by name and return an error message if missing."""
56
+ tool = tool_registry.get(tool_name)
57
+ if tool:
58
+ tool_registry.activate_tools([tool_name])
59
+ return tool, None
60
+ return None, tool_result_message(
61
+ tool_use_id, f"Error: Tool '{tool_name}' not found", is_error=True
62
+ )
80
63
 
81
- cache_read_tokens = _get_usage_field(prompt_details, "cached_tokens") if prompt_details else 0
82
64
 
83
- return {
84
- "input_tokens": _get_usage_field(usage, "prompt_tokens"),
85
- "output_tokens": _get_usage_field(usage, "completion_tokens"),
86
- "cache_read_input_tokens": cache_read_tokens,
87
- "cache_creation_input_tokens": 0,
88
- }
65
+ async def _check_tool_permissions(
66
+ tool: Tool[Any, Any],
67
+ parsed_input: Any,
68
+ query_context: "QueryContext",
69
+ can_use_tool_fn: Optional[Any],
70
+ ) -> tuple[bool, Optional[str]]:
71
+ """Evaluate whether a tool call is allowed."""
72
+ try:
73
+ if can_use_tool_fn is not None:
74
+ decision = can_use_tool_fn(tool, parsed_input)
75
+ if inspect.isawaitable(decision):
76
+ decision = await decision
77
+ if isinstance(decision, PermissionResult):
78
+ return decision.result, decision.message
79
+ if isinstance(decision, dict) and "result" in decision:
80
+ return bool(decision.get("result")), decision.get("message")
81
+ if isinstance(decision, tuple) and len(decision) == 2:
82
+ return bool(decision[0]), decision[1]
83
+ return bool(decision), None
84
+
85
+ if query_context.safe_mode and tool.needs_permissions(parsed_input):
86
+ loop = asyncio.get_running_loop()
87
+ input_preview = (
88
+ parsed_input.model_dump()
89
+ if hasattr(parsed_input, "model_dump")
90
+ else str(parsed_input)
91
+ )
92
+ prompt = f"Allow tool '{tool.name}' with input {input_preview}? [y/N]: "
93
+ response = await loop.run_in_executor(None, lambda: input(prompt))
94
+ return response.strip().lower() in ("y", "yes"), None
95
+
96
+ return True, None
97
+ except Exception:
98
+ logger.exception(
99
+ f"Error checking permissions for tool '{tool.name}'",
100
+ extra={"tool": getattr(tool, "name", None)},
101
+ )
102
+ return False, None
89
103
 
90
104
 
91
105
  class ToolRegistry:
@@ -118,6 +132,10 @@ class ToolRegistry:
118
132
  try:
119
133
  deferred = tool.defer_loading()
120
134
  except Exception:
135
+ logger.exception(
136
+ "[tool_registry] Tool.defer_loading failed",
137
+ extra={"tool": getattr(tool, "name", None)},
138
+ )
121
139
  deferred = False
122
140
  if deferred:
123
141
  self._deferred.add(name)
@@ -234,43 +252,38 @@ async def query_llm(
234
252
  Returns:
235
253
  AssistantMessage with the model's response
236
254
  """
237
- config = get_global_config()
238
-
239
- # Get the model profile
240
- profile_name = getattr(config.model_pointers, model, None)
241
- if profile_name is None:
242
- profile_name = model
243
-
244
- model_profile = config.model_profiles.get(profile_name)
245
- if model_profile is None:
246
- fallback_profile = getattr(config.model_pointers, "main", "default")
247
- model_profile = config.model_profiles.get(fallback_profile)
248
-
249
- if not model_profile:
250
- raise ValueError(f"No model profile found for pointer: {model}")
255
+ model_profile = resolve_model_profile(model)
251
256
 
252
257
  # Normalize messages based on protocol family (Anthropic allows tool blocks; OpenAI-style prefers text-only)
253
258
  protocol = provider_protocol(model_profile.provider)
259
+ tool_mode = determine_tool_mode(model_profile)
260
+ messages_for_model: List[Union[UserMessage, AssistantMessage, ProgressMessage]]
261
+ if tool_mode == "text":
262
+ messages_for_model = cast(
263
+ List[Union[UserMessage, AssistantMessage, ProgressMessage]],
264
+ text_mode_history(messages),
265
+ )
266
+ else:
267
+ messages_for_model = messages
268
+
254
269
  normalized_messages = normalize_messages_for_api(
255
- messages,
256
- protocol=protocol,
270
+ messages_for_model, protocol=protocol, tool_mode=tool_mode
271
+ )
272
+ logger.info(
273
+ "[query_llm] Preparing model request",
274
+ extra={
275
+ "model_pointer": model,
276
+ "provider": getattr(model_profile.provider, "value", str(model_profile.provider)),
277
+ "model": model_profile.model,
278
+ "normalized_messages": len(normalized_messages),
279
+ "tool_count": len(tools),
280
+ "max_thinking_tokens": max_thinking_tokens,
281
+ "tool_mode": tool_mode,
282
+ },
257
283
  )
258
284
 
259
285
  if protocol == "openai":
260
- summary_parts = []
261
- for idx, m in enumerate(normalized_messages):
262
- role = m.get("role")
263
- tool_calls = m.get("tool_calls")
264
- tc_ids = []
265
- if tool_calls:
266
- tc_ids = [tc.get("id") for tc in tool_calls]
267
- tool_call_id = m.get("tool_call_id")
268
- summary_parts.append(
269
- f"{idx}:{role}"
270
- + (f" tool_calls={tc_ids}" if tc_ids else "")
271
- + (f" tool_call_id={tool_call_id}" if tool_call_id else "")
272
- )
273
- logger.debug(f"[query_llm] OpenAI normalized messages: {' | '.join(summary_parts)}")
286
+ log_openai_messages(normalized_messages)
274
287
 
275
288
  logger.debug(
276
289
  f"[query_llm] Sending {len(normalized_messages)} messages to model pointer "
@@ -285,81 +298,48 @@ async def query_llm(
285
298
  # Create the appropriate client based on provider
286
299
  if model_profile.provider == ProviderType.ANTHROPIC:
287
300
  async with AsyncAnthropic(api_key=model_profile.api_key) as client:
288
- # Build tool schemas
289
- tool_schemas = []
290
- for tool in tools:
291
- description = await build_tool_description(
292
- tool, include_examples=True, max_examples=2
293
- )
294
- tool_schema = {
295
- "name": tool.name,
296
- "description": description,
297
- "input_schema": tool.input_schema.model_json_schema(),
298
- "defer_loading": bool(getattr(tool, "defer_loading", lambda: False)()),
299
- }
300
- examples = tool_input_examples(tool, limit=5)
301
- if examples:
302
- tool_schema["input_examples"] = examples
303
- tool_schemas.append(tool_schema)
304
-
305
- # Make the API call
301
+ tool_schemas = await build_anthropic_tool_schemas(tools)
306
302
  response = await client.messages.create(
307
303
  model=model_profile.model,
308
304
  max_tokens=model_profile.max_tokens,
309
305
  system=system_prompt,
310
- messages=normalized_messages,
306
+ messages=normalized_messages, # type: ignore[arg-type]
311
307
  tools=tool_schemas if tool_schemas else None, # type: ignore
312
308
  temperature=model_profile.temperature,
313
309
  )
314
310
 
315
311
  duration_ms = (time.time() - start_time) * 1000
316
312
 
317
- usage_tokens = _anthropic_usage_tokens(getattr(response, "usage", None))
313
+ usage_tokens = anthropic_usage_tokens(getattr(response, "usage", None))
318
314
  record_usage(model_profile.model, duration_ms=duration_ms, **usage_tokens)
319
315
 
320
316
  # Calculate cost (simplified, should use actual pricing)
321
317
  cost_usd = 0.0 # TODO: Implement cost calculation
322
318
 
323
- # Convert response to our format
324
- content_blocks = []
325
- for block in response.content:
326
- if block.type == "text":
327
- content_blocks.append({"type": "text", "text": block.text})
328
- elif block.type == "tool_use":
329
- content_blocks.append(
330
- {
331
- "type": "tool_use",
332
- "tool_use_id": block.id,
333
- "name": block.name,
334
- "input": block.input,
335
- }
336
- )
319
+ content_blocks = content_blocks_from_anthropic_response(response, tool_mode)
320
+ tool_use_blocks = [
321
+ block for block in response.content if getattr(block, "type", None) == "tool_use"
322
+ ]
323
+ logger.info(
324
+ "[query_llm] Received response from Anthropic",
325
+ extra={
326
+ "model": model_profile.model,
327
+ "duration_ms": round(duration_ms, 2),
328
+ "usage_tokens": usage_tokens,
329
+ "tool_use_blocks": len(tool_use_blocks),
330
+ },
331
+ )
337
332
 
338
333
  return create_assistant_message(
339
- content=content_blocks, cost_usd=cost_usd, duration_ms=duration_ms
334
+ content=content_blocks,
335
+ cost_usd=cost_usd,
336
+ duration_ms=duration_ms,
340
337
  )
341
338
 
342
339
  elif model_profile.provider == ProviderType.OPENAI_COMPATIBLE:
343
340
  # OpenAI-compatible APIs (OpenAI, DeepSeek, Mistral, etc.)
344
- async with AsyncOpenAI(
345
- api_key=model_profile.api_key, base_url=model_profile.api_base
346
- ) as client:
347
- # Build tool schemas for OpenAI format
348
- openai_tools = []
349
- for tool in tools:
350
- description = await build_tool_description(
351
- tool, include_examples=True, max_examples=2
352
- )
353
- openai_tools.append(
354
- {
355
- "type": "function",
356
- "function": {
357
- "name": tool.name,
358
- "description": description,
359
- "parameters": tool.input_schema.model_json_schema(),
360
- },
361
- }
362
- )
341
+ async with AsyncOpenAI(api_key=model_profile.api_key, base_url=model_profile.api_base) as client:
342
+ openai_tools = await build_openai_tool_schemas(tools)
363
343
 
364
344
  # Prepare messages for OpenAI format
365
345
  openai_messages = [
@@ -367,38 +347,34 @@ async def query_llm(
367
347
  ] + normalized_messages
368
348
 
369
349
  # Make the API call
370
- response = await client.chat.completions.create(
350
+ openai_response: Any = await client.chat.completions.create(
371
351
  model=model_profile.model,
372
352
  messages=openai_messages,
373
- tools=openai_tools if openai_tools else None,
353
+ tools=openai_tools if openai_tools else None, # type: ignore[arg-type]
374
354
  temperature=model_profile.temperature,
375
355
  max_tokens=model_profile.max_tokens,
376
356
  )
377
357
 
378
358
  duration_ms = (time.time() - start_time) * 1000
379
- usage_tokens = _openai_usage_tokens(getattr(response, "usage", None))
359
+ usage_tokens = openai_usage_tokens(getattr(openai_response, "usage", None))
380
360
  record_usage(model_profile.model, duration_ms=duration_ms, **usage_tokens)
381
361
  cost_usd = 0.0 # TODO: Implement cost calculation
382
362
 
383
363
  # Convert OpenAI response to our format
384
364
  content_blocks = []
385
- choice = response.choices[0]
386
-
387
- if choice.message.content:
388
- content_blocks.append({"type": "text", "text": choice.message.content})
389
-
390
- if choice.message.tool_calls:
391
- for tool_call in choice.message.tool_calls:
392
- import json
365
+ choice = openai_response.choices[0]
366
+
367
+ logger.info(
368
+ "[query_llm] Received response from OpenAI-compatible provider",
369
+ extra={
370
+ "model": model_profile.model,
371
+ "duration_ms": round(duration_ms, 2),
372
+ "usage_tokens": usage_tokens,
373
+ "finish_reason": getattr(choice, "finish_reason", None),
374
+ },
375
+ )
393
376
 
394
- content_blocks.append(
395
- {
396
- "type": "tool_use",
397
- "tool_use_id": tool_call.id,
398
- "name": tool_call.function.name,
399
- "input": json.loads(tool_call.function.arguments),
400
- }
401
- )
377
+ content_blocks = content_blocks_from_openai_choice(choice, tool_mode)
402
378
 
403
379
  return create_assistant_message(
404
380
  content=content_blocks, cost_usd=cost_usd, duration_ms=duration_ms
@@ -411,7 +387,16 @@ async def query_llm(
411
387
 
412
388
  except Exception as e:
413
389
  # Return error message
414
- logger.error(f"Error querying AI model: {e}")
390
+ logger.exception(
391
+ "Error querying AI model",
392
+ extra={
393
+ "model": getattr(model_profile, "model", None),
394
+ "model_pointer": model,
395
+ "provider": getattr(model_profile.provider, "value", None)
396
+ if model_profile
397
+ else None,
398
+ },
399
+ )
415
400
  duration_ms = (time.time() - start_time) * 1000
416
401
  error_msg = create_assistant_message(
417
402
  content=f"Error querying AI model: {str(e)}", duration_ms=duration_ms
@@ -445,54 +430,38 @@ async def query(
445
430
  Yields:
446
431
  Messages (user, assistant, progress) as they are generated
447
432
  """
433
+ logger.info(
434
+ "[query] Starting query loop",
435
+ extra={
436
+ "message_count": len(messages),
437
+ "tool_count": len(query_context.tools),
438
+ "safe_mode": query_context.safe_mode,
439
+ "model_pointer": query_context.model,
440
+ },
441
+ )
448
442
  # Work on a copy so external mutations (e.g., UI appending messages while consuming)
449
443
  # do not interfere with recursion or normalization.
450
444
  messages = list(messages)
445
+ model_profile = resolve_model_profile(query_context.model)
446
+ tool_mode = determine_tool_mode(model_profile)
447
+ tools_for_model: List[Tool[Any, Any]] = [] if tool_mode == "text" else query_context.all_tools()
451
448
 
452
- async def _check_permissions(
453
- tool: Tool[Any, Any], parsed_input: Any
454
- ) -> tuple[bool, Optional[str]]:
455
- """Check permissions for tool execution."""
456
- try:
457
- if can_use_tool_fn is not None:
458
- decision = can_use_tool_fn(tool, parsed_input)
459
- if inspect.isawaitable(decision):
460
- decision = await decision
461
- if isinstance(decision, PermissionResult):
462
- return decision.result, decision.message
463
- if isinstance(decision, dict) and "result" in decision:
464
- return bool(decision.get("result")), decision.get("message")
465
- if isinstance(decision, tuple) and len(decision) == 2:
466
- return bool(decision[0]), decision[1]
467
- return bool(decision), None
468
-
469
- if query_context.safe_mode and tool.needs_permissions(parsed_input):
470
- loop = asyncio.get_running_loop()
471
- input_preview = (
472
- parsed_input.model_dump()
473
- if hasattr(parsed_input, "model_dump")
474
- else str(parsed_input)
475
- )
476
- prompt = f"Allow tool '{tool.name}' with input {input_preview}? [y/N]: "
477
- response = await loop.run_in_executor(None, lambda: input(prompt))
478
- return response.strip().lower() in ("y", "yes"), None
479
-
480
- return True, None
481
- except Exception as exc:
482
- # Fail closed on any errors
483
- logger.error(f"Error checking permissions for tool '{tool.name}': {exc}")
484
- return False, None
485
-
486
- # Build full system prompt with context
487
- full_system_prompt = system_prompt
488
- if context:
489
- context_str = "\n".join(f"{k}: {v}" for k, v in context.items())
490
- full_system_prompt = f"{system_prompt}\n\nContext:\n{context_str}"
449
+ full_system_prompt = build_full_system_prompt(
450
+ system_prompt, context, tool_mode, query_context.all_tools()
451
+ )
452
+ logger.debug(
453
+ "[query] Built system prompt",
454
+ extra={
455
+ "prompt_chars": len(full_system_prompt),
456
+ "context_entries": len(context),
457
+ "tool_count": len(tools_for_model),
458
+ },
459
+ )
491
460
 
492
461
  assistant_message = await query_llm(
493
462
  messages,
494
463
  full_system_prompt,
495
- query_context.all_tools(),
464
+ tools_for_model,
496
465
  query_context.max_thinking_tokens,
497
466
  query_context.model,
498
467
  query_context.abort_controller,
@@ -505,173 +474,140 @@ async def query(
505
474
 
506
475
  yield assistant_message
507
476
 
508
- tool_block_count = 0
509
- if isinstance(assistant_message.message.content, list):
510
- tool_block_count = sum(
511
- 1
512
- for block in assistant_message.message.content
513
- if hasattr(block, "type") and block.type == "tool_use"
514
- )
477
+ tool_use_blocks = extract_tool_use_blocks(assistant_message)
478
+ text_blocks = (
479
+ len(assistant_message.message.content)
480
+ if isinstance(assistant_message.message.content, list)
481
+ else 1
482
+ )
515
483
  logger.debug(
516
- f"[query] Assistant message received: "
517
- f"text_blocks={len(assistant_message.message.content) if isinstance(assistant_message.message.content, list) else 1}, "
518
- f"tool_use_blocks={tool_block_count}"
484
+ f"[query] Assistant message received: text_blocks={text_blocks}, "
485
+ f"tool_use_blocks={len(tool_use_blocks)}"
519
486
  )
520
487
 
521
- # Check for tool use
522
- tool_use_blocks = []
523
- if isinstance(assistant_message.message.content, list):
524
- for block in assistant_message.message.content:
525
- normalized_block = MessageContent(**block) if isinstance(block, dict) else block
526
- if hasattr(normalized_block, "type") and normalized_block.type == "tool_use":
527
- tool_use_blocks.append(normalized_block)
528
-
529
- # If no tool use, we're done
530
488
  if not tool_use_blocks:
531
489
  logger.debug("[query] No tool_use blocks; returning response to user.")
532
490
  return
533
491
 
534
- # Execute tools
535
- tool_results: List[UserMessage] = []
536
-
537
492
  logger.debug(f"[query] Executing {len(tool_use_blocks)} tool_use block(s).")
493
+ tool_results: List[UserMessage] = []
494
+ permission_denied = False
495
+ sibling_ids = set(
496
+ getattr(t, "tool_use_id", None) or getattr(t, "id", None) or "" for t in tool_use_blocks
497
+ )
538
498
 
539
499
  for tool_use in tool_use_blocks:
540
500
  tool_name = tool_use.name
541
- tool_id = getattr(tool_use, "tool_use_id", None) or getattr(tool_use, "id", None) or ""
501
+ if not tool_name:
502
+ continue
503
+ tool_use_id = getattr(tool_use, "tool_use_id", None) or getattr(tool_use, "id", None) or ""
542
504
  tool_input = getattr(tool_use, "input", {}) or {}
543
505
 
544
- # Find the tool
545
- tool = query_context.tool_registry.get(tool_name)
546
- # Auto-activate when used so subsequent rounds list it as active.
547
- if tool:
548
- query_context.activate_tools([tool_name])
549
-
550
- if not tool:
551
- # Tool not found
552
- logger.warning(f"[query] Tool '{tool_name}' not found for tool_use_id={tool_id}")
553
- result_msg = create_user_message(
554
- [
555
- {
556
- "type": "tool_result",
557
- "tool_use_id": tool_id,
558
- "text": f"Error: Tool '{tool_name}' not found",
559
- "is_error": True,
560
- }
561
- ]
562
- )
563
- tool_results.append(result_msg)
564
- yield result_msg
506
+ tool, missing_msg = _resolve_tool(query_context.tool_registry, tool_name, tool_use_id)
507
+ if missing_msg:
508
+ logger.warning(f"[query] Tool '{tool_name}' not found for tool_use_id={tool_use_id}")
509
+ tool_results.append(missing_msg)
510
+ yield missing_msg
565
511
  continue
512
+ assert tool is not None
566
513
 
567
- # Execute the tool
568
514
  tool_context = ToolUseContext(
569
515
  safe_mode=query_context.safe_mode,
570
516
  verbose=query_context.verbose,
571
517
  permission_checker=can_use_tool_fn,
572
518
  tool_registry=query_context.tool_registry,
519
+ abort_signal=query_context.abort_controller,
573
520
  )
574
521
 
575
522
  try:
576
- # Parse input using tool's schema
577
523
  parsed_input = tool.input_schema(**tool_input)
578
524
  logger.debug(
579
- f"[query] tool_use_id={tool_id} name={tool_name} parsed_input="
525
+ f"[query] tool_use_id={tool_use_id} name={tool_name} parsed_input="
580
526
  f"{str(parsed_input)[:500]}"
581
527
  )
582
528
 
583
- # Validate input before execution
584
529
  validation = await tool.validate_input(parsed_input, tool_context)
585
530
  if not validation.result:
586
531
  logger.debug(
587
- f"[query] Validation failed for tool_use_id={tool_id}: {validation.message}"
532
+ f"[query] Validation failed for tool_use_id={tool_use_id}: {validation.message}"
588
533
  )
589
- result_msg = create_user_message(
590
- [
591
- {
592
- "type": "tool_result",
593
- "tool_use_id": tool_id,
594
- "text": validation.message or "Tool input validation failed.",
595
- "is_error": True,
596
- }
597
- ]
534
+ result_msg = tool_result_message(
535
+ tool_use_id,
536
+ validation.message or "Tool input validation failed.",
537
+ is_error=True,
598
538
  )
599
539
  tool_results.append(result_msg)
600
540
  yield result_msg
601
541
  continue
602
542
 
603
- # Permission check (safe mode or custom checker)
604
543
  if query_context.safe_mode or can_use_tool_fn is not None:
605
- allowed, denial_message = await _check_permissions(tool, parsed_input)
544
+ allowed, denial_message = await _check_tool_permissions(
545
+ tool, parsed_input, query_context, can_use_tool_fn
546
+ )
606
547
  if not allowed:
607
548
  logger.debug(
608
- f"[query] Permission denied for tool_use_id={tool_id}: {denial_message}"
549
+ f"[query] Permission denied for tool_use_id={tool_use_id}: {denial_message}"
609
550
  )
610
- denial_text = denial_message or f"Permission denied for tool '{tool_name}'."
611
- result_msg = create_user_message(
612
- [
613
- {
614
- "type": "tool_result",
615
- "tool_use_id": tool_id,
616
- "text": denial_text,
617
- "is_error": True,
618
- }
619
- ]
620
- )
621
- tool_results.append(result_msg)
622
- yield result_msg
623
- continue
551
+ denial_text = denial_message or f"User aborted the tool invocation: {tool_name}"
552
+ denial_msg = tool_result_message(tool_use_id, denial_text, is_error=True)
553
+ tool_results.append(denial_msg)
554
+ yield denial_msg
555
+ permission_denied = True
556
+ break
624
557
 
625
- # Execute tool
626
558
  async for output in tool.call(parsed_input, tool_context):
627
559
  if isinstance(output, ToolProgress):
628
- # Yield progress
629
560
  progress = create_progress_message(
630
- tool_use_id=tool_id,
631
- sibling_tool_use_ids=set(
632
- getattr(t, "tool_use_id", None) or getattr(t, "id", None) or ""
633
- for t in tool_use_blocks
634
- ),
561
+ tool_use_id=tool_use_id,
562
+ sibling_tool_use_ids=sibling_ids,
635
563
  content=output.content,
636
564
  )
637
565
  yield progress
638
- logger.debug(f"[query] Progress from tool_use_id={tool_id}: {output.content}")
566
+ logger.debug(f"[query] Progress from tool_use_id={tool_use_id}: {output.content}")
639
567
  elif isinstance(output, ToolResult):
640
- # Tool completed
641
568
  result_content = output.result_for_assistant or str(output.data)
642
- result_msg = create_user_message(
643
- [{"type": "tool_result", "tool_use_id": tool_id, "text": result_content}],
644
- tool_use_result=output.data,
569
+ result_msg = tool_result_message(
570
+ tool_use_id, result_content, tool_use_result=output.data
645
571
  )
646
572
  tool_results.append(result_msg)
647
573
  yield result_msg
648
574
  logger.debug(
649
- f"[query] Tool completed tool_use_id={tool_id} name={tool_name} "
575
+ f"[query] Tool completed tool_use_id={tool_use_id} name={tool_name} "
650
576
  f"result_len={len(result_content)}"
651
577
  )
652
578
 
579
+ except ValidationError as ve:
580
+ detail_text = format_pydantic_errors(ve)
581
+ error_msg = tool_result_message(
582
+ tool_use_id,
583
+ f"Invalid input for tool '{tool_name}': {detail_text}",
584
+ is_error=True,
585
+ )
586
+ tool_results.append(error_msg)
587
+ yield error_msg
588
+ continue
653
589
  except Exception as e:
654
- # Tool execution failed
655
- logger.error(f"Error executing tool '{tool_name}': {e}")
656
- error_msg = create_user_message(
657
- [
658
- {
659
- "type": "tool_result",
660
- "tool_use_id": tool_id,
661
- "text": f"Error executing tool: {str(e)}",
662
- "is_error": True,
663
- }
664
- ]
590
+ logger.exception(
591
+ f"Error executing tool '{tool_name}'",
592
+ extra={"tool": tool_name, "tool_use_id": tool_use_id},
593
+ )
594
+ error_msg = tool_result_message(
595
+ tool_use_id, f"Error executing tool: {str(e)}", is_error=True
665
596
  )
666
597
  tool_results.append(error_msg)
667
598
  yield error_msg
668
599
 
600
+ if permission_denied:
601
+ break
602
+
669
603
  # Check for abort after tools
670
604
  if query_context.abort_controller.is_set():
671
605
  yield create_assistant_message(INTERRUPT_MESSAGE_FOR_TOOL_USE)
672
606
  return
673
607
 
674
- # Continue conversation with tool results
608
+ if permission_denied:
609
+ return
610
+
675
611
  new_messages = messages + [assistant_message] + tool_results
676
612
  logger.debug(
677
613
  f"[query] Recursing with {len(new_messages)} messages after tools; "