zrb 1.15.3__py3-none-any.whl → 2.0.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of zrb might be problematic. Click here for more details.
- zrb/__init__.py +118 -133
- zrb/attr/type.py +10 -7
- zrb/builtin/__init__.py +55 -1
- zrb/builtin/git.py +12 -1
- zrb/builtin/group.py +31 -15
- zrb/builtin/llm/chat.py +147 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
- zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
- zrb/builtin/searxng/config/settings.yml +5671 -0
- zrb/builtin/searxng/start.py +21 -0
- zrb/builtin/shell/autocomplete/bash.py +4 -3
- zrb/builtin/shell/autocomplete/zsh.py +4 -3
- zrb/callback/callback.py +8 -1
- zrb/cmd/cmd_result.py +2 -1
- zrb/config/config.py +555 -169
- zrb/config/helper.py +84 -0
- zrb/config/web_auth_config.py +50 -35
- zrb/context/any_shared_context.py +20 -3
- zrb/context/context.py +39 -5
- zrb/context/print_fn.py +13 -0
- zrb/context/shared_context.py +17 -8
- zrb/group/any_group.py +3 -3
- zrb/group/group.py +3 -3
- zrb/input/any_input.py +5 -1
- zrb/input/base_input.py +18 -6
- zrb/input/option_input.py +41 -1
- zrb/input/text_input.py +7 -24
- zrb/llm/agent/__init__.py +9 -0
- zrb/llm/agent/agent.py +215 -0
- zrb/llm/agent/summarizer.py +20 -0
- zrb/llm/app/__init__.py +10 -0
- zrb/llm/app/completion.py +281 -0
- zrb/llm/app/confirmation/allow_tool.py +66 -0
- zrb/llm/app/confirmation/handler.py +178 -0
- zrb/llm/app/confirmation/replace_confirmation.py +77 -0
- zrb/llm/app/keybinding.py +34 -0
- zrb/llm/app/layout.py +117 -0
- zrb/llm/app/lexer.py +155 -0
- zrb/llm/app/redirection.py +28 -0
- zrb/llm/app/style.py +16 -0
- zrb/llm/app/ui.py +733 -0
- zrb/llm/config/__init__.py +4 -0
- zrb/llm/config/config.py +122 -0
- zrb/llm/config/limiter.py +247 -0
- zrb/llm/history_manager/__init__.py +4 -0
- zrb/llm/history_manager/any_history_manager.py +23 -0
- zrb/llm/history_manager/file_history_manager.py +91 -0
- zrb/llm/history_processor/summarizer.py +108 -0
- zrb/llm/note/__init__.py +3 -0
- zrb/llm/note/manager.py +122 -0
- zrb/llm/prompt/__init__.py +29 -0
- zrb/llm/prompt/claude_compatibility.py +92 -0
- zrb/llm/prompt/compose.py +55 -0
- zrb/llm/prompt/default.py +51 -0
- zrb/llm/prompt/markdown/file_extractor.md +112 -0
- zrb/llm/prompt/markdown/mandate.md +23 -0
- zrb/llm/prompt/markdown/persona.md +3 -0
- zrb/llm/prompt/markdown/repo_extractor.md +112 -0
- zrb/llm/prompt/markdown/repo_summarizer.md +29 -0
- zrb/llm/prompt/markdown/summarizer.md +21 -0
- zrb/llm/prompt/note.py +41 -0
- zrb/llm/prompt/system_context.py +46 -0
- zrb/llm/prompt/zrb.py +41 -0
- zrb/llm/skill/__init__.py +3 -0
- zrb/llm/skill/manager.py +86 -0
- zrb/llm/task/__init__.py +4 -0
- zrb/llm/task/llm_chat_task.py +316 -0
- zrb/llm/task/llm_task.py +245 -0
- zrb/llm/tool/__init__.py +39 -0
- zrb/llm/tool/bash.py +75 -0
- zrb/llm/tool/code.py +266 -0
- zrb/llm/tool/file.py +419 -0
- zrb/llm/tool/note.py +70 -0
- zrb/{builtin/llm → llm}/tool/rag.py +33 -37
- zrb/llm/tool/search/brave.py +53 -0
- zrb/llm/tool/search/searxng.py +47 -0
- zrb/llm/tool/search/serpapi.py +47 -0
- zrb/llm/tool/skill.py +19 -0
- zrb/llm/tool/sub_agent.py +70 -0
- zrb/llm/tool/web.py +97 -0
- zrb/llm/tool/zrb_task.py +66 -0
- zrb/llm/util/attachment.py +101 -0
- zrb/llm/util/prompt.py +104 -0
- zrb/llm/util/stream_response.py +178 -0
- zrb/runner/cli.py +21 -20
- zrb/runner/common_util.py +24 -19
- zrb/runner/web_route/task_input_api_route.py +5 -5
- zrb/runner/web_util/user.py +7 -3
- zrb/session/any_session.py +12 -9
- zrb/session/session.py +38 -17
- zrb/task/any_task.py +24 -3
- zrb/task/base/context.py +42 -22
- zrb/task/base/execution.py +67 -55
- zrb/task/base/lifecycle.py +14 -7
- zrb/task/base/monitoring.py +12 -7
- zrb/task/base_task.py +113 -50
- zrb/task/base_trigger.py +16 -6
- zrb/task/cmd_task.py +6 -0
- zrb/task/http_check.py +11 -5
- zrb/task/make_task.py +5 -3
- zrb/task/rsync_task.py +30 -10
- zrb/task/scaffolder.py +7 -4
- zrb/task/scheduler.py +7 -4
- zrb/task/tcp_check.py +6 -4
- zrb/util/ascii_art/art/bee.txt +17 -0
- zrb/util/ascii_art/art/cat.txt +9 -0
- zrb/util/ascii_art/art/ghost.txt +16 -0
- zrb/util/ascii_art/art/panda.txt +17 -0
- zrb/util/ascii_art/art/rose.txt +14 -0
- zrb/util/ascii_art/art/unicorn.txt +15 -0
- zrb/util/ascii_art/banner.py +92 -0
- zrb/util/attr.py +54 -39
- zrb/util/cli/markdown.py +32 -0
- zrb/util/cli/text.py +30 -0
- zrb/util/cmd/command.py +33 -10
- zrb/util/file.py +61 -33
- zrb/util/git.py +2 -2
- zrb/util/{llm/prompt.py → markdown.py} +2 -3
- zrb/util/match.py +78 -0
- zrb/util/run.py +3 -3
- zrb/util/string/conversion.py +1 -1
- zrb/util/truncate.py +23 -0
- zrb/util/yaml.py +204 -0
- zrb/xcom/xcom.py +10 -0
- {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/METADATA +41 -27
- {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/RECORD +129 -131
- {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/WHEEL +1 -1
- zrb/attr/__init__.py +0 -0
- zrb/builtin/llm/chat_session.py +0 -311
- zrb/builtin/llm/history.py +0 -71
- zrb/builtin/llm/input.py +0 -27
- zrb/builtin/llm/llm_ask.py +0 -187
- zrb/builtin/llm/previous-session.js +0 -21
- zrb/builtin/llm/tool/__init__.py +0 -0
- zrb/builtin/llm/tool/api.py +0 -71
- zrb/builtin/llm/tool/cli.py +0 -38
- zrb/builtin/llm/tool/code.py +0 -254
- zrb/builtin/llm/tool/file.py +0 -626
- zrb/builtin/llm/tool/sub_agent.py +0 -137
- zrb/builtin/llm/tool/web.py +0 -195
- zrb/builtin/project/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/my_module/service/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/__init__.py +0 -0
- zrb/builtin/project/create/__init__.py +0 -0
- zrb/builtin/shell/__init__.py +0 -0
- zrb/builtin/shell/autocomplete/__init__.py +0 -0
- zrb/callback/__init__.py +0 -0
- zrb/cmd/__init__.py +0 -0
- zrb/config/default_prompt/file_extractor_system_prompt.md +0 -12
- zrb/config/default_prompt/interactive_system_prompt.md +0 -35
- zrb/config/default_prompt/persona.md +0 -1
- zrb/config/default_prompt/repo_extractor_system_prompt.md +0 -112
- zrb/config/default_prompt/repo_summarizer_system_prompt.md +0 -10
- zrb/config/default_prompt/summarization_prompt.md +0 -16
- zrb/config/default_prompt/system_prompt.md +0 -32
- zrb/config/llm_config.py +0 -243
- zrb/config/llm_context/config.py +0 -129
- zrb/config/llm_context/config_parser.py +0 -46
- zrb/config/llm_rate_limitter.py +0 -137
- zrb/content_transformer/__init__.py +0 -0
- zrb/context/__init__.py +0 -0
- zrb/dot_dict/__init__.py +0 -0
- zrb/env/__init__.py +0 -0
- zrb/group/__init__.py +0 -0
- zrb/input/__init__.py +0 -0
- zrb/runner/__init__.py +0 -0
- zrb/runner/web_route/__init__.py +0 -0
- zrb/runner/web_route/home_page/__init__.py +0 -0
- zrb/session/__init__.py +0 -0
- zrb/session_state_log/__init__.py +0 -0
- zrb/session_state_logger/__init__.py +0 -0
- zrb/task/__init__.py +0 -0
- zrb/task/base/__init__.py +0 -0
- zrb/task/llm/__init__.py +0 -0
- zrb/task/llm/agent.py +0 -243
- zrb/task/llm/config.py +0 -103
- zrb/task/llm/conversation_history.py +0 -128
- zrb/task/llm/conversation_history_model.py +0 -242
- zrb/task/llm/default_workflow/coding.md +0 -24
- zrb/task/llm/default_workflow/copywriting.md +0 -17
- zrb/task/llm/default_workflow/researching.md +0 -18
- zrb/task/llm/error.py +0 -95
- zrb/task/llm/history_summarization.py +0 -216
- zrb/task/llm/print_node.py +0 -101
- zrb/task/llm/prompt.py +0 -325
- zrb/task/llm/tool_wrapper.py +0 -220
- zrb/task/llm/typing.py +0 -3
- zrb/task/llm_task.py +0 -341
- zrb/task_status/__init__.py +0 -0
- zrb/util/__init__.py +0 -0
- zrb/util/cli/__init__.py +0 -0
- zrb/util/cmd/__init__.py +0 -0
- zrb/util/codemod/__init__.py +0 -0
- zrb/util/string/__init__.py +0 -0
- zrb/xcom/__init__.py +0 -0
- {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/entry_points.txt +0 -0
zrb/llm/agent/agent.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
from contextvars import ContextVar
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
3
|
+
|
|
4
|
+
from zrb.llm.config.config import llm_config as default_llm_config
|
|
5
|
+
from zrb.llm.config.limiter import LLMLimiter
|
|
6
|
+
from zrb.llm.util.attachment import normalize_attachments
|
|
7
|
+
from zrb.llm.util.prompt import expand_prompt
|
|
8
|
+
|
|
9
|
+
# Context variable to propagate tool confirmation callback to sub-agents
|
|
10
|
+
tool_confirmation_var: ContextVar[Callable[[Any], Any] | None] = ContextVar(
|
|
11
|
+
"tool_confirmation", default=None
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from pydantic_ai import Agent, DeferredToolRequests, DeferredToolResults, Tool
|
|
16
|
+
from pydantic_ai._agent_graph import HistoryProcessor
|
|
17
|
+
from pydantic_ai.messages import UserPromptPart
|
|
18
|
+
from pydantic_ai.models import Model
|
|
19
|
+
from pydantic_ai.output import OutputDataT, OutputSpec
|
|
20
|
+
from pydantic_ai.settings import ModelSettings
|
|
21
|
+
from pydantic_ai.tools import ToolFuncEither
|
|
22
|
+
from pydantic_ai.toolsets import AbstractToolset
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def create_agent(
|
|
26
|
+
model: "Model | str | None" = None,
|
|
27
|
+
system_prompt: str = "",
|
|
28
|
+
tools: list["Tool | ToolFuncEither"] = [],
|
|
29
|
+
toolsets: list["AbstractToolset[None]"] = [],
|
|
30
|
+
model_settings: "ModelSettings | None" = None,
|
|
31
|
+
history_processors: list["HistoryProcessor"] | None = None,
|
|
32
|
+
output_type: "OutputSpec[OutputDataT]" = str,
|
|
33
|
+
retries: int = 1,
|
|
34
|
+
yolo: bool = False,
|
|
35
|
+
) -> "Agent[None, Any]":
|
|
36
|
+
from pydantic_ai import Agent, DeferredToolRequests
|
|
37
|
+
from pydantic_ai.toolsets import FunctionToolset
|
|
38
|
+
|
|
39
|
+
# Expand system prompt with references
|
|
40
|
+
effective_system_prompt = expand_prompt(system_prompt)
|
|
41
|
+
|
|
42
|
+
final_output_type = output_type
|
|
43
|
+
effective_toolsets = list(toolsets)
|
|
44
|
+
if tools:
|
|
45
|
+
effective_toolsets.append(FunctionToolset(tools=tools))
|
|
46
|
+
|
|
47
|
+
if not yolo:
|
|
48
|
+
final_output_type = output_type | DeferredToolRequests
|
|
49
|
+
effective_toolsets = [ts.approval_required() for ts in effective_toolsets]
|
|
50
|
+
|
|
51
|
+
if model is None:
|
|
52
|
+
model = default_llm_config.model
|
|
53
|
+
|
|
54
|
+
return Agent(
|
|
55
|
+
model=model,
|
|
56
|
+
output_type=final_output_type,
|
|
57
|
+
instructions=effective_system_prompt,
|
|
58
|
+
toolsets=effective_toolsets,
|
|
59
|
+
model_settings=model_settings,
|
|
60
|
+
history_processors=history_processors,
|
|
61
|
+
retries=retries,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
async def run_agent(
|
|
66
|
+
agent: "Agent[None, Any]",
|
|
67
|
+
message: str | None,
|
|
68
|
+
message_history: list[Any],
|
|
69
|
+
limiter: LLMLimiter,
|
|
70
|
+
attachments: list[Any] | None = None,
|
|
71
|
+
print_fn: Callable[[str], Any] = print,
|
|
72
|
+
event_handler: Callable[[Any], Any] | None = None,
|
|
73
|
+
tool_confirmation: Callable[[Any], Any] | None = None,
|
|
74
|
+
) -> tuple[Any, list[Any]]:
|
|
75
|
+
"""
|
|
76
|
+
Runs the agent with rate limiting, history management, and optional CLI confirmation loop.
|
|
77
|
+
Returns (result_output, new_message_history).
|
|
78
|
+
"""
|
|
79
|
+
import asyncio
|
|
80
|
+
|
|
81
|
+
from pydantic_ai import AgentRunResultEvent, DeferredToolRequests
|
|
82
|
+
|
|
83
|
+
# Resolve tool confirmation callback (Arg > Context > None)
|
|
84
|
+
effective_tool_confirmation = tool_confirmation
|
|
85
|
+
if effective_tool_confirmation is None:
|
|
86
|
+
effective_tool_confirmation = tool_confirmation_var.get()
|
|
87
|
+
|
|
88
|
+
# Set context var for sub-agents
|
|
89
|
+
token = tool_confirmation_var.set(effective_tool_confirmation)
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
# Expand user message with references
|
|
93
|
+
effective_message = expand_prompt(message) if message else message
|
|
94
|
+
|
|
95
|
+
# Prepare Prompt Content
|
|
96
|
+
prompt_content = _get_prompt_content(effective_message, attachments, print_fn)
|
|
97
|
+
|
|
98
|
+
# 1. Prune & Throttle
|
|
99
|
+
current_history = await _acquire_rate_limit(
|
|
100
|
+
limiter, prompt_content, message_history, print_fn
|
|
101
|
+
)
|
|
102
|
+
current_message = prompt_content
|
|
103
|
+
current_results = None
|
|
104
|
+
|
|
105
|
+
# 2. Execution Loop
|
|
106
|
+
while True:
|
|
107
|
+
result_output = None
|
|
108
|
+
run_history = []
|
|
109
|
+
|
|
110
|
+
async for event in agent.run_stream_events(
|
|
111
|
+
current_message,
|
|
112
|
+
message_history=current_history,
|
|
113
|
+
deferred_tool_results=current_results,
|
|
114
|
+
):
|
|
115
|
+
await asyncio.sleep(0)
|
|
116
|
+
if isinstance(event, AgentRunResultEvent):
|
|
117
|
+
result = event.result
|
|
118
|
+
result_output = result.output
|
|
119
|
+
run_history = result.all_messages()
|
|
120
|
+
if event_handler:
|
|
121
|
+
await event_handler(event)
|
|
122
|
+
|
|
123
|
+
# Handle Deferred Calls
|
|
124
|
+
if isinstance(result_output, DeferredToolRequests):
|
|
125
|
+
current_results = await _process_deferred_requests(
|
|
126
|
+
result_output, effective_tool_confirmation
|
|
127
|
+
)
|
|
128
|
+
if current_results is None:
|
|
129
|
+
return result_output, run_history
|
|
130
|
+
# Prepare next iteration
|
|
131
|
+
current_message = None
|
|
132
|
+
current_history = run_history
|
|
133
|
+
continue
|
|
134
|
+
return result_output, run_history
|
|
135
|
+
finally:
|
|
136
|
+
tool_confirmation_var.reset(token)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _get_prompt_content(
|
|
140
|
+
message: str | None, attachments: list[Any] | None, print_fn: Callable[[str], Any]
|
|
141
|
+
) -> "list[UserPromptPart] | str | None":
|
|
142
|
+
from pydantic_ai.messages import UserPromptPart
|
|
143
|
+
|
|
144
|
+
prompt_content = message
|
|
145
|
+
if attachments:
|
|
146
|
+
attachments = normalize_attachments(attachments, print_fn)
|
|
147
|
+
parts: list[UserPromptPart] = []
|
|
148
|
+
if message:
|
|
149
|
+
parts.append(UserPromptPart(content=message))
|
|
150
|
+
parts.extend(attachments)
|
|
151
|
+
prompt_content = parts
|
|
152
|
+
return prompt_content
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
async def _acquire_rate_limit(
|
|
156
|
+
limiter: LLMLimiter,
|
|
157
|
+
message: str | None,
|
|
158
|
+
message_history: list[Any],
|
|
159
|
+
print_fn: Callable[[str], Any],
|
|
160
|
+
) -> list[Any]:
|
|
161
|
+
"""Prunes history and waits if rate limits are exceeded."""
|
|
162
|
+
if not message:
|
|
163
|
+
return message_history
|
|
164
|
+
|
|
165
|
+
# Prune
|
|
166
|
+
pruned_history = limiter.fit_context_window(message_history, message)
|
|
167
|
+
|
|
168
|
+
# Throttle
|
|
169
|
+
est_tokens = limiter.count_tokens(pruned_history) + limiter.count_tokens(message)
|
|
170
|
+
await limiter.acquire(
|
|
171
|
+
est_tokens, notifier=lambda msg: print_fn(msg) if msg else None
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return pruned_history
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
async def _process_deferred_requests(
|
|
178
|
+
result_output: "DeferredToolRequests",
|
|
179
|
+
effective_tool_confirmation: Callable[[Any], Any] | None,
|
|
180
|
+
) -> "DeferredToolResults | None":
|
|
181
|
+
"""Handles tool approvals/denials via callback or CLI fallback."""
|
|
182
|
+
import asyncio
|
|
183
|
+
import inspect
|
|
184
|
+
|
|
185
|
+
from pydantic_ai import DeferredToolResults, ToolApproved, ToolDenied
|
|
186
|
+
|
|
187
|
+
all_requests = (result_output.calls or []) + (result_output.approvals or [])
|
|
188
|
+
if not all_requests:
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
current_results = DeferredToolResults()
|
|
192
|
+
|
|
193
|
+
for call in all_requests:
|
|
194
|
+
if effective_tool_confirmation:
|
|
195
|
+
res = effective_tool_confirmation(call)
|
|
196
|
+
if inspect.isawaitable(res):
|
|
197
|
+
result = await res
|
|
198
|
+
else:
|
|
199
|
+
result = res
|
|
200
|
+
current_results.approvals[call.tool_call_id] = result
|
|
201
|
+
else:
|
|
202
|
+
# CLI Fallback
|
|
203
|
+
prompt_text = f"Execute tool '{call.tool_name}' with args {call.args}?"
|
|
204
|
+
prompt_cli = f"\n[?] {prompt_text} (y/N) "
|
|
205
|
+
|
|
206
|
+
# We use asyncio.to_thread(input, ...) to avoid blocking the loop
|
|
207
|
+
user_input = await asyncio.to_thread(input, prompt_cli)
|
|
208
|
+
answer = user_input.strip().lower() in ("y", "yes")
|
|
209
|
+
|
|
210
|
+
if answer:
|
|
211
|
+
current_results.approvals[call.tool_call_id] = ToolApproved()
|
|
212
|
+
else:
|
|
213
|
+
current_results.approvals[call.tool_call_id] = ToolDenied("User denied")
|
|
214
|
+
|
|
215
|
+
return current_results
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
from zrb.llm.agent.agent import create_agent
|
|
4
|
+
from zrb.llm.prompt.default import get_summarizer_system_prompt
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from pydantic_ai import Agent
|
|
8
|
+
from pydantic_ai.models import Model
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def create_summarizer_agent(
|
|
12
|
+
model: "str | None | Model" = None,
|
|
13
|
+
system_prompt: str | None = None,
|
|
14
|
+
) -> "Agent[None, str]":
|
|
15
|
+
effective_system_prompt = system_prompt or get_summarizer_system_prompt()
|
|
16
|
+
|
|
17
|
+
return create_agent(
|
|
18
|
+
model=model,
|
|
19
|
+
system_prompt=effective_system_prompt,
|
|
20
|
+
)
|
zrb/llm/app/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from zrb.llm.app.confirmation.allow_tool import allow_tool_usage
|
|
2
|
+
from zrb.llm.app.confirmation.handler import ConfirmationMiddleware, last_confirmation
|
|
3
|
+
from zrb.llm.app.confirmation.replace_confirmation import replace_confirmation
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"allow_tool_usage",
|
|
7
|
+
"ConfirmationMiddleware",
|
|
8
|
+
"last_confirmation",
|
|
9
|
+
"replace_confirmation",
|
|
10
|
+
]
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Iterable
|
|
4
|
+
|
|
5
|
+
from prompt_toolkit.completion import (
|
|
6
|
+
CompleteEvent,
|
|
7
|
+
Completer,
|
|
8
|
+
Completion,
|
|
9
|
+
PathCompleter,
|
|
10
|
+
)
|
|
11
|
+
from prompt_toolkit.document import Document
|
|
12
|
+
|
|
13
|
+
from zrb.llm.history_manager.any_history_manager import AnyHistoryManager
|
|
14
|
+
from zrb.util.match import fuzzy_match
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InputCompleter(Completer):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
history_manager: AnyHistoryManager,
|
|
21
|
+
attach_commands: list[str] = [],
|
|
22
|
+
exit_commands: list[str] = [],
|
|
23
|
+
info_commands: list[str] = [],
|
|
24
|
+
save_commands: list[str] = [],
|
|
25
|
+
load_commands: list[str] = [],
|
|
26
|
+
redirect_output_commands: list[str] = [],
|
|
27
|
+
summarize_commands: list[str] = [],
|
|
28
|
+
exec_commands: list[str] = [],
|
|
29
|
+
):
|
|
30
|
+
self._history_manager = history_manager
|
|
31
|
+
self._attach_commands = attach_commands
|
|
32
|
+
self._exit_commands = exit_commands
|
|
33
|
+
self._info_commands = info_commands
|
|
34
|
+
self._save_commands = save_commands
|
|
35
|
+
self._load_commands = load_commands
|
|
36
|
+
self._redirect_output_commands = redirect_output_commands
|
|
37
|
+
self._summarize_commands = summarize_commands
|
|
38
|
+
self._exec_commands = exec_commands
|
|
39
|
+
# expanduser=True allows ~/path
|
|
40
|
+
self._path_completer = PathCompleter(expanduser=True)
|
|
41
|
+
# Cache for file listing to improve performance
|
|
42
|
+
self._file_cache: list[str] | None = None
|
|
43
|
+
self._file_cache_time = 0
|
|
44
|
+
self._cmd_history = self._get_cmd_history()
|
|
45
|
+
|
|
46
|
+
def get_completions(
|
|
47
|
+
self, document: Document, complete_event: CompleteEvent
|
|
48
|
+
) -> Iterable[Completion]:
|
|
49
|
+
text_before_cursor = document.text_before_cursor.lstrip()
|
|
50
|
+
word = document.get_word_before_cursor(WORD=True)
|
|
51
|
+
|
|
52
|
+
all_commands = (
|
|
53
|
+
self._exit_commands
|
|
54
|
+
+ self._attach_commands
|
|
55
|
+
+ self._summarize_commands
|
|
56
|
+
+ self._info_commands
|
|
57
|
+
+ self._save_commands
|
|
58
|
+
+ self._load_commands
|
|
59
|
+
+ self._redirect_output_commands
|
|
60
|
+
+ self._exec_commands
|
|
61
|
+
)
|
|
62
|
+
command_prefixes = {cmd[0] for cmd in all_commands if cmd}
|
|
63
|
+
|
|
64
|
+
# 1. Command and Argument Completion
|
|
65
|
+
if text_before_cursor and text_before_cursor[0] in command_prefixes:
|
|
66
|
+
parts = text_before_cursor.split()
|
|
67
|
+
# Check if we are typing the command itself or arguments
|
|
68
|
+
is_typing_command = len(parts) == 1 and not text_before_cursor.endswith(" ")
|
|
69
|
+
is_typing_arg = (len(parts) == 1 and text_before_cursor.endswith(" ")) or (
|
|
70
|
+
len(parts) >= 2
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
if is_typing_command:
|
|
74
|
+
lower_word = word.lower()
|
|
75
|
+
prefix = text_before_cursor[0]
|
|
76
|
+
for cmd in all_commands:
|
|
77
|
+
if cmd.startswith(prefix) and cmd.lower().startswith(lower_word):
|
|
78
|
+
yield Completion(cmd, start_position=-len(word))
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
if is_typing_arg:
|
|
82
|
+
cmd = parts[0]
|
|
83
|
+
arg_prefix = text_before_cursor[len(cmd) :].lstrip()
|
|
84
|
+
|
|
85
|
+
# Exec Command: Suggest History
|
|
86
|
+
if self._is_command(cmd, self._exec_commands):
|
|
87
|
+
# Filter history
|
|
88
|
+
matches = [h for h in self._cmd_history if h.startswith(arg_prefix)]
|
|
89
|
+
# Sort matches by length (shorter first) as heuristic? Or just recent?
|
|
90
|
+
# Since _cmd_history is set (unique), we lose order.
|
|
91
|
+
# But Python 3.7+ dicts preserve insertion order, so if we used dict keys, we kept order.
|
|
92
|
+
# Let's assume _get_cmd_history returns recent last.
|
|
93
|
+
# We reverse to show most recent first.
|
|
94
|
+
for h in reversed(matches):
|
|
95
|
+
yield Completion(h, start_position=-len(arg_prefix))
|
|
96
|
+
return
|
|
97
|
+
|
|
98
|
+
# Check if we are typing the second part (argument) strictly
|
|
99
|
+
# (Re-evaluating logic for other commands which only take 1 arg usually)
|
|
100
|
+
if not (
|
|
101
|
+
(len(parts) == 1 and text_before_cursor.endswith(" "))
|
|
102
|
+
or (len(parts) == 2 and not text_before_cursor.endswith(" "))
|
|
103
|
+
):
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
arg_prefix = parts[1] if len(parts) == 2 else ""
|
|
107
|
+
|
|
108
|
+
# Save Command: Suggest Timestamp
|
|
109
|
+
if self._is_command(cmd, self._save_commands):
|
|
110
|
+
ts = datetime.now().strftime("%Y-%m-%d-%H-%M")
|
|
111
|
+
if ts.startswith(arg_prefix):
|
|
112
|
+
yield Completion(ts, start_position=-len(arg_prefix))
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
# Redirect Command: Suggest Timestamp.txt
|
|
116
|
+
if self._is_command(cmd, self._redirect_output_commands):
|
|
117
|
+
ts = datetime.now().strftime("%Y-%m-%d-%H-%M.txt")
|
|
118
|
+
if ts.startswith(arg_prefix):
|
|
119
|
+
yield Completion(ts, start_position=-len(arg_prefix))
|
|
120
|
+
return
|
|
121
|
+
|
|
122
|
+
# Load Command: Search History
|
|
123
|
+
if self._is_command(cmd, self._load_commands):
|
|
124
|
+
results = self._history_manager.search(arg_prefix)
|
|
125
|
+
for res in results[:10]:
|
|
126
|
+
yield Completion(res, start_position=-len(arg_prefix))
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
# Attach Command: Suggest Files
|
|
130
|
+
if self._is_command(cmd, self._attach_commands):
|
|
131
|
+
yield from self._get_file_completions(
|
|
132
|
+
arg_prefix, complete_event, only_files=True
|
|
133
|
+
)
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
# Other commands (Exit, Info, Summarize) need no completion
|
|
137
|
+
return
|
|
138
|
+
|
|
139
|
+
# 2. File Completion (@)
|
|
140
|
+
if word.startswith("@"):
|
|
141
|
+
path_part = word[1:]
|
|
142
|
+
yield from self._get_file_completions(
|
|
143
|
+
path_part, complete_event, only_files=False
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def _get_cmd_history(self) -> list[str]:
|
|
147
|
+
history_files = [
|
|
148
|
+
os.path.expanduser("~/.bash_history"),
|
|
149
|
+
os.path.expanduser("~/.zsh_history"),
|
|
150
|
+
]
|
|
151
|
+
unique_cmds = {} # Use dict to preserve order (insertion order)
|
|
152
|
+
|
|
153
|
+
for hist_file in history_files:
|
|
154
|
+
if not os.path.exists(hist_file):
|
|
155
|
+
continue
|
|
156
|
+
try:
|
|
157
|
+
with open(hist_file, "r", errors="ignore") as f:
|
|
158
|
+
for line in f:
|
|
159
|
+
line = line.strip()
|
|
160
|
+
if not line:
|
|
161
|
+
continue
|
|
162
|
+
# Handle zsh timestamp format: : 1612345678:0;command
|
|
163
|
+
if line.startswith(": ") and ";" in line:
|
|
164
|
+
parts = line.split(";", 1)
|
|
165
|
+
if len(parts) == 2:
|
|
166
|
+
line = parts[1]
|
|
167
|
+
|
|
168
|
+
if line:
|
|
169
|
+
# Remove existing to update position to end (most recent)
|
|
170
|
+
if line in unique_cmds:
|
|
171
|
+
del unique_cmds[line]
|
|
172
|
+
unique_cmds[line] = None
|
|
173
|
+
except Exception:
|
|
174
|
+
pass
|
|
175
|
+
|
|
176
|
+
return list(unique_cmds.keys())
|
|
177
|
+
|
|
178
|
+
def _is_command(self, cmd: str, cmd_list: list[str]) -> bool:
|
|
179
|
+
return cmd.lower() in [c.lower() for c in cmd_list]
|
|
180
|
+
|
|
181
|
+
def _get_file_completions(
|
|
182
|
+
self, text: str, complete_event: CompleteEvent, only_files: bool = False
|
|
183
|
+
) -> Iterable[Completion]:
|
|
184
|
+
# Logic:
|
|
185
|
+
# - If text indicates path traversal (/, ., ~), use PathCompleter
|
|
186
|
+
# - Else, check file count. If < 5000, use Fuzzy. Else use PathCompleter.
|
|
187
|
+
|
|
188
|
+
if self._is_path_navigation(text):
|
|
189
|
+
yield from self._get_path_completions(text, complete_event, only_files)
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
# Count files (cached strategy could be added here if needed)
|
|
193
|
+
files = self._get_recursive_files(limit=5000)
|
|
194
|
+
if len(files) < 5000:
|
|
195
|
+
# Fuzzy Match
|
|
196
|
+
yield from self._get_fuzzy_completions(text, files, only_files)
|
|
197
|
+
else:
|
|
198
|
+
# Fallback to PathCompleter for large repos
|
|
199
|
+
yield from self._get_path_completions(text, complete_event, only_files)
|
|
200
|
+
|
|
201
|
+
def _is_path_navigation(self, text: str) -> bool:
|
|
202
|
+
return (
|
|
203
|
+
text.startswith("/")
|
|
204
|
+
or text.startswith(".")
|
|
205
|
+
or text.startswith("~")
|
|
206
|
+
or os.sep in text
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def _get_path_completions(
|
|
210
|
+
self, text: str, complete_event: CompleteEvent, only_files: bool
|
|
211
|
+
) -> Iterable[Completion]:
|
|
212
|
+
# PathCompleter needs a document where text represents the path
|
|
213
|
+
fake_document = Document(text=text, cursor_position=len(text))
|
|
214
|
+
for c in self._path_completer.get_completions(fake_document, complete_event):
|
|
215
|
+
if only_files:
|
|
216
|
+
# Check if the completed path is a directory
|
|
217
|
+
# Note: 'text' is the prefix. c.text is the completion suffix.
|
|
218
|
+
# We need to reconstruct full path to check isdir
|
|
219
|
+
# This is tricky with PathCompleter's internal logic.
|
|
220
|
+
# A simple heuristic: if it ends with path separator, it's a dir.
|
|
221
|
+
if c.text.endswith(os.sep):
|
|
222
|
+
continue
|
|
223
|
+
yield c
|
|
224
|
+
|
|
225
|
+
def _get_fuzzy_completions(
|
|
226
|
+
self, text: str, files: list[str], only_files: bool
|
|
227
|
+
) -> Iterable[Completion]:
|
|
228
|
+
matches = []
|
|
229
|
+
for f in files:
|
|
230
|
+
if only_files and f.endswith(os.sep):
|
|
231
|
+
continue
|
|
232
|
+
is_match, score = fuzzy_match(f, text)
|
|
233
|
+
if is_match:
|
|
234
|
+
matches.append((score, f))
|
|
235
|
+
|
|
236
|
+
# Sort by score (lower is better)
|
|
237
|
+
matches.sort(key=lambda x: x[0])
|
|
238
|
+
|
|
239
|
+
# Return top 20
|
|
240
|
+
for _, f in matches[:20]:
|
|
241
|
+
yield Completion(f, start_position=-len(text))
|
|
242
|
+
|
|
243
|
+
def _get_recursive_files(self, root: str = ".", limit: int = 5000) -> list[str]:
|
|
244
|
+
# Simple walker with exclusions
|
|
245
|
+
paths = []
|
|
246
|
+
# Check if current dir is hidden
|
|
247
|
+
cwd_is_hidden = os.path.basename(os.path.abspath(root)).startswith(".")
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
for dirpath, dirnames, filenames in os.walk(root):
|
|
251
|
+
# Exclude hidden directories unless root is hidden
|
|
252
|
+
if not cwd_is_hidden:
|
|
253
|
+
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
|
254
|
+
|
|
255
|
+
# Exclude common ignores
|
|
256
|
+
dirnames[:] = [
|
|
257
|
+
d
|
|
258
|
+
for d in dirnames
|
|
259
|
+
if d not in ("node_modules", "__pycache__", "venv", ".venv")
|
|
260
|
+
]
|
|
261
|
+
|
|
262
|
+
rel_dir = os.path.relpath(dirpath, root)
|
|
263
|
+
if rel_dir == ".":
|
|
264
|
+
rel_dir = ""
|
|
265
|
+
|
|
266
|
+
# Add directories
|
|
267
|
+
for d in dirnames:
|
|
268
|
+
paths.append(os.path.join(rel_dir, d) + os.sep)
|
|
269
|
+
if len(paths) >= limit:
|
|
270
|
+
return paths
|
|
271
|
+
|
|
272
|
+
# Add files
|
|
273
|
+
for f in filenames:
|
|
274
|
+
if not cwd_is_hidden and f.startswith("."):
|
|
275
|
+
continue
|
|
276
|
+
paths.append(os.path.join(rel_dir, f))
|
|
277
|
+
if len(paths) >= limit:
|
|
278
|
+
return paths
|
|
279
|
+
except Exception:
|
|
280
|
+
pass
|
|
281
|
+
return paths
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
from typing import Any, Awaitable, Callable, Dict, Optional
|
|
4
|
+
|
|
5
|
+
from zrb.llm.app.confirmation.handler import ConfirmationMiddleware, UIProtocol
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def allow_tool_usage(
|
|
9
|
+
tool_name: str, kwargs: Optional[Dict[str, str]] = None
|
|
10
|
+
) -> ConfirmationMiddleware:
|
|
11
|
+
"""
|
|
12
|
+
Creates a confirmation middleware that automatically approves a tool execution
|
|
13
|
+
if it matches the specified tool_name and argument constraints.
|
|
14
|
+
|
|
15
|
+
:param tool_name: The name of the tool to allow.
|
|
16
|
+
:param kwargs: A dictionary of regex patterns for arguments.
|
|
17
|
+
If None or empty, the tool is allowed regardless of arguments.
|
|
18
|
+
If provided, arguments in the tool call must match the regex patterns
|
|
19
|
+
specified in kwargs (only for arguments present in both).
|
|
20
|
+
:return: A ConfirmationMiddleware function.
|
|
21
|
+
"""
|
|
22
|
+
from pydantic_ai import ToolApproved
|
|
23
|
+
|
|
24
|
+
async def middleware(
|
|
25
|
+
ui: UIProtocol,
|
|
26
|
+
call: Any,
|
|
27
|
+
response: str,
|
|
28
|
+
next_handler: Callable[[UIProtocol, Any, str], Awaitable[Any]],
|
|
29
|
+
) -> Any:
|
|
30
|
+
# Check if tool name matches
|
|
31
|
+
if call.tool_name != tool_name:
|
|
32
|
+
return await next_handler(ui, call, response)
|
|
33
|
+
|
|
34
|
+
# If kwargs is empty or None, approve
|
|
35
|
+
if not kwargs:
|
|
36
|
+
ui.append_to_output(f"\n✅ Auto-approved tool: {tool_name}")
|
|
37
|
+
return ToolApproved()
|
|
38
|
+
|
|
39
|
+
# Parse arguments
|
|
40
|
+
try:
|
|
41
|
+
args = call.args
|
|
42
|
+
if isinstance(args, str):
|
|
43
|
+
args = json.loads(args)
|
|
44
|
+
|
|
45
|
+
if not isinstance(args, dict):
|
|
46
|
+
# If args is not a dict (e.g. primitive), and kwargs is not empty,
|
|
47
|
+
# we assume it doesn't match complex constraints (or we can't check keys).
|
|
48
|
+
# So we delegate to the next handler.
|
|
49
|
+
return await next_handler(ui, call, response)
|
|
50
|
+
|
|
51
|
+
except (json.JSONDecodeError, ValueError):
|
|
52
|
+
return await next_handler(ui, call, response)
|
|
53
|
+
|
|
54
|
+
# Check constraints
|
|
55
|
+
# "all parameter in the call parameter has to match the ones in kwargs (if that parameter defined in the kwargs)"
|
|
56
|
+
for arg_name, arg_value in args.items():
|
|
57
|
+
if arg_name in kwargs:
|
|
58
|
+
pattern = kwargs[arg_name]
|
|
59
|
+
# Convert arg_value to string for regex matching
|
|
60
|
+
if not re.search(pattern, str(arg_value)):
|
|
61
|
+
return await next_handler(ui, call, response)
|
|
62
|
+
|
|
63
|
+
ui.append_to_output(f"\n✅ Auto-approved tool: {tool_name} with matching args")
|
|
64
|
+
return ToolApproved()
|
|
65
|
+
|
|
66
|
+
return middleware
|