zrb 1.21.17__py3-none-any.whl → 1.21.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. zrb/attr/type.py +10 -7
  2. zrb/builtin/git.py +12 -1
  3. zrb/builtin/llm/chat_completion.py +287 -0
  4. zrb/builtin/llm/chat_session_cmd.py +90 -28
  5. zrb/builtin/llm/chat_trigger.py +6 -1
  6. zrb/builtin/llm/tool/cli.py +29 -13
  7. zrb/builtin/llm/tool/code.py +9 -1
  8. zrb/builtin/llm/tool/file.py +32 -6
  9. zrb/builtin/llm/tool/note.py +9 -9
  10. zrb/builtin/llm/tool/search/__init__.py +1 -0
  11. zrb/builtin/llm/tool/search/brave.py +66 -0
  12. zrb/builtin/llm/tool/search/searxng.py +61 -0
  13. zrb/builtin/llm/tool/search/serpapi.py +61 -0
  14. zrb/builtin/llm/tool/sub_agent.py +30 -10
  15. zrb/builtin/llm/tool/web.py +17 -72
  16. zrb/config/config.py +67 -26
  17. zrb/config/default_prompt/interactive_system_prompt.md +16 -13
  18. zrb/config/default_prompt/summarization_prompt.md +54 -8
  19. zrb/config/default_prompt/system_prompt.md +16 -18
  20. zrb/config/llm_rate_limitter.py +15 -6
  21. zrb/input/option_input.py +13 -1
  22. zrb/task/llm/agent.py +42 -143
  23. zrb/task/llm/agent_runner.py +152 -0
  24. zrb/task/llm/conversation_history.py +35 -24
  25. zrb/task/llm/conversation_history_model.py +4 -11
  26. zrb/task/llm/history_processor.py +206 -0
  27. zrb/task/llm/history_summarization.py +2 -179
  28. zrb/task/llm/print_node.py +14 -5
  29. zrb/task/llm/prompt.py +2 -17
  30. zrb/task/llm/subagent_conversation_history.py +41 -0
  31. zrb/task/llm/tool_confirmation_completer.py +41 -0
  32. zrb/task/llm/tool_wrapper.py +15 -11
  33. zrb/task/llm_task.py +41 -40
  34. zrb/util/attr.py +12 -7
  35. zrb/util/git.py +2 -2
  36. zrb/xcom/xcom.py +10 -0
  37. {zrb-1.21.17.dist-info → zrb-1.21.33.dist-info}/METADATA +3 -3
  38. {zrb-1.21.17.dist-info → zrb-1.21.33.dist-info}/RECORD +40 -32
  39. zrb/task/llm/history_summarization_tool.py +0 -24
  40. {zrb-1.21.17.dist-info → zrb-1.21.33.dist-info}/WHEEL +0 -0
  41. {zrb-1.21.17.dist-info → zrb-1.21.33.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,206 @@
1
+ import json
2
+ import sys
3
+ import traceback
4
+ from typing import TYPE_CHECKING, Any, Callable, Coroutine
5
+
6
+ from zrb.config.llm_config import llm_config
7
+ from zrb.config.llm_rate_limitter import LLMRateLimitter
8
+ from zrb.config.llm_rate_limitter import llm_rate_limitter as default_llm_rate_limitter
9
+ from zrb.context.any_context import AnyContext
10
+ from zrb.task.llm.agent_runner import run_agent_iteration
11
+ from zrb.util.cli.style import stylize_faint
12
+ from zrb.util.markdown import make_markdown_section
13
+
14
+ if sys.version_info >= (3, 12):
15
+ from typing import TypedDict
16
+ else:
17
+ from typing_extensions import TypedDict
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from pydantic_ai import ModelMessage
22
+ from pydantic_ai.models import Model
23
+ from pydantic_ai.settings import ModelSettings
24
+
25
+
26
+ class SingleMessage(TypedDict):
27
+ """
28
+ SingleConversation
29
+
30
+ Attributes:
31
+ role: Either AI, User, Tool Call, or Tool Result
32
+ time: yyyy-mm-ddTHH:MM:SSZ:
33
+ content: The content of the message (summarize if too long)
34
+ """
35
+
36
+ role: str
37
+ time: str
38
+ content: str
39
+
40
+
41
+ class ConversationSummary(TypedDict):
42
+ """
43
+ Conversation history
44
+
45
+ Attributes:
46
+ transcript: Several last transcript of the conversation
47
+ summary: Descriptive conversation summary
48
+ """
49
+
50
+ transcript: list[SingleMessage]
51
+ summary: str
52
+
53
+
54
+ def save_conversation_summary(conversation_summary: ConversationSummary):
55
+ """
56
+ Write conversation summary for main assistant to continue conversation.
57
+ """
58
+ return conversation_summary
59
+
60
+
61
+ def create_summarize_history_processor(
62
+ ctx: AnyContext,
63
+ system_prompt: str,
64
+ rate_limitter: LLMRateLimitter | None = None,
65
+ summarization_model: "Model | str | None" = None,
66
+ summarization_model_settings: "ModelSettings | None" = None,
67
+ summarization_system_prompt: str | None = None,
68
+ summarization_token_threshold: int | None = None,
69
+ summarization_retries: int = 2,
70
+ ) -> Callable[[list["ModelMessage"]], Coroutine[None, None, list["ModelMessage"]]]:
71
+ from pydantic_ai import Agent, ModelMessage, ModelRequest
72
+ from pydantic_ai.messages import ModelMessagesTypeAdapter, UserPromptPart
73
+
74
+ if rate_limitter is None:
75
+ rate_limitter = default_llm_rate_limitter
76
+ if summarization_model is None:
77
+ summarization_model = llm_config.default_small_model
78
+ if summarization_model_settings is None:
79
+ summarization_model_settings = llm_config.default_small_model_settings
80
+ if summarization_system_prompt is None:
81
+ summarization_system_prompt = llm_config.default_summarization_prompt
82
+ if summarization_token_threshold is None:
83
+ summarization_token_threshold = (
84
+ llm_config.default_history_summarization_token_threshold
85
+ )
86
+
87
+ async def maybe_summarize_history(
88
+ messages: list[ModelMessage],
89
+ ) -> list[ModelMessage]:
90
+ history_list = json.loads(ModelMessagesTypeAdapter.dump_json(messages))
91
+ history_json_str = json.dumps(history_list)
92
+ # Estimate token usage
93
+ # Note: Pydantic ai has run context parameter
94
+ # (https://ai.pydantic.dev/message-history/#runcontext-parameter)
95
+ # But we cannot use run_ctx.usage.total_tokens because total token keep increasing
96
+ # even after summariztion.
97
+ estimated_token_usage = rate_limitter.count_token(history_json_str)
98
+ _print_request_info(
99
+ ctx, estimated_token_usage, summarization_token_threshold, messages
100
+ )
101
+ if estimated_token_usage < summarization_token_threshold or len(messages) == 1:
102
+ return messages
103
+ history_list_without_instruction = [
104
+ {
105
+ key: obj[key]
106
+ for key in obj
107
+ if index == len(history_list) - 1 or key != "instructions"
108
+ }
109
+ for index, obj in enumerate(history_list)
110
+ ]
111
+ history_json_str_without_instruction = json.dumps(
112
+ history_list_without_instruction
113
+ )
114
+ summarization_message = f"Summarize the following conversation: {history_json_str_without_instruction}"
115
+ summarization_agent = Agent[None, ConversationSummary](
116
+ model=summarization_model,
117
+ output_type=save_conversation_summary,
118
+ instructions=summarization_system_prompt,
119
+ model_settings=summarization_model_settings,
120
+ retries=summarization_retries,
121
+ )
122
+ try:
123
+ _print_info(ctx, "📝 Rollup Conversation", 2)
124
+ summary_run = await run_agent_iteration(
125
+ ctx=ctx,
126
+ agent=summarization_agent,
127
+ user_prompt=summarization_message,
128
+ attachments=[],
129
+ history_list=[],
130
+ rate_limitter=rate_limitter,
131
+ log_indent_level=2,
132
+ )
133
+ if summary_run and summary_run.result and summary_run.result.output:
134
+ usage = summary_run.result.usage()
135
+ _print_info(ctx, f"📝 Rollup Conversation Token: {usage}", 2)
136
+ ctx.print(plain=True)
137
+ ctx.log_info("History summarized and updated.")
138
+ condensed_message = make_markdown_section(
139
+ header="Past Conversation",
140
+ content="\n".join(
141
+ [
142
+ make_markdown_section(
143
+ "Summary", _extract_summary(summary_run.result.output)
144
+ ),
145
+ make_markdown_section(
146
+ "Past Trancript",
147
+ _extract_transcript(summary_run.result.output),
148
+ ),
149
+ ]
150
+ ),
151
+ )
152
+ return [
153
+ ModelRequest(
154
+ instructions=system_prompt,
155
+ parts=[UserPromptPart(condensed_message)],
156
+ )
157
+ ]
158
+ ctx.log_warning("History summarization failed or returned no data.")
159
+ except BaseException as e:
160
+ ctx.log_warning(f"Error during history summarization: {e}")
161
+ traceback.print_exc()
162
+ return messages
163
+
164
+ return maybe_summarize_history
165
+
166
+
167
+ def _print_request_info(
168
+ ctx: AnyContext,
169
+ estimated_token_usage: int,
170
+ summarization_token_threshold: int,
171
+ messages: list["ModelMessage"],
172
+ ):
173
+ _print_info(ctx, f"Current request token (estimated): {estimated_token_usage}")
174
+ _print_info(ctx, f"Summarization token threshold: {summarization_token_threshold}")
175
+ _print_info(ctx, f"History length: {len(messages)}")
176
+
177
+
178
+ def _print_info(ctx: AnyContext, text: str, log_indent_level: int = 0):
179
+ log_prefix = (2 * (log_indent_level + 1)) * " "
180
+ ctx.print(stylize_faint(f"{log_prefix}{text}"), plain=True)
181
+
182
+
183
+ def _extract_summary(summary_result_output: dict[str, Any] | str) -> str:
184
+ summary = (
185
+ summary_result_output.get("summary", "")
186
+ if isinstance(summary_result_output, dict)
187
+ else ""
188
+ )
189
+ return summary
190
+
191
+
192
+ def _extract_transcript(summary_result_output: dict[str, Any] | str) -> str:
193
+ transcript_list = (
194
+ summary_result_output.get("transcript", [])
195
+ if isinstance(summary_result_output, dict)
196
+ else []
197
+ )
198
+ transcript_list = [] if not isinstance(transcript_list, list) else transcript_list
199
+ return "\n".join(_format_transcript_message(message) for message in transcript_list)
200
+
201
+
202
+ def _format_transcript_message(message: dict[str, str]) -> str:
203
+ role = message.get("role", "Message")
204
+ time = message.get("time", "<unknown>")
205
+ content = message.get("content", "<empty>")
206
+ return f"{role} ({time}): {content}"
@@ -1,36 +1,7 @@
1
- import json
2
- import traceback
3
- from typing import TYPE_CHECKING
4
-
5
- from zrb.attr.type import BoolAttr, IntAttr
1
+ from zrb.attr.type import IntAttr
6
2
  from zrb.config.llm_config import llm_config
7
- from zrb.config.llm_rate_limitter import LLMRateLimiter, llm_rate_limitter
8
3
  from zrb.context.any_context import AnyContext
9
- from zrb.task.llm.agent import run_agent_iteration
10
- from zrb.task.llm.conversation_history import (
11
- count_part_in_history_list,
12
- inject_conversation_history_notes,
13
- replace_system_prompt_in_history,
14
- )
15
- from zrb.task.llm.conversation_history_model import ConversationHistory
16
- from zrb.task.llm.history_summarization_tool import (
17
- create_history_summarization_tool,
18
- )
19
- from zrb.task.llm.typing import ListOfDict
20
- from zrb.util.attr import get_bool_attr, get_int_attr
21
- from zrb.util.cli.style import stylize_faint
22
- from zrb.util.markdown import make_markdown_section
23
- from zrb.util.truncate import truncate_str
24
-
25
- if TYPE_CHECKING:
26
- from pydantic_ai.models import Model
27
- from pydantic_ai.settings import ModelSettings
28
-
29
-
30
- def _count_token_in_history(history_list: ListOfDict) -> int:
31
- """Counts the total number of tokens in a conversation history list."""
32
- text_to_count = json.dumps(history_list)
33
- return llm_rate_limitter.count_token(text_to_count)
4
+ from zrb.util.attr import get_int_attr
34
5
 
35
6
 
36
7
  def get_history_summarization_token_threshold(
@@ -52,151 +23,3 @@ def get_history_summarization_token_threshold(
52
23
  "Defaulting to -1 (no threshold)."
53
24
  )
54
25
  return -1
55
-
56
-
57
- def should_summarize_history(
58
- ctx: AnyContext,
59
- history_list: ListOfDict,
60
- should_summarize_history_attr: BoolAttr | None,
61
- render_summarize_history: bool,
62
- history_summarization_token_threshold_attr: IntAttr | None,
63
- render_history_summarization_token_threshold: bool,
64
- ) -> bool:
65
- """Determines if history summarization should occur based on token length and config."""
66
- history_part_count = count_part_in_history_list(history_list)
67
- if history_part_count == 0:
68
- return False
69
- summarization_token_threshold = get_history_summarization_token_threshold(
70
- ctx,
71
- history_summarization_token_threshold_attr,
72
- render_history_summarization_token_threshold,
73
- )
74
- history_token_count = _count_token_in_history(history_list)
75
- if (
76
- summarization_token_threshold == -1
77
- or summarization_token_threshold > history_token_count
78
- ):
79
- return False
80
- return get_bool_attr(
81
- ctx,
82
- should_summarize_history_attr,
83
- llm_config.default_summarize_history,
84
- auto_render=render_summarize_history,
85
- )
86
-
87
-
88
- async def summarize_history(
89
- ctx: AnyContext,
90
- model: "Model | str | None",
91
- settings: "ModelSettings | None",
92
- system_prompt: str,
93
- conversation_history: ConversationHistory,
94
- rate_limitter: LLMRateLimiter | None = None,
95
- retries: int = 3,
96
- ) -> ConversationHistory:
97
- """Runs an LLM call to update the conversation summary."""
98
- from pydantic_ai import Agent
99
-
100
- inject_conversation_history_notes(conversation_history)
101
- ctx.log_info("Attempting to summarize conversation history...")
102
- # Construct the user prompt for the summarization agent
103
- user_prompt = "\n".join(
104
- [
105
- make_markdown_section(
106
- "Past Conversation",
107
- "\n".join(
108
- [
109
- make_markdown_section(
110
- "Summary",
111
- conversation_history.past_conversation_summary,
112
- as_code=True,
113
- ),
114
- make_markdown_section(
115
- "Last Transcript",
116
- conversation_history.past_conversation_transcript,
117
- as_code=True,
118
- ),
119
- ]
120
- ),
121
- ),
122
- make_markdown_section(
123
- "Recent Conversation (JSON)",
124
- json.dumps(truncate_str(conversation_history.history, 1000)),
125
- as_code=True,
126
- ),
127
- ]
128
- )
129
- summarize = create_history_summarization_tool(conversation_history)
130
- summarization_agent = Agent[None, str](
131
- model=model,
132
- output_type=summarize,
133
- system_prompt=system_prompt,
134
- model_settings=settings,
135
- retries=retries,
136
- )
137
- try:
138
- ctx.print(stylize_faint(" 📝 Rollup Conversation"), plain=True)
139
- summary_run = await run_agent_iteration(
140
- ctx=ctx,
141
- agent=summarization_agent,
142
- user_prompt=user_prompt,
143
- attachments=[],
144
- history_list=[],
145
- rate_limitter=rate_limitter,
146
- log_indent_level=2,
147
- )
148
- if summary_run and summary_run.result and summary_run.result.output:
149
- usage = summary_run.result.usage()
150
- ctx.print(
151
- stylize_faint(f" 📝 Rollup Conversation Token: {usage}"), plain=True
152
- )
153
- ctx.print(plain=True)
154
- ctx.log_info("History summarized and updated.")
155
- else:
156
- ctx.log_warning("History summarization failed or returned no data.")
157
- except BaseException as e:
158
- ctx.log_warning(f"Error during history summarization: {e}")
159
- traceback.print_exc()
160
- # Return the original summary if summarization fails
161
- return conversation_history
162
-
163
-
164
- async def maybe_summarize_history(
165
- ctx: AnyContext,
166
- conversation_history: ConversationHistory,
167
- should_summarize_history_attr: BoolAttr | None,
168
- render_summarize_history: bool,
169
- history_summarization_token_threshold_attr: IntAttr | None,
170
- render_history_summarization_token_threshold: bool,
171
- model: "str | Model | None",
172
- model_settings: "ModelSettings | None",
173
- summarization_prompt: str,
174
- rate_limitter: LLMRateLimiter | None = None,
175
- ) -> ConversationHistory:
176
- """Summarizes history and updates context if enabled and threshold met."""
177
- shorten_history = replace_system_prompt_in_history(conversation_history.history)
178
- if should_summarize_history(
179
- ctx,
180
- shorten_history,
181
- should_summarize_history_attr,
182
- render_summarize_history,
183
- history_summarization_token_threshold_attr,
184
- render_history_summarization_token_threshold,
185
- ):
186
- original_history = conversation_history.history
187
- conversation_history.history = shorten_history
188
- conversation_history = await summarize_history(
189
- ctx=ctx,
190
- model=model,
191
- settings=model_settings,
192
- system_prompt=summarization_prompt,
193
- conversation_history=conversation_history,
194
- rate_limitter=rate_limitter,
195
- )
196
- conversation_history.history = original_history
197
- if (
198
- conversation_history.past_conversation_summary != ""
199
- and conversation_history.past_conversation_transcript != ""
200
- ):
201
- conversation_history.history = []
202
- return conversation_history
@@ -2,6 +2,7 @@ import json
2
2
  from collections.abc import Callable
3
3
  from typing import Any
4
4
 
5
+ from zrb.config.config import CFG
5
6
  from zrb.util.cli.style import stylize_faint
6
7
 
7
8
 
@@ -104,12 +105,20 @@ async def print_node(
104
105
  and event.tool_call_id
105
106
  ):
106
107
  call_id = event.tool_call_id
107
- result_content = event.result.content
108
- print_func(
109
- _format_content(
110
- f"{call_id} | {result_content}", log_indent_level
108
+ if CFG.LLM_SHOW_TOOL_CALL_RESULT:
109
+ result_content = event.result.content
110
+ print_func(
111
+ _format_content(
112
+ f"{call_id} | Return {result_content}",
113
+ log_indent_level,
114
+ )
115
+ )
116
+ else:
117
+ print_func(
118
+ _format_content(
119
+ f"{call_id} | Executed", log_indent_level
120
+ )
111
121
  )
112
- )
113
122
  except UnexpectedModelBehavior as e:
114
123
  print_func("") # ensure newline consistency
115
124
  print_func(
zrb/task/llm/prompt.py CHANGED
@@ -115,11 +115,11 @@ def _construct_system_prompt(
115
115
  ),
116
116
  ),
117
117
  make_markdown_section(
118
- "🧠 Long Term Note",
118
+ "🧠 Long Term Note Content",
119
119
  conversation_history.long_term_note,
120
120
  ),
121
121
  make_markdown_section(
122
- "📝 Contextual Note",
122
+ "📝 Contextual Note Content",
123
123
  conversation_history.contextual_note,
124
124
  ),
125
125
  make_markdown_section(
@@ -129,21 +129,6 @@ def _construct_system_prompt(
129
129
  ]
130
130
  ),
131
131
  ),
132
- make_markdown_section(
133
- "💬 PAST CONVERSATION",
134
- "\n".join(
135
- [
136
- make_markdown_section(
137
- "Narrative Summary",
138
- conversation_history.past_conversation_summary,
139
- ),
140
- make_markdown_section(
141
- "Past Transcript",
142
- conversation_history.past_conversation_transcript,
143
- ),
144
- ]
145
- ),
146
- ),
147
132
  ]
148
133
  )
149
134
 
@@ -0,0 +1,41 @@
1
+ from zrb.context.any_context import AnyContext
2
+ from zrb.task.llm.conversation_history_model import ConversationHistory
3
+ from zrb.task.llm.typing import ListOfDict
4
+ from zrb.xcom.xcom import Xcom
5
+
6
+
7
+ def inject_subagent_conversation_history_into_ctx(
8
+ ctx: AnyContext, conversation_history: ConversationHistory
9
+ ):
10
+ subagent_messages_xcom = _get_global_subagent_history_xcom(ctx)
11
+ existing_subagent_history = subagent_messages_xcom.get({})
12
+ subagent_messages_xcom.set(
13
+ {**existing_subagent_history, **conversation_history.subagent_history}
14
+ )
15
+
16
+
17
+ def extract_subagent_conversation_history_from_ctx(
18
+ ctx: AnyContext,
19
+ ) -> dict[str, ListOfDict]:
20
+ subagent_messsages_xcom = _get_global_subagent_history_xcom(ctx)
21
+ return subagent_messsages_xcom.get({})
22
+
23
+
24
+ def get_ctx_subagent_history(ctx: AnyContext, subagent_name: str) -> ListOfDict:
25
+ subagent_history = extract_subagent_conversation_history_from_ctx(ctx)
26
+ return subagent_history.get(subagent_name, [])
27
+
28
+
29
+ def set_ctx_subagent_history(ctx: AnyContext, subagent_name: str, messages: ListOfDict):
30
+ subagent_history = extract_subagent_conversation_history_from_ctx(ctx)
31
+ subagent_history[subagent_name] = messages
32
+ subagent_messages_xcom = _get_global_subagent_history_xcom(ctx)
33
+ subagent_messages_xcom.set(subagent_history)
34
+
35
+
36
+ def _get_global_subagent_history_xcom(ctx: AnyContext) -> Xcom:
37
+ if "_global_subagents" not in ctx.xcom:
38
+ ctx.xcom["_global_subagents"] = Xcom([{}])
39
+ if not isinstance(ctx.xcom["_global_subagents"], Xcom):
40
+ raise ValueError("ctx.xcom._global_subagents must be an Xcom")
41
+ return ctx.xcom["_global_subagents"]
@@ -0,0 +1,41 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ if TYPE_CHECKING:
4
+ from prompt_toolkit.completion import Completer
5
+
6
+
7
+ def get_tool_confirmation_completer(
8
+ options: list[str], meta_dict: dict[str, str]
9
+ ) -> "Completer":
10
+ from prompt_toolkit.completion import Completer, Completion
11
+
12
+ class ToolConfirmationCompleter(Completer):
13
+ """Custom completer for tool confirmation that doesn't auto-complete partial words."""
14
+
15
+ def __init__(self, options, meta_dict):
16
+ self.options = options
17
+ self.meta_dict = meta_dict
18
+
19
+ def get_completions(self, document, complete_event):
20
+ text = document.text.strip()
21
+ # 1. Input is empty, OR
22
+ # 2. Input exactly matches the beginning of an option
23
+ if text == "":
24
+ # Show all options when nothing is typed
25
+ for option in self.options:
26
+ yield Completion(
27
+ option,
28
+ start_position=0,
29
+ display_meta=self.meta_dict.get(option, ""),
30
+ )
31
+ return
32
+ # Only complete if text exactly matches the beginning of an option
33
+ for option in self.options:
34
+ if option.startswith(text):
35
+ yield Completion(
36
+ option,
37
+ start_position=-len(text),
38
+ display_meta=self.meta_dict.get(option, ""),
39
+ )
40
+
41
+ return ToolConfirmationCompleter(options, meta_dict)
@@ -11,11 +11,11 @@ from zrb.config.llm_rate_limitter import llm_rate_limitter
11
11
  from zrb.context.any_context import AnyContext
12
12
  from zrb.task.llm.error import ToolExecutionError
13
13
  from zrb.task.llm.file_replacement import edit_replacement, is_single_path_replacement
14
+ from zrb.task.llm.tool_confirmation_completer import get_tool_confirmation_completer
14
15
  from zrb.util.callable import get_callable_name
15
16
  from zrb.util.cli.markdown import render_markdown
16
17
  from zrb.util.cli.style import (
17
18
  stylize_blue,
18
- stylize_error,
19
19
  stylize_faint,
20
20
  stylize_green,
21
21
  stylize_yellow,
@@ -185,7 +185,7 @@ async def _handle_user_response(
185
185
  ]
186
186
  )
187
187
  ctx.print(complete_confirmation_message, plain=True)
188
- user_response = await _read_line()
188
+ user_response = await _read_line(args, kwargs)
189
189
  ctx.print("", plain=True)
190
190
  new_kwargs, is_edited = _get_edited_kwargs(ctx, user_response, kwargs)
191
191
  if is_edited:
@@ -250,13 +250,7 @@ def _get_user_approval_and_reason(
250
250
  try:
251
251
  approved = True if approval_str.strip() == "" else to_boolean(approval_str)
252
252
  if not approved and reason == "":
253
- ctx.print(
254
- stylize_error(
255
- f"You must specify rejection reason (i.e., No, <why>) for {func_call_str}" # noqa
256
- ),
257
- plain=True,
258
- )
259
- return None
253
+ reason = "User disapproving the tool execution"
260
254
  return approved, reason
261
255
  except Exception:
262
256
  return False, user_response
@@ -300,11 +294,21 @@ def _truncate_arg(arg: str, length: int = 19) -> str:
300
294
  return normalized_arg
301
295
 
302
296
 
303
- async def _read_line():
297
+ async def _read_line(args: list[Any] | tuple[Any], kwargs: dict[str, Any]):
304
298
  from prompt_toolkit import PromptSession
305
299
 
300
+ options = ["yes", "no", "edit"]
301
+ meta_dict = {
302
+ "yes": "Approve the execution",
303
+ "no": "Disapprove the execution",
304
+ "edit": "Edit tool execution parameters",
305
+ }
306
+ for key in kwargs:
307
+ options.append(f"edit {key}")
308
+ meta_dict[f"edit {key}"] = f"Edit tool execution parameter: {key}"
309
+ completer = get_tool_confirmation_completer(options, meta_dict)
306
310
  reader = PromptSession()
307
- return await reader.prompt_async()
311
+ return await reader.prompt_async(completer=completer)
308
312
 
309
313
 
310
314
  def _adjust_signature(wrapper: Callable, original_sig: inspect.Signature):