zrb 1.21.17__py3-none-any.whl → 1.21.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/attr/type.py +10 -7
- zrb/builtin/git.py +12 -1
- zrb/builtin/llm/chat_completion.py +287 -0
- zrb/builtin/llm/chat_session_cmd.py +90 -28
- zrb/builtin/llm/chat_trigger.py +6 -1
- zrb/builtin/llm/tool/cli.py +29 -13
- zrb/builtin/llm/tool/code.py +9 -1
- zrb/builtin/llm/tool/file.py +32 -6
- zrb/builtin/llm/tool/note.py +9 -9
- zrb/builtin/llm/tool/search/__init__.py +1 -0
- zrb/builtin/llm/tool/search/brave.py +66 -0
- zrb/builtin/llm/tool/search/searxng.py +61 -0
- zrb/builtin/llm/tool/search/serpapi.py +61 -0
- zrb/builtin/llm/tool/sub_agent.py +30 -10
- zrb/builtin/llm/tool/web.py +17 -72
- zrb/config/config.py +67 -26
- zrb/config/default_prompt/interactive_system_prompt.md +16 -13
- zrb/config/default_prompt/summarization_prompt.md +54 -8
- zrb/config/default_prompt/system_prompt.md +16 -18
- zrb/config/llm_rate_limitter.py +15 -6
- zrb/input/option_input.py +13 -1
- zrb/task/llm/agent.py +42 -143
- zrb/task/llm/agent_runner.py +152 -0
- zrb/task/llm/conversation_history.py +35 -24
- zrb/task/llm/conversation_history_model.py +4 -11
- zrb/task/llm/history_processor.py +206 -0
- zrb/task/llm/history_summarization.py +2 -179
- zrb/task/llm/print_node.py +14 -5
- zrb/task/llm/prompt.py +2 -17
- zrb/task/llm/subagent_conversation_history.py +41 -0
- zrb/task/llm/tool_confirmation_completer.py +41 -0
- zrb/task/llm/tool_wrapper.py +15 -11
- zrb/task/llm_task.py +41 -40
- zrb/util/attr.py +12 -7
- zrb/util/git.py +2 -2
- zrb/xcom/xcom.py +10 -0
- {zrb-1.21.17.dist-info → zrb-1.21.33.dist-info}/METADATA +3 -3
- {zrb-1.21.17.dist-info → zrb-1.21.33.dist-info}/RECORD +40 -32
- zrb/task/llm/history_summarization_tool.py +0 -24
- {zrb-1.21.17.dist-info → zrb-1.21.33.dist-info}/WHEEL +0 -0
- {zrb-1.21.17.dist-info → zrb-1.21.33.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import sys
|
|
3
|
+
import traceback
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Coroutine
|
|
5
|
+
|
|
6
|
+
from zrb.config.llm_config import llm_config
|
|
7
|
+
from zrb.config.llm_rate_limitter import LLMRateLimitter
|
|
8
|
+
from zrb.config.llm_rate_limitter import llm_rate_limitter as default_llm_rate_limitter
|
|
9
|
+
from zrb.context.any_context import AnyContext
|
|
10
|
+
from zrb.task.llm.agent_runner import run_agent_iteration
|
|
11
|
+
from zrb.util.cli.style import stylize_faint
|
|
12
|
+
from zrb.util.markdown import make_markdown_section
|
|
13
|
+
|
|
14
|
+
if sys.version_info >= (3, 12):
|
|
15
|
+
from typing import TypedDict
|
|
16
|
+
else:
|
|
17
|
+
from typing_extensions import TypedDict
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from pydantic_ai import ModelMessage
|
|
22
|
+
from pydantic_ai.models import Model
|
|
23
|
+
from pydantic_ai.settings import ModelSettings
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SingleMessage(TypedDict):
|
|
27
|
+
"""
|
|
28
|
+
SingleConversation
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
role: Either AI, User, Tool Call, or Tool Result
|
|
32
|
+
time: yyyy-mm-ddTHH:MM:SSZ:
|
|
33
|
+
content: The content of the message (summarize if too long)
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
role: str
|
|
37
|
+
time: str
|
|
38
|
+
content: str
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ConversationSummary(TypedDict):
|
|
42
|
+
"""
|
|
43
|
+
Conversation history
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
transcript: Several last transcript of the conversation
|
|
47
|
+
summary: Descriptive conversation summary
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
transcript: list[SingleMessage]
|
|
51
|
+
summary: str
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def save_conversation_summary(conversation_summary: ConversationSummary):
|
|
55
|
+
"""
|
|
56
|
+
Write conversation summary for main assistant to continue conversation.
|
|
57
|
+
"""
|
|
58
|
+
return conversation_summary
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def create_summarize_history_processor(
|
|
62
|
+
ctx: AnyContext,
|
|
63
|
+
system_prompt: str,
|
|
64
|
+
rate_limitter: LLMRateLimitter | None = None,
|
|
65
|
+
summarization_model: "Model | str | None" = None,
|
|
66
|
+
summarization_model_settings: "ModelSettings | None" = None,
|
|
67
|
+
summarization_system_prompt: str | None = None,
|
|
68
|
+
summarization_token_threshold: int | None = None,
|
|
69
|
+
summarization_retries: int = 2,
|
|
70
|
+
) -> Callable[[list["ModelMessage"]], Coroutine[None, None, list["ModelMessage"]]]:
|
|
71
|
+
from pydantic_ai import Agent, ModelMessage, ModelRequest
|
|
72
|
+
from pydantic_ai.messages import ModelMessagesTypeAdapter, UserPromptPart
|
|
73
|
+
|
|
74
|
+
if rate_limitter is None:
|
|
75
|
+
rate_limitter = default_llm_rate_limitter
|
|
76
|
+
if summarization_model is None:
|
|
77
|
+
summarization_model = llm_config.default_small_model
|
|
78
|
+
if summarization_model_settings is None:
|
|
79
|
+
summarization_model_settings = llm_config.default_small_model_settings
|
|
80
|
+
if summarization_system_prompt is None:
|
|
81
|
+
summarization_system_prompt = llm_config.default_summarization_prompt
|
|
82
|
+
if summarization_token_threshold is None:
|
|
83
|
+
summarization_token_threshold = (
|
|
84
|
+
llm_config.default_history_summarization_token_threshold
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
async def maybe_summarize_history(
|
|
88
|
+
messages: list[ModelMessage],
|
|
89
|
+
) -> list[ModelMessage]:
|
|
90
|
+
history_list = json.loads(ModelMessagesTypeAdapter.dump_json(messages))
|
|
91
|
+
history_json_str = json.dumps(history_list)
|
|
92
|
+
# Estimate token usage
|
|
93
|
+
# Note: Pydantic ai has run context parameter
|
|
94
|
+
# (https://ai.pydantic.dev/message-history/#runcontext-parameter)
|
|
95
|
+
# But we cannot use run_ctx.usage.total_tokens because total token keep increasing
|
|
96
|
+
# even after summariztion.
|
|
97
|
+
estimated_token_usage = rate_limitter.count_token(history_json_str)
|
|
98
|
+
_print_request_info(
|
|
99
|
+
ctx, estimated_token_usage, summarization_token_threshold, messages
|
|
100
|
+
)
|
|
101
|
+
if estimated_token_usage < summarization_token_threshold or len(messages) == 1:
|
|
102
|
+
return messages
|
|
103
|
+
history_list_without_instruction = [
|
|
104
|
+
{
|
|
105
|
+
key: obj[key]
|
|
106
|
+
for key in obj
|
|
107
|
+
if index == len(history_list) - 1 or key != "instructions"
|
|
108
|
+
}
|
|
109
|
+
for index, obj in enumerate(history_list)
|
|
110
|
+
]
|
|
111
|
+
history_json_str_without_instruction = json.dumps(
|
|
112
|
+
history_list_without_instruction
|
|
113
|
+
)
|
|
114
|
+
summarization_message = f"Summarize the following conversation: {history_json_str_without_instruction}"
|
|
115
|
+
summarization_agent = Agent[None, ConversationSummary](
|
|
116
|
+
model=summarization_model,
|
|
117
|
+
output_type=save_conversation_summary,
|
|
118
|
+
instructions=summarization_system_prompt,
|
|
119
|
+
model_settings=summarization_model_settings,
|
|
120
|
+
retries=summarization_retries,
|
|
121
|
+
)
|
|
122
|
+
try:
|
|
123
|
+
_print_info(ctx, "📝 Rollup Conversation", 2)
|
|
124
|
+
summary_run = await run_agent_iteration(
|
|
125
|
+
ctx=ctx,
|
|
126
|
+
agent=summarization_agent,
|
|
127
|
+
user_prompt=summarization_message,
|
|
128
|
+
attachments=[],
|
|
129
|
+
history_list=[],
|
|
130
|
+
rate_limitter=rate_limitter,
|
|
131
|
+
log_indent_level=2,
|
|
132
|
+
)
|
|
133
|
+
if summary_run and summary_run.result and summary_run.result.output:
|
|
134
|
+
usage = summary_run.result.usage()
|
|
135
|
+
_print_info(ctx, f"📝 Rollup Conversation Token: {usage}", 2)
|
|
136
|
+
ctx.print(plain=True)
|
|
137
|
+
ctx.log_info("History summarized and updated.")
|
|
138
|
+
condensed_message = make_markdown_section(
|
|
139
|
+
header="Past Conversation",
|
|
140
|
+
content="\n".join(
|
|
141
|
+
[
|
|
142
|
+
make_markdown_section(
|
|
143
|
+
"Summary", _extract_summary(summary_run.result.output)
|
|
144
|
+
),
|
|
145
|
+
make_markdown_section(
|
|
146
|
+
"Past Trancript",
|
|
147
|
+
_extract_transcript(summary_run.result.output),
|
|
148
|
+
),
|
|
149
|
+
]
|
|
150
|
+
),
|
|
151
|
+
)
|
|
152
|
+
return [
|
|
153
|
+
ModelRequest(
|
|
154
|
+
instructions=system_prompt,
|
|
155
|
+
parts=[UserPromptPart(condensed_message)],
|
|
156
|
+
)
|
|
157
|
+
]
|
|
158
|
+
ctx.log_warning("History summarization failed or returned no data.")
|
|
159
|
+
except BaseException as e:
|
|
160
|
+
ctx.log_warning(f"Error during history summarization: {e}")
|
|
161
|
+
traceback.print_exc()
|
|
162
|
+
return messages
|
|
163
|
+
|
|
164
|
+
return maybe_summarize_history
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _print_request_info(
|
|
168
|
+
ctx: AnyContext,
|
|
169
|
+
estimated_token_usage: int,
|
|
170
|
+
summarization_token_threshold: int,
|
|
171
|
+
messages: list["ModelMessage"],
|
|
172
|
+
):
|
|
173
|
+
_print_info(ctx, f"Current request token (estimated): {estimated_token_usage}")
|
|
174
|
+
_print_info(ctx, f"Summarization token threshold: {summarization_token_threshold}")
|
|
175
|
+
_print_info(ctx, f"History length: {len(messages)}")
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _print_info(ctx: AnyContext, text: str, log_indent_level: int = 0):
|
|
179
|
+
log_prefix = (2 * (log_indent_level + 1)) * " "
|
|
180
|
+
ctx.print(stylize_faint(f"{log_prefix}{text}"), plain=True)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _extract_summary(summary_result_output: dict[str, Any] | str) -> str:
|
|
184
|
+
summary = (
|
|
185
|
+
summary_result_output.get("summary", "")
|
|
186
|
+
if isinstance(summary_result_output, dict)
|
|
187
|
+
else ""
|
|
188
|
+
)
|
|
189
|
+
return summary
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _extract_transcript(summary_result_output: dict[str, Any] | str) -> str:
|
|
193
|
+
transcript_list = (
|
|
194
|
+
summary_result_output.get("transcript", [])
|
|
195
|
+
if isinstance(summary_result_output, dict)
|
|
196
|
+
else []
|
|
197
|
+
)
|
|
198
|
+
transcript_list = [] if not isinstance(transcript_list, list) else transcript_list
|
|
199
|
+
return "\n".join(_format_transcript_message(message) for message in transcript_list)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _format_transcript_message(message: dict[str, str]) -> str:
|
|
203
|
+
role = message.get("role", "Message")
|
|
204
|
+
time = message.get("time", "<unknown>")
|
|
205
|
+
content = message.get("content", "<empty>")
|
|
206
|
+
return f"{role} ({time}): {content}"
|
|
@@ -1,36 +1,7 @@
|
|
|
1
|
-
import
|
|
2
|
-
import traceback
|
|
3
|
-
from typing import TYPE_CHECKING
|
|
4
|
-
|
|
5
|
-
from zrb.attr.type import BoolAttr, IntAttr
|
|
1
|
+
from zrb.attr.type import IntAttr
|
|
6
2
|
from zrb.config.llm_config import llm_config
|
|
7
|
-
from zrb.config.llm_rate_limitter import LLMRateLimiter, llm_rate_limitter
|
|
8
3
|
from zrb.context.any_context import AnyContext
|
|
9
|
-
from zrb.
|
|
10
|
-
from zrb.task.llm.conversation_history import (
|
|
11
|
-
count_part_in_history_list,
|
|
12
|
-
inject_conversation_history_notes,
|
|
13
|
-
replace_system_prompt_in_history,
|
|
14
|
-
)
|
|
15
|
-
from zrb.task.llm.conversation_history_model import ConversationHistory
|
|
16
|
-
from zrb.task.llm.history_summarization_tool import (
|
|
17
|
-
create_history_summarization_tool,
|
|
18
|
-
)
|
|
19
|
-
from zrb.task.llm.typing import ListOfDict
|
|
20
|
-
from zrb.util.attr import get_bool_attr, get_int_attr
|
|
21
|
-
from zrb.util.cli.style import stylize_faint
|
|
22
|
-
from zrb.util.markdown import make_markdown_section
|
|
23
|
-
from zrb.util.truncate import truncate_str
|
|
24
|
-
|
|
25
|
-
if TYPE_CHECKING:
|
|
26
|
-
from pydantic_ai.models import Model
|
|
27
|
-
from pydantic_ai.settings import ModelSettings
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def _count_token_in_history(history_list: ListOfDict) -> int:
|
|
31
|
-
"""Counts the total number of tokens in a conversation history list."""
|
|
32
|
-
text_to_count = json.dumps(history_list)
|
|
33
|
-
return llm_rate_limitter.count_token(text_to_count)
|
|
4
|
+
from zrb.util.attr import get_int_attr
|
|
34
5
|
|
|
35
6
|
|
|
36
7
|
def get_history_summarization_token_threshold(
|
|
@@ -52,151 +23,3 @@ def get_history_summarization_token_threshold(
|
|
|
52
23
|
"Defaulting to -1 (no threshold)."
|
|
53
24
|
)
|
|
54
25
|
return -1
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def should_summarize_history(
|
|
58
|
-
ctx: AnyContext,
|
|
59
|
-
history_list: ListOfDict,
|
|
60
|
-
should_summarize_history_attr: BoolAttr | None,
|
|
61
|
-
render_summarize_history: bool,
|
|
62
|
-
history_summarization_token_threshold_attr: IntAttr | None,
|
|
63
|
-
render_history_summarization_token_threshold: bool,
|
|
64
|
-
) -> bool:
|
|
65
|
-
"""Determines if history summarization should occur based on token length and config."""
|
|
66
|
-
history_part_count = count_part_in_history_list(history_list)
|
|
67
|
-
if history_part_count == 0:
|
|
68
|
-
return False
|
|
69
|
-
summarization_token_threshold = get_history_summarization_token_threshold(
|
|
70
|
-
ctx,
|
|
71
|
-
history_summarization_token_threshold_attr,
|
|
72
|
-
render_history_summarization_token_threshold,
|
|
73
|
-
)
|
|
74
|
-
history_token_count = _count_token_in_history(history_list)
|
|
75
|
-
if (
|
|
76
|
-
summarization_token_threshold == -1
|
|
77
|
-
or summarization_token_threshold > history_token_count
|
|
78
|
-
):
|
|
79
|
-
return False
|
|
80
|
-
return get_bool_attr(
|
|
81
|
-
ctx,
|
|
82
|
-
should_summarize_history_attr,
|
|
83
|
-
llm_config.default_summarize_history,
|
|
84
|
-
auto_render=render_summarize_history,
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
async def summarize_history(
|
|
89
|
-
ctx: AnyContext,
|
|
90
|
-
model: "Model | str | None",
|
|
91
|
-
settings: "ModelSettings | None",
|
|
92
|
-
system_prompt: str,
|
|
93
|
-
conversation_history: ConversationHistory,
|
|
94
|
-
rate_limitter: LLMRateLimiter | None = None,
|
|
95
|
-
retries: int = 3,
|
|
96
|
-
) -> ConversationHistory:
|
|
97
|
-
"""Runs an LLM call to update the conversation summary."""
|
|
98
|
-
from pydantic_ai import Agent
|
|
99
|
-
|
|
100
|
-
inject_conversation_history_notes(conversation_history)
|
|
101
|
-
ctx.log_info("Attempting to summarize conversation history...")
|
|
102
|
-
# Construct the user prompt for the summarization agent
|
|
103
|
-
user_prompt = "\n".join(
|
|
104
|
-
[
|
|
105
|
-
make_markdown_section(
|
|
106
|
-
"Past Conversation",
|
|
107
|
-
"\n".join(
|
|
108
|
-
[
|
|
109
|
-
make_markdown_section(
|
|
110
|
-
"Summary",
|
|
111
|
-
conversation_history.past_conversation_summary,
|
|
112
|
-
as_code=True,
|
|
113
|
-
),
|
|
114
|
-
make_markdown_section(
|
|
115
|
-
"Last Transcript",
|
|
116
|
-
conversation_history.past_conversation_transcript,
|
|
117
|
-
as_code=True,
|
|
118
|
-
),
|
|
119
|
-
]
|
|
120
|
-
),
|
|
121
|
-
),
|
|
122
|
-
make_markdown_section(
|
|
123
|
-
"Recent Conversation (JSON)",
|
|
124
|
-
json.dumps(truncate_str(conversation_history.history, 1000)),
|
|
125
|
-
as_code=True,
|
|
126
|
-
),
|
|
127
|
-
]
|
|
128
|
-
)
|
|
129
|
-
summarize = create_history_summarization_tool(conversation_history)
|
|
130
|
-
summarization_agent = Agent[None, str](
|
|
131
|
-
model=model,
|
|
132
|
-
output_type=summarize,
|
|
133
|
-
system_prompt=system_prompt,
|
|
134
|
-
model_settings=settings,
|
|
135
|
-
retries=retries,
|
|
136
|
-
)
|
|
137
|
-
try:
|
|
138
|
-
ctx.print(stylize_faint(" 📝 Rollup Conversation"), plain=True)
|
|
139
|
-
summary_run = await run_agent_iteration(
|
|
140
|
-
ctx=ctx,
|
|
141
|
-
agent=summarization_agent,
|
|
142
|
-
user_prompt=user_prompt,
|
|
143
|
-
attachments=[],
|
|
144
|
-
history_list=[],
|
|
145
|
-
rate_limitter=rate_limitter,
|
|
146
|
-
log_indent_level=2,
|
|
147
|
-
)
|
|
148
|
-
if summary_run and summary_run.result and summary_run.result.output:
|
|
149
|
-
usage = summary_run.result.usage()
|
|
150
|
-
ctx.print(
|
|
151
|
-
stylize_faint(f" 📝 Rollup Conversation Token: {usage}"), plain=True
|
|
152
|
-
)
|
|
153
|
-
ctx.print(plain=True)
|
|
154
|
-
ctx.log_info("History summarized and updated.")
|
|
155
|
-
else:
|
|
156
|
-
ctx.log_warning("History summarization failed or returned no data.")
|
|
157
|
-
except BaseException as e:
|
|
158
|
-
ctx.log_warning(f"Error during history summarization: {e}")
|
|
159
|
-
traceback.print_exc()
|
|
160
|
-
# Return the original summary if summarization fails
|
|
161
|
-
return conversation_history
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
async def maybe_summarize_history(
|
|
165
|
-
ctx: AnyContext,
|
|
166
|
-
conversation_history: ConversationHistory,
|
|
167
|
-
should_summarize_history_attr: BoolAttr | None,
|
|
168
|
-
render_summarize_history: bool,
|
|
169
|
-
history_summarization_token_threshold_attr: IntAttr | None,
|
|
170
|
-
render_history_summarization_token_threshold: bool,
|
|
171
|
-
model: "str | Model | None",
|
|
172
|
-
model_settings: "ModelSettings | None",
|
|
173
|
-
summarization_prompt: str,
|
|
174
|
-
rate_limitter: LLMRateLimiter | None = None,
|
|
175
|
-
) -> ConversationHistory:
|
|
176
|
-
"""Summarizes history and updates context if enabled and threshold met."""
|
|
177
|
-
shorten_history = replace_system_prompt_in_history(conversation_history.history)
|
|
178
|
-
if should_summarize_history(
|
|
179
|
-
ctx,
|
|
180
|
-
shorten_history,
|
|
181
|
-
should_summarize_history_attr,
|
|
182
|
-
render_summarize_history,
|
|
183
|
-
history_summarization_token_threshold_attr,
|
|
184
|
-
render_history_summarization_token_threshold,
|
|
185
|
-
):
|
|
186
|
-
original_history = conversation_history.history
|
|
187
|
-
conversation_history.history = shorten_history
|
|
188
|
-
conversation_history = await summarize_history(
|
|
189
|
-
ctx=ctx,
|
|
190
|
-
model=model,
|
|
191
|
-
settings=model_settings,
|
|
192
|
-
system_prompt=summarization_prompt,
|
|
193
|
-
conversation_history=conversation_history,
|
|
194
|
-
rate_limitter=rate_limitter,
|
|
195
|
-
)
|
|
196
|
-
conversation_history.history = original_history
|
|
197
|
-
if (
|
|
198
|
-
conversation_history.past_conversation_summary != ""
|
|
199
|
-
and conversation_history.past_conversation_transcript != ""
|
|
200
|
-
):
|
|
201
|
-
conversation_history.history = []
|
|
202
|
-
return conversation_history
|
zrb/task/llm/print_node.py
CHANGED
|
@@ -2,6 +2,7 @@ import json
|
|
|
2
2
|
from collections.abc import Callable
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
+
from zrb.config.config import CFG
|
|
5
6
|
from zrb.util.cli.style import stylize_faint
|
|
6
7
|
|
|
7
8
|
|
|
@@ -104,12 +105,20 @@ async def print_node(
|
|
|
104
105
|
and event.tool_call_id
|
|
105
106
|
):
|
|
106
107
|
call_id = event.tool_call_id
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
108
|
+
if CFG.LLM_SHOW_TOOL_CALL_RESULT:
|
|
109
|
+
result_content = event.result.content
|
|
110
|
+
print_func(
|
|
111
|
+
_format_content(
|
|
112
|
+
f"{call_id} | Return {result_content}",
|
|
113
|
+
log_indent_level,
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
print_func(
|
|
118
|
+
_format_content(
|
|
119
|
+
f"{call_id} | Executed", log_indent_level
|
|
120
|
+
)
|
|
111
121
|
)
|
|
112
|
-
)
|
|
113
122
|
except UnexpectedModelBehavior as e:
|
|
114
123
|
print_func("") # ensure newline consistency
|
|
115
124
|
print_func(
|
zrb/task/llm/prompt.py
CHANGED
|
@@ -115,11 +115,11 @@ def _construct_system_prompt(
|
|
|
115
115
|
),
|
|
116
116
|
),
|
|
117
117
|
make_markdown_section(
|
|
118
|
-
"🧠 Long Term Note",
|
|
118
|
+
"🧠 Long Term Note Content",
|
|
119
119
|
conversation_history.long_term_note,
|
|
120
120
|
),
|
|
121
121
|
make_markdown_section(
|
|
122
|
-
"📝 Contextual Note",
|
|
122
|
+
"📝 Contextual Note Content",
|
|
123
123
|
conversation_history.contextual_note,
|
|
124
124
|
),
|
|
125
125
|
make_markdown_section(
|
|
@@ -129,21 +129,6 @@ def _construct_system_prompt(
|
|
|
129
129
|
]
|
|
130
130
|
),
|
|
131
131
|
),
|
|
132
|
-
make_markdown_section(
|
|
133
|
-
"💬 PAST CONVERSATION",
|
|
134
|
-
"\n".join(
|
|
135
|
-
[
|
|
136
|
-
make_markdown_section(
|
|
137
|
-
"Narrative Summary",
|
|
138
|
-
conversation_history.past_conversation_summary,
|
|
139
|
-
),
|
|
140
|
-
make_markdown_section(
|
|
141
|
-
"Past Transcript",
|
|
142
|
-
conversation_history.past_conversation_transcript,
|
|
143
|
-
),
|
|
144
|
-
]
|
|
145
|
-
),
|
|
146
|
-
),
|
|
147
132
|
]
|
|
148
133
|
)
|
|
149
134
|
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from zrb.context.any_context import AnyContext
|
|
2
|
+
from zrb.task.llm.conversation_history_model import ConversationHistory
|
|
3
|
+
from zrb.task.llm.typing import ListOfDict
|
|
4
|
+
from zrb.xcom.xcom import Xcom
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def inject_subagent_conversation_history_into_ctx(
|
|
8
|
+
ctx: AnyContext, conversation_history: ConversationHistory
|
|
9
|
+
):
|
|
10
|
+
subagent_messages_xcom = _get_global_subagent_history_xcom(ctx)
|
|
11
|
+
existing_subagent_history = subagent_messages_xcom.get({})
|
|
12
|
+
subagent_messages_xcom.set(
|
|
13
|
+
{**existing_subagent_history, **conversation_history.subagent_history}
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def extract_subagent_conversation_history_from_ctx(
|
|
18
|
+
ctx: AnyContext,
|
|
19
|
+
) -> dict[str, ListOfDict]:
|
|
20
|
+
subagent_messsages_xcom = _get_global_subagent_history_xcom(ctx)
|
|
21
|
+
return subagent_messsages_xcom.get({})
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_ctx_subagent_history(ctx: AnyContext, subagent_name: str) -> ListOfDict:
|
|
25
|
+
subagent_history = extract_subagent_conversation_history_from_ctx(ctx)
|
|
26
|
+
return subagent_history.get(subagent_name, [])
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def set_ctx_subagent_history(ctx: AnyContext, subagent_name: str, messages: ListOfDict):
|
|
30
|
+
subagent_history = extract_subagent_conversation_history_from_ctx(ctx)
|
|
31
|
+
subagent_history[subagent_name] = messages
|
|
32
|
+
subagent_messages_xcom = _get_global_subagent_history_xcom(ctx)
|
|
33
|
+
subagent_messages_xcom.set(subagent_history)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _get_global_subagent_history_xcom(ctx: AnyContext) -> Xcom:
|
|
37
|
+
if "_global_subagents" not in ctx.xcom:
|
|
38
|
+
ctx.xcom["_global_subagents"] = Xcom([{}])
|
|
39
|
+
if not isinstance(ctx.xcom["_global_subagents"], Xcom):
|
|
40
|
+
raise ValueError("ctx.xcom._global_subagents must be an Xcom")
|
|
41
|
+
return ctx.xcom["_global_subagents"]
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING:
|
|
4
|
+
from prompt_toolkit.completion import Completer
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_tool_confirmation_completer(
|
|
8
|
+
options: list[str], meta_dict: dict[str, str]
|
|
9
|
+
) -> "Completer":
|
|
10
|
+
from prompt_toolkit.completion import Completer, Completion
|
|
11
|
+
|
|
12
|
+
class ToolConfirmationCompleter(Completer):
|
|
13
|
+
"""Custom completer for tool confirmation that doesn't auto-complete partial words."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, options, meta_dict):
|
|
16
|
+
self.options = options
|
|
17
|
+
self.meta_dict = meta_dict
|
|
18
|
+
|
|
19
|
+
def get_completions(self, document, complete_event):
|
|
20
|
+
text = document.text.strip()
|
|
21
|
+
# 1. Input is empty, OR
|
|
22
|
+
# 2. Input exactly matches the beginning of an option
|
|
23
|
+
if text == "":
|
|
24
|
+
# Show all options when nothing is typed
|
|
25
|
+
for option in self.options:
|
|
26
|
+
yield Completion(
|
|
27
|
+
option,
|
|
28
|
+
start_position=0,
|
|
29
|
+
display_meta=self.meta_dict.get(option, ""),
|
|
30
|
+
)
|
|
31
|
+
return
|
|
32
|
+
# Only complete if text exactly matches the beginning of an option
|
|
33
|
+
for option in self.options:
|
|
34
|
+
if option.startswith(text):
|
|
35
|
+
yield Completion(
|
|
36
|
+
option,
|
|
37
|
+
start_position=-len(text),
|
|
38
|
+
display_meta=self.meta_dict.get(option, ""),
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
return ToolConfirmationCompleter(options, meta_dict)
|
zrb/task/llm/tool_wrapper.py
CHANGED
|
@@ -11,11 +11,11 @@ from zrb.config.llm_rate_limitter import llm_rate_limitter
|
|
|
11
11
|
from zrb.context.any_context import AnyContext
|
|
12
12
|
from zrb.task.llm.error import ToolExecutionError
|
|
13
13
|
from zrb.task.llm.file_replacement import edit_replacement, is_single_path_replacement
|
|
14
|
+
from zrb.task.llm.tool_confirmation_completer import get_tool_confirmation_completer
|
|
14
15
|
from zrb.util.callable import get_callable_name
|
|
15
16
|
from zrb.util.cli.markdown import render_markdown
|
|
16
17
|
from zrb.util.cli.style import (
|
|
17
18
|
stylize_blue,
|
|
18
|
-
stylize_error,
|
|
19
19
|
stylize_faint,
|
|
20
20
|
stylize_green,
|
|
21
21
|
stylize_yellow,
|
|
@@ -185,7 +185,7 @@ async def _handle_user_response(
|
|
|
185
185
|
]
|
|
186
186
|
)
|
|
187
187
|
ctx.print(complete_confirmation_message, plain=True)
|
|
188
|
-
user_response = await _read_line()
|
|
188
|
+
user_response = await _read_line(args, kwargs)
|
|
189
189
|
ctx.print("", plain=True)
|
|
190
190
|
new_kwargs, is_edited = _get_edited_kwargs(ctx, user_response, kwargs)
|
|
191
191
|
if is_edited:
|
|
@@ -250,13 +250,7 @@ def _get_user_approval_and_reason(
|
|
|
250
250
|
try:
|
|
251
251
|
approved = True if approval_str.strip() == "" else to_boolean(approval_str)
|
|
252
252
|
if not approved and reason == "":
|
|
253
|
-
|
|
254
|
-
stylize_error(
|
|
255
|
-
f"You must specify rejection reason (i.e., No, <why>) for {func_call_str}" # noqa
|
|
256
|
-
),
|
|
257
|
-
plain=True,
|
|
258
|
-
)
|
|
259
|
-
return None
|
|
253
|
+
reason = "User disapproving the tool execution"
|
|
260
254
|
return approved, reason
|
|
261
255
|
except Exception:
|
|
262
256
|
return False, user_response
|
|
@@ -300,11 +294,21 @@ def _truncate_arg(arg: str, length: int = 19) -> str:
|
|
|
300
294
|
return normalized_arg
|
|
301
295
|
|
|
302
296
|
|
|
303
|
-
async def _read_line():
|
|
297
|
+
async def _read_line(args: list[Any] | tuple[Any], kwargs: dict[str, Any]):
|
|
304
298
|
from prompt_toolkit import PromptSession
|
|
305
299
|
|
|
300
|
+
options = ["yes", "no", "edit"]
|
|
301
|
+
meta_dict = {
|
|
302
|
+
"yes": "Approve the execution",
|
|
303
|
+
"no": "Disapprove the execution",
|
|
304
|
+
"edit": "Edit tool execution parameters",
|
|
305
|
+
}
|
|
306
|
+
for key in kwargs:
|
|
307
|
+
options.append(f"edit {key}")
|
|
308
|
+
meta_dict[f"edit {key}"] = f"Edit tool execution parameter: {key}"
|
|
309
|
+
completer = get_tool_confirmation_completer(options, meta_dict)
|
|
306
310
|
reader = PromptSession()
|
|
307
|
-
return await reader.prompt_async()
|
|
311
|
+
return await reader.prompt_async(completer=completer)
|
|
308
312
|
|
|
309
313
|
|
|
310
314
|
def _adjust_signature(wrapper: Callable, original_sig: inspect.Signature):
|