ripperdoc 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ripperdoc/__init__.py +1 -1
  2. ripperdoc/cli/cli.py +9 -2
  3. ripperdoc/cli/commands/agents_cmd.py +8 -4
  4. ripperdoc/cli/commands/context_cmd.py +3 -3
  5. ripperdoc/cli/commands/cost_cmd.py +5 -0
  6. ripperdoc/cli/commands/doctor_cmd.py +12 -4
  7. ripperdoc/cli/commands/memory_cmd.py +6 -13
  8. ripperdoc/cli/commands/models_cmd.py +36 -6
  9. ripperdoc/cli/commands/resume_cmd.py +4 -2
  10. ripperdoc/cli/commands/status_cmd.py +1 -1
  11. ripperdoc/cli/ui/rich_ui.py +135 -2
  12. ripperdoc/cli/ui/thinking_spinner.py +128 -0
  13. ripperdoc/core/agents.py +174 -6
  14. ripperdoc/core/config.py +9 -1
  15. ripperdoc/core/default_tools.py +6 -0
  16. ripperdoc/core/providers/__init__.py +47 -0
  17. ripperdoc/core/providers/anthropic.py +147 -0
  18. ripperdoc/core/providers/base.py +236 -0
  19. ripperdoc/core/providers/gemini.py +496 -0
  20. ripperdoc/core/providers/openai.py +253 -0
  21. ripperdoc/core/query.py +337 -141
  22. ripperdoc/core/query_utils.py +65 -24
  23. ripperdoc/core/system_prompt.py +67 -61
  24. ripperdoc/core/tool.py +12 -3
  25. ripperdoc/sdk/client.py +12 -1
  26. ripperdoc/tools/ask_user_question_tool.py +433 -0
  27. ripperdoc/tools/background_shell.py +104 -18
  28. ripperdoc/tools/bash_tool.py +33 -13
  29. ripperdoc/tools/enter_plan_mode_tool.py +223 -0
  30. ripperdoc/tools/exit_plan_mode_tool.py +150 -0
  31. ripperdoc/tools/file_edit_tool.py +13 -0
  32. ripperdoc/tools/file_read_tool.py +16 -0
  33. ripperdoc/tools/file_write_tool.py +13 -0
  34. ripperdoc/tools/glob_tool.py +5 -1
  35. ripperdoc/tools/ls_tool.py +14 -10
  36. ripperdoc/tools/mcp_tools.py +113 -4
  37. ripperdoc/tools/multi_edit_tool.py +12 -0
  38. ripperdoc/tools/notebook_edit_tool.py +12 -0
  39. ripperdoc/tools/task_tool.py +88 -5
  40. ripperdoc/tools/todo_tool.py +1 -3
  41. ripperdoc/tools/tool_search_tool.py +8 -4
  42. ripperdoc/utils/file_watch.py +134 -0
  43. ripperdoc/utils/git_utils.py +36 -38
  44. ripperdoc/utils/json_utils.py +1 -2
  45. ripperdoc/utils/log.py +3 -4
  46. ripperdoc/utils/mcp.py +49 -10
  47. ripperdoc/utils/memory.py +1 -3
  48. ripperdoc/utils/message_compaction.py +5 -11
  49. ripperdoc/utils/messages.py +9 -13
  50. ripperdoc/utils/output_utils.py +1 -3
  51. ripperdoc/utils/prompt.py +17 -0
  52. ripperdoc/utils/session_usage.py +7 -0
  53. ripperdoc/utils/shell_utils.py +159 -0
  54. ripperdoc/utils/token_estimation.py +33 -0
  55. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.4.dist-info}/METADATA +3 -1
  56. ripperdoc-0.2.4.dist-info/RECORD +99 -0
  57. ripperdoc-0.2.2.dist-info/RECORD +0 -86
  58. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.4.dist-info}/WHEEL +0 -0
  59. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.4.dist-info}/entry_points.txt +0 -0
  60. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.4.dist-info}/licenses/LICENSE +0 -0
  61. {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.4.dist-info}/top_level.txt +0 -0
ripperdoc/core/query.py CHANGED
@@ -6,48 +6,60 @@ the query-response loop including tool execution.
6
6
 
7
7
  import asyncio
8
8
  import inspect
9
+ import os
9
10
  import time
10
- from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Tuple, Union, cast
11
+ from asyncio import CancelledError
12
+ from typing import (
13
+ Any,
14
+ AsyncGenerator,
15
+ Awaitable,
16
+ Callable,
17
+ Dict,
18
+ Iterable,
19
+ List,
20
+ Optional,
21
+ Tuple,
22
+ Union,
23
+ cast,
24
+ )
11
25
 
12
- from anthropic import AsyncAnthropic
13
- from openai import AsyncOpenAI
14
26
  from pydantic import ValidationError
15
27
 
16
- from ripperdoc.core.config import ProviderType, provider_protocol
28
+ from ripperdoc.core.config import provider_protocol
29
+ from ripperdoc.core.providers import ProviderClient, get_provider_client
17
30
  from ripperdoc.core.permissions import PermissionResult
18
31
  from ripperdoc.core.query_utils import (
19
- anthropic_usage_tokens,
20
- build_anthropic_tool_schemas,
21
32
  build_full_system_prompt,
22
- build_openai_tool_schemas,
23
- content_blocks_from_anthropic_response,
24
- content_blocks_from_openai_choice,
25
33
  determine_tool_mode,
26
34
  extract_tool_use_blocks,
27
35
  format_pydantic_errors,
28
36
  log_openai_messages,
29
- openai_usage_tokens,
30
37
  resolve_model_profile,
31
38
  text_mode_history,
32
39
  tool_result_message,
33
40
  )
34
41
  from ripperdoc.core.tool import Tool, ToolProgress, ToolResult, ToolUseContext
42
+ from ripperdoc.utils.file_watch import ChangedFileNotice, FileSnapshot, detect_changed_files
35
43
  from ripperdoc.utils.log import get_logger
36
44
  from ripperdoc.utils.messages import (
37
45
  AssistantMessage,
46
+ MessageContent,
38
47
  ProgressMessage,
39
48
  UserMessage,
40
49
  create_assistant_message,
50
+ create_user_message,
41
51
  create_progress_message,
42
52
  normalize_messages_for_api,
43
53
  INTERRUPT_MESSAGE,
44
54
  INTERRUPT_MESSAGE_FOR_TOOL_USE,
45
55
  )
46
- from ripperdoc.utils.session_usage import record_usage
47
56
 
48
57
 
49
58
  logger = get_logger()
50
59
 
60
+ DEFAULT_REQUEST_TIMEOUT_SEC = float(os.getenv("RIPPERDOC_API_TIMEOUT", "120"))
61
+ MAX_LLM_RETRIES = int(os.getenv("RIPPERDOC_MAX_RETRIES", "10"))
62
+
51
63
 
52
64
  def _resolve_tool(
53
65
  tool_registry: "ToolRegistry", tool_name: str, tool_use_id: str
@@ -62,11 +74,23 @@ def _resolve_tool(
62
74
  )
63
75
 
64
76
 
77
+ ToolPermissionCallable = Callable[
78
+ [Tool[Any, Any], Any],
79
+ Union[
80
+ PermissionResult,
81
+ Dict[str, Any],
82
+ Tuple[bool, Optional[str]],
83
+ bool,
84
+ Awaitable[Union[PermissionResult, Dict[str, Any], Tuple[bool, Optional[str]], bool]],
85
+ ],
86
+ ]
87
+
88
+
65
89
  async def _check_tool_permissions(
66
90
  tool: Tool[Any, Any],
67
91
  parsed_input: Any,
68
92
  query_context: "QueryContext",
69
- can_use_tool_fn: Optional[Any],
93
+ can_use_tool_fn: Optional[ToolPermissionCallable],
70
94
  ) -> tuple[bool, Optional[str]]:
71
95
  """Evaluate whether a tool call is allowed."""
72
96
  try:
@@ -102,6 +126,155 @@ async def _check_tool_permissions(
102
126
  return False, None
103
127
 
104
128
 
129
+ def _format_changed_file_notice(notices: List[ChangedFileNotice]) -> str:
130
+ """Render a system notice about files that changed on disk."""
131
+ lines: List[str] = [
132
+ "System notice: Files you previously read have changed on disk.",
133
+ "Please re-read the affected files before making further edits.",
134
+ "",
135
+ ]
136
+ for notice in notices:
137
+ lines.append(f"- {notice.file_path}")
138
+ summary = (notice.summary or "").rstrip()
139
+ if summary:
140
+ indented = "\n".join(f" {line}" for line in summary.splitlines())
141
+ lines.append(indented)
142
+ return "\n".join(lines)
143
+
144
+
145
+ async def _run_tool_use_generator(
146
+ tool: Tool[Any, Any],
147
+ tool_use_id: str,
148
+ tool_name: str,
149
+ parsed_input: Any,
150
+ sibling_ids: set[str],
151
+ tool_context: ToolUseContext,
152
+ ) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
153
+ """Execute a single tool_use and yield progress/results."""
154
+ try:
155
+ async for output in tool.call(parsed_input, tool_context):
156
+ if isinstance(output, ToolProgress):
157
+ yield create_progress_message(
158
+ tool_use_id=tool_use_id,
159
+ sibling_tool_use_ids=sibling_ids,
160
+ content=output.content,
161
+ )
162
+ logger.debug(f"[query] Progress from tool_use_id={tool_use_id}: {output.content}")
163
+ elif isinstance(output, ToolResult):
164
+ result_content = output.result_for_assistant or str(output.data)
165
+ result_msg = tool_result_message(
166
+ tool_use_id, result_content, tool_use_result=output.data
167
+ )
168
+ yield result_msg
169
+ logger.debug(
170
+ f"[query] Tool completed tool_use_id={tool_use_id} name={tool_name} "
171
+ f"result_len={len(result_content)}"
172
+ )
173
+ except Exception as exc:
174
+ logger.exception(
175
+ f"Error executing tool '{tool_name}'",
176
+ extra={"tool": tool_name, "tool_use_id": tool_use_id},
177
+ )
178
+ yield tool_result_message(tool_use_id, f"Error executing tool: {str(exc)}", is_error=True)
179
+
180
+
181
+ def _group_tool_calls_by_concurrency(prepared_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
182
+ """Group consecutive tool calls by their concurrency safety."""
183
+ groups: List[Dict[str, Any]] = []
184
+ for call in prepared_calls:
185
+ is_safe = bool(call.get("is_concurrency_safe"))
186
+ if groups and groups[-1]["is_concurrency_safe"] == is_safe:
187
+ groups[-1]["items"].append(call)
188
+ else:
189
+ groups.append({"is_concurrency_safe": is_safe, "items": [call]})
190
+ return groups
191
+
192
+
193
+ async def _execute_tools_sequentially(
194
+ items: List[Dict[str, Any]], tool_results: List[UserMessage]
195
+ ) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
196
+ """Run tool generators one by one."""
197
+ for item in items:
198
+ gen = item.get("generator")
199
+ if not gen:
200
+ continue
201
+ async for message in gen:
202
+ if isinstance(message, UserMessage):
203
+ tool_results.append(message)
204
+ yield message
205
+
206
+
207
+ async def _execute_tools_in_parallel(
208
+ items: List[Dict[str, Any]], tool_results: List[UserMessage]
209
+ ) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
210
+ """Run tool generators concurrently."""
211
+ generators = [call["generator"] for call in items if call.get("generator")]
212
+ async for message in _run_concurrent_tool_uses(generators, tool_results):
213
+ yield message
214
+
215
+
216
+ async def _run_tools_concurrently(
217
+ prepared_calls: List[Dict[str, Any]], tool_results: List[UserMessage]
218
+ ) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
219
+ """Run tools grouped by concurrency safety (parallel for safe groups)."""
220
+ for group in _group_tool_calls_by_concurrency(prepared_calls):
221
+ if group["is_concurrency_safe"]:
222
+ logger.debug(
223
+ f"[query] Executing {len(group['items'])} concurrency-safe tool(s) in parallel"
224
+ )
225
+ async for message in _execute_tools_in_parallel(group["items"], tool_results):
226
+ yield message
227
+ else:
228
+ logger.debug(
229
+ f"[query] Executing {len(group['items'])} tool(s) sequentially (not concurrency safe)"
230
+ )
231
+ async for message in _run_tools_serially(group["items"], tool_results):
232
+ yield message
233
+
234
+
235
+ async def _run_tools_serially(
236
+ prepared_calls: List[Dict[str, Any]], tool_results: List[UserMessage]
237
+ ) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
238
+ """Run all tools sequentially (helper for clarity)."""
239
+ async for message in _execute_tools_sequentially(prepared_calls, tool_results):
240
+ yield message
241
+
242
+
243
+ async def _run_concurrent_tool_uses(
244
+ generators: List[AsyncGenerator[Union[UserMessage, ProgressMessage], None]],
245
+ tool_results: List[UserMessage],
246
+ ) -> AsyncGenerator[Union[UserMessage, ProgressMessage], None]:
247
+ """Drain multiple tool generators concurrently and stream outputs."""
248
+ if not generators:
249
+ return
250
+
251
+ queue: asyncio.Queue[Optional[Union[UserMessage, ProgressMessage]]] = asyncio.Queue()
252
+
253
+ async def _consume(gen: AsyncGenerator[Union[UserMessage, ProgressMessage], None]) -> None:
254
+ try:
255
+ async for message in gen:
256
+ await queue.put(message)
257
+ except Exception:
258
+ logger.exception("[query] Unexpected error while consuming tool generator")
259
+ finally:
260
+ await queue.put(None)
261
+
262
+ tasks = [asyncio.create_task(_consume(gen)) for gen in generators]
263
+ active = len(tasks)
264
+
265
+ try:
266
+ while active:
267
+ message = await queue.get()
268
+ if message is None:
269
+ active -= 1
270
+ continue
271
+ if isinstance(message, UserMessage):
272
+ tool_results.append(message)
273
+ yield message
274
+ finally:
275
+ await asyncio.gather(*tasks, return_exceptions=True)
276
+
277
+
105
278
  class ToolRegistry:
106
279
  """Track available tools, including deferred ones, and expose search/activation helpers."""
107
280
 
@@ -204,6 +377,8 @@ class QueryContext:
204
377
  safe_mode: bool = False,
205
378
  model: str = "main",
206
379
  verbose: bool = False,
380
+ pause_ui: Optional[Callable[[], None]] = None,
381
+ resume_ui: Optional[Callable[[], None]] = None,
207
382
  ) -> None:
208
383
  self.tool_registry = ToolRegistry(tools)
209
384
  self.max_thinking_tokens = max_thinking_tokens
@@ -211,6 +386,9 @@ class QueryContext:
211
386
  self.model = model
212
387
  self.verbose = verbose
213
388
  self.abort_controller = asyncio.Event()
389
+ self.file_state_cache: Dict[str, FileSnapshot] = {}
390
+ self.pause_ui = pause_ui
391
+ self.resume_ui = resume_ui
214
392
 
215
393
  @property
216
394
  def tools(self) -> List[Tool[Any, Any]]:
@@ -238,6 +416,11 @@ async def query_llm(
238
416
  max_thinking_tokens: int = 0,
239
417
  model: str = "main",
240
418
  abort_signal: Optional[asyncio.Event] = None,
419
+ *,
420
+ progress_callback: Optional[Callable[[str], Awaitable[None]]] = None,
421
+ request_timeout: Optional[float] = None,
422
+ max_retries: int = MAX_LLM_RETRIES,
423
+ stream: bool = True,
241
424
  ) -> AssistantMessage:
242
425
  """Query the AI model and return the response.
243
426
 
@@ -248,10 +431,16 @@ async def query_llm(
248
431
  max_thinking_tokens: Maximum tokens for thinking (0 = disabled)
249
432
  model: Model pointer to use
250
433
  abort_signal: Event to signal abortion
434
+ progress_callback: Optional async callback invoked with streamed text chunks
435
+ request_timeout: Max seconds to wait for a provider response before retrying
436
+ max_retries: Number of retries on timeout/errors (total attempts = retries + 1)
437
+ stream: Enable streaming for providers that support it (text-only mode)
251
438
 
252
439
  Returns:
253
440
  AssistantMessage with the model's response
254
441
  """
442
+ request_timeout = request_timeout or DEFAULT_REQUEST_TIMEOUT_SEC
443
+ request_timeout = request_timeout or DEFAULT_REQUEST_TIMEOUT_SEC
255
444
  model_profile = resolve_model_profile(model)
256
445
 
257
446
  # Normalize messages based on protocol family (Anthropic allows tool blocks; OpenAI-style prefers text-only)
@@ -266,7 +455,7 @@ async def query_llm(
266
455
  else:
267
456
  messages_for_model = messages
268
457
 
269
- normalized_messages = normalize_messages_for_api(
458
+ normalized_messages: List[Dict[str, Any]] = normalize_messages_for_api(
270
459
  messages_for_model, protocol=protocol, tool_mode=tool_mode
271
460
  )
272
461
  logger.info(
@@ -295,95 +484,36 @@ async def query_llm(
295
484
  start_time = time.time()
296
485
 
297
486
  try:
298
- # Create the appropriate client based on provider
299
- if model_profile.provider == ProviderType.ANTHROPIC:
300
- async with AsyncAnthropic(api_key=model_profile.api_key) as client:
301
- tool_schemas = await build_anthropic_tool_schemas(tools)
302
- response = await client.messages.create(
303
- model=model_profile.model,
304
- max_tokens=model_profile.max_tokens,
305
- system=system_prompt,
306
- messages=normalized_messages, # type: ignore[arg-type]
307
- tools=tool_schemas if tool_schemas else None, # type: ignore
308
- temperature=model_profile.temperature,
309
- )
310
-
311
- duration_ms = (time.time() - start_time) * 1000
312
-
313
- usage_tokens = anthropic_usage_tokens(getattr(response, "usage", None))
314
- record_usage(model_profile.model, duration_ms=duration_ms, **usage_tokens)
315
-
316
- # Calculate cost (simplified, should use actual pricing)
317
- cost_usd = 0.0 # TODO: Implement cost calculation
318
-
319
- content_blocks = content_blocks_from_anthropic_response(response, tool_mode)
320
- tool_use_blocks = [
321
- block for block in response.content if getattr(block, "type", None) == "tool_use"
322
- ]
323
- logger.info(
324
- "[query_llm] Received response from Anthropic",
325
- extra={
326
- "model": model_profile.model,
327
- "duration_ms": round(duration_ms, 2),
328
- "usage_tokens": usage_tokens,
329
- "tool_use_blocks": len(tool_use_blocks),
330
- },
331
- )
332
-
333
- return create_assistant_message(
334
- content=content_blocks,
335
- cost_usd=cost_usd,
336
- duration_ms=duration_ms,
337
- )
338
-
339
- elif model_profile.provider == ProviderType.OPENAI_COMPATIBLE:
340
- # OpenAI-compatible APIs (OpenAI, DeepSeek, Mistral, etc.)
341
- async with AsyncOpenAI(api_key=model_profile.api_key, base_url=model_profile.api_base) as client:
342
- openai_tools = await build_openai_tool_schemas(tools)
343
-
344
- # Prepare messages for OpenAI format
345
- openai_messages = [
346
- {"role": "system", "content": system_prompt}
347
- ] + normalized_messages
348
-
349
- # Make the API call
350
- openai_response: Any = await client.chat.completions.create(
351
- model=model_profile.model,
352
- messages=openai_messages,
353
- tools=openai_tools if openai_tools else None, # type: ignore[arg-type]
354
- temperature=model_profile.temperature,
355
- max_tokens=model_profile.max_tokens,
356
- )
357
-
358
- duration_ms = (time.time() - start_time) * 1000
359
- usage_tokens = openai_usage_tokens(getattr(openai_response, "usage", None))
360
- record_usage(model_profile.model, duration_ms=duration_ms, **usage_tokens)
361
- cost_usd = 0.0 # TODO: Implement cost calculation
362
-
363
- # Convert OpenAI response to our format
364
- content_blocks = []
365
- choice = openai_response.choices[0]
366
-
367
- logger.info(
368
- "[query_llm] Received response from OpenAI-compatible provider",
369
- extra={
370
- "model": model_profile.model,
371
- "duration_ms": round(duration_ms, 2),
372
- "usage_tokens": usage_tokens,
373
- "finish_reason": getattr(choice, "finish_reason", None),
374
- },
375
- )
376
-
377
- content_blocks = content_blocks_from_openai_choice(choice, tool_mode)
378
-
379
- return create_assistant_message(
380
- content=content_blocks, cost_usd=cost_usd, duration_ms=duration_ms
381
- )
487
+ client: Optional[ProviderClient] = get_provider_client(model_profile.provider)
488
+ if client is None:
489
+ duration_ms = (time.time() - start_time) * 1000
490
+ error_msg = create_assistant_message(
491
+ content=(
492
+ "Gemini protocol is not supported yet in Ripperdoc. "
493
+ "Please configure an Anthropic or OpenAI-compatible model."
494
+ ),
495
+ duration_ms=duration_ms,
496
+ )
497
+ error_msg.is_api_error_message = True
498
+ return error_msg
499
+
500
+ provider_response = await client.call(
501
+ model_profile=model_profile,
502
+ system_prompt=system_prompt,
503
+ normalized_messages=normalized_messages,
504
+ tools=tools,
505
+ tool_mode=tool_mode,
506
+ stream=stream,
507
+ progress_callback=progress_callback,
508
+ request_timeout=request_timeout,
509
+ max_retries=max_retries,
510
+ )
382
511
 
383
- elif model_profile.provider == ProviderType.GEMINI:
384
- raise NotImplementedError("Gemini protocol is not yet supported.")
385
- else:
386
- raise NotImplementedError(f"Provider {model_profile.provider} not yet implemented")
512
+ return create_assistant_message(
513
+ content=provider_response.content_blocks,
514
+ cost_usd=provider_response.cost_usd,
515
+ duration_ms=provider_response.duration_ms,
516
+ )
387
517
 
388
518
  except Exception as e:
389
519
  # Return error message
@@ -392,9 +522,9 @@ async def query_llm(
392
522
  extra={
393
523
  "model": getattr(model_profile, "model", None),
394
524
  "model_pointer": model,
395
- "provider": getattr(model_profile.provider, "value", None)
396
- if model_profile
397
- else None,
525
+ "provider": (
526
+ getattr(model_profile.provider, "value", None) if model_profile else None
527
+ ),
398
528
  },
399
529
  )
400
530
  duration_ms = (time.time() - start_time) * 1000
@@ -410,7 +540,7 @@ async def query(
410
540
  system_prompt: str,
411
541
  context: Dict[str, str],
412
542
  query_context: QueryContext,
413
- can_use_tool_fn: Optional[Any] = None,
543
+ can_use_tool_fn: Optional[ToolPermissionCallable] = None,
414
544
  ) -> AsyncGenerator[Union[UserMessage, AssistantMessage, ProgressMessage], None]:
415
545
  """Execute a query with tool support.
416
546
 
@@ -442,6 +572,9 @@ async def query(
442
572
  # Work on a copy so external mutations (e.g., UI appending messages while consuming)
443
573
  # do not interfere with recursion or normalization.
444
574
  messages = list(messages)
575
+ change_notices = detect_changed_files(query_context.file_state_cache)
576
+ if change_notices:
577
+ messages.append(create_user_message(_format_changed_file_notice(change_notices)))
445
578
  model_profile = resolve_model_profile(query_context.model)
446
579
  tool_mode = determine_tool_mode(model_profile)
447
580
  tools_for_model: List[Tool[Any, Any]] = [] if tool_mode == "text" else query_context.all_tools()
@@ -458,15 +591,74 @@ async def query(
458
591
  },
459
592
  )
460
593
 
461
- assistant_message = await query_llm(
462
- messages,
463
- full_system_prompt,
464
- tools_for_model,
465
- query_context.max_thinking_tokens,
466
- query_context.model,
467
- query_context.abort_controller,
594
+ progress_queue: asyncio.Queue[Optional[ProgressMessage]] = asyncio.Queue()
595
+
596
+ async def _stream_progress(chunk: str) -> None:
597
+ if not chunk:
598
+ return
599
+ try:
600
+ await progress_queue.put(
601
+ create_progress_message(
602
+ tool_use_id="stream",
603
+ sibling_tool_use_ids=set(),
604
+ content=chunk,
605
+ )
606
+ )
607
+ except Exception:
608
+ logger.exception("[query] Failed to enqueue stream progress chunk")
609
+
610
+ assistant_task = asyncio.create_task(
611
+ query_llm(
612
+ messages,
613
+ full_system_prompt,
614
+ tools_for_model,
615
+ query_context.max_thinking_tokens,
616
+ query_context.model,
617
+ query_context.abort_controller,
618
+ progress_callback=_stream_progress,
619
+ request_timeout=DEFAULT_REQUEST_TIMEOUT_SEC,
620
+ max_retries=MAX_LLM_RETRIES,
621
+ stream=True,
622
+ )
468
623
  )
469
624
 
625
+ assistant_message: Optional[AssistantMessage] = None
626
+
627
+ while True:
628
+ if query_context.abort_controller.is_set():
629
+ assistant_task.cancel()
630
+ try:
631
+ await assistant_task
632
+ except CancelledError:
633
+ pass
634
+ yield create_assistant_message(INTERRUPT_MESSAGE)
635
+ return
636
+ if assistant_task.done():
637
+ assistant_message = await assistant_task
638
+ break
639
+ try:
640
+ progress = progress_queue.get_nowait()
641
+ except asyncio.QueueEmpty:
642
+ waiter = asyncio.create_task(progress_queue.get())
643
+ done, pending = await asyncio.wait(
644
+ {assistant_task, waiter}, return_when=asyncio.FIRST_COMPLETED
645
+ )
646
+ if assistant_task in done:
647
+ for task in pending:
648
+ task.cancel()
649
+ assistant_message = await assistant_task
650
+ break
651
+ progress = waiter.result()
652
+ if progress:
653
+ yield progress
654
+
655
+ while not progress_queue.empty():
656
+ residual = progress_queue.get_nowait()
657
+ if residual:
658
+ yield residual
659
+
660
+ assert assistant_message is not None
661
+
470
662
  # Check for abort
471
663
  if query_context.abort_controller.is_set():
472
664
  yield create_assistant_message(INTERRUPT_MESSAGE)
@@ -474,7 +666,7 @@ async def query(
474
666
 
475
667
  yield assistant_message
476
668
 
477
- tool_use_blocks = extract_tool_use_blocks(assistant_message)
669
+ tool_use_blocks: List[MessageContent] = extract_tool_use_blocks(assistant_message)
478
670
  text_blocks = (
479
671
  len(assistant_message.message.content)
480
672
  if isinstance(assistant_message.message.content, list)
@@ -495,6 +687,7 @@ async def query(
495
687
  sibling_ids = set(
496
688
  getattr(t, "tool_use_id", None) or getattr(t, "id", None) or "" for t in tool_use_blocks
497
689
  )
690
+ prepared_calls: List[Dict[str, Any]] = []
498
691
 
499
692
  for tool_use in tool_use_blocks:
500
693
  tool_name = tool_use.name
@@ -511,14 +704,6 @@ async def query(
511
704
  continue
512
705
  assert tool is not None
513
706
 
514
- tool_context = ToolUseContext(
515
- safe_mode=query_context.safe_mode,
516
- verbose=query_context.verbose,
517
- permission_checker=can_use_tool_fn,
518
- tool_registry=query_context.tool_registry,
519
- abort_signal=query_context.abort_controller,
520
- )
521
-
522
707
  try:
523
708
  parsed_input = tool.input_schema(**tool_input)
524
709
  logger.debug(
@@ -526,6 +711,17 @@ async def query(
526
711
  f"{str(parsed_input)[:500]}"
527
712
  )
528
713
 
714
+ tool_context = ToolUseContext(
715
+ safe_mode=query_context.safe_mode,
716
+ verbose=query_context.verbose,
717
+ permission_checker=can_use_tool_fn,
718
+ tool_registry=query_context.tool_registry,
719
+ file_state_cache=query_context.file_state_cache,
720
+ abort_signal=query_context.abort_controller,
721
+ pause_ui=query_context.pause_ui,
722
+ resume_ui=query_context.resume_ui,
723
+ )
724
+
529
725
  validation = await tool.validate_input(parsed_input, tool_context)
530
726
  if not validation.result:
531
727
  logger.debug(
@@ -555,26 +751,19 @@ async def query(
555
751
  permission_denied = True
556
752
  break
557
753
 
558
- async for output in tool.call(parsed_input, tool_context):
559
- if isinstance(output, ToolProgress):
560
- progress = create_progress_message(
561
- tool_use_id=tool_use_id,
562
- sibling_tool_use_ids=sibling_ids,
563
- content=output.content,
564
- )
565
- yield progress
566
- logger.debug(f"[query] Progress from tool_use_id={tool_use_id}: {output.content}")
567
- elif isinstance(output, ToolResult):
568
- result_content = output.result_for_assistant or str(output.data)
569
- result_msg = tool_result_message(
570
- tool_use_id, result_content, tool_use_result=output.data
571
- )
572
- tool_results.append(result_msg)
573
- yield result_msg
574
- logger.debug(
575
- f"[query] Tool completed tool_use_id={tool_use_id} name={tool_name} "
576
- f"result_len={len(result_content)}"
577
- )
754
+ prepared_calls.append(
755
+ {
756
+ "is_concurrency_safe": tool.is_concurrency_safe(),
757
+ "generator": _run_tool_use_generator(
758
+ tool,
759
+ tool_use_id,
760
+ tool_name,
761
+ parsed_input,
762
+ sibling_ids,
763
+ tool_context,
764
+ ),
765
+ }
766
+ )
578
767
 
579
768
  except ValidationError as ve:
580
769
  detail_text = format_pydantic_errors(ve)
@@ -600,6 +789,13 @@ async def query(
600
789
  if permission_denied:
601
790
  break
602
791
 
792
+ if permission_denied:
793
+ return
794
+
795
+ if prepared_calls:
796
+ async for message in _run_tools_concurrently(prepared_calls, tool_results):
797
+ yield message
798
+
603
799
  # Check for abort after tools
604
800
  if query_context.abort_controller.is_set():
605
801
  yield create_assistant_message(INTERRUPT_MESSAGE_FOR_TOOL_USE)