ripperdoc 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. ripperdoc/__init__.py +1 -1
  2. ripperdoc/cli/cli.py +9 -2
  3. ripperdoc/cli/commands/agents_cmd.py +8 -4
  4. ripperdoc/cli/commands/cost_cmd.py +5 -0
  5. ripperdoc/cli/commands/doctor_cmd.py +12 -4
  6. ripperdoc/cli/commands/memory_cmd.py +6 -13
  7. ripperdoc/cli/commands/models_cmd.py +36 -6
  8. ripperdoc/cli/commands/resume_cmd.py +4 -2
  9. ripperdoc/cli/commands/status_cmd.py +1 -1
  10. ripperdoc/cli/ui/rich_ui.py +102 -2
  11. ripperdoc/cli/ui/thinking_spinner.py +128 -0
  12. ripperdoc/core/agents.py +13 -5
  13. ripperdoc/core/config.py +9 -1
  14. ripperdoc/core/providers/__init__.py +31 -0
  15. ripperdoc/core/providers/anthropic.py +136 -0
  16. ripperdoc/core/providers/base.py +187 -0
  17. ripperdoc/core/providers/gemini.py +172 -0
  18. ripperdoc/core/providers/openai.py +142 -0
  19. ripperdoc/core/query.py +331 -141
  20. ripperdoc/core/query_utils.py +64 -23
  21. ripperdoc/core/tool.py +5 -3
  22. ripperdoc/sdk/client.py +12 -1
  23. ripperdoc/tools/background_shell.py +54 -18
  24. ripperdoc/tools/bash_tool.py +33 -13
  25. ripperdoc/tools/file_edit_tool.py +13 -0
  26. ripperdoc/tools/file_read_tool.py +16 -0
  27. ripperdoc/tools/file_write_tool.py +13 -0
  28. ripperdoc/tools/glob_tool.py +5 -1
  29. ripperdoc/tools/ls_tool.py +14 -10
  30. ripperdoc/tools/multi_edit_tool.py +12 -0
  31. ripperdoc/tools/notebook_edit_tool.py +12 -0
  32. ripperdoc/tools/todo_tool.py +1 -3
  33. ripperdoc/tools/tool_search_tool.py +8 -4
  34. ripperdoc/utils/file_watch.py +134 -0
  35. ripperdoc/utils/git_utils.py +36 -38
  36. ripperdoc/utils/json_utils.py +1 -2
  37. ripperdoc/utils/log.py +3 -4
  38. ripperdoc/utils/memory.py +1 -3
  39. ripperdoc/utils/message_compaction.py +2 -6
  40. ripperdoc/utils/messages.py +9 -13
  41. ripperdoc/utils/output_utils.py +1 -3
  42. ripperdoc/utils/prompt.py +17 -0
  43. ripperdoc/utils/session_usage.py +7 -0
  44. ripperdoc/utils/shell_utils.py +159 -0
  45. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.3.dist-info}/METADATA +1 -1
  46. ripperdoc-0.2.3.dist-info/RECORD +95 -0
  47. ripperdoc-0.2.2.dist-info/RECORD +0 -86
  48. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.3.dist-info}/WHEEL +0 -0
  49. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.3.dist-info}/entry_points.txt +0 -0
  50. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.3.dist-info}/licenses/LICENSE +0 -0
  51. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.3.dist-info}/top_level.txt +0 -0
ripperdoc/core/query.py CHANGED
@@ -6,48 +6,60 @@ the query-response loop including tool execution.
6
6
 
7
7
  import asyncio
8
8
  import inspect
9
+ import os
9
10
  import time
10
- from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Tuple, Union, cast
11
+ from asyncio import CancelledError
12
+ from typing import (
13
+ Any,
14
+ AsyncGenerator,
15
+ Awaitable,
16
+ Callable,
17
+ Dict,
18
+ Iterable,
19
+ List,
20
+ Optional,
21
+ Tuple,
22
+ Union,
23
+ cast,
24
+ )
11
25
 
12
- from anthropic import AsyncAnthropic
13
- from openai import AsyncOpenAI
14
26
  from pydantic import ValidationError
15
27
 
16
- from ripperdoc.core.config import ProviderType, provider_protocol
28
+ from ripperdoc.core.config import provider_protocol
29
+ from ripperdoc.core.providers import ProviderClient, get_provider_client
17
30
  from ripperdoc.core.permissions import PermissionResult
18
31
  from ripperdoc.core.query_utils import (
19
- anthropic_usage_tokens,
20
- build_anthropic_tool_schemas,
21
32
  build_full_system_prompt,
22
- build_openai_tool_schemas,
23
- content_blocks_from_anthropic_response,
24
- content_blocks_from_openai_choice,
25
33
  determine_tool_mode,
26
34
  extract_tool_use_blocks,
27
35
  format_pydantic_errors,
28
36
  log_openai_messages,
29
- openai_usage_tokens,
30
37
  resolve_model_profile,
31
38
  text_mode_history,
32
39
  tool_result_message,
33
40
  )
34
41
  from ripperdoc.core.tool import Tool, ToolProgress, ToolResult, ToolUseContext
42
+ from ripperdoc.utils.file_watch import ChangedFileNotice, FileSnapshot, detect_changed_files
35
43
  from ripperdoc.utils.log import get_logger
36
44
  from ripperdoc.utils.messages import (
37
45
  AssistantMessage,
46
+ MessageContent,
38
47
  ProgressMessage,
39
48
  UserMessage,
40
49
  create_assistant_message,
50
+ create_user_message,
41
51
  create_progress_message,
42
52
  normalize_messages_for_api,
43
53
  INTERRUPT_MESSAGE,
44
54
  INTERRUPT_MESSAGE_FOR_TOOL_USE,
45
55
  )
46
- from ripperdoc.utils.session_usage import record_usage
47
56
 
48
57
 
49
58
  logger = get_logger()
50
59
 
60
+ DEFAULT_REQUEST_TIMEOUT_SEC = float(os.getenv("RIPPERDOC_API_TIMEOUT", "120"))
61
+ MAX_LLM_RETRIES = 1
62
+
51
63
 
52
64
  def _resolve_tool(
53
65
  tool_registry: "ToolRegistry", tool_name: str, tool_use_id: str
@@ -62,11 +74,23 @@ def _resolve_tool(
62
74
  )
63
75
 
64
76
 
77
+ ToolPermissionCallable = Callable[
78
+ [Tool[Any, Any], Any],
79
+ Union[
80
+ PermissionResult,
81
+ Dict[str, Any],
82
+ Tuple[bool, Optional[str]],
83
+ bool,
84
+ Awaitable[Union[PermissionResult, Dict[str, Any], Tuple[bool, Optional[str]], bool]],
85
+ ],
86
+ ]
87
+
88
+
65
89
  async def _check_tool_permissions(
66
90
  tool: Tool[Any, Any],
67
91
  parsed_input: Any,
68
92
  query_context: "QueryContext",
69
- can_use_tool_fn: Optional[Any],
93
+ can_use_tool_fn: Optional[ToolPermissionCallable],
70
94
  ) -> tuple[bool, Optional[str]]:
71
95
  """Evaluate whether a tool call is allowed."""
72
96
  try:
@@ -102,6 +126,155 @@ async def _check_tool_permissions(
102
126
  return False, None
103
127
 
104
128
 
129
+ def _format_changed_file_notice(notices: List[ChangedFileNotice]) -> str:
130
+ """Render a system notice about files that changed on disk."""
131
+ lines: List[str] = [
132
+ "System notice: Files you previously read have changed on disk.",
133
+ "Please re-read the affected files before making further edits.",
134
+ "",
135
+ ]
136
+ for notice in notices:
137
+ lines.append(f"- {notice.file_path}")
138
+ summary = (notice.summary or "").rstrip()
139
+ if summary:
140
+ indented = "\n".join(f" {line}" for line in summary.splitlines())
141
+ lines.append(indented)
142
+ return "\n".join(lines)
143
+
144
+
145
+ async def _run_tool_use_generator(
146
+ tool: Tool[Any, Any],
147
+ tool_use_id: str,
148
+ tool_name: str,
149
+ parsed_input: Any,
150
+ sibling_ids: set[str],
151
+ tool_context: ToolUseContext,
152
+ ) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
153
+ """Execute a single tool_use and yield progress/results."""
154
+ try:
155
+ async for output in tool.call(parsed_input, tool_context):
156
+ if isinstance(output, ToolProgress):
157
+ yield create_progress_message(
158
+ tool_use_id=tool_use_id,
159
+ sibling_tool_use_ids=sibling_ids,
160
+ content=output.content,
161
+ )
162
+ logger.debug(f"[query] Progress from tool_use_id={tool_use_id}: {output.content}")
163
+ elif isinstance(output, ToolResult):
164
+ result_content = output.result_for_assistant or str(output.data)
165
+ result_msg = tool_result_message(
166
+ tool_use_id, result_content, tool_use_result=output.data
167
+ )
168
+ yield result_msg
169
+ logger.debug(
170
+ f"[query] Tool completed tool_use_id={tool_use_id} name={tool_name} "
171
+ f"result_len={len(result_content)}"
172
+ )
173
+ except Exception as exc:
174
+ logger.exception(
175
+ f"Error executing tool '{tool_name}'",
176
+ extra={"tool": tool_name, "tool_use_id": tool_use_id},
177
+ )
178
+ yield tool_result_message(tool_use_id, f"Error executing tool: {str(exc)}", is_error=True)
179
+
180
+
181
+ def _group_tool_calls_by_concurrency(prepared_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
182
+ """Group consecutive tool calls by their concurrency safety."""
183
+ groups: List[Dict[str, Any]] = []
184
+ for call in prepared_calls:
185
+ is_safe = bool(call.get("is_concurrency_safe"))
186
+ if groups and groups[-1]["is_concurrency_safe"] == is_safe:
187
+ groups[-1]["items"].append(call)
188
+ else:
189
+ groups.append({"is_concurrency_safe": is_safe, "items": [call]})
190
+ return groups
191
+
192
+
193
+ async def _execute_tools_sequentially(
194
+ items: List[Dict[str, Any]], tool_results: List[UserMessage]
195
+ ) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
196
+ """Run tool generators one by one."""
197
+ for item in items:
198
+ gen = item.get("generator")
199
+ if not gen:
200
+ continue
201
+ async for message in gen:
202
+ if isinstance(message, UserMessage):
203
+ tool_results.append(message)
204
+ yield message
205
+
206
+
207
+ async def _execute_tools_in_parallel(
208
+ items: List[Dict[str, Any]], tool_results: List[UserMessage]
209
+ ) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
210
+ """Run tool generators concurrently."""
211
+ generators = [call["generator"] for call in items if call.get("generator")]
212
+ async for message in _run_concurrent_tool_uses(generators, tool_results):
213
+ yield message
214
+
215
+
216
+ async def _run_tools_concurrently(
217
+ prepared_calls: List[Dict[str, Any]], tool_results: List[UserMessage]
218
+ ) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
219
+ """Run tools grouped by concurrency safety (parallel for safe groups)."""
220
+ for group in _group_tool_calls_by_concurrency(prepared_calls):
221
+ if group["is_concurrency_safe"]:
222
+ logger.debug(
223
+ f"[query] Executing {len(group['items'])} concurrency-safe tool(s) in parallel"
224
+ )
225
+ async for message in _execute_tools_in_parallel(group["items"], tool_results):
226
+ yield message
227
+ else:
228
+ logger.debug(
229
+ f"[query] Executing {len(group['items'])} tool(s) sequentially (not concurrency safe)"
230
+ )
231
+ async for message in _run_tools_serially(group["items"], tool_results):
232
+ yield message
233
+
234
+
235
+ async def _run_tools_serially(
236
+ prepared_calls: List[Dict[str, Any]], tool_results: List[UserMessage]
237
+ ) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
238
+ """Run all tools sequentially (helper for clarity)."""
239
+ async for message in _execute_tools_sequentially(prepared_calls, tool_results):
240
+ yield message
241
+
242
+
243
+ async def _run_concurrent_tool_uses(
244
+ generators: List[AsyncGenerator[Union[UserMessage, ProgressMessage], None]],
245
+ tool_results: List[UserMessage],
246
+ ) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
247
+ """Drain multiple tool generators concurrently and stream outputs."""
248
+ if not generators:
249
+ return
250
+
251
+ queue: asyncio.Queue[Optional[Union[UserMessage, ProgressMessage]]] = asyncio.Queue()
252
+
253
+ async def _consume(gen: AsyncGenerator[Union[UserMessage, ProgressMessage], None]) -> None:
254
+ try:
255
+ async for message in gen:
256
+ await queue.put(message)
257
+ except Exception:
258
+ logger.exception("[query] Unexpected error while consuming tool generator")
259
+ finally:
260
+ await queue.put(None)
261
+
262
+ tasks = [asyncio.create_task(_consume(gen)) for gen in generators]
263
+ active = len(tasks)
264
+
265
+ try:
266
+ while active:
267
+ message = await queue.get()
268
+ if message is None:
269
+ active -= 1
270
+ continue
271
+ if isinstance(message, UserMessage):
272
+ tool_results.append(message)
273
+ yield message
274
+ finally:
275
+ await asyncio.gather(*tasks, return_exceptions=True)
276
+
277
+
105
278
  class ToolRegistry:
106
279
  """Track available tools, including deferred ones, and expose search/activation helpers."""
107
280
 
@@ -211,6 +384,7 @@ class QueryContext:
211
384
  self.model = model
212
385
  self.verbose = verbose
213
386
  self.abort_controller = asyncio.Event()
387
+ self.file_state_cache: Dict[str, FileSnapshot] = {}
214
388
 
215
389
  @property
216
390
  def tools(self) -> List[Tool[Any, Any]]:
@@ -238,6 +412,11 @@ async def query_llm(
238
412
  max_thinking_tokens: int = 0,
239
413
  model: str = "main",
240
414
  abort_signal: Optional[asyncio.Event] = None,
415
+ *,
416
+ progress_callback: Optional[Callable[[str], Awaitable[None]]] = None,
417
+ request_timeout: Optional[float] = None,
418
+ max_retries: int = MAX_LLM_RETRIES,
419
+ stream: bool = True,
241
420
  ) -> AssistantMessage:
242
421
  """Query the AI model and return the response.
243
422
 
@@ -248,10 +427,16 @@ async def query_llm(
248
427
  max_thinking_tokens: Maximum tokens for thinking (0 = disabled)
249
428
  model: Model pointer to use
250
429
  abort_signal: Event to signal abortion
430
+ progress_callback: Optional async callback invoked with streamed text chunks
431
+ request_timeout: Max seconds to wait for a provider response before retrying
432
+ max_retries: Number of retries on timeout/errors (total attempts = retries + 1)
433
+ stream: Enable streaming for providers that support it (text-only mode)
251
434
 
252
435
  Returns:
253
436
  AssistantMessage with the model's response
254
437
  """
438
+ request_timeout = request_timeout or DEFAULT_REQUEST_TIMEOUT_SEC
439
+ request_timeout = request_timeout or DEFAULT_REQUEST_TIMEOUT_SEC
255
440
  model_profile = resolve_model_profile(model)
256
441
 
257
442
  # Normalize messages based on protocol family (Anthropic allows tool blocks; OpenAI-style prefers text-only)
@@ -266,7 +451,7 @@ async def query_llm(
266
451
  else:
267
452
  messages_for_model = messages
268
453
 
269
- normalized_messages = normalize_messages_for_api(
454
+ normalized_messages: List[Dict[str, Any]] = normalize_messages_for_api(
270
455
  messages_for_model, protocol=protocol, tool_mode=tool_mode
271
456
  )
272
457
  logger.info(
@@ -295,95 +480,36 @@ async def query_llm(
295
480
  start_time = time.time()
296
481
 
297
482
  try:
298
- # Create the appropriate client based on provider
299
- if model_profile.provider == ProviderType.ANTHROPIC:
300
- async with AsyncAnthropic(api_key=model_profile.api_key) as client:
301
- tool_schemas = await build_anthropic_tool_schemas(tools)
302
- response = await client.messages.create(
303
- model=model_profile.model,
304
- max_tokens=model_profile.max_tokens,
305
- system=system_prompt,
306
- messages=normalized_messages, # type: ignore[arg-type]
307
- tools=tool_schemas if tool_schemas else None, # type: ignore
308
- temperature=model_profile.temperature,
309
- )
310
-
311
- duration_ms = (time.time() - start_time) * 1000
312
-
313
- usage_tokens = anthropic_usage_tokens(getattr(response, "usage", None))
314
- record_usage(model_profile.model, duration_ms=duration_ms, **usage_tokens)
315
-
316
- # Calculate cost (simplified, should use actual pricing)
317
- cost_usd = 0.0 # TODO: Implement cost calculation
318
-
319
- content_blocks = content_blocks_from_anthropic_response(response, tool_mode)
320
- tool_use_blocks = [
321
- block for block in response.content if getattr(block, "type", None) == "tool_use"
322
- ]
323
- logger.info(
324
- "[query_llm] Received response from Anthropic",
325
- extra={
326
- "model": model_profile.model,
327
- "duration_ms": round(duration_ms, 2),
328
- "usage_tokens": usage_tokens,
329
- "tool_use_blocks": len(tool_use_blocks),
330
- },
331
- )
332
-
333
- return create_assistant_message(
334
- content=content_blocks,
335
- cost_usd=cost_usd,
336
- duration_ms=duration_ms,
337
- )
338
-
339
- elif model_profile.provider == ProviderType.OPENAI_COMPATIBLE:
340
- # OpenAI-compatible APIs (OpenAI, DeepSeek, Mistral, etc.)
341
- async with AsyncOpenAI(api_key=model_profile.api_key, base_url=model_profile.api_base) as client:
342
- openai_tools = await build_openai_tool_schemas(tools)
343
-
344
- # Prepare messages for OpenAI format
345
- openai_messages = [
346
- {"role": "system", "content": system_prompt}
347
- ] + normalized_messages
348
-
349
- # Make the API call
350
- openai_response: Any = await client.chat.completions.create(
351
- model=model_profile.model,
352
- messages=openai_messages,
353
- tools=openai_tools if openai_tools else None, # type: ignore[arg-type]
354
- temperature=model_profile.temperature,
355
- max_tokens=model_profile.max_tokens,
356
- )
357
-
358
- duration_ms = (time.time() - start_time) * 1000
359
- usage_tokens = openai_usage_tokens(getattr(openai_response, "usage", None))
360
- record_usage(model_profile.model, duration_ms=duration_ms, **usage_tokens)
361
- cost_usd = 0.0 # TODO: Implement cost calculation
362
-
363
- # Convert OpenAI response to our format
364
- content_blocks = []
365
- choice = openai_response.choices[0]
366
-
367
- logger.info(
368
- "[query_llm] Received response from OpenAI-compatible provider",
369
- extra={
370
- "model": model_profile.model,
371
- "duration_ms": round(duration_ms, 2),
372
- "usage_tokens": usage_tokens,
373
- "finish_reason": getattr(choice, "finish_reason", None),
374
- },
375
- )
376
-
377
- content_blocks = content_blocks_from_openai_choice(choice, tool_mode)
378
-
379
- return create_assistant_message(
380
- content=content_blocks, cost_usd=cost_usd, duration_ms=duration_ms
381
- )
483
+ client: Optional[ProviderClient] = get_provider_client(model_profile.provider)
484
+ if client is None:
485
+ duration_ms = (time.time() - start_time) * 1000
486
+ error_msg = create_assistant_message(
487
+ content=(
488
+ "Gemini protocol is not supported yet in Ripperdoc. "
489
+ "Please configure an Anthropic or OpenAI-compatible model."
490
+ ),
491
+ duration_ms=duration_ms,
492
+ )
493
+ error_msg.is_api_error_message = True
494
+ return error_msg
495
+
496
+ provider_response = await client.call(
497
+ model_profile=model_profile,
498
+ system_prompt=system_prompt,
499
+ normalized_messages=normalized_messages,
500
+ tools=tools,
501
+ tool_mode=tool_mode,
502
+ stream=stream,
503
+ progress_callback=progress_callback,
504
+ request_timeout=request_timeout,
505
+ max_retries=max_retries,
506
+ )
382
507
 
383
- elif model_profile.provider == ProviderType.GEMINI:
384
- raise NotImplementedError("Gemini protocol is not yet supported.")
385
- else:
386
- raise NotImplementedError(f"Provider {model_profile.provider} not yet implemented")
508
+ return create_assistant_message(
509
+ content=provider_response.content_blocks,
510
+ cost_usd=provider_response.cost_usd,
511
+ duration_ms=provider_response.duration_ms,
512
+ )
387
513
 
388
514
  except Exception as e:
389
515
  # Return error message
@@ -392,9 +518,9 @@ async def query_llm(
392
518
  extra={
393
519
  "model": getattr(model_profile, "model", None),
394
520
  "model_pointer": model,
395
- "provider": getattr(model_profile.provider, "value", None)
396
- if model_profile
397
- else None,
521
+ "provider": (
522
+ getattr(model_profile.provider, "value", None) if model_profile else None
523
+ ),
398
524
  },
399
525
  )
400
526
  duration_ms = (time.time() - start_time) * 1000
@@ -410,7 +536,7 @@ async def query(
410
536
  system_prompt: str,
411
537
  context: Dict[str, str],
412
538
  query_context: QueryContext,
413
- can_use_tool_fn: Optional[Any] = None,
539
+ can_use_tool_fn: Optional[ToolPermissionCallable] = None,
414
540
  ) -> AsyncGenerator[Union[UserMessage, AssistantMessage, ProgressMessage], None]:
415
541
  """Execute a query with tool support.
416
542
 
@@ -442,6 +568,9 @@ async def query(
442
568
  # Work on a copy so external mutations (e.g., UI appending messages while consuming)
443
569
  # do not interfere with recursion or normalization.
444
570
  messages = list(messages)
571
+ change_notices = detect_changed_files(query_context.file_state_cache)
572
+ if change_notices:
573
+ messages.append(create_user_message(_format_changed_file_notice(change_notices)))
445
574
  model_profile = resolve_model_profile(query_context.model)
446
575
  tool_mode = determine_tool_mode(model_profile)
447
576
  tools_for_model: List[Tool[Any, Any]] = [] if tool_mode == "text" else query_context.all_tools()
@@ -458,15 +587,74 @@ async def query(
458
587
  },
459
588
  )
460
589
 
461
- assistant_message = await query_llm(
462
- messages,
463
- full_system_prompt,
464
- tools_for_model,
465
- query_context.max_thinking_tokens,
466
- query_context.model,
467
- query_context.abort_controller,
590
+ progress_queue: asyncio.Queue[Optional[ProgressMessage]] = asyncio.Queue()
591
+
592
+ async def _stream_progress(chunk: str) -> None:
593
+ if not chunk:
594
+ return
595
+ try:
596
+ await progress_queue.put(
597
+ create_progress_message(
598
+ tool_use_id="stream",
599
+ sibling_tool_use_ids=set(),
600
+ content=chunk,
601
+ )
602
+ )
603
+ except Exception:
604
+ logger.exception("[query] Failed to enqueue stream progress chunk")
605
+
606
+ assistant_task = asyncio.create_task(
607
+ query_llm(
608
+ messages,
609
+ full_system_prompt,
610
+ tools_for_model,
611
+ query_context.max_thinking_tokens,
612
+ query_context.model,
613
+ query_context.abort_controller,
614
+ progress_callback=_stream_progress,
615
+ request_timeout=DEFAULT_REQUEST_TIMEOUT_SEC,
616
+ max_retries=MAX_LLM_RETRIES,
617
+ stream=True,
618
+ )
468
619
  )
469
620
 
621
+ assistant_message: Optional[AssistantMessage] = None
622
+
623
+ while True:
624
+ if query_context.abort_controller.is_set():
625
+ assistant_task.cancel()
626
+ try:
627
+ await assistant_task
628
+ except CancelledError:
629
+ pass
630
+ yield create_assistant_message(INTERRUPT_MESSAGE)
631
+ return
632
+ if assistant_task.done():
633
+ assistant_message = await assistant_task
634
+ break
635
+ try:
636
+ progress = progress_queue.get_nowait()
637
+ except asyncio.QueueEmpty:
638
+ waiter = asyncio.create_task(progress_queue.get())
639
+ done, pending = await asyncio.wait(
640
+ {assistant_task, waiter}, return_when=asyncio.FIRST_COMPLETED
641
+ )
642
+ if assistant_task in done:
643
+ for task in pending:
644
+ task.cancel()
645
+ assistant_message = await assistant_task
646
+ break
647
+ progress = waiter.result()
648
+ if progress:
649
+ yield progress
650
+
651
+ while not progress_queue.empty():
652
+ residual = progress_queue.get_nowait()
653
+ if residual:
654
+ yield residual
655
+
656
+ assert assistant_message is not None
657
+
470
658
  # Check for abort
471
659
  if query_context.abort_controller.is_set():
472
660
  yield create_assistant_message(INTERRUPT_MESSAGE)
@@ -474,7 +662,7 @@ async def query(
474
662
 
475
663
  yield assistant_message
476
664
 
477
- tool_use_blocks = extract_tool_use_blocks(assistant_message)
665
+ tool_use_blocks: List[MessageContent] = extract_tool_use_blocks(assistant_message)
478
666
  text_blocks = (
479
667
  len(assistant_message.message.content)
480
668
  if isinstance(assistant_message.message.content, list)
@@ -495,6 +683,7 @@ async def query(
495
683
  sibling_ids = set(
496
684
  getattr(t, "tool_use_id", None) or getattr(t, "id", None) or "" for t in tool_use_blocks
497
685
  )
686
+ prepared_calls: List[Dict[str, Any]] = []
498
687
 
499
688
  for tool_use in tool_use_blocks:
500
689
  tool_name = tool_use.name
@@ -511,14 +700,6 @@ async def query(
511
700
  continue
512
701
  assert tool is not None
513
702
 
514
- tool_context = ToolUseContext(
515
- safe_mode=query_context.safe_mode,
516
- verbose=query_context.verbose,
517
- permission_checker=can_use_tool_fn,
518
- tool_registry=query_context.tool_registry,
519
- abort_signal=query_context.abort_controller,
520
- )
521
-
522
703
  try:
523
704
  parsed_input = tool.input_schema(**tool_input)
524
705
  logger.debug(
@@ -526,6 +707,15 @@ async def query(
526
707
  f"{str(parsed_input)[:500]}"
527
708
  )
528
709
 
710
+ tool_context = ToolUseContext(
711
+ safe_mode=query_context.safe_mode,
712
+ verbose=query_context.verbose,
713
+ permission_checker=can_use_tool_fn,
714
+ tool_registry=query_context.tool_registry,
715
+ file_state_cache=query_context.file_state_cache,
716
+ abort_signal=query_context.abort_controller,
717
+ )
718
+
529
719
  validation = await tool.validate_input(parsed_input, tool_context)
530
720
  if not validation.result:
531
721
  logger.debug(
@@ -555,26 +745,19 @@ async def query(
555
745
  permission_denied = True
556
746
  break
557
747
 
558
- async for output in tool.call(parsed_input, tool_context):
559
- if isinstance(output, ToolProgress):
560
- progress = create_progress_message(
561
- tool_use_id=tool_use_id,
562
- sibling_tool_use_ids=sibling_ids,
563
- content=output.content,
564
- )
565
- yield progress
566
- logger.debug(f"[query] Progress from tool_use_id={tool_use_id}: {output.content}")
567
- elif isinstance(output, ToolResult):
568
- result_content = output.result_for_assistant or str(output.data)
569
- result_msg = tool_result_message(
570
- tool_use_id, result_content, tool_use_result=output.data
571
- )
572
- tool_results.append(result_msg)
573
- yield result_msg
574
- logger.debug(
575
- f"[query] Tool completed tool_use_id={tool_use_id} name={tool_name} "
576
- f"result_len={len(result_content)}"
577
- )
748
+ prepared_calls.append(
749
+ {
750
+ "is_concurrency_safe": tool.is_concurrency_safe(),
751
+ "generator": _run_tool_use_generator(
752
+ tool,
753
+ tool_use_id,
754
+ tool_name,
755
+ parsed_input,
756
+ sibling_ids,
757
+ tool_context,
758
+ ),
759
+ }
760
+ )
578
761
 
579
762
  except ValidationError as ve:
580
763
  detail_text = format_pydantic_errors(ve)
@@ -600,6 +783,13 @@ async def query(
600
783
  if permission_denied:
601
784
  break
602
785
 
786
+ if permission_denied:
787
+ return
788
+
789
+ if prepared_calls:
790
+ async for message in _run_tools_concurrently(prepared_calls, tool_results):
791
+ yield message
792
+
603
793
  # Check for abort after tools
604
794
  if query_context.abort_controller.is_set():
605
795
  yield create_assistant_message(INTERRUPT_MESSAGE_FOR_TOOL_USE)