zrb 1.21.9__py3-none-any.whl → 1.21.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of zrb might be problematic. Click here for more details.
- zrb/attr/type.py +10 -7
- zrb/builtin/git.py +12 -1
- zrb/builtin/llm/chat_completion.py +274 -0
- zrb/builtin/llm/chat_session_cmd.py +90 -28
- zrb/builtin/llm/chat_trigger.py +7 -1
- zrb/builtin/llm/history.py +4 -4
- zrb/builtin/llm/tool/code.py +4 -1
- zrb/builtin/llm/tool/file.py +36 -81
- zrb/builtin/llm/tool/note.py +36 -16
- zrb/builtin/llm/tool/sub_agent.py +30 -10
- zrb/config/config.py +108 -13
- zrb/config/default_prompt/interactive_system_prompt.md +1 -1
- zrb/config/default_prompt/summarization_prompt.md +54 -8
- zrb/config/default_prompt/system_prompt.md +1 -1
- zrb/config/llm_rate_limitter.py +24 -5
- zrb/input/option_input.py +13 -1
- zrb/task/llm/agent.py +42 -144
- zrb/task/llm/agent_runner.py +152 -0
- zrb/task/llm/config.py +7 -5
- zrb/task/llm/conversation_history.py +35 -24
- zrb/task/llm/conversation_history_model.py +4 -11
- zrb/task/llm/default_workflow/coding/workflow.md +2 -3
- zrb/task/llm/file_replacement.py +206 -0
- zrb/task/llm/file_tool_model.py +57 -0
- zrb/task/llm/history_processor.py +206 -0
- zrb/task/llm/history_summarization.py +2 -179
- zrb/task/llm/print_node.py +14 -5
- zrb/task/llm/prompt.py +7 -18
- zrb/task/llm/subagent_conversation_history.py +41 -0
- zrb/task/llm/tool_wrapper.py +27 -12
- zrb/task/llm_task.py +55 -47
- zrb/util/attr.py +17 -10
- zrb/util/cli/text.py +6 -4
- zrb/util/git.py +2 -2
- zrb/util/yaml.py +1 -0
- zrb/xcom/xcom.py +10 -0
- {zrb-1.21.9.dist-info → zrb-1.21.28.dist-info}/METADATA +5 -5
- {zrb-1.21.9.dist-info → zrb-1.21.28.dist-info}/RECORD +40 -35
- zrb/task/llm/history_summarization_tool.py +0 -24
- {zrb-1.21.9.dist-info → zrb-1.21.28.dist-info}/WHEEL +0 -0
- {zrb-1.21.9.dist-info → zrb-1.21.28.dist-info}/entry_points.txt +0 -0
|
@@ -1,11 +1,57 @@
|
|
|
1
|
-
You are a memory management AI. Your
|
|
1
|
+
You are a smart memory management AI. Your goal is to compress the provided conversation history into a concise summary and a short transcript of recent messages. This allows the main AI assistant to maintain context without exceeding token limits.
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
You will receive a JSON string representing the full conversation history. This JSON contains a list of message objects.
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
2. **Transcript:** Extract ONLY the last 4 (four) turns of the `Recent Conversation` to serve as the new transcript.
|
|
7
|
-
* **Do not change or shorten the content of these turns, with one exception:** If a tool call returns a very long output, do not include the full output. Instead, briefly summarize the result of the tool call.
|
|
8
|
-
* Ensure the timestamp format is `[YYYY-MM-DD HH:MM:SS UTC+Z] Role: Message/Tool name being called`.
|
|
9
|
-
3. **Update Memory:** Call the `final_result` tool with all the information you consolidated.
|
|
5
|
+
Your task is to call the `save_conversation_summary` tool **once** with the following data. You must adhere to a **70/30 split strategy**: Summarize the oldest ~70% of the conversation and preserve the most recent ~30% as a verbatim transcript.
|
|
10
6
|
|
|
11
|
-
|
|
7
|
+
1. **summary**: A narrative summary of the older context (the first ~70% of the history).
|
|
8
|
+
* **Length:** Comprehensive but concise.
|
|
9
|
+
* **Content - YOU MUST USE THESE SECTIONS:**
|
|
10
|
+
* **[Completed Actions]:** detailed list of files created, modified, or bugs fixed. **Do not omit file paths.**
|
|
11
|
+
* **[Active Context]:** What is the current high-level goal?
|
|
12
|
+
* **[Pending Steps]:** What specifically remains to be done?
|
|
13
|
+
* **[Constraints]:** Key user preferences or technical constraints.
|
|
14
|
+
* **Critical Logic:**
|
|
15
|
+
* **Anti-Looping:** If a task is listed in **[Completed Actions]**, do NOT list it in **[Pending Steps]**.
|
|
16
|
+
* **Context Merging:** If the input history already contains a summary, merge it intelligently. Updates to files supersede older descriptions.
|
|
17
|
+
|
|
18
|
+
2. **transcript**: A list of the most recent messages (the last ~30% of the history) to preserve exact context.
|
|
19
|
+
* **Format:** A list of objects with `role`, `time`, and `content`.
|
|
20
|
+
* **Time Format:** Use "yyyy-mm-ddTHH:MM:SSZ" (e.g., "2023-10-27T10:00:00Z").
|
|
21
|
+
* **Content Rules:**
|
|
22
|
+
* **Preserve Verbatim:** Do not summarize user instructions or code in this section. The main AI needs the exact recent commands to function correctly.
|
|
23
|
+
* **Tool Outputs:** If a tool output in this recent section is huge (e.g., > 100 lines of file content), you may summarize it (e.g., "File content of X read successfully... "), but preserve any error messages or short confirmations exactly.
|
|
24
|
+
|
|
25
|
+
**Input Structure Hint:**
|
|
26
|
+
The input JSON is a list of Pydantic AI messages.
|
|
27
|
+
- `kind="request"` -> usually User.
|
|
28
|
+
- `kind="response"` -> usually Model.
|
|
29
|
+
- Tool Results -> `part_kind="tool-return"`.
|
|
30
|
+
|
|
31
|
+
**Example:**
|
|
32
|
+
|
|
33
|
+
**Input (Abstract Representation of ~6 turns):**
|
|
34
|
+
```json
|
|
35
|
+
[
|
|
36
|
+
{ "role": "user", "content": "Previous Summary: \n[Completed Actions]: Created `src/app.py`.\n[Active Context]: Fixing login bug.\n[Pending Steps]: Verify fix." },
|
|
37
|
+
{ "role": "model", "content": "I see the bug. I will fix `src/app.py` now." },
|
|
38
|
+
{ "role": "tool_call", "content": "write_file('src/app.py', '...fixed code...')" },
|
|
39
|
+
{ "role": "tool_result", "content": "Success" },
|
|
40
|
+
{ "role": "user", "content": "Great. Now add a test for it." },
|
|
41
|
+
{ "role": "model", "content": "Okay, I will create `tests/test_login.py`." }
|
|
42
|
+
]
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
**Output (Tool Call `save_conversation_summary`):**
|
|
46
|
+
```json
|
|
47
|
+
{
|
|
48
|
+
"summary": "[Completed Actions]: Created `src/app.py` and fixed login bug in `src/app.py`.\n[Active Context]: Adding tests for login functionality.\n[Pending Steps]: Create `tests/test_login.py`.\n[Constraints]: None.",
|
|
49
|
+
"transcript": [
|
|
50
|
+
{ "role": "user", "time": "2023-10-27T10:05:00Z", "content": "Great. Now add a test for it." },
|
|
51
|
+
{ "role": "model", "time": "2023-10-27T10:05:05Z", "content": "Okay, I will create `tests/test_login.py`." }
|
|
52
|
+
]
|
|
53
|
+
}
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
**Final Note:**
|
|
57
|
+
The `summary` + `transcript` is the ONLY memory the main AI will have. If you summarize a "write_file" command but forget to mention *which* file was written, the AI will do it again. **Be specific.**
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
This is a single request session. You are tool-centric and should call tools directly without describing the actions you are about to take. Only communicate to report the final result.
|
|
2
2
|
|
|
3
3
|
# Core Principles
|
|
4
4
|
|
zrb/config/llm_rate_limitter.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import Any, Callable
|
|
|
7
7
|
from zrb.config.config import CFG
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
class
|
|
10
|
+
class LLMRateLimitter:
|
|
11
11
|
"""
|
|
12
12
|
Helper class to enforce LLM API rate limits and throttling.
|
|
13
13
|
Tracks requests and tokens in a rolling 60-second window.
|
|
@@ -129,7 +129,7 @@ class LLMRateLimiter:
|
|
|
129
129
|
async def throttle(
|
|
130
130
|
self,
|
|
131
131
|
prompt: Any,
|
|
132
|
-
throttle_notif_callback: Callable | None = None,
|
|
132
|
+
throttle_notif_callback: Callable[[str], Any] | None = None,
|
|
133
133
|
):
|
|
134
134
|
now = time.time()
|
|
135
135
|
str_prompt = self._prompt_to_str(prompt)
|
|
@@ -142,7 +142,17 @@ class LLMRateLimiter:
|
|
|
142
142
|
# Check per-request token limit
|
|
143
143
|
if tokens > self.max_tokens_per_request:
|
|
144
144
|
raise ValueError(
|
|
145
|
-
|
|
145
|
+
(
|
|
146
|
+
"Request exceeds max_tokens_per_request "
|
|
147
|
+
f"({tokens} > {self.max_tokens_per_request})."
|
|
148
|
+
)
|
|
149
|
+
)
|
|
150
|
+
if tokens > self.max_tokens_per_minute:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
(
|
|
153
|
+
"Request exceeds max_tokens_per_minute "
|
|
154
|
+
f"({tokens} > {self.max_tokens_per_minute})."
|
|
155
|
+
)
|
|
146
156
|
)
|
|
147
157
|
# Wait if over per-minute request or token limit
|
|
148
158
|
while (
|
|
@@ -150,7 +160,16 @@ class LLMRateLimiter:
|
|
|
150
160
|
or sum(t for _, t in self.token_times) + tokens > self.max_tokens_per_minute
|
|
151
161
|
):
|
|
152
162
|
if throttle_notif_callback is not None:
|
|
153
|
-
|
|
163
|
+
if len(self.request_times) >= self.max_requests_per_minute:
|
|
164
|
+
rpm = len(self.request_times)
|
|
165
|
+
throttle_notif_callback(
|
|
166
|
+
f"Max request per minute exceeded: {rpm} of {self.max_requests_per_minute}"
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
tpm = sum(t for _, t in self.token_times) + tokens
|
|
170
|
+
throttle_notif_callback(
|
|
171
|
+
f"Max token per minute exceeded: {tpm} of {self.max_tokens_per_minute}"
|
|
172
|
+
)
|
|
154
173
|
await asyncio.sleep(self.throttle_sleep)
|
|
155
174
|
now = time.time()
|
|
156
175
|
while self.request_times and now - self.request_times[0] > 60:
|
|
@@ -168,4 +187,4 @@ class LLMRateLimiter:
|
|
|
168
187
|
return f"{prompt}"
|
|
169
188
|
|
|
170
189
|
|
|
171
|
-
llm_rate_limitter =
|
|
190
|
+
llm_rate_limitter = LLMRateLimitter()
|
zrb/input/option_input.py
CHANGED
|
@@ -47,9 +47,21 @@ class OptionInput(BaseInput):
|
|
|
47
47
|
option_str = ", ".join(options)
|
|
48
48
|
if default_value != "":
|
|
49
49
|
prompt_message = f"{prompt_message} ({option_str}) [{default_value}]"
|
|
50
|
-
value =
|
|
50
|
+
value = self._get_value_from_user_input(shared_ctx, prompt_message, options)
|
|
51
51
|
if value.strip() != "" and value.strip() not in options:
|
|
52
52
|
value = self._prompt_cli_str(shared_ctx)
|
|
53
53
|
if value.strip() == "":
|
|
54
54
|
value = default_value
|
|
55
55
|
return value
|
|
56
|
+
|
|
57
|
+
def _get_value_from_user_input(
|
|
58
|
+
self, shared_ctx: AnySharedContext, prompt_message: str, options: list[str]
|
|
59
|
+
) -> str:
|
|
60
|
+
from prompt_toolkit import PromptSession
|
|
61
|
+
from prompt_toolkit.completion import WordCompleter
|
|
62
|
+
|
|
63
|
+
if shared_ctx.is_tty:
|
|
64
|
+
reader = PromptSession()
|
|
65
|
+
option_completer = WordCompleter(options, ignore_case=True)
|
|
66
|
+
return reader.prompt(f"{prompt_message}: ", completer=option_completer)
|
|
67
|
+
return input(f"{prompt_message}: ")
|
zrb/task/llm/agent.py
CHANGED
|
@@ -1,22 +1,16 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
import json
|
|
3
2
|
from collections.abc import Callable
|
|
4
3
|
from dataclasses import dataclass
|
|
5
4
|
from typing import TYPE_CHECKING, Any
|
|
6
5
|
|
|
7
|
-
from zrb.config.llm_rate_limitter import
|
|
6
|
+
from zrb.config.llm_rate_limitter import LLMRateLimitter
|
|
8
7
|
from zrb.context.any_context import AnyContext
|
|
9
|
-
from zrb.
|
|
10
|
-
from zrb.task.llm.error import extract_api_error_details
|
|
11
|
-
from zrb.task.llm.print_node import print_node
|
|
8
|
+
from zrb.task.llm.history_processor import create_summarize_history_processor
|
|
12
9
|
from zrb.task.llm.tool_wrapper import wrap_func, wrap_tool
|
|
13
|
-
from zrb.task.llm.typing import ListOfDict
|
|
14
|
-
from zrb.util.cli.style import stylize_faint
|
|
15
10
|
|
|
16
11
|
if TYPE_CHECKING:
|
|
17
12
|
from pydantic_ai import Agent, Tool
|
|
18
|
-
from pydantic_ai.
|
|
19
|
-
from pydantic_ai.messages import UserContent
|
|
13
|
+
from pydantic_ai._agent_graph import HistoryProcessor
|
|
20
14
|
from pydantic_ai.models import Model
|
|
21
15
|
from pydantic_ai.output import OutputDataT, OutputSpec
|
|
22
16
|
from pydantic_ai.settings import ModelSettings
|
|
@@ -28,13 +22,21 @@ if TYPE_CHECKING:
|
|
|
28
22
|
def create_agent_instance(
|
|
29
23
|
ctx: AnyContext,
|
|
30
24
|
model: "str | Model",
|
|
25
|
+
rate_limitter: LLMRateLimitter | None = None,
|
|
31
26
|
output_type: "OutputSpec[OutputDataT]" = str,
|
|
32
27
|
system_prompt: str = "",
|
|
33
28
|
model_settings: "ModelSettings | None" = None,
|
|
34
|
-
tools:
|
|
29
|
+
tools: list["ToolOrCallable"] = [],
|
|
35
30
|
toolsets: list["AbstractToolset[None]"] = [],
|
|
36
31
|
retries: int = 3,
|
|
37
32
|
yolo_mode: bool | list[str] | None = None,
|
|
33
|
+
summarization_model: "Model | str | None" = None,
|
|
34
|
+
summarization_model_settings: "ModelSettings | None" = None,
|
|
35
|
+
summarization_system_prompt: str | None = None,
|
|
36
|
+
summarization_retries: int = 2,
|
|
37
|
+
summarization_token_threshold: int | None = None,
|
|
38
|
+
history_processors: list["HistoryProcessor"] | None = None,
|
|
39
|
+
auto_summarize: bool = True,
|
|
38
40
|
) -> "Agent[None, Any]":
|
|
39
41
|
"""Creates a new Agent instance with configured tools and servers."""
|
|
40
42
|
from pydantic_ai import Agent, RunContext, Tool
|
|
@@ -102,6 +104,21 @@ def create_agent_instance(
|
|
|
102
104
|
ConfirmationWrapperToolset(wrapped=toolset, ctx=ctx, yolo_mode=yolo_mode)
|
|
103
105
|
for toolset in toolsets
|
|
104
106
|
]
|
|
107
|
+
# Create History processor with summarizer
|
|
108
|
+
history_processors = [] if history_processors is None else history_processors
|
|
109
|
+
if auto_summarize:
|
|
110
|
+
history_processors += [
|
|
111
|
+
create_summarize_history_processor(
|
|
112
|
+
ctx=ctx,
|
|
113
|
+
system_prompt=system_prompt,
|
|
114
|
+
rate_limitter=rate_limitter,
|
|
115
|
+
summarization_model=summarization_model,
|
|
116
|
+
summarization_model_settings=summarization_model_settings,
|
|
117
|
+
summarization_system_prompt=summarization_system_prompt,
|
|
118
|
+
summarization_token_threshold=summarization_token_threshold,
|
|
119
|
+
summarization_retries=summarization_retries,
|
|
120
|
+
)
|
|
121
|
+
]
|
|
105
122
|
# Return Agent
|
|
106
123
|
return Agent[None, Any](
|
|
107
124
|
model=model,
|
|
@@ -111,12 +128,14 @@ def create_agent_instance(
|
|
|
111
128
|
toolsets=wrapped_toolsets,
|
|
112
129
|
model_settings=model_settings,
|
|
113
130
|
retries=retries,
|
|
131
|
+
history_processors=history_processors,
|
|
114
132
|
)
|
|
115
133
|
|
|
116
134
|
|
|
117
135
|
def get_agent(
|
|
118
136
|
ctx: AnyContext,
|
|
119
137
|
model: "str | Model",
|
|
138
|
+
rate_limitter: LLMRateLimitter | None = None,
|
|
120
139
|
output_type: "OutputSpec[OutputDataT]" = str,
|
|
121
140
|
system_prompt: str = "",
|
|
122
141
|
model_settings: "ModelSettings | None" = None,
|
|
@@ -128,6 +147,12 @@ def get_agent(
|
|
|
128
147
|
additional_toolsets: "list[AbstractToolset[None] | str]" = [],
|
|
129
148
|
retries: int = 3,
|
|
130
149
|
yolo_mode: bool | list[str] | None = None,
|
|
150
|
+
summarization_model: "Model | str | None" = None,
|
|
151
|
+
summarization_model_settings: "ModelSettings | None" = None,
|
|
152
|
+
summarization_system_prompt: str | None = None,
|
|
153
|
+
summarization_retries: int = 2,
|
|
154
|
+
summarization_token_threshold: int | None = None,
|
|
155
|
+
history_processors: list["HistoryProcessor"] | None = None,
|
|
131
156
|
) -> "Agent":
|
|
132
157
|
"""Retrieves the configured Agent instance or creates one if necessary."""
|
|
133
158
|
# Get tools for agent
|
|
@@ -143,6 +168,7 @@ def get_agent(
|
|
|
143
168
|
return create_agent_instance(
|
|
144
169
|
ctx=ctx,
|
|
145
170
|
model=model,
|
|
171
|
+
rate_limitter=rate_limitter,
|
|
146
172
|
output_type=output_type,
|
|
147
173
|
system_prompt=system_prompt,
|
|
148
174
|
tools=tools,
|
|
@@ -150,6 +176,12 @@ def get_agent(
|
|
|
150
176
|
model_settings=model_settings,
|
|
151
177
|
retries=retries,
|
|
152
178
|
yolo_mode=yolo_mode,
|
|
179
|
+
summarization_model=summarization_model,
|
|
180
|
+
summarization_model_settings=summarization_model_settings,
|
|
181
|
+
summarization_system_prompt=summarization_system_prompt,
|
|
182
|
+
summarization_retries=summarization_retries,
|
|
183
|
+
summarization_token_threshold=summarization_token_threshold,
|
|
184
|
+
history_processors=history_processors,
|
|
153
185
|
)
|
|
154
186
|
|
|
155
187
|
|
|
@@ -170,137 +202,3 @@ def _render_toolset_or_str_list(
|
|
|
170
202
|
continue
|
|
171
203
|
toolsets.append(toolset_or_str)
|
|
172
204
|
return toolsets
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
async def run_agent_iteration(
|
|
176
|
-
ctx: AnyContext,
|
|
177
|
-
agent: "Agent[None, Any]",
|
|
178
|
-
user_prompt: str,
|
|
179
|
-
attachments: "list[UserContent] | None" = None,
|
|
180
|
-
history_list: ListOfDict | None = None,
|
|
181
|
-
rate_limitter: LLMRateLimiter | None = None,
|
|
182
|
-
max_retry: int = 2,
|
|
183
|
-
log_indent_level: int = 0,
|
|
184
|
-
) -> "AgentRun":
|
|
185
|
-
"""
|
|
186
|
-
Runs a single iteration of the agent execution loop.
|
|
187
|
-
|
|
188
|
-
Args:
|
|
189
|
-
ctx: The task context.
|
|
190
|
-
agent: The Pydantic AI agent instance.
|
|
191
|
-
user_prompt: The user's input prompt.
|
|
192
|
-
history_list: The current conversation history.
|
|
193
|
-
|
|
194
|
-
Returns:
|
|
195
|
-
The agent run result object.
|
|
196
|
-
|
|
197
|
-
Raises:
|
|
198
|
-
Exception: If any error occurs during agent execution.
|
|
199
|
-
"""
|
|
200
|
-
if max_retry < 0:
|
|
201
|
-
raise ValueError("Max retry cannot be less than 0")
|
|
202
|
-
attempt = 0
|
|
203
|
-
while attempt < max_retry:
|
|
204
|
-
try:
|
|
205
|
-
return await _run_single_agent_iteration(
|
|
206
|
-
ctx=ctx,
|
|
207
|
-
agent=agent,
|
|
208
|
-
user_prompt=user_prompt,
|
|
209
|
-
attachments=[] if attachments is None else attachments,
|
|
210
|
-
history_list=[] if history_list is None else history_list,
|
|
211
|
-
rate_limitter=(
|
|
212
|
-
llm_rate_limitter if rate_limitter is None else rate_limitter
|
|
213
|
-
),
|
|
214
|
-
log_indent_level=log_indent_level,
|
|
215
|
-
)
|
|
216
|
-
except BaseException:
|
|
217
|
-
attempt += 1
|
|
218
|
-
if attempt == max_retry:
|
|
219
|
-
raise
|
|
220
|
-
raise Exception("Max retry exceeded")
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
async def _run_single_agent_iteration(
|
|
224
|
-
ctx: AnyContext,
|
|
225
|
-
agent: "Agent",
|
|
226
|
-
user_prompt: str,
|
|
227
|
-
attachments: "list[UserContent]",
|
|
228
|
-
history_list: ListOfDict,
|
|
229
|
-
rate_limitter: LLMRateLimiter,
|
|
230
|
-
log_indent_level: int,
|
|
231
|
-
) -> "AgentRun":
|
|
232
|
-
from openai import APIError
|
|
233
|
-
from pydantic_ai.messages import ModelMessagesTypeAdapter
|
|
234
|
-
|
|
235
|
-
agent_payload = _estimate_request_payload(
|
|
236
|
-
agent, user_prompt, attachments, history_list
|
|
237
|
-
)
|
|
238
|
-
callback = _create_print_throttle_notif(ctx)
|
|
239
|
-
if rate_limitter:
|
|
240
|
-
await rate_limitter.throttle(agent_payload, callback)
|
|
241
|
-
else:
|
|
242
|
-
await llm_rate_limitter.throttle(agent_payload, callback)
|
|
243
|
-
|
|
244
|
-
user_prompt_with_attachments = [user_prompt] + attachments
|
|
245
|
-
async with agent:
|
|
246
|
-
async with agent.iter(
|
|
247
|
-
user_prompt=user_prompt_with_attachments,
|
|
248
|
-
message_history=ModelMessagesTypeAdapter.validate_python(history_list),
|
|
249
|
-
) as agent_run:
|
|
250
|
-
async for node in agent_run:
|
|
251
|
-
# Each node represents a step in the agent's execution
|
|
252
|
-
try:
|
|
253
|
-
await print_node(
|
|
254
|
-
_get_plain_printer(ctx), agent_run, node, log_indent_level
|
|
255
|
-
)
|
|
256
|
-
except APIError as e:
|
|
257
|
-
# Extract detailed error information from the response
|
|
258
|
-
error_details = extract_api_error_details(e)
|
|
259
|
-
ctx.log_error(f"API Error: {error_details}")
|
|
260
|
-
raise
|
|
261
|
-
except Exception as e:
|
|
262
|
-
ctx.log_error(f"Error processing node: {str(e)}")
|
|
263
|
-
ctx.log_error(f"Error type: {type(e).__name__}")
|
|
264
|
-
raise
|
|
265
|
-
return agent_run
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
def _create_print_throttle_notif(ctx: AnyContext) -> Callable[[], None]:
|
|
269
|
-
def _print_throttle_notif():
|
|
270
|
-
ctx.print(stylize_faint(" ⌛>> Request Throttled"), plain=True)
|
|
271
|
-
|
|
272
|
-
return _print_throttle_notif
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
def _estimate_request_payload(
|
|
276
|
-
agent: "Agent",
|
|
277
|
-
user_prompt: str,
|
|
278
|
-
attachments: "list[UserContent]",
|
|
279
|
-
history_list: ListOfDict,
|
|
280
|
-
) -> str:
|
|
281
|
-
system_prompts = agent._system_prompts if hasattr(agent, "_system_prompts") else ()
|
|
282
|
-
return json.dumps(
|
|
283
|
-
[
|
|
284
|
-
{"role": "system", "content": "\n".join(system_prompts)},
|
|
285
|
-
*history_list,
|
|
286
|
-
{"role": "user", "content": user_prompt},
|
|
287
|
-
*[_estimate_attachment_payload(attachment) for attachment in attachments],
|
|
288
|
-
]
|
|
289
|
-
)
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
def _estimate_attachment_payload(attachment: "UserContent") -> Any:
|
|
293
|
-
if hasattr(attachment, "url"):
|
|
294
|
-
return {"role": "user", "content": attachment.url}
|
|
295
|
-
if hasattr(attachment, "data"):
|
|
296
|
-
return {"role": "user", "content": "x" * len(attachment.data)}
|
|
297
|
-
return ""
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
def _get_plain_printer(ctx: AnyContext):
|
|
301
|
-
def printer(*args, **kwargs):
|
|
302
|
-
if "plain" not in kwargs:
|
|
303
|
-
kwargs["plain"] = True
|
|
304
|
-
return ctx.print(*args, **kwargs)
|
|
305
|
-
|
|
306
|
-
return printer
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
from zrb.config.llm_rate_limitter import LLMRateLimitter, llm_rate_limitter
|
|
6
|
+
from zrb.context.any_context import AnyContext
|
|
7
|
+
from zrb.task.llm.error import extract_api_error_details
|
|
8
|
+
from zrb.task.llm.print_node import print_node
|
|
9
|
+
from zrb.task.llm.typing import ListOfDict
|
|
10
|
+
from zrb.util.cli.style import stylize_faint
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from pydantic_ai import Agent, Tool
|
|
14
|
+
from pydantic_ai.agent import AgentRun
|
|
15
|
+
from pydantic_ai.messages import UserContent
|
|
16
|
+
|
|
17
|
+
ToolOrCallable = Tool | Callable
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
async def run_agent_iteration(
|
|
21
|
+
ctx: AnyContext,
|
|
22
|
+
agent: "Agent[None, Any]",
|
|
23
|
+
user_prompt: str,
|
|
24
|
+
attachments: "list[UserContent] | None" = None,
|
|
25
|
+
history_list: ListOfDict | None = None,
|
|
26
|
+
rate_limitter: LLMRateLimitter | None = None,
|
|
27
|
+
max_retry: int = 2,
|
|
28
|
+
log_indent_level: int = 0,
|
|
29
|
+
) -> "AgentRun":
|
|
30
|
+
"""
|
|
31
|
+
Runs a single iteration of the agent execution loop.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
ctx: The task context.
|
|
35
|
+
agent: The Pydantic AI agent instance.
|
|
36
|
+
user_prompt: The user's input prompt.
|
|
37
|
+
history_list: The current conversation history.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The agent run result object.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
Exception: If any error occurs during agent execution.
|
|
44
|
+
"""
|
|
45
|
+
if max_retry < 0:
|
|
46
|
+
raise ValueError("Max retry cannot be less than 0")
|
|
47
|
+
attempt = 0
|
|
48
|
+
while attempt < max_retry:
|
|
49
|
+
try:
|
|
50
|
+
return await _run_single_agent_iteration(
|
|
51
|
+
ctx=ctx,
|
|
52
|
+
agent=agent,
|
|
53
|
+
user_prompt=user_prompt,
|
|
54
|
+
attachments=[] if attachments is None else attachments,
|
|
55
|
+
history_list=[] if history_list is None else history_list,
|
|
56
|
+
rate_limitter=(
|
|
57
|
+
llm_rate_limitter if rate_limitter is None else rate_limitter
|
|
58
|
+
),
|
|
59
|
+
log_indent_level=log_indent_level,
|
|
60
|
+
)
|
|
61
|
+
except BaseException:
|
|
62
|
+
attempt += 1
|
|
63
|
+
if attempt == max_retry:
|
|
64
|
+
raise
|
|
65
|
+
raise Exception("Max retry exceeded")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
async def _run_single_agent_iteration(
|
|
69
|
+
ctx: AnyContext,
|
|
70
|
+
agent: "Agent",
|
|
71
|
+
user_prompt: str,
|
|
72
|
+
attachments: "list[UserContent]",
|
|
73
|
+
history_list: ListOfDict,
|
|
74
|
+
rate_limitter: LLMRateLimitter,
|
|
75
|
+
log_indent_level: int,
|
|
76
|
+
) -> "AgentRun":
|
|
77
|
+
from openai import APIError
|
|
78
|
+
from pydantic_ai import UsageLimits
|
|
79
|
+
from pydantic_ai.messages import ModelMessagesTypeAdapter
|
|
80
|
+
|
|
81
|
+
agent_payload = _estimate_request_payload(
|
|
82
|
+
agent, user_prompt, attachments, history_list
|
|
83
|
+
)
|
|
84
|
+
callback = _create_print_throttle_notif(ctx)
|
|
85
|
+
if rate_limitter:
|
|
86
|
+
await rate_limitter.throttle(agent_payload, callback)
|
|
87
|
+
else:
|
|
88
|
+
await llm_rate_limitter.throttle(agent_payload, callback)
|
|
89
|
+
user_prompt_with_attachments = [user_prompt] + attachments
|
|
90
|
+
async with agent:
|
|
91
|
+
async with agent.iter(
|
|
92
|
+
user_prompt=user_prompt_with_attachments,
|
|
93
|
+
message_history=ModelMessagesTypeAdapter.validate_python(history_list),
|
|
94
|
+
usage_limits=UsageLimits(request_limit=None), # We don't want limit
|
|
95
|
+
) as agent_run:
|
|
96
|
+
async for node in agent_run:
|
|
97
|
+
# Each node represents a step in the agent's execution
|
|
98
|
+
try:
|
|
99
|
+
await print_node(
|
|
100
|
+
_get_plain_printer(ctx), agent_run, node, log_indent_level
|
|
101
|
+
)
|
|
102
|
+
except APIError as e:
|
|
103
|
+
# Extract detailed error information from the response
|
|
104
|
+
error_details = extract_api_error_details(e)
|
|
105
|
+
ctx.log_error(f"API Error: {error_details}")
|
|
106
|
+
raise
|
|
107
|
+
except Exception as e:
|
|
108
|
+
ctx.log_error(f"Error processing node: {str(e)}")
|
|
109
|
+
ctx.log_error(f"Error type: {type(e).__name__}")
|
|
110
|
+
raise
|
|
111
|
+
return agent_run
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _create_print_throttle_notif(ctx: AnyContext) -> Callable[[str], None]:
|
|
115
|
+
def _print_throttle_notif(reason: str):
|
|
116
|
+
ctx.print(stylize_faint(f" ⌛>> Request Throttled: {reason}"), plain=True)
|
|
117
|
+
|
|
118
|
+
return _print_throttle_notif
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _estimate_request_payload(
|
|
122
|
+
agent: "Agent",
|
|
123
|
+
user_prompt: str,
|
|
124
|
+
attachments: "list[UserContent]",
|
|
125
|
+
history_list: ListOfDict,
|
|
126
|
+
) -> str:
|
|
127
|
+
system_prompts = agent._system_prompts if hasattr(agent, "_system_prompts") else ()
|
|
128
|
+
return json.dumps(
|
|
129
|
+
[
|
|
130
|
+
{"role": "system", "content": "\n".join(system_prompts)},
|
|
131
|
+
*history_list,
|
|
132
|
+
{"role": "user", "content": user_prompt},
|
|
133
|
+
*[_estimate_attachment_payload(attachment) for attachment in attachments],
|
|
134
|
+
]
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _estimate_attachment_payload(attachment: "UserContent") -> Any:
|
|
139
|
+
if hasattr(attachment, "url"):
|
|
140
|
+
return {"role": "user", "content": attachment.url}
|
|
141
|
+
if hasattr(attachment, "data"):
|
|
142
|
+
return {"role": "user", "content": "x" * len(attachment.data)}
|
|
143
|
+
return ""
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _get_plain_printer(ctx: AnyContext):
|
|
147
|
+
def printer(*args, **kwargs):
|
|
148
|
+
if "plain" not in kwargs:
|
|
149
|
+
kwargs["plain"] = True
|
|
150
|
+
return ctx.print(*args, **kwargs)
|
|
151
|
+
|
|
152
|
+
return printer
|
zrb/task/llm/config.py
CHANGED
|
@@ -4,7 +4,7 @@ if TYPE_CHECKING:
|
|
|
4
4
|
from pydantic_ai.models import Model
|
|
5
5
|
from pydantic_ai.settings import ModelSettings
|
|
6
6
|
|
|
7
|
-
from zrb.attr.type import BoolAttr, StrAttr, StrListAttr
|
|
7
|
+
from zrb.attr.type import BoolAttr, StrAttr, StrListAttr
|
|
8
8
|
from zrb.config.llm_config import LLMConfig, llm_config
|
|
9
9
|
from zrb.context.any_context import AnyContext
|
|
10
10
|
from zrb.util.attr import get_attr, get_bool_attr, get_str_list_attr
|
|
@@ -12,7 +12,9 @@ from zrb.util.attr import get_attr, get_bool_attr, get_str_list_attr
|
|
|
12
12
|
|
|
13
13
|
def get_yolo_mode(
|
|
14
14
|
ctx: AnyContext,
|
|
15
|
-
yolo_mode_attr:
|
|
15
|
+
yolo_mode_attr: (
|
|
16
|
+
Callable[[AnyContext], list[str] | bool | None] | StrListAttr | BoolAttr | None
|
|
17
|
+
) = None,
|
|
16
18
|
render_yolo_mode: bool = True,
|
|
17
19
|
) -> bool | list[str]:
|
|
18
20
|
if yolo_mode_attr is None:
|
|
@@ -77,11 +79,11 @@ def get_model_api_key(
|
|
|
77
79
|
|
|
78
80
|
def get_model(
|
|
79
81
|
ctx: AnyContext,
|
|
80
|
-
model_attr: "Callable[[AnyContext], Model | str |
|
|
82
|
+
model_attr: "Callable[[AnyContext], Model | str | None] | Model | str | None",
|
|
81
83
|
render_model: bool,
|
|
82
|
-
model_base_url_attr:
|
|
84
|
+
model_base_url_attr: "Callable[[AnyContext], Model | str | None] | Model | str | None",
|
|
83
85
|
render_model_base_url: bool = True,
|
|
84
|
-
model_api_key_attr:
|
|
86
|
+
model_api_key_attr: "Callable[[AnyContext], Model | str | None] | Model | str | None" = None,
|
|
85
87
|
render_model_api_key: bool = True,
|
|
86
88
|
is_small_model: bool = False,
|
|
87
89
|
) -> "str | Model":
|