zrb 1.15.3__py3-none-any.whl → 2.0.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of zrb might be problematic. Click here for more details.
- zrb/__init__.py +118 -133
- zrb/attr/type.py +10 -7
- zrb/builtin/__init__.py +55 -1
- zrb/builtin/git.py +12 -1
- zrb/builtin/group.py +31 -15
- zrb/builtin/llm/chat.py +147 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
- zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
- zrb/builtin/searxng/config/settings.yml +5671 -0
- zrb/builtin/searxng/start.py +21 -0
- zrb/builtin/shell/autocomplete/bash.py +4 -3
- zrb/builtin/shell/autocomplete/zsh.py +4 -3
- zrb/callback/callback.py +8 -1
- zrb/cmd/cmd_result.py +2 -1
- zrb/config/config.py +555 -169
- zrb/config/helper.py +84 -0
- zrb/config/web_auth_config.py +50 -35
- zrb/context/any_shared_context.py +20 -3
- zrb/context/context.py +39 -5
- zrb/context/print_fn.py +13 -0
- zrb/context/shared_context.py +17 -8
- zrb/group/any_group.py +3 -3
- zrb/group/group.py +3 -3
- zrb/input/any_input.py +5 -1
- zrb/input/base_input.py +18 -6
- zrb/input/option_input.py +41 -1
- zrb/input/text_input.py +7 -24
- zrb/llm/agent/__init__.py +9 -0
- zrb/llm/agent/agent.py +215 -0
- zrb/llm/agent/summarizer.py +20 -0
- zrb/llm/app/__init__.py +10 -0
- zrb/llm/app/completion.py +281 -0
- zrb/llm/app/confirmation/allow_tool.py +66 -0
- zrb/llm/app/confirmation/handler.py +178 -0
- zrb/llm/app/confirmation/replace_confirmation.py +77 -0
- zrb/llm/app/keybinding.py +34 -0
- zrb/llm/app/layout.py +117 -0
- zrb/llm/app/lexer.py +155 -0
- zrb/llm/app/redirection.py +28 -0
- zrb/llm/app/style.py +16 -0
- zrb/llm/app/ui.py +733 -0
- zrb/llm/config/__init__.py +4 -0
- zrb/llm/config/config.py +122 -0
- zrb/llm/config/limiter.py +247 -0
- zrb/llm/history_manager/__init__.py +4 -0
- zrb/llm/history_manager/any_history_manager.py +23 -0
- zrb/llm/history_manager/file_history_manager.py +91 -0
- zrb/llm/history_processor/summarizer.py +108 -0
- zrb/llm/note/__init__.py +3 -0
- zrb/llm/note/manager.py +122 -0
- zrb/llm/prompt/__init__.py +29 -0
- zrb/llm/prompt/claude_compatibility.py +92 -0
- zrb/llm/prompt/compose.py +55 -0
- zrb/llm/prompt/default.py +51 -0
- zrb/llm/prompt/markdown/file_extractor.md +112 -0
- zrb/llm/prompt/markdown/mandate.md +23 -0
- zrb/llm/prompt/markdown/persona.md +3 -0
- zrb/llm/prompt/markdown/repo_extractor.md +112 -0
- zrb/llm/prompt/markdown/repo_summarizer.md +29 -0
- zrb/llm/prompt/markdown/summarizer.md +21 -0
- zrb/llm/prompt/note.py +41 -0
- zrb/llm/prompt/system_context.py +46 -0
- zrb/llm/prompt/zrb.py +41 -0
- zrb/llm/skill/__init__.py +3 -0
- zrb/llm/skill/manager.py +86 -0
- zrb/llm/task/__init__.py +4 -0
- zrb/llm/task/llm_chat_task.py +316 -0
- zrb/llm/task/llm_task.py +245 -0
- zrb/llm/tool/__init__.py +39 -0
- zrb/llm/tool/bash.py +75 -0
- zrb/llm/tool/code.py +266 -0
- zrb/llm/tool/file.py +419 -0
- zrb/llm/tool/note.py +70 -0
- zrb/{builtin/llm → llm}/tool/rag.py +33 -37
- zrb/llm/tool/search/brave.py +53 -0
- zrb/llm/tool/search/searxng.py +47 -0
- zrb/llm/tool/search/serpapi.py +47 -0
- zrb/llm/tool/skill.py +19 -0
- zrb/llm/tool/sub_agent.py +70 -0
- zrb/llm/tool/web.py +97 -0
- zrb/llm/tool/zrb_task.py +66 -0
- zrb/llm/util/attachment.py +101 -0
- zrb/llm/util/prompt.py +104 -0
- zrb/llm/util/stream_response.py +178 -0
- zrb/runner/cli.py +21 -20
- zrb/runner/common_util.py +24 -19
- zrb/runner/web_route/task_input_api_route.py +5 -5
- zrb/runner/web_util/user.py +7 -3
- zrb/session/any_session.py +12 -9
- zrb/session/session.py +38 -17
- zrb/task/any_task.py +24 -3
- zrb/task/base/context.py +42 -22
- zrb/task/base/execution.py +67 -55
- zrb/task/base/lifecycle.py +14 -7
- zrb/task/base/monitoring.py +12 -7
- zrb/task/base_task.py +113 -50
- zrb/task/base_trigger.py +16 -6
- zrb/task/cmd_task.py +6 -0
- zrb/task/http_check.py +11 -5
- zrb/task/make_task.py +5 -3
- zrb/task/rsync_task.py +30 -10
- zrb/task/scaffolder.py +7 -4
- zrb/task/scheduler.py +7 -4
- zrb/task/tcp_check.py +6 -4
- zrb/util/ascii_art/art/bee.txt +17 -0
- zrb/util/ascii_art/art/cat.txt +9 -0
- zrb/util/ascii_art/art/ghost.txt +16 -0
- zrb/util/ascii_art/art/panda.txt +17 -0
- zrb/util/ascii_art/art/rose.txt +14 -0
- zrb/util/ascii_art/art/unicorn.txt +15 -0
- zrb/util/ascii_art/banner.py +92 -0
- zrb/util/attr.py +54 -39
- zrb/util/cli/markdown.py +32 -0
- zrb/util/cli/text.py +30 -0
- zrb/util/cmd/command.py +33 -10
- zrb/util/file.py +61 -33
- zrb/util/git.py +2 -2
- zrb/util/{llm/prompt.py → markdown.py} +2 -3
- zrb/util/match.py +78 -0
- zrb/util/run.py +3 -3
- zrb/util/string/conversion.py +1 -1
- zrb/util/truncate.py +23 -0
- zrb/util/yaml.py +204 -0
- zrb/xcom/xcom.py +10 -0
- {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/METADATA +41 -27
- {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/RECORD +129 -131
- {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/WHEEL +1 -1
- zrb/attr/__init__.py +0 -0
- zrb/builtin/llm/chat_session.py +0 -311
- zrb/builtin/llm/history.py +0 -71
- zrb/builtin/llm/input.py +0 -27
- zrb/builtin/llm/llm_ask.py +0 -187
- zrb/builtin/llm/previous-session.js +0 -21
- zrb/builtin/llm/tool/__init__.py +0 -0
- zrb/builtin/llm/tool/api.py +0 -71
- zrb/builtin/llm/tool/cli.py +0 -38
- zrb/builtin/llm/tool/code.py +0 -254
- zrb/builtin/llm/tool/file.py +0 -626
- zrb/builtin/llm/tool/sub_agent.py +0 -137
- zrb/builtin/llm/tool/web.py +0 -195
- zrb/builtin/project/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/my_module/service/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/__init__.py +0 -0
- zrb/builtin/project/create/__init__.py +0 -0
- zrb/builtin/shell/__init__.py +0 -0
- zrb/builtin/shell/autocomplete/__init__.py +0 -0
- zrb/callback/__init__.py +0 -0
- zrb/cmd/__init__.py +0 -0
- zrb/config/default_prompt/file_extractor_system_prompt.md +0 -12
- zrb/config/default_prompt/interactive_system_prompt.md +0 -35
- zrb/config/default_prompt/persona.md +0 -1
- zrb/config/default_prompt/repo_extractor_system_prompt.md +0 -112
- zrb/config/default_prompt/repo_summarizer_system_prompt.md +0 -10
- zrb/config/default_prompt/summarization_prompt.md +0 -16
- zrb/config/default_prompt/system_prompt.md +0 -32
- zrb/config/llm_config.py +0 -243
- zrb/config/llm_context/config.py +0 -129
- zrb/config/llm_context/config_parser.py +0 -46
- zrb/config/llm_rate_limitter.py +0 -137
- zrb/content_transformer/__init__.py +0 -0
- zrb/context/__init__.py +0 -0
- zrb/dot_dict/__init__.py +0 -0
- zrb/env/__init__.py +0 -0
- zrb/group/__init__.py +0 -0
- zrb/input/__init__.py +0 -0
- zrb/runner/__init__.py +0 -0
- zrb/runner/web_route/__init__.py +0 -0
- zrb/runner/web_route/home_page/__init__.py +0 -0
- zrb/session/__init__.py +0 -0
- zrb/session_state_log/__init__.py +0 -0
- zrb/session_state_logger/__init__.py +0 -0
- zrb/task/__init__.py +0 -0
- zrb/task/base/__init__.py +0 -0
- zrb/task/llm/__init__.py +0 -0
- zrb/task/llm/agent.py +0 -243
- zrb/task/llm/config.py +0 -103
- zrb/task/llm/conversation_history.py +0 -128
- zrb/task/llm/conversation_history_model.py +0 -242
- zrb/task/llm/default_workflow/coding.md +0 -24
- zrb/task/llm/default_workflow/copywriting.md +0 -17
- zrb/task/llm/default_workflow/researching.md +0 -18
- zrb/task/llm/error.py +0 -95
- zrb/task/llm/history_summarization.py +0 -216
- zrb/task/llm/print_node.py +0 -101
- zrb/task/llm/prompt.py +0 -325
- zrb/task/llm/tool_wrapper.py +0 -220
- zrb/task/llm/typing.py +0 -3
- zrb/task/llm_task.py +0 -341
- zrb/task_status/__init__.py +0 -0
- zrb/util/__init__.py +0 -0
- zrb/util/cli/__init__.py +0 -0
- zrb/util/cmd/__init__.py +0 -0
- zrb/util/codemod/__init__.py +0 -0
- zrb/util/string/__init__.py +0 -0
- zrb/xcom/__init__.py +0 -0
- {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
3
|
+
|
|
4
|
+
from zrb.context.any_context import AnyContext
|
|
5
|
+
from zrb.util.cli.style import stylize_faint
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from pydantic_ai import AgentStreamEvent
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def create_event_handler(
|
|
12
|
+
print_event: Callable[..., None],
|
|
13
|
+
indent_level: int = 1,
|
|
14
|
+
show_tool_call_detail: bool = False,
|
|
15
|
+
show_tool_result: bool = False,
|
|
16
|
+
):
|
|
17
|
+
from pydantic_ai import (
|
|
18
|
+
AgentRunResultEvent,
|
|
19
|
+
FinalResultEvent,
|
|
20
|
+
FunctionToolCallEvent,
|
|
21
|
+
FunctionToolResultEvent,
|
|
22
|
+
PartDeltaEvent,
|
|
23
|
+
PartStartEvent,
|
|
24
|
+
TextPartDelta,
|
|
25
|
+
ThinkingPartDelta,
|
|
26
|
+
ToolCallPartDelta,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
indentation = indent_level * 2 * " "
|
|
30
|
+
progress_char_list = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
|
31
|
+
progress_char_index = 0
|
|
32
|
+
was_tool_call_delta = False
|
|
33
|
+
event_prefix = indentation
|
|
34
|
+
|
|
35
|
+
def fprint(content: str, preserve_leading_newline: bool = False):
|
|
36
|
+
if preserve_leading_newline and content.startswith("\n"):
|
|
37
|
+
return print_event("\n" + content[1:].replace("\n", f"\n{indentation} "))
|
|
38
|
+
return print_event(content.replace("\n", f"\n{indentation} "))
|
|
39
|
+
|
|
40
|
+
async def handle_event(event: "AgentStreamEvent"):
|
|
41
|
+
from pydantic_ai import ToolCallPart
|
|
42
|
+
|
|
43
|
+
nonlocal progress_char_index, was_tool_call_delta, event_prefix
|
|
44
|
+
if isinstance(event, PartStartEvent):
|
|
45
|
+
# Skip ToolCallPart start, we handle it in Deltas/CallEvent
|
|
46
|
+
if isinstance(event.part, ToolCallPart):
|
|
47
|
+
return
|
|
48
|
+
content = _get_event_part_content(event)
|
|
49
|
+
# Use preserve_leading_newline=True because event_prefix contains the correctly indented newline
|
|
50
|
+
fprint(f"{event_prefix}🧠 {content}", preserve_leading_newline=True)
|
|
51
|
+
was_tool_call_delta = False
|
|
52
|
+
elif isinstance(event, PartDeltaEvent):
|
|
53
|
+
if isinstance(event.delta, TextPartDelta):
|
|
54
|
+
# Standard fprint for deltas to ensure wrapping indentation
|
|
55
|
+
fprint(f"{event.delta.content_delta}")
|
|
56
|
+
was_tool_call_delta = False
|
|
57
|
+
elif isinstance(event.delta, ThinkingPartDelta):
|
|
58
|
+
fprint(f"{event.delta.content_delta}")
|
|
59
|
+
was_tool_call_delta = False
|
|
60
|
+
elif isinstance(event.delta, ToolCallPartDelta):
|
|
61
|
+
if show_tool_call_detail:
|
|
62
|
+
fprint(f"{event.delta.args_delta}")
|
|
63
|
+
else:
|
|
64
|
+
progress_char = progress_char_list[progress_char_index]
|
|
65
|
+
if not was_tool_call_delta:
|
|
66
|
+
# Print newline for tool param spinner
|
|
67
|
+
fprint("\n")
|
|
68
|
+
|
|
69
|
+
# Split \r to avoid UI._append_to_output stripping the ANSI start code along with the line
|
|
70
|
+
print_event("\r")
|
|
71
|
+
print_event(
|
|
72
|
+
f"{indentation}🔄 Prepare tool parameters {progress_char}"
|
|
73
|
+
)
|
|
74
|
+
progress_char_index += 1
|
|
75
|
+
if progress_char_index >= len(progress_char_list):
|
|
76
|
+
progress_char_index = 0
|
|
77
|
+
was_tool_call_delta = True
|
|
78
|
+
elif isinstance(event, FunctionToolCallEvent):
|
|
79
|
+
args = _get_truncated_event_part_args(event)
|
|
80
|
+
# Use preserve_leading_newline=True for the block header
|
|
81
|
+
fprint(
|
|
82
|
+
f"{event_prefix}🧰 {event.part.tool_call_id} | {event.part.tool_name} {args}",
|
|
83
|
+
preserve_leading_newline=True,
|
|
84
|
+
)
|
|
85
|
+
was_tool_call_delta = False
|
|
86
|
+
elif isinstance(event, FunctionToolResultEvent):
|
|
87
|
+
if show_tool_result:
|
|
88
|
+
fprint(
|
|
89
|
+
f"{event_prefix}🔠 {event.tool_call_id} | Return {event.result.content}",
|
|
90
|
+
preserve_leading_newline=True,
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
fprint(
|
|
94
|
+
f"{event_prefix}🔠 {event.tool_call_id} Executed",
|
|
95
|
+
preserve_leading_newline=True,
|
|
96
|
+
)
|
|
97
|
+
was_tool_call_delta = False
|
|
98
|
+
elif isinstance(event, AgentRunResultEvent):
|
|
99
|
+
usage = event.result.usage()
|
|
100
|
+
usage_msg = " ".join(
|
|
101
|
+
[
|
|
102
|
+
"💸",
|
|
103
|
+
f"(Requests: {usage.requests} |",
|
|
104
|
+
f"Tool Calls: {usage.tool_calls} |",
|
|
105
|
+
f"Total: {usage.total_tokens})",
|
|
106
|
+
f"Input: {usage.input_tokens} |",
|
|
107
|
+
f"Audio Input: {usage.input_audio_tokens} |",
|
|
108
|
+
f"Output: {usage.output_tokens} |",
|
|
109
|
+
f"Audio Output: {usage.output_audio_tokens} |",
|
|
110
|
+
f"Cache Read: {usage.cache_read_tokens} |",
|
|
111
|
+
f"Cache Write: {usage.cache_write_tokens} |",
|
|
112
|
+
f"Details: {usage.details}",
|
|
113
|
+
]
|
|
114
|
+
)
|
|
115
|
+
fprint(
|
|
116
|
+
f"{event_prefix}{stylize_faint(usage_msg)}\n",
|
|
117
|
+
preserve_leading_newline=True,
|
|
118
|
+
)
|
|
119
|
+
was_tool_call_delta = False
|
|
120
|
+
elif isinstance(event, FinalResultEvent):
|
|
121
|
+
was_tool_call_delta = False
|
|
122
|
+
event_prefix = f"\n{indentation}"
|
|
123
|
+
|
|
124
|
+
return handle_event
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def create_faint_printer(ctx: AnyContext):
|
|
128
|
+
def faint_print(*values: object):
|
|
129
|
+
message = stylize_faint(" ".join([f"{value}" for value in values]))
|
|
130
|
+
ctx.print(message, end="", plain=True)
|
|
131
|
+
|
|
132
|
+
return faint_print
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _get_truncated_event_part_args(event: "AgentStreamEvent") -> Any:
|
|
136
|
+
# Handle empty arguments across different providers
|
|
137
|
+
if not hasattr(event, "part"):
|
|
138
|
+
return {}
|
|
139
|
+
part = getattr(event, "part")
|
|
140
|
+
if not hasattr(part, "args"):
|
|
141
|
+
return {}
|
|
142
|
+
args = getattr(part, "args")
|
|
143
|
+
if args == "" or args is None:
|
|
144
|
+
return {}
|
|
145
|
+
if isinstance(args, str):
|
|
146
|
+
# Some providers might send "null" or "{}" as a string
|
|
147
|
+
if args.strip() in ["null", "{}"]:
|
|
148
|
+
return {}
|
|
149
|
+
try:
|
|
150
|
+
obj = json.loads(args)
|
|
151
|
+
if isinstance(obj, dict):
|
|
152
|
+
return _truncate_kwargs(obj)
|
|
153
|
+
except json.JSONDecodeError:
|
|
154
|
+
pass
|
|
155
|
+
# Handle dummy property if present (from our schema sanitization)
|
|
156
|
+
if isinstance(args, dict):
|
|
157
|
+
return _truncate_kwargs(args)
|
|
158
|
+
return args
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _truncate_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
162
|
+
return {key: _truncate_arg(val) for key, val in kwargs.items()}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _truncate_arg(arg: str, length: int = 30) -> str:
|
|
166
|
+
if isinstance(arg, str) and len(arg) > length:
|
|
167
|
+
return f"{arg[:length-4]} ..."
|
|
168
|
+
return arg
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _get_event_part_content(event: "AgentStreamEvent") -> str:
|
|
172
|
+
if not hasattr(event, "part"):
|
|
173
|
+
return ""
|
|
174
|
+
part = getattr(event, "part")
|
|
175
|
+
if hasattr(part, "content"):
|
|
176
|
+
return getattr(part, "content")
|
|
177
|
+
# For parts without content (like ToolCallPart, though we skip it now), return empty or simple repr
|
|
178
|
+
return ""
|
zrb/runner/cli.py
CHANGED
|
@@ -7,7 +7,7 @@ from zrb.context.any_context import AnyContext
|
|
|
7
7
|
from zrb.context.shared_context import SharedContext
|
|
8
8
|
from zrb.group.any_group import AnyGroup
|
|
9
9
|
from zrb.group.group import Group
|
|
10
|
-
from zrb.runner.common_util import
|
|
10
|
+
from zrb.runner.common_util import get_task_str_kwargs
|
|
11
11
|
from zrb.session.session import Session
|
|
12
12
|
from zrb.session_state_logger.session_state_logger_factory import session_state_logger
|
|
13
13
|
from zrb.task.any_task import AnyTask
|
|
@@ -38,23 +38,25 @@ class Cli(Group):
|
|
|
38
38
|
def banner(self) -> str:
|
|
39
39
|
return CFG.BANNER
|
|
40
40
|
|
|
41
|
-
def run(self,
|
|
42
|
-
|
|
43
|
-
node, node_path,
|
|
41
|
+
def run(self, str_args: list[str] = []):
|
|
42
|
+
str_kwargs, str_args = self._extract_kwargs_from_args(str_args)
|
|
43
|
+
node, node_path, str_args = extract_node_from_args(self, str_args)
|
|
44
44
|
if isinstance(node, AnyGroup):
|
|
45
45
|
self._show_group_info(node)
|
|
46
46
|
return
|
|
47
|
-
if "h" in
|
|
47
|
+
if "h" in str_kwargs or "help" in str_kwargs:
|
|
48
48
|
self._show_task_info(node)
|
|
49
49
|
return
|
|
50
|
-
|
|
50
|
+
task_str_kwargs = get_task_str_kwargs(
|
|
51
|
+
task=node, str_args=str_args, str_kwargs=str_kwargs, cli_mode=True
|
|
52
|
+
)
|
|
51
53
|
try:
|
|
52
|
-
result = self._run_task(node,
|
|
54
|
+
result = self._run_task(node, str_args, task_str_kwargs)
|
|
53
55
|
if result is not None:
|
|
54
56
|
print(result)
|
|
55
57
|
return result
|
|
56
58
|
finally:
|
|
57
|
-
run_command = self._get_run_command(node_path,
|
|
59
|
+
run_command = self._get_run_command(node_path, task_str_kwargs)
|
|
58
60
|
self._print_run_command(run_command)
|
|
59
61
|
|
|
60
62
|
def _print_run_command(self, run_command: str):
|
|
@@ -64,11 +66,14 @@ class Cli(Group):
|
|
|
64
66
|
file=sys.stderr,
|
|
65
67
|
)
|
|
66
68
|
|
|
67
|
-
def _get_run_command(
|
|
69
|
+
def _get_run_command(
|
|
70
|
+
self, node_path: list[str], task_str_kwargs: dict[str, str]
|
|
71
|
+
) -> str:
|
|
68
72
|
parts = [self.name] + node_path
|
|
69
|
-
if len(
|
|
73
|
+
if len(task_str_kwargs) > 0:
|
|
70
74
|
parts += [
|
|
71
|
-
self._get_run_command_param(key, val)
|
|
75
|
+
self._get_run_command_param(key, val)
|
|
76
|
+
for key, val in task_str_kwargs.items()
|
|
72
77
|
]
|
|
73
78
|
return " ".join(parts)
|
|
74
79
|
|
|
@@ -81,13 +86,9 @@ class Cli(Group):
|
|
|
81
86
|
self, task: AnyTask, args: list[str], run_kwargs: dict[str, str]
|
|
82
87
|
) -> tuple[Any]:
|
|
83
88
|
shared_ctx = SharedContext(args=args)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
shared_ctx, run_kwargs[task_input.name]
|
|
88
|
-
)
|
|
89
|
-
continue
|
|
90
|
-
return task.run(Session(shared_ctx=shared_ctx, root_group=self))
|
|
89
|
+
return task.run(
|
|
90
|
+
Session(shared_ctx=shared_ctx, root_group=self), str_kwargs=run_kwargs
|
|
91
|
+
)
|
|
91
92
|
|
|
92
93
|
def _show_task_info(self, task: AnyTask):
|
|
93
94
|
description = task.description
|
|
@@ -150,11 +151,11 @@ class Cli(Group):
|
|
|
150
151
|
kwargs[key] = args[i + 1]
|
|
151
152
|
i += 1 # Skip the next argument as it's a value
|
|
152
153
|
else:
|
|
153
|
-
kwargs[key] =
|
|
154
|
+
kwargs[key] = "true"
|
|
154
155
|
elif arg.startswith("-"):
|
|
155
156
|
# Handle short flags like -t or -n
|
|
156
157
|
key = arg[1:]
|
|
157
|
-
kwargs[key] =
|
|
158
|
+
kwargs[key] = "true"
|
|
158
159
|
else:
|
|
159
160
|
# Anything else is considered a positional argument
|
|
160
161
|
residual_args.append(arg)
|
zrb/runner/common_util.py
CHANGED
|
@@ -1,31 +1,36 @@
|
|
|
1
|
-
from typing import Any
|
|
2
|
-
|
|
3
1
|
from zrb.context.shared_context import SharedContext
|
|
4
2
|
from zrb.task.any_task import AnyTask
|
|
5
3
|
|
|
6
4
|
|
|
7
|
-
def
|
|
8
|
-
task: AnyTask,
|
|
5
|
+
def get_task_str_kwargs(
|
|
6
|
+
task: AnyTask, str_args: list[str], str_kwargs: dict[str, str], cli_mode: bool
|
|
9
7
|
) -> dict[str, str]:
|
|
10
8
|
arg_index = 0
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
shared_ctx = SharedContext(args=args)
|
|
9
|
+
dummmy_shared_ctx = SharedContext()
|
|
10
|
+
task_str_kwargs = {}
|
|
14
11
|
for task_input in task.inputs:
|
|
12
|
+
task_name = task_input.name
|
|
15
13
|
if task_input.name in str_kwargs:
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
14
|
+
task_str_kwargs[task_input.name] = str_kwargs[task_name]
|
|
15
|
+
# Update dummy shared context for next input default value
|
|
16
|
+
task_input.update_shared_context(
|
|
17
|
+
dummmy_shared_ctx, str_value=str_kwargs[task_name]
|
|
18
|
+
)
|
|
19
|
+
elif arg_index < len(str_args) and task_input.allow_positional_parsing:
|
|
20
|
+
task_str_kwargs[task_name] = str_args[arg_index]
|
|
21
|
+
# Update dummy shared context for next input default value
|
|
22
|
+
task_input.update_shared_context(
|
|
23
|
+
dummmy_shared_ctx, str_value=task_str_kwargs[task_name]
|
|
24
|
+
)
|
|
22
25
|
arg_index += 1
|
|
23
26
|
else:
|
|
24
27
|
if cli_mode and task_input.always_prompt:
|
|
25
|
-
str_value = task_input.prompt_cli_str(
|
|
28
|
+
str_value = task_input.prompt_cli_str(dummmy_shared_ctx)
|
|
26
29
|
else:
|
|
27
|
-
str_value = task_input.get_default_str(
|
|
28
|
-
|
|
29
|
-
# Update shared context for next input default value
|
|
30
|
-
task_input.update_shared_context(
|
|
31
|
-
|
|
30
|
+
str_value = task_input.get_default_str(dummmy_shared_ctx)
|
|
31
|
+
task_str_kwargs[task_name] = str_value
|
|
32
|
+
# Update dummy shared context for next input default value
|
|
33
|
+
task_input.update_shared_context(
|
|
34
|
+
dummmy_shared_ctx, str_value=task_str_kwargs[task_name]
|
|
35
|
+
)
|
|
36
|
+
return task_str_kwargs
|
|
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
|
|
|
3
3
|
|
|
4
4
|
from zrb.config.web_auth_config import WebAuthConfig
|
|
5
5
|
from zrb.group.any_group import AnyGroup
|
|
6
|
-
from zrb.runner.common_util import
|
|
6
|
+
from zrb.runner.common_util import get_task_str_kwargs
|
|
7
7
|
from zrb.runner.web_util.user import get_user_from_request
|
|
8
8
|
from zrb.task.any_task import AnyTask
|
|
9
9
|
from zrb.util.group import NodeNotFoundError, extract_node_from_args
|
|
@@ -39,9 +39,9 @@ def serve_task_input_api(
|
|
|
39
39
|
if isinstance(task, AnyTask):
|
|
40
40
|
if not user.can_access_task(task):
|
|
41
41
|
return JSONResponse(content={"detail": "Forbidden"}, status_code=403)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
task=task,
|
|
42
|
+
str_kwargs = json.loads(query)
|
|
43
|
+
task_str_kwargs = get_task_str_kwargs(
|
|
44
|
+
task=task, str_args=[], str_kwargs=str_kwargs, cli_mode=False
|
|
45
45
|
)
|
|
46
|
-
return
|
|
46
|
+
return task_str_kwargs
|
|
47
47
|
return JSONResponse(content={"detail": "Not found"}, status_code=404)
|
zrb/runner/web_util/user.py
CHANGED
|
@@ -19,7 +19,7 @@ def get_user_by_credentials(
|
|
|
19
19
|
|
|
20
20
|
async def get_user_from_request(
|
|
21
21
|
web_auth_config: WebAuthConfig, request: "Request"
|
|
22
|
-
) -> User
|
|
22
|
+
) -> User:
|
|
23
23
|
from fastapi.security import OAuth2PasswordBearer
|
|
24
24
|
|
|
25
25
|
if not web_auth_config.enable_auth:
|
|
@@ -45,7 +45,11 @@ def _get_user_from_cookie(
|
|
|
45
45
|
return None
|
|
46
46
|
|
|
47
47
|
|
|
48
|
-
def _get_user_from_token(
|
|
48
|
+
def _get_user_from_token(
|
|
49
|
+
web_auth_config: WebAuthConfig, token: str | None
|
|
50
|
+
) -> User | None:
|
|
51
|
+
if token is None:
|
|
52
|
+
return None
|
|
49
53
|
try:
|
|
50
54
|
from jose import jwt
|
|
51
55
|
|
|
@@ -54,7 +58,7 @@ def _get_user_from_token(web_auth_config: WebAuthConfig, token: str) -> User | N
|
|
|
54
58
|
web_auth_config.secret_key,
|
|
55
59
|
options={"require_sub": True, "require_exp": True},
|
|
56
60
|
)
|
|
57
|
-
username: str = payload.get("sub")
|
|
61
|
+
username: str | None = payload.get("sub")
|
|
58
62
|
if username is None:
|
|
59
63
|
return None
|
|
60
64
|
user = web_auth_config.find_user_by_username(username)
|
zrb/session/any_session.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations # Enables forward references
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
from abc import ABC, abstractmethod
|
|
4
5
|
from typing import TYPE_CHECKING, Any, Coroutine, TypeVar
|
|
5
6
|
|
|
@@ -14,9 +15,6 @@ if TYPE_CHECKING:
|
|
|
14
15
|
from zrb.task.any_task import AnyTask
|
|
15
16
|
|
|
16
17
|
|
|
17
|
-
TAnySession = TypeVar("TAnySession", bound="AnySession")
|
|
18
|
-
|
|
19
|
-
|
|
20
18
|
class AnySession(ABC):
|
|
21
19
|
"""Abstract base class for managing task execution and context in a session.
|
|
22
20
|
|
|
@@ -62,12 +60,13 @@ class AnySession(ABC):
|
|
|
62
60
|
|
|
63
61
|
@property
|
|
64
62
|
@abstractmethod
|
|
65
|
-
def parent(self) ->
|
|
63
|
+
def parent(self) -> "AnySession | None":
|
|
66
64
|
"""Parent session"""
|
|
67
65
|
pass
|
|
68
66
|
|
|
67
|
+
@property
|
|
69
68
|
@abstractmethod
|
|
70
|
-
def task_path(self) -> str:
|
|
69
|
+
def task_path(self) -> list[str]:
|
|
71
70
|
"""Main task's path"""
|
|
72
71
|
pass
|
|
73
72
|
|
|
@@ -105,7 +104,9 @@ class AnySession(ABC):
|
|
|
105
104
|
pass
|
|
106
105
|
|
|
107
106
|
@abstractmethod
|
|
108
|
-
def defer_monitoring(
|
|
107
|
+
def defer_monitoring(
|
|
108
|
+
self, task: "AnyTask", coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
|
|
109
|
+
):
|
|
109
110
|
"""Defers the execution of a task's monitoring coroutine for later processing.
|
|
110
111
|
|
|
111
112
|
Args:
|
|
@@ -115,7 +116,9 @@ class AnySession(ABC):
|
|
|
115
116
|
pass
|
|
116
117
|
|
|
117
118
|
@abstractmethod
|
|
118
|
-
def defer_action(
|
|
119
|
+
def defer_action(
|
|
120
|
+
self, task: "AnyTask", coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
|
|
121
|
+
):
|
|
119
122
|
"""Defers the execution of a task's coroutine for later processing.
|
|
120
123
|
|
|
121
124
|
Args:
|
|
@@ -125,7 +128,7 @@ class AnySession(ABC):
|
|
|
125
128
|
pass
|
|
126
129
|
|
|
127
130
|
@abstractmethod
|
|
128
|
-
def defer_coro(self, coro: Coroutine):
|
|
131
|
+
def defer_coro(self, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]):
|
|
129
132
|
"""Defers the execution of a coroutine for later processing.
|
|
130
133
|
|
|
131
134
|
Args:
|
|
@@ -185,7 +188,7 @@ class AnySession(ABC):
|
|
|
185
188
|
pass
|
|
186
189
|
|
|
187
190
|
@abstractmethod
|
|
188
|
-
def is_allowed_to_run(self, task: "AnyTask"):
|
|
191
|
+
def is_allowed_to_run(self, task: "AnyTask") -> bool:
|
|
189
192
|
"""Determines if the specified task is allowed to run based on its current state.
|
|
190
193
|
|
|
191
194
|
Args:
|
zrb/session/session.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import asyncio
|
|
2
4
|
from typing import TYPE_CHECKING, Any, Coroutine
|
|
3
5
|
|
|
@@ -48,10 +50,10 @@ class Session(AnySession):
|
|
|
48
50
|
self._context: dict[AnyTask, Context] = {}
|
|
49
51
|
self._shared_ctx = shared_ctx
|
|
50
52
|
self._shared_ctx.set_session(self)
|
|
51
|
-
self._parent = parent
|
|
52
|
-
self._action_coros: dict[AnyTask, asyncio.Task] = {}
|
|
53
|
-
self._monitoring_coros: dict[AnyTask, asyncio.Task] = {}
|
|
54
|
-
self._coros: list[asyncio.Task] = []
|
|
53
|
+
self._parent: AnySession | None = parent
|
|
54
|
+
self._action_coros: dict[AnyTask, asyncio.Task[Any]] = {}
|
|
55
|
+
self._monitoring_coros: dict[AnyTask, asyncio.Task[Any]] = {}
|
|
56
|
+
self._coros: list[asyncio.Task[Any]] = []
|
|
55
57
|
self._colors = [
|
|
56
58
|
GREEN,
|
|
57
59
|
YELLOW,
|
|
@@ -114,11 +116,13 @@ class Session(AnySession):
|
|
|
114
116
|
return self._parent
|
|
115
117
|
|
|
116
118
|
@property
|
|
117
|
-
def task_path(self) -> str:
|
|
119
|
+
def task_path(self) -> list[str]:
|
|
118
120
|
return self._main_task_path
|
|
119
121
|
|
|
120
122
|
@property
|
|
121
123
|
def final_result(self) -> Any:
|
|
124
|
+
if self._main_task is None:
|
|
125
|
+
return None
|
|
122
126
|
xcom: Xcom = self.shared_ctx.xcom[self._main_task.name]
|
|
123
127
|
try:
|
|
124
128
|
return xcom.peek()
|
|
@@ -134,7 +138,11 @@ class Session(AnySession):
|
|
|
134
138
|
def set_main_task(self, main_task: AnyTask):
|
|
135
139
|
self.register_task(main_task)
|
|
136
140
|
self._main_task = main_task
|
|
137
|
-
main_task_path =
|
|
141
|
+
main_task_path = (
|
|
142
|
+
None
|
|
143
|
+
if self._root_group is None
|
|
144
|
+
else get_node_path(self._root_group, main_task)
|
|
145
|
+
)
|
|
138
146
|
self._main_task_path = [] if main_task_path is None else main_task_path
|
|
139
147
|
|
|
140
148
|
def as_state_log(self) -> "SessionStateLog":
|
|
@@ -171,7 +179,7 @@ class Session(AnySession):
|
|
|
171
179
|
return SessionStateLog(
|
|
172
180
|
name=self.name,
|
|
173
181
|
start_time=log_start_time,
|
|
174
|
-
main_task_name=self._main_task.name,
|
|
182
|
+
main_task_name="" if self._main_task is None else self._main_task.name,
|
|
175
183
|
path=self.task_path,
|
|
176
184
|
final_result=(
|
|
177
185
|
remove_style(f"{self.final_result}")
|
|
@@ -188,16 +196,29 @@ class Session(AnySession):
|
|
|
188
196
|
self._register_single_task(task)
|
|
189
197
|
return self._context[task]
|
|
190
198
|
|
|
191
|
-
def defer_monitoring(
|
|
199
|
+
def defer_monitoring(
|
|
200
|
+
self, task: AnyTask, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
|
|
201
|
+
):
|
|
192
202
|
self._register_single_task(task)
|
|
193
|
-
|
|
203
|
+
if isinstance(coro, asyncio.Task):
|
|
204
|
+
self._monitoring_coros[task] = coro
|
|
205
|
+
else:
|
|
206
|
+
self._monitoring_coros[task] = asyncio.create_task(coro)
|
|
194
207
|
|
|
195
|
-
def defer_action(
|
|
208
|
+
def defer_action(
|
|
209
|
+
self, task: AnyTask, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
|
|
210
|
+
):
|
|
196
211
|
self._register_single_task(task)
|
|
197
|
-
|
|
212
|
+
if isinstance(coro, asyncio.Task):
|
|
213
|
+
self._action_coros[task] = coro
|
|
214
|
+
else:
|
|
215
|
+
self._action_coros[task] = asyncio.create_task(coro)
|
|
198
216
|
|
|
199
|
-
def defer_coro(self, coro: Coroutine):
|
|
200
|
-
|
|
217
|
+
def defer_coro(self, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]):
|
|
218
|
+
if isinstance(coro, asyncio.Task):
|
|
219
|
+
self._coros.append(coro)
|
|
220
|
+
else:
|
|
221
|
+
self._coros.append(asyncio.create_task(coro))
|
|
201
222
|
self._coros = [
|
|
202
223
|
existing_coro for existing_coro in self._coros if not existing_coro.done()
|
|
203
224
|
]
|
|
@@ -246,15 +267,15 @@ class Session(AnySession):
|
|
|
246
267
|
|
|
247
268
|
def get_next_tasks(self, task: AnyTask) -> list[AnyTask]:
|
|
248
269
|
self._register_single_task(task)
|
|
249
|
-
return self._downstreams.get(task)
|
|
270
|
+
return self._downstreams.get(task, [])
|
|
250
271
|
|
|
251
272
|
def get_task_status(self, task: AnyTask) -> TaskStatus:
|
|
252
273
|
self._register_single_task(task)
|
|
253
274
|
return self._task_status[task]
|
|
254
275
|
|
|
255
276
|
def _register_single_task(self, task: AnyTask):
|
|
256
|
-
if task.name not in self._shared_ctx.
|
|
257
|
-
self._shared_ctx.
|
|
277
|
+
if task.name not in self._shared_ctx.xcom:
|
|
278
|
+
self._shared_ctx.xcom[task.name] = Xcom([])
|
|
258
279
|
if task not in self._context:
|
|
259
280
|
self._context[task] = Context(
|
|
260
281
|
shared_ctx=self._shared_ctx,
|
|
@@ -278,7 +299,7 @@ class Session(AnySession):
|
|
|
278
299
|
self._color_index = 0
|
|
279
300
|
return chosen
|
|
280
301
|
|
|
281
|
-
def _get_icon(self, task: AnyTask) ->
|
|
302
|
+
def _get_icon(self, task: AnyTask) -> str:
|
|
282
303
|
if task.icon is not None:
|
|
283
304
|
return task.icon
|
|
284
305
|
chosen = self._icons[self._icon_index]
|
zrb/task/any_task.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations # Enables forward references
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import TYPE_CHECKING, Any
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
5
5
|
|
|
6
6
|
from zrb.env.any_env import AnyEnv
|
|
7
7
|
from zrb.input.any_input import AnyInput
|
|
@@ -36,6 +36,14 @@ class AnyTask(ABC):
|
|
|
36
36
|
the actual implementation for these abstract members.
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def __rshift__(self, other: "AnyTask | list[AnyTask]") -> "AnyTask | list[AnyTask]":
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def __lshift__(self, other: "AnyTask | list[AnyTask]") -> "AnyTask":
|
|
45
|
+
pass
|
|
46
|
+
|
|
39
47
|
@property
|
|
40
48
|
@abstractmethod
|
|
41
49
|
def name(self) -> str:
|
|
@@ -148,13 +156,17 @@ class AnyTask(ABC):
|
|
|
148
156
|
|
|
149
157
|
@abstractmethod
|
|
150
158
|
def run(
|
|
151
|
-
self,
|
|
159
|
+
self,
|
|
160
|
+
session: "AnySession | None" = None,
|
|
161
|
+
str_kwargs: dict[str, str] | None = None,
|
|
162
|
+
kwargs: dict[str, Any] | None = None,
|
|
152
163
|
) -> Any:
|
|
153
164
|
"""Runs the task synchronously.
|
|
154
165
|
|
|
155
166
|
Args:
|
|
156
167
|
session (AnySession): The shared session.
|
|
157
168
|
str_kwargs(dict[str, str]): The input string values.
|
|
169
|
+
kwargs(dict[str, Any]): The input values.
|
|
158
170
|
|
|
159
171
|
Returns:
|
|
160
172
|
Any: The result of the task execution.
|
|
@@ -163,13 +175,17 @@ class AnyTask(ABC):
|
|
|
163
175
|
|
|
164
176
|
@abstractmethod
|
|
165
177
|
async def async_run(
|
|
166
|
-
self,
|
|
178
|
+
self,
|
|
179
|
+
session: "AnySession | None" = None,
|
|
180
|
+
str_kwargs: dict[str, str] | None = None,
|
|
181
|
+
kwargs: dict[str, Any] | None = None,
|
|
167
182
|
) -> Any:
|
|
168
183
|
"""Runs the task asynchronously.
|
|
169
184
|
|
|
170
185
|
Args:
|
|
171
186
|
session (AnySession): The shared session.
|
|
172
187
|
str_kwargs(dict[str, str]): The input string values.
|
|
188
|
+
kwargs(dict[str, Any]): The input values.
|
|
173
189
|
|
|
174
190
|
Returns:
|
|
175
191
|
Any: The result of the task execution.
|
|
@@ -203,3 +219,8 @@ class AnyTask(ABC):
|
|
|
203
219
|
session (AnySession): The shared session.
|
|
204
220
|
"""
|
|
205
221
|
pass
|
|
222
|
+
|
|
223
|
+
@abstractmethod
|
|
224
|
+
def to_function(self) -> Callable[..., Any]:
|
|
225
|
+
"""Turn a task into a function"""
|
|
226
|
+
pass
|