zrb 1.21.37__py3-none-any.whl → 1.21.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of zrb might be problematic. Click here for more details.
- zrb/builtin/llm/chat_completion.py +46 -0
- zrb/builtin/llm/chat_session.py +89 -29
- zrb/builtin/llm/chat_session_cmd.py +87 -11
- zrb/builtin/llm/chat_trigger.py +92 -5
- zrb/builtin/llm/history.py +14 -7
- zrb/builtin/llm/llm_ask.py +16 -7
- zrb/builtin/llm/tool/file.py +3 -2
- zrb/builtin/llm/tool/search/brave.py +2 -2
- zrb/builtin/llm/tool/search/searxng.py +2 -2
- zrb/builtin/llm/tool/search/serpapi.py +2 -2
- zrb/builtin/llm/xcom_names.py +3 -0
- zrb/callback/callback.py +8 -1
- zrb/config/config.py +1 -1
- zrb/context/context.py +11 -0
- zrb/task/base/context.py +25 -13
- zrb/task/base/execution.py +52 -47
- zrb/task/base/lifecycle.py +1 -1
- zrb/task/base_task.py +31 -45
- zrb/task/base_trigger.py +0 -1
- zrb/task/llm/agent.py +39 -31
- zrb/task/llm/agent_runner.py +59 -1
- zrb/task/llm/default_workflow/researching/workflow.md +2 -0
- zrb/task/llm/print_node.py +15 -2
- zrb/task/llm/prompt.py +70 -40
- zrb/task/llm/workflow.py +13 -1
- zrb/task/llm_task.py +83 -28
- zrb/util/run.py +3 -3
- {zrb-1.21.37.dist-info → zrb-1.21.43.dist-info}/METADATA +1 -1
- {zrb-1.21.37.dist-info → zrb-1.21.43.dist-info}/RECORD +31 -30
- {zrb-1.21.37.dist-info → zrb-1.21.43.dist-info}/WHEEL +0 -0
- {zrb-1.21.37.dist-info → zrb-1.21.43.dist-info}/entry_points.txt +0 -0
zrb/task/llm/agent_runner.py
CHANGED
|
@@ -1,4 +1,9 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import json
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import termios
|
|
6
|
+
import tty
|
|
2
7
|
from collections.abc import Callable
|
|
3
8
|
from typing import TYPE_CHECKING, Any
|
|
4
9
|
|
|
@@ -93,11 +98,17 @@ async def _run_single_agent_iteration(
|
|
|
93
98
|
message_history=ModelMessagesTypeAdapter.validate_python(history_list),
|
|
94
99
|
usage_limits=UsageLimits(request_limit=None), # We don't want limit
|
|
95
100
|
) as agent_run:
|
|
101
|
+
escape_task = asyncio.create_task(_wait_for_escape(ctx))
|
|
96
102
|
async for node in agent_run:
|
|
97
103
|
# Each node represents a step in the agent's execution
|
|
98
104
|
try:
|
|
99
105
|
await print_node(
|
|
100
|
-
_get_plain_printer(ctx),
|
|
106
|
+
_get_plain_printer(ctx),
|
|
107
|
+
agent_run,
|
|
108
|
+
node,
|
|
109
|
+
ctx.is_tty,
|
|
110
|
+
log_indent_level,
|
|
111
|
+
lambda: escape_task.done(),
|
|
101
112
|
)
|
|
102
113
|
except APIError as e:
|
|
103
114
|
# Extract detailed error information from the response
|
|
@@ -108,9 +119,56 @@ async def _run_single_agent_iteration(
|
|
|
108
119
|
ctx.log_error(f"Error processing node: {str(e)}")
|
|
109
120
|
ctx.log_error(f"Error type: {type(e).__name__}")
|
|
110
121
|
raise
|
|
122
|
+
if escape_task.done():
|
|
123
|
+
break
|
|
124
|
+
# Clean escape_task
|
|
125
|
+
if not escape_task.done():
|
|
126
|
+
try:
|
|
127
|
+
escape_task.cancel()
|
|
128
|
+
await escape_task
|
|
129
|
+
except asyncio.CancelledError:
|
|
130
|
+
pass
|
|
111
131
|
return agent_run
|
|
112
132
|
|
|
113
133
|
|
|
134
|
+
async def _wait_for_escape(ctx: AnyContext) -> None:
|
|
135
|
+
if not ctx.is_tty:
|
|
136
|
+
# Wait forever
|
|
137
|
+
await asyncio.Future()
|
|
138
|
+
return
|
|
139
|
+
loop = asyncio.get_event_loop()
|
|
140
|
+
future = loop.create_future()
|
|
141
|
+
fd = sys.stdin.fileno()
|
|
142
|
+
old_settings = termios.tcgetattr(fd)
|
|
143
|
+
try:
|
|
144
|
+
tty.setcbreak(fd)
|
|
145
|
+
loop.add_reader(fd, _create_escape_detector(ctx, future, fd))
|
|
146
|
+
await future
|
|
147
|
+
except asyncio.CancelledError:
|
|
148
|
+
raise
|
|
149
|
+
finally:
|
|
150
|
+
loop.remove_reader(fd)
|
|
151
|
+
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _create_escape_detector(
|
|
155
|
+
ctx: AnyContext, future: asyncio.Future[Any], fd: int | Any
|
|
156
|
+
) -> Callable[[], None]:
|
|
157
|
+
def on_stdin():
|
|
158
|
+
try:
|
|
159
|
+
# Read just one byte
|
|
160
|
+
ch = os.read(fd, 1)
|
|
161
|
+
if ch == b"\x1b":
|
|
162
|
+
ctx.print("\n🚫 Interrupted by user.", plain=True)
|
|
163
|
+
if not future.done():
|
|
164
|
+
future.set_result(None)
|
|
165
|
+
except Exception as e:
|
|
166
|
+
if not future.done():
|
|
167
|
+
future.set_exception(e)
|
|
168
|
+
|
|
169
|
+
return on_stdin
|
|
170
|
+
|
|
171
|
+
|
|
114
172
|
def _create_print_throttle_notif(ctx: AnyContext) -> Callable[[str], None]:
|
|
115
173
|
def _print_throttle_notif(text: str, *args: Any, **kwargs: Any):
|
|
116
174
|
new_line = kwargs.get("new_line", True)
|
|
@@ -9,6 +9,7 @@ Follow this workflow to deliver accurate, well-sourced, and synthesized research
|
|
|
9
9
|
- **Source Hierarchy:** Prioritize authoritative sources over secondary ones
|
|
10
10
|
- **Synthesis Excellence:** Connect information into coherent narratives
|
|
11
11
|
- **Comprehensive Attribution:** Cite all significant claims and data points
|
|
12
|
+
- **Mandatory Citations:** ALWAYS include a "Sources" section at the end of the response with a list of all URLs used.
|
|
12
13
|
|
|
13
14
|
# Tool Usage Guideline
|
|
14
15
|
- Use `search_internet` for web research and information gathering
|
|
@@ -114,6 +115,7 @@ Follow this workflow to deliver accurate, well-sourced, and synthesized research
|
|
|
114
115
|
- **Detailed Analysis:** Provide comprehensive information for deep dives
|
|
115
116
|
- **Actionable Insights:** Highlight implications and recommended actions
|
|
116
117
|
- **Further Reading:** Suggest additional resources for interested readers
|
|
118
|
+
- **Sources:** ALWAYS end the response with a "Sources" section listing all URLs used.
|
|
117
119
|
|
|
118
120
|
# Risk Assessment Guidelines
|
|
119
121
|
|
zrb/task/llm/print_node.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import json
|
|
2
3
|
from collections.abc import Callable
|
|
3
4
|
from typing import Any
|
|
@@ -7,7 +8,12 @@ from zrb.util.cli.style import stylize_faint
|
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
async def print_node(
|
|
10
|
-
print_func: Callable,
|
|
11
|
+
print_func: Callable,
|
|
12
|
+
agent_run: Any,
|
|
13
|
+
node: Any,
|
|
14
|
+
is_tty: bool,
|
|
15
|
+
log_indent_level: int = 0,
|
|
16
|
+
stop_check: Callable[[], bool] | None = None,
|
|
11
17
|
):
|
|
12
18
|
"""Prints the details of an agent execution node using a provided print function."""
|
|
13
19
|
from pydantic_ai import Agent
|
|
@@ -31,13 +37,17 @@ async def print_node(
|
|
|
31
37
|
return
|
|
32
38
|
if Agent.is_model_request_node(node):
|
|
33
39
|
# A model request node => We can stream tokens from the model's request
|
|
34
|
-
|
|
40
|
+
esc_notif = " (Press esc to cancel)" if is_tty else ""
|
|
41
|
+
print_func(_format_header(f"🧠 Processing{esc_notif}...", log_indent_level))
|
|
35
42
|
# Reference: https://ai.pydantic.dev/agents/#streaming-all-events-and-output
|
|
36
43
|
try:
|
|
37
44
|
async with node.stream(agent_run.ctx) as request_stream:
|
|
38
45
|
is_streaming = False
|
|
39
46
|
is_tool_processing = False
|
|
40
47
|
async for event in request_stream:
|
|
48
|
+
if stop_check and stop_check():
|
|
49
|
+
return
|
|
50
|
+
await asyncio.sleep(0)
|
|
41
51
|
if isinstance(event, PartStartEvent) and event.part:
|
|
42
52
|
if is_streaming:
|
|
43
53
|
print_func("")
|
|
@@ -127,6 +137,9 @@ async def print_node(
|
|
|
127
137
|
try:
|
|
128
138
|
async with node.stream(agent_run.ctx) as handle_stream:
|
|
129
139
|
async for event in handle_stream:
|
|
140
|
+
if stop_check and stop_check():
|
|
141
|
+
return
|
|
142
|
+
await asyncio.sleep(0)
|
|
130
143
|
if isinstance(event, FunctionToolCallEvent):
|
|
131
144
|
args = _get_event_part_args(event)
|
|
132
145
|
call_id = event.part.tool_call_id
|
zrb/task/llm/prompt.py
CHANGED
|
@@ -10,7 +10,7 @@ from zrb.context.any_context import AnyContext
|
|
|
10
10
|
from zrb.task.llm.conversation_history_model import ConversationHistory
|
|
11
11
|
from zrb.task.llm.workflow import LLMWorkflow, get_available_workflows
|
|
12
12
|
from zrb.util.attr import get_attr, get_str_attr, get_str_list_attr
|
|
13
|
-
from zrb.util.file import read_dir, read_file_with_line_numbers
|
|
13
|
+
from zrb.util.file import read_dir, read_file, read_file_with_line_numbers
|
|
14
14
|
from zrb.util.markdown import make_markdown_section
|
|
15
15
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
@@ -32,11 +32,8 @@ def get_system_and_user_prompt(
|
|
|
32
32
|
) -> tuple[str, str]:
|
|
33
33
|
if conversation_history is None:
|
|
34
34
|
conversation_history = ConversationHistory()
|
|
35
|
-
new_user_message_prompt, apendixes = _get_user_message_prompt(user_message)
|
|
36
35
|
new_system_prompt = _construct_system_prompt(
|
|
37
36
|
ctx=ctx,
|
|
38
|
-
user_message=user_message,
|
|
39
|
-
apendixes=apendixes,
|
|
40
37
|
persona_attr=persona_attr,
|
|
41
38
|
render_persona=render_persona,
|
|
42
39
|
system_prompt_attr=system_prompt_attr,
|
|
@@ -47,13 +44,12 @@ def get_system_and_user_prompt(
|
|
|
47
44
|
render_workflows=render_workflows,
|
|
48
45
|
conversation_history=conversation_history,
|
|
49
46
|
)
|
|
47
|
+
new_user_message_prompt = _get_user_message_prompt(user_message)
|
|
50
48
|
return new_system_prompt, new_user_message_prompt
|
|
51
49
|
|
|
52
50
|
|
|
53
51
|
def _construct_system_prompt(
|
|
54
52
|
ctx: AnyContext,
|
|
55
|
-
user_message: str,
|
|
56
|
-
apendixes: str,
|
|
57
53
|
persona_attr: StrAttr | None = None,
|
|
58
54
|
render_persona: bool = False,
|
|
59
55
|
system_prompt_attr: StrAttr | None = None,
|
|
@@ -71,6 +67,7 @@ def _construct_system_prompt(
|
|
|
71
67
|
special_instruction_prompt = _get_special_instruction_prompt(
|
|
72
68
|
ctx, special_instruction_prompt_attr, render_special_instruction_prompt
|
|
73
69
|
)
|
|
70
|
+
project_instructions = _get_project_instructions()
|
|
74
71
|
available_workflows = get_available_workflows()
|
|
75
72
|
active_workflow_names = set(
|
|
76
73
|
_get_active_workflow_names(ctx, workflows_attr, render_workflows)
|
|
@@ -98,6 +95,7 @@ def _construct_system_prompt(
|
|
|
98
95
|
]
|
|
99
96
|
),
|
|
100
97
|
),
|
|
98
|
+
make_markdown_section("📜 PROJECT INSTRUCTIONS", project_instructions),
|
|
101
99
|
make_markdown_section("🛠️ AVAILABLE WORKFLOWS", inactive_workflow_prompt),
|
|
102
100
|
make_markdown_section(
|
|
103
101
|
"📚 CONTEXT",
|
|
@@ -122,10 +120,6 @@ def _construct_system_prompt(
|
|
|
122
120
|
"📝 Contextual Note Content",
|
|
123
121
|
conversation_history.contextual_note,
|
|
124
122
|
),
|
|
125
|
-
make_markdown_section(
|
|
126
|
-
"📄 Apendixes",
|
|
127
|
-
apendixes,
|
|
128
|
-
),
|
|
129
123
|
]
|
|
130
124
|
),
|
|
131
125
|
),
|
|
@@ -133,21 +127,63 @@ def _construct_system_prompt(
|
|
|
133
127
|
)
|
|
134
128
|
|
|
135
129
|
|
|
130
|
+
def _get_project_instructions() -> str:
|
|
131
|
+
instructions = []
|
|
132
|
+
cwd = os.path.abspath(os.getcwd())
|
|
133
|
+
home = os.path.abspath(os.path.expanduser("~"))
|
|
134
|
+
search_dirs = []
|
|
135
|
+
if cwd == home or cwd.startswith(os.path.join(home, "")):
|
|
136
|
+
current_dir = cwd
|
|
137
|
+
while True:
|
|
138
|
+
search_dirs.append(current_dir)
|
|
139
|
+
if current_dir == home:
|
|
140
|
+
break
|
|
141
|
+
parent_dir = os.path.dirname(current_dir)
|
|
142
|
+
if parent_dir == current_dir:
|
|
143
|
+
break
|
|
144
|
+
current_dir = parent_dir
|
|
145
|
+
else:
|
|
146
|
+
search_dirs.append(cwd)
|
|
147
|
+
for file_name in ["AGENTS.md", "CLAUDE.md"]:
|
|
148
|
+
for dir_path in search_dirs:
|
|
149
|
+
abs_file_name = os.path.join(dir_path, file_name)
|
|
150
|
+
if os.path.isfile(abs_file_name):
|
|
151
|
+
content = read_file(abs_file_name)
|
|
152
|
+
instructions.append(
|
|
153
|
+
make_markdown_section(
|
|
154
|
+
f"Instruction from `{abs_file_name}`", content
|
|
155
|
+
)
|
|
156
|
+
)
|
|
157
|
+
break
|
|
158
|
+
return "\n".join(instructions)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _get_prompt_attr(
|
|
162
|
+
ctx: AnyContext,
|
|
163
|
+
attr: StrAttr | None,
|
|
164
|
+
render: bool,
|
|
165
|
+
default: str | None,
|
|
166
|
+
) -> str:
|
|
167
|
+
"""Generic helper to get a prompt attribute, prioritizing task-specific then default."""
|
|
168
|
+
value = get_attr(
|
|
169
|
+
ctx,
|
|
170
|
+
attr,
|
|
171
|
+
None,
|
|
172
|
+
auto_render=render,
|
|
173
|
+
)
|
|
174
|
+
if value is not None:
|
|
175
|
+
return value
|
|
176
|
+
return default or ""
|
|
177
|
+
|
|
178
|
+
|
|
136
179
|
def _get_persona(
|
|
137
180
|
ctx: AnyContext,
|
|
138
181
|
persona_attr: StrAttr | None,
|
|
139
182
|
render_persona: bool,
|
|
140
183
|
) -> str:
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
ctx,
|
|
144
|
-
persona_attr,
|
|
145
|
-
None,
|
|
146
|
-
auto_render=render_persona,
|
|
184
|
+
return _get_prompt_attr(
|
|
185
|
+
ctx, persona_attr, render_persona, llm_config.default_persona
|
|
147
186
|
)
|
|
148
|
-
if persona is not None:
|
|
149
|
-
return persona
|
|
150
|
-
return llm_config.default_persona or ""
|
|
151
187
|
|
|
152
188
|
|
|
153
189
|
def _get_base_system_prompt(
|
|
@@ -155,16 +191,9 @@ def _get_base_system_prompt(
|
|
|
155
191
|
system_prompt_attr: StrAttr | None,
|
|
156
192
|
render_system_prompt: bool,
|
|
157
193
|
) -> str:
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
ctx,
|
|
161
|
-
system_prompt_attr,
|
|
162
|
-
None,
|
|
163
|
-
auto_render=render_system_prompt,
|
|
194
|
+
return _get_prompt_attr(
|
|
195
|
+
ctx, system_prompt_attr, render_system_prompt, llm_config.default_system_prompt
|
|
164
196
|
)
|
|
165
|
-
if system_prompt is not None:
|
|
166
|
-
return system_prompt
|
|
167
|
-
return llm_config.default_system_prompt or ""
|
|
168
197
|
|
|
169
198
|
|
|
170
199
|
def _get_special_instruction_prompt(
|
|
@@ -172,16 +201,12 @@ def _get_special_instruction_prompt(
|
|
|
172
201
|
special_instruction_prompt_attr: StrAttr | None,
|
|
173
202
|
render_spcecial_instruction_prompt: bool,
|
|
174
203
|
) -> str:
|
|
175
|
-
|
|
176
|
-
special_instruction = get_attr(
|
|
204
|
+
return _get_prompt_attr(
|
|
177
205
|
ctx,
|
|
178
206
|
special_instruction_prompt_attr,
|
|
179
|
-
|
|
180
|
-
|
|
207
|
+
render_spcecial_instruction_prompt,
|
|
208
|
+
llm_config.default_special_instruction_prompt,
|
|
181
209
|
)
|
|
182
|
-
if special_instruction is not None:
|
|
183
|
-
return special_instruction
|
|
184
|
-
return llm_config.default_special_instruction_prompt
|
|
185
210
|
|
|
186
211
|
|
|
187
212
|
def _get_active_workflow_names(
|
|
@@ -229,7 +254,8 @@ def _get_workflow_prompt(
|
|
|
229
254
|
)
|
|
230
255
|
|
|
231
256
|
|
|
232
|
-
def _get_user_message_prompt(user_message: str) ->
|
|
257
|
+
def _get_user_message_prompt(user_message: str) -> str:
|
|
258
|
+
current_directory = os.getcwd()
|
|
233
259
|
processed_user_message = user_message
|
|
234
260
|
# Match “@” + any non-space/comma sequence that contains at least one “/”
|
|
235
261
|
pattern = r"(?<!\w)@(?=[^,\s]*\/)([^,\?\!\s]+)"
|
|
@@ -247,19 +273,19 @@ def _get_user_message_prompt(user_message: str) -> tuple[str, str]:
|
|
|
247
273
|
ref_type = "directory"
|
|
248
274
|
if content != "":
|
|
249
275
|
# Replace the @-reference in the user message with the placeholder
|
|
250
|
-
|
|
276
|
+
rel_path = os.path.relpath(resource_path, current_directory)
|
|
277
|
+
placeholder = f"`{rel_path}`"
|
|
251
278
|
processed_user_message = processed_user_message.replace(
|
|
252
279
|
f"@{ref}", placeholder, 1
|
|
253
280
|
)
|
|
254
281
|
apendix_list.append(
|
|
255
282
|
make_markdown_section(
|
|
256
|
-
f"
|
|
283
|
+
f"{placeholder} {ref_type}",
|
|
257
284
|
"\n".join(content) if isinstance(content, list) else content,
|
|
258
285
|
as_code=True,
|
|
259
286
|
)
|
|
260
287
|
)
|
|
261
288
|
apendixes = "\n".join(apendix_list)
|
|
262
|
-
current_directory = os.getcwd()
|
|
263
289
|
iso_date = datetime.now(timezone.utc).astimezone().isoformat()
|
|
264
290
|
modified_user_message = make_markdown_section(
|
|
265
291
|
"User Request",
|
|
@@ -269,10 +295,14 @@ def _get_user_message_prompt(user_message: str) -> tuple[str, str]:
|
|
|
269
295
|
f"- Current Time: {iso_date}",
|
|
270
296
|
"---",
|
|
271
297
|
processed_user_message,
|
|
298
|
+
make_markdown_section(
|
|
299
|
+
"📄 Apendixes",
|
|
300
|
+
apendixes,
|
|
301
|
+
),
|
|
272
302
|
]
|
|
273
303
|
),
|
|
274
304
|
)
|
|
275
|
-
return modified_user_message
|
|
305
|
+
return modified_user_message
|
|
276
306
|
|
|
277
307
|
|
|
278
308
|
def get_user_message(
|
zrb/task/llm/workflow.py
CHANGED
|
@@ -3,9 +3,13 @@ import os
|
|
|
3
3
|
from zrb.config.config import CFG
|
|
4
4
|
from zrb.config.llm_context.config import llm_context_config
|
|
5
5
|
from zrb.config.llm_context.workflow import LLMWorkflow
|
|
6
|
+
from zrb.context.any_context import AnyContext
|
|
7
|
+
from zrb.xcom.xcom import Xcom
|
|
6
8
|
|
|
9
|
+
LLM_LOADED_WORKFLOW_XCOM_NAME = "_llm_loaded_workflow_name"
|
|
7
10
|
|
|
8
|
-
|
|
11
|
+
|
|
12
|
+
def load_workflow(ctx: AnyContext, workflow_name: str | list[str]) -> str:
|
|
9
13
|
"""
|
|
10
14
|
Loads and formats one or more workflow documents for LLM consumption.
|
|
11
15
|
|
|
@@ -36,9 +40,17 @@ def load_workflow(workflow_name: str | list[str]) -> str:
|
|
|
36
40
|
]
|
|
37
41
|
)
|
|
38
42
|
)
|
|
43
|
+
llm_loaded_workflow_xcom = get_llm_loaded_workflow_xcom(ctx)
|
|
44
|
+
llm_loaded_workflow_xcom.push(names)
|
|
39
45
|
return "\n".join(contents)
|
|
40
46
|
|
|
41
47
|
|
|
48
|
+
def get_llm_loaded_workflow_xcom(ctx: AnyContext) -> Xcom:
|
|
49
|
+
if LLM_LOADED_WORKFLOW_XCOM_NAME not in ctx.xcom:
|
|
50
|
+
ctx.xcom[LLM_LOADED_WORKFLOW_XCOM_NAME] = Xcom([])
|
|
51
|
+
return ctx.xcom[LLM_LOADED_WORKFLOW_XCOM_NAME]
|
|
52
|
+
|
|
53
|
+
|
|
42
54
|
def get_available_workflows() -> dict[str, LLMWorkflow]:
|
|
43
55
|
available_workflows = {
|
|
44
56
|
workflow_name.strip().lower(): workflow
|
zrb/task/llm_task.py
CHANGED
|
@@ -82,7 +82,7 @@ class LLMTask(BaseTask):
|
|
|
82
82
|
render_system_prompt: bool = False,
|
|
83
83
|
special_instruction_prompt: "Callable[[AnyContext], str | None] | str | None" = None,
|
|
84
84
|
render_special_instruction_prompt: bool = False,
|
|
85
|
-
workflows:
|
|
85
|
+
workflows: "Callable[[AnyContext], list[str] | None] | list[str] | None" = None,
|
|
86
86
|
render_workflows: bool = True,
|
|
87
87
|
message: StrAttr | None = None,
|
|
88
88
|
attachment: "UserContent | list[UserContent] | Callable[[AnyContext], UserContent | list[UserContent]] | None" = None, # noqa
|
|
@@ -235,7 +235,58 @@ class LLMTask(BaseTask):
|
|
|
235
235
|
self._yolo_mode = yolo_mode
|
|
236
236
|
|
|
237
237
|
async def _exec_action(self, ctx: AnyContext) -> Any:
|
|
238
|
-
# Get dependent configurations
|
|
238
|
+
# 1. Get dependent configurations
|
|
239
|
+
(
|
|
240
|
+
model_settings,
|
|
241
|
+
model,
|
|
242
|
+
yolo_mode,
|
|
243
|
+
summarization_prompt,
|
|
244
|
+
user_message,
|
|
245
|
+
attachments,
|
|
246
|
+
) = self._get_llm_config(ctx)
|
|
247
|
+
|
|
248
|
+
# 2. Prepare initial state (read history from previous session)
|
|
249
|
+
conversation_history = await read_conversation_history(
|
|
250
|
+
ctx=ctx,
|
|
251
|
+
conversation_history_reader=self._conversation_history_reader,
|
|
252
|
+
conversation_history_file_attr=self._conversation_history_file,
|
|
253
|
+
render_history_file=self._render_history_file,
|
|
254
|
+
conversation_history_attr=self._conversation_history,
|
|
255
|
+
)
|
|
256
|
+
inject_conversation_history_notes(conversation_history)
|
|
257
|
+
inject_subagent_conversation_history_into_ctx(ctx, conversation_history)
|
|
258
|
+
|
|
259
|
+
# 3. Get system prompt and user prompt
|
|
260
|
+
system_prompt, user_prompt = self._get_prompts(
|
|
261
|
+
ctx, user_message, conversation_history
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# 4. Get the agent instance
|
|
265
|
+
ctx.log_info(f"SYSTEM PROMPT:\n{system_prompt}")
|
|
266
|
+
ctx.log_info(f"USER PROMPT:\n{user_prompt}")
|
|
267
|
+
agent = self._create_agent(
|
|
268
|
+
ctx,
|
|
269
|
+
model,
|
|
270
|
+
system_prompt,
|
|
271
|
+
model_settings,
|
|
272
|
+
yolo_mode,
|
|
273
|
+
summarization_prompt,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# 5. Run the agent iteration
|
|
277
|
+
result = await self._execute_agent(
|
|
278
|
+
ctx=ctx,
|
|
279
|
+
agent=agent,
|
|
280
|
+
user_prompt=user_prompt,
|
|
281
|
+
attachments=attachments,
|
|
282
|
+
conversation_history=conversation_history,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# 6. Save history and usage
|
|
286
|
+
await self._save_history_and_usage(ctx, conversation_history)
|
|
287
|
+
return result
|
|
288
|
+
|
|
289
|
+
def _get_llm_config(self, ctx: AnyContext):
|
|
239
290
|
model_settings = get_model_settings(ctx, self._model_settings)
|
|
240
291
|
model = get_model(
|
|
241
292
|
ctx=ctx,
|
|
@@ -258,18 +309,22 @@ class LLMTask(BaseTask):
|
|
|
258
309
|
)
|
|
259
310
|
user_message = get_user_message(ctx, self._message, self._render_message)
|
|
260
311
|
attachments = get_attachments(ctx, self._attachment)
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
312
|
+
return (
|
|
313
|
+
model_settings,
|
|
314
|
+
model,
|
|
315
|
+
yolo_mode,
|
|
316
|
+
summarization_prompt,
|
|
317
|
+
user_message,
|
|
318
|
+
attachments,
|
|
268
319
|
)
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
320
|
+
|
|
321
|
+
def _get_prompts(
|
|
322
|
+
self,
|
|
323
|
+
ctx: AnyContext,
|
|
324
|
+
user_message: str,
|
|
325
|
+
conversation_history: ConversationHistory,
|
|
326
|
+
):
|
|
327
|
+
return get_system_and_user_prompt(
|
|
273
328
|
ctx=ctx,
|
|
274
329
|
user_message=user_message,
|
|
275
330
|
persona_attr=self._persona,
|
|
@@ -282,7 +337,16 @@ class LLMTask(BaseTask):
|
|
|
282
337
|
render_workflows=self._render_workflows,
|
|
283
338
|
conversation_history=conversation_history,
|
|
284
339
|
)
|
|
285
|
-
|
|
340
|
+
|
|
341
|
+
def _create_agent(
|
|
342
|
+
self,
|
|
343
|
+
ctx: AnyContext,
|
|
344
|
+
model,
|
|
345
|
+
system_prompt,
|
|
346
|
+
model_settings,
|
|
347
|
+
yolo_mode,
|
|
348
|
+
summarization_prompt,
|
|
349
|
+
):
|
|
286
350
|
small_model = get_model(
|
|
287
351
|
ctx=ctx,
|
|
288
352
|
model_attr=self._small_model,
|
|
@@ -298,10 +362,7 @@ class LLMTask(BaseTask):
|
|
|
298
362
|
self._history_summarization_token_threshold,
|
|
299
363
|
self._render_history_summarization_token_threshold,
|
|
300
364
|
)
|
|
301
|
-
|
|
302
|
-
ctx.log_info(f"SYSTEM PROMPT:\n{system_prompt}")
|
|
303
|
-
ctx.log_info(f"USER PROMPT:\n{user_prompt}")
|
|
304
|
-
agent = get_agent(
|
|
365
|
+
return get_agent(
|
|
305
366
|
ctx=ctx,
|
|
306
367
|
model=model,
|
|
307
368
|
rate_limitter=self._rate_limitter,
|
|
@@ -319,15 +380,10 @@ class LLMTask(BaseTask):
|
|
|
319
380
|
summarization_token_threshold=summarization_token_threshold,
|
|
320
381
|
history_processors=[], # TODO: make this a property
|
|
321
382
|
)
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
user_prompt=user_prompt,
|
|
327
|
-
attachments=attachments,
|
|
328
|
-
conversation_history=conversation_history,
|
|
329
|
-
)
|
|
330
|
-
# 6. Write conversation history
|
|
383
|
+
|
|
384
|
+
async def _save_history_and_usage(
|
|
385
|
+
self, ctx: AnyContext, conversation_history: ConversationHistory
|
|
386
|
+
):
|
|
331
387
|
conversation_history.subagent_history = (
|
|
332
388
|
extract_subagent_conversation_history_from_ctx(ctx)
|
|
333
389
|
)
|
|
@@ -338,7 +394,6 @@ class LLMTask(BaseTask):
|
|
|
338
394
|
conversation_history_file_attr=self._conversation_history_file,
|
|
339
395
|
render_history_file=self._render_history_file,
|
|
340
396
|
)
|
|
341
|
-
return result
|
|
342
397
|
|
|
343
398
|
async def _execute_agent(
|
|
344
399
|
self,
|
zrb/util/run.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Any
|
|
|
5
5
|
|
|
6
6
|
async def run_async(value: Any) -> Any:
|
|
7
7
|
"""
|
|
8
|
-
Run a value asynchronously, awaiting if it's awaitable or
|
|
8
|
+
Run a value asynchronously, awaiting if it's awaitable or returning it directly.
|
|
9
9
|
|
|
10
10
|
Args:
|
|
11
11
|
value (Any): The value to run. Can be awaitable or not.
|
|
@@ -14,7 +14,7 @@ async def run_async(value: Any) -> Any:
|
|
|
14
14
|
Any: The result of the awaited value or the value itself if not awaitable.
|
|
15
15
|
"""
|
|
16
16
|
if isinstance(value, asyncio.Task):
|
|
17
|
-
return value
|
|
17
|
+
return await value
|
|
18
18
|
if inspect.isawaitable(value):
|
|
19
19
|
return await value
|
|
20
|
-
return
|
|
20
|
+
return value
|