zrb 1.15.3__py3-none-any.whl → 2.0.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of zrb might be problematic. Click here for more details.

Files changed (204) hide show
  1. zrb/__init__.py +118 -133
  2. zrb/attr/type.py +10 -7
  3. zrb/builtin/__init__.py +55 -1
  4. zrb/builtin/git.py +12 -1
  5. zrb/builtin/group.py +31 -15
  6. zrb/builtin/llm/chat.py +147 -0
  7. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
  8. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
  9. zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
  10. zrb/builtin/searxng/config/settings.yml +5671 -0
  11. zrb/builtin/searxng/start.py +21 -0
  12. zrb/builtin/shell/autocomplete/bash.py +4 -3
  13. zrb/builtin/shell/autocomplete/zsh.py +4 -3
  14. zrb/callback/callback.py +8 -1
  15. zrb/cmd/cmd_result.py +2 -1
  16. zrb/config/config.py +555 -169
  17. zrb/config/helper.py +84 -0
  18. zrb/config/web_auth_config.py +50 -35
  19. zrb/context/any_shared_context.py +20 -3
  20. zrb/context/context.py +39 -5
  21. zrb/context/print_fn.py +13 -0
  22. zrb/context/shared_context.py +17 -8
  23. zrb/group/any_group.py +3 -3
  24. zrb/group/group.py +3 -3
  25. zrb/input/any_input.py +5 -1
  26. zrb/input/base_input.py +18 -6
  27. zrb/input/option_input.py +41 -1
  28. zrb/input/text_input.py +7 -24
  29. zrb/llm/agent/__init__.py +9 -0
  30. zrb/llm/agent/agent.py +215 -0
  31. zrb/llm/agent/summarizer.py +20 -0
  32. zrb/llm/app/__init__.py +10 -0
  33. zrb/llm/app/completion.py +281 -0
  34. zrb/llm/app/confirmation/allow_tool.py +66 -0
  35. zrb/llm/app/confirmation/handler.py +178 -0
  36. zrb/llm/app/confirmation/replace_confirmation.py +77 -0
  37. zrb/llm/app/keybinding.py +34 -0
  38. zrb/llm/app/layout.py +117 -0
  39. zrb/llm/app/lexer.py +155 -0
  40. zrb/llm/app/redirection.py +28 -0
  41. zrb/llm/app/style.py +16 -0
  42. zrb/llm/app/ui.py +733 -0
  43. zrb/llm/config/__init__.py +4 -0
  44. zrb/llm/config/config.py +122 -0
  45. zrb/llm/config/limiter.py +247 -0
  46. zrb/llm/history_manager/__init__.py +4 -0
  47. zrb/llm/history_manager/any_history_manager.py +23 -0
  48. zrb/llm/history_manager/file_history_manager.py +91 -0
  49. zrb/llm/history_processor/summarizer.py +108 -0
  50. zrb/llm/note/__init__.py +3 -0
  51. zrb/llm/note/manager.py +122 -0
  52. zrb/llm/prompt/__init__.py +29 -0
  53. zrb/llm/prompt/claude_compatibility.py +92 -0
  54. zrb/llm/prompt/compose.py +55 -0
  55. zrb/llm/prompt/default.py +51 -0
  56. zrb/llm/prompt/markdown/file_extractor.md +112 -0
  57. zrb/llm/prompt/markdown/mandate.md +23 -0
  58. zrb/llm/prompt/markdown/persona.md +3 -0
  59. zrb/llm/prompt/markdown/repo_extractor.md +112 -0
  60. zrb/llm/prompt/markdown/repo_summarizer.md +29 -0
  61. zrb/llm/prompt/markdown/summarizer.md +21 -0
  62. zrb/llm/prompt/note.py +41 -0
  63. zrb/llm/prompt/system_context.py +46 -0
  64. zrb/llm/prompt/zrb.py +41 -0
  65. zrb/llm/skill/__init__.py +3 -0
  66. zrb/llm/skill/manager.py +86 -0
  67. zrb/llm/task/__init__.py +4 -0
  68. zrb/llm/task/llm_chat_task.py +316 -0
  69. zrb/llm/task/llm_task.py +245 -0
  70. zrb/llm/tool/__init__.py +39 -0
  71. zrb/llm/tool/bash.py +75 -0
  72. zrb/llm/tool/code.py +266 -0
  73. zrb/llm/tool/file.py +419 -0
  74. zrb/llm/tool/note.py +70 -0
  75. zrb/{builtin/llm → llm}/tool/rag.py +33 -37
  76. zrb/llm/tool/search/brave.py +53 -0
  77. zrb/llm/tool/search/searxng.py +47 -0
  78. zrb/llm/tool/search/serpapi.py +47 -0
  79. zrb/llm/tool/skill.py +19 -0
  80. zrb/llm/tool/sub_agent.py +70 -0
  81. zrb/llm/tool/web.py +97 -0
  82. zrb/llm/tool/zrb_task.py +66 -0
  83. zrb/llm/util/attachment.py +101 -0
  84. zrb/llm/util/prompt.py +104 -0
  85. zrb/llm/util/stream_response.py +178 -0
  86. zrb/runner/cli.py +21 -20
  87. zrb/runner/common_util.py +24 -19
  88. zrb/runner/web_route/task_input_api_route.py +5 -5
  89. zrb/runner/web_util/user.py +7 -3
  90. zrb/session/any_session.py +12 -9
  91. zrb/session/session.py +38 -17
  92. zrb/task/any_task.py +24 -3
  93. zrb/task/base/context.py +42 -22
  94. zrb/task/base/execution.py +67 -55
  95. zrb/task/base/lifecycle.py +14 -7
  96. zrb/task/base/monitoring.py +12 -7
  97. zrb/task/base_task.py +113 -50
  98. zrb/task/base_trigger.py +16 -6
  99. zrb/task/cmd_task.py +6 -0
  100. zrb/task/http_check.py +11 -5
  101. zrb/task/make_task.py +5 -3
  102. zrb/task/rsync_task.py +30 -10
  103. zrb/task/scaffolder.py +7 -4
  104. zrb/task/scheduler.py +7 -4
  105. zrb/task/tcp_check.py +6 -4
  106. zrb/util/ascii_art/art/bee.txt +17 -0
  107. zrb/util/ascii_art/art/cat.txt +9 -0
  108. zrb/util/ascii_art/art/ghost.txt +16 -0
  109. zrb/util/ascii_art/art/panda.txt +17 -0
  110. zrb/util/ascii_art/art/rose.txt +14 -0
  111. zrb/util/ascii_art/art/unicorn.txt +15 -0
  112. zrb/util/ascii_art/banner.py +92 -0
  113. zrb/util/attr.py +54 -39
  114. zrb/util/cli/markdown.py +32 -0
  115. zrb/util/cli/text.py +30 -0
  116. zrb/util/cmd/command.py +33 -10
  117. zrb/util/file.py +61 -33
  118. zrb/util/git.py +2 -2
  119. zrb/util/{llm/prompt.py → markdown.py} +2 -3
  120. zrb/util/match.py +78 -0
  121. zrb/util/run.py +3 -3
  122. zrb/util/string/conversion.py +1 -1
  123. zrb/util/truncate.py +23 -0
  124. zrb/util/yaml.py +204 -0
  125. zrb/xcom/xcom.py +10 -0
  126. {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/METADATA +41 -27
  127. {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/RECORD +129 -131
  128. {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/WHEEL +1 -1
  129. zrb/attr/__init__.py +0 -0
  130. zrb/builtin/llm/chat_session.py +0 -311
  131. zrb/builtin/llm/history.py +0 -71
  132. zrb/builtin/llm/input.py +0 -27
  133. zrb/builtin/llm/llm_ask.py +0 -187
  134. zrb/builtin/llm/previous-session.js +0 -21
  135. zrb/builtin/llm/tool/__init__.py +0 -0
  136. zrb/builtin/llm/tool/api.py +0 -71
  137. zrb/builtin/llm/tool/cli.py +0 -38
  138. zrb/builtin/llm/tool/code.py +0 -254
  139. zrb/builtin/llm/tool/file.py +0 -626
  140. zrb/builtin/llm/tool/sub_agent.py +0 -137
  141. zrb/builtin/llm/tool/web.py +0 -195
  142. zrb/builtin/project/__init__.py +0 -0
  143. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/__init__.py +0 -0
  144. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/my_module/service/__init__.py +0 -0
  145. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/__init__.py +0 -0
  146. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/__init__.py +0 -0
  147. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/__init__.py +0 -0
  148. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/__init__.py +0 -0
  149. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/__init__.py +0 -0
  150. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/__init__.py +0 -0
  151. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/__init__.py +0 -0
  152. zrb/builtin/project/create/__init__.py +0 -0
  153. zrb/builtin/shell/__init__.py +0 -0
  154. zrb/builtin/shell/autocomplete/__init__.py +0 -0
  155. zrb/callback/__init__.py +0 -0
  156. zrb/cmd/__init__.py +0 -0
  157. zrb/config/default_prompt/file_extractor_system_prompt.md +0 -12
  158. zrb/config/default_prompt/interactive_system_prompt.md +0 -35
  159. zrb/config/default_prompt/persona.md +0 -1
  160. zrb/config/default_prompt/repo_extractor_system_prompt.md +0 -112
  161. zrb/config/default_prompt/repo_summarizer_system_prompt.md +0 -10
  162. zrb/config/default_prompt/summarization_prompt.md +0 -16
  163. zrb/config/default_prompt/system_prompt.md +0 -32
  164. zrb/config/llm_config.py +0 -243
  165. zrb/config/llm_context/config.py +0 -129
  166. zrb/config/llm_context/config_parser.py +0 -46
  167. zrb/config/llm_rate_limitter.py +0 -137
  168. zrb/content_transformer/__init__.py +0 -0
  169. zrb/context/__init__.py +0 -0
  170. zrb/dot_dict/__init__.py +0 -0
  171. zrb/env/__init__.py +0 -0
  172. zrb/group/__init__.py +0 -0
  173. zrb/input/__init__.py +0 -0
  174. zrb/runner/__init__.py +0 -0
  175. zrb/runner/web_route/__init__.py +0 -0
  176. zrb/runner/web_route/home_page/__init__.py +0 -0
  177. zrb/session/__init__.py +0 -0
  178. zrb/session_state_log/__init__.py +0 -0
  179. zrb/session_state_logger/__init__.py +0 -0
  180. zrb/task/__init__.py +0 -0
  181. zrb/task/base/__init__.py +0 -0
  182. zrb/task/llm/__init__.py +0 -0
  183. zrb/task/llm/agent.py +0 -243
  184. zrb/task/llm/config.py +0 -103
  185. zrb/task/llm/conversation_history.py +0 -128
  186. zrb/task/llm/conversation_history_model.py +0 -242
  187. zrb/task/llm/default_workflow/coding.md +0 -24
  188. zrb/task/llm/default_workflow/copywriting.md +0 -17
  189. zrb/task/llm/default_workflow/researching.md +0 -18
  190. zrb/task/llm/error.py +0 -95
  191. zrb/task/llm/history_summarization.py +0 -216
  192. zrb/task/llm/print_node.py +0 -101
  193. zrb/task/llm/prompt.py +0 -325
  194. zrb/task/llm/tool_wrapper.py +0 -220
  195. zrb/task/llm/typing.py +0 -3
  196. zrb/task/llm_task.py +0 -341
  197. zrb/task_status/__init__.py +0 -0
  198. zrb/util/__init__.py +0 -0
  199. zrb/util/cli/__init__.py +0 -0
  200. zrb/util/cmd/__init__.py +0 -0
  201. zrb/util/codemod/__init__.py +0 -0
  202. zrb/util/string/__init__.py +0 -0
  203. zrb/xcom/__init__.py +0 -0
  204. {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,178 @@
1
+ import json
2
+ from typing import TYPE_CHECKING, Any, Callable
3
+
4
+ from zrb.context.any_context import AnyContext
5
+ from zrb.util.cli.style import stylize_faint
6
+
7
+ if TYPE_CHECKING:
8
+ from pydantic_ai import AgentStreamEvent
9
+
10
+
11
+ def create_event_handler(
12
+ print_event: Callable[..., None],
13
+ indent_level: int = 1,
14
+ show_tool_call_detail: bool = False,
15
+ show_tool_result: bool = False,
16
+ ):
17
+ from pydantic_ai import (
18
+ AgentRunResultEvent,
19
+ FinalResultEvent,
20
+ FunctionToolCallEvent,
21
+ FunctionToolResultEvent,
22
+ PartDeltaEvent,
23
+ PartStartEvent,
24
+ TextPartDelta,
25
+ ThinkingPartDelta,
26
+ ToolCallPartDelta,
27
+ )
28
+
29
+ indentation = indent_level * 2 * " "
30
+ progress_char_list = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
31
+ progress_char_index = 0
32
+ was_tool_call_delta = False
33
+ event_prefix = indentation
34
+
35
+ def fprint(content: str, preserve_leading_newline: bool = False):
36
+ if preserve_leading_newline and content.startswith("\n"):
37
+ return print_event("\n" + content[1:].replace("\n", f"\n{indentation} "))
38
+ return print_event(content.replace("\n", f"\n{indentation} "))
39
+
40
+ async def handle_event(event: "AgentStreamEvent"):
41
+ from pydantic_ai import ToolCallPart
42
+
43
+ nonlocal progress_char_index, was_tool_call_delta, event_prefix
44
+ if isinstance(event, PartStartEvent):
45
+ # Skip ToolCallPart start, we handle it in Deltas/CallEvent
46
+ if isinstance(event.part, ToolCallPart):
47
+ return
48
+ content = _get_event_part_content(event)
49
+ # Use preserve_leading_newline=True because event_prefix contains the correctly indented newline
50
+ fprint(f"{event_prefix}🧠 {content}", preserve_leading_newline=True)
51
+ was_tool_call_delta = False
52
+ elif isinstance(event, PartDeltaEvent):
53
+ if isinstance(event.delta, TextPartDelta):
54
+ # Standard fprint for deltas to ensure wrapping indentation
55
+ fprint(f"{event.delta.content_delta}")
56
+ was_tool_call_delta = False
57
+ elif isinstance(event.delta, ThinkingPartDelta):
58
+ fprint(f"{event.delta.content_delta}")
59
+ was_tool_call_delta = False
60
+ elif isinstance(event.delta, ToolCallPartDelta):
61
+ if show_tool_call_detail:
62
+ fprint(f"{event.delta.args_delta}")
63
+ else:
64
+ progress_char = progress_char_list[progress_char_index]
65
+ if not was_tool_call_delta:
66
+ # Print newline for tool param spinner
67
+ fprint("\n")
68
+
69
+ # Split \r to avoid UI._append_to_output stripping the ANSI start code along with the line
70
+ print_event("\r")
71
+ print_event(
72
+ f"{indentation}🔄 Prepare tool parameters {progress_char}"
73
+ )
74
+ progress_char_index += 1
75
+ if progress_char_index >= len(progress_char_list):
76
+ progress_char_index = 0
77
+ was_tool_call_delta = True
78
+ elif isinstance(event, FunctionToolCallEvent):
79
+ args = _get_truncated_event_part_args(event)
80
+ # Use preserve_leading_newline=True for the block header
81
+ fprint(
82
+ f"{event_prefix}🧰 {event.part.tool_call_id} | {event.part.tool_name} {args}",
83
+ preserve_leading_newline=True,
84
+ )
85
+ was_tool_call_delta = False
86
+ elif isinstance(event, FunctionToolResultEvent):
87
+ if show_tool_result:
88
+ fprint(
89
+ f"{event_prefix}🔠 {event.tool_call_id} | Return {event.result.content}",
90
+ preserve_leading_newline=True,
91
+ )
92
+ else:
93
+ fprint(
94
+ f"{event_prefix}🔠 {event.tool_call_id} Executed",
95
+ preserve_leading_newline=True,
96
+ )
97
+ was_tool_call_delta = False
98
+ elif isinstance(event, AgentRunResultEvent):
99
+ usage = event.result.usage()
100
+ usage_msg = " ".join(
101
+ [
102
+ "💸",
103
+ f"(Requests: {usage.requests} |",
104
+ f"Tool Calls: {usage.tool_calls} |",
105
+ f"Total: {usage.total_tokens})",
106
+ f"Input: {usage.input_tokens} |",
107
+ f"Audio Input: {usage.input_audio_tokens} |",
108
+ f"Output: {usage.output_tokens} |",
109
+ f"Audio Output: {usage.output_audio_tokens} |",
110
+ f"Cache Read: {usage.cache_read_tokens} |",
111
+ f"Cache Write: {usage.cache_write_tokens} |",
112
+ f"Details: {usage.details}",
113
+ ]
114
+ )
115
+ fprint(
116
+ f"{event_prefix}{stylize_faint(usage_msg)}\n",
117
+ preserve_leading_newline=True,
118
+ )
119
+ was_tool_call_delta = False
120
+ elif isinstance(event, FinalResultEvent):
121
+ was_tool_call_delta = False
122
+ event_prefix = f"\n{indentation}"
123
+
124
+ return handle_event
125
+
126
+
127
+ def create_faint_printer(ctx: AnyContext):
128
+ def faint_print(*values: object):
129
+ message = stylize_faint(" ".join([f"{value}" for value in values]))
130
+ ctx.print(message, end="", plain=True)
131
+
132
+ return faint_print
133
+
134
+
135
+ def _get_truncated_event_part_args(event: "AgentStreamEvent") -> Any:
136
+ # Handle empty arguments across different providers
137
+ if not hasattr(event, "part"):
138
+ return {}
139
+ part = getattr(event, "part")
140
+ if not hasattr(part, "args"):
141
+ return {}
142
+ args = getattr(part, "args")
143
+ if args == "" or args is None:
144
+ return {}
145
+ if isinstance(args, str):
146
+ # Some providers might send "null" or "{}" as a string
147
+ if args.strip() in ["null", "{}"]:
148
+ return {}
149
+ try:
150
+ obj = json.loads(args)
151
+ if isinstance(obj, dict):
152
+ return _truncate_kwargs(obj)
153
+ except json.JSONDecodeError:
154
+ pass
155
+ # Handle dummy property if present (from our schema sanitization)
156
+ if isinstance(args, dict):
157
+ return _truncate_kwargs(args)
158
+ return args
159
+
160
+
161
+ def _truncate_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
162
+ return {key: _truncate_arg(val) for key, val in kwargs.items()}
163
+
164
+
165
+ def _truncate_arg(arg: str, length: int = 30) -> str:
166
+ if isinstance(arg, str) and len(arg) > length:
167
+ return f"{arg[:length-4]} ..."
168
+ return arg
169
+
170
+
171
+ def _get_event_part_content(event: "AgentStreamEvent") -> str:
172
+ if not hasattr(event, "part"):
173
+ return ""
174
+ part = getattr(event, "part")
175
+ if hasattr(part, "content"):
176
+ return getattr(part, "content")
177
+ # For parts without content (like ToolCallPart, though we skip it now), return empty or simple repr
178
+ return ""
zrb/runner/cli.py CHANGED
@@ -7,7 +7,7 @@ from zrb.context.any_context import AnyContext
7
7
  from zrb.context.shared_context import SharedContext
8
8
  from zrb.group.any_group import AnyGroup
9
9
  from zrb.group.group import Group
10
- from zrb.runner.common_util import get_run_kwargs
10
+ from zrb.runner.common_util import get_task_str_kwargs
11
11
  from zrb.session.session import Session
12
12
  from zrb.session_state_logger.session_state_logger_factory import session_state_logger
13
13
  from zrb.task.any_task import AnyTask
@@ -38,23 +38,25 @@ class Cli(Group):
38
38
  def banner(self) -> str:
39
39
  return CFG.BANNER
40
40
 
41
- def run(self, args: list[str] = []):
42
- kwargs, args = self._extract_kwargs_from_args(args)
43
- node, node_path, args = extract_node_from_args(self, args)
41
+ def run(self, str_args: list[str] = []):
42
+ str_kwargs, str_args = self._extract_kwargs_from_args(str_args)
43
+ node, node_path, str_args = extract_node_from_args(self, str_args)
44
44
  if isinstance(node, AnyGroup):
45
45
  self._show_group_info(node)
46
46
  return
47
- if "h" in kwargs or "help" in kwargs:
47
+ if "h" in str_kwargs or "help" in str_kwargs:
48
48
  self._show_task_info(node)
49
49
  return
50
- run_kwargs = get_run_kwargs(task=node, args=args, kwargs=kwargs, cli_mode=True)
50
+ task_str_kwargs = get_task_str_kwargs(
51
+ task=node, str_args=str_args, str_kwargs=str_kwargs, cli_mode=True
52
+ )
51
53
  try:
52
- result = self._run_task(node, args, run_kwargs)
54
+ result = self._run_task(node, str_args, task_str_kwargs)
53
55
  if result is not None:
54
56
  print(result)
55
57
  return result
56
58
  finally:
57
- run_command = self._get_run_command(node_path, run_kwargs)
59
+ run_command = self._get_run_command(node_path, task_str_kwargs)
58
60
  self._print_run_command(run_command)
59
61
 
60
62
  def _print_run_command(self, run_command: str):
@@ -64,11 +66,14 @@ class Cli(Group):
64
66
  file=sys.stderr,
65
67
  )
66
68
 
67
- def _get_run_command(self, node_path: list[str], run_kwargs: dict[str, str]) -> str:
69
+ def _get_run_command(
70
+ self, node_path: list[str], task_str_kwargs: dict[str, str]
71
+ ) -> str:
68
72
  parts = [self.name] + node_path
69
- if len(run_kwargs) > 0:
73
+ if len(task_str_kwargs) > 0:
70
74
  parts += [
71
- self._get_run_command_param(key, val) for key, val in run_kwargs.items()
75
+ self._get_run_command_param(key, val)
76
+ for key, val in task_str_kwargs.items()
72
77
  ]
73
78
  return " ".join(parts)
74
79
 
@@ -81,13 +86,9 @@ class Cli(Group):
81
86
  self, task: AnyTask, args: list[str], run_kwargs: dict[str, str]
82
87
  ) -> tuple[Any]:
83
88
  shared_ctx = SharedContext(args=args)
84
- for task_input in task.inputs:
85
- if task_input.name in run_kwargs:
86
- task_input.update_shared_context(
87
- shared_ctx, run_kwargs[task_input.name]
88
- )
89
- continue
90
- return task.run(Session(shared_ctx=shared_ctx, root_group=self))
89
+ return task.run(
90
+ Session(shared_ctx=shared_ctx, root_group=self), str_kwargs=run_kwargs
91
+ )
91
92
 
92
93
  def _show_task_info(self, task: AnyTask):
93
94
  description = task.description
@@ -150,11 +151,11 @@ class Cli(Group):
150
151
  kwargs[key] = args[i + 1]
151
152
  i += 1 # Skip the next argument as it's a value
152
153
  else:
153
- kwargs[key] = True
154
+ kwargs[key] = "true"
154
155
  elif arg.startswith("-"):
155
156
  # Handle short flags like -t or -n
156
157
  key = arg[1:]
157
- kwargs[key] = True
158
+ kwargs[key] = "true"
158
159
  else:
159
160
  # Anything else is considered a positional argument
160
161
  residual_args.append(arg)
zrb/runner/common_util.py CHANGED
@@ -1,31 +1,36 @@
1
- from typing import Any
2
-
3
1
  from zrb.context.shared_context import SharedContext
4
2
  from zrb.task.any_task import AnyTask
5
3
 
6
4
 
7
- def get_run_kwargs(
8
- task: AnyTask, args: list[str], kwargs: dict[str, str], cli_mode: bool
5
+ def get_task_str_kwargs(
6
+ task: AnyTask, str_args: list[str], str_kwargs: dict[str, str], cli_mode: bool
9
7
  ) -> dict[str, str]:
10
8
  arg_index = 0
11
- str_kwargs = {key: f"{val}" for key, val in kwargs.items()}
12
- run_kwargs = {**str_kwargs}
13
- shared_ctx = SharedContext(args=args)
9
+ dummmy_shared_ctx = SharedContext()
10
+ task_str_kwargs = {}
14
11
  for task_input in task.inputs:
12
+ task_name = task_input.name
15
13
  if task_input.name in str_kwargs:
16
- # Update shared context for next input default value
17
- task_input.update_shared_context(shared_ctx, str_kwargs[task_input.name])
18
- elif arg_index < len(args) and task_input.allow_positional_parsing:
19
- run_kwargs[task_input.name] = args[arg_index]
20
- # Update shared context for next input default value
21
- task_input.update_shared_context(shared_ctx, run_kwargs[task_input.name])
14
+ task_str_kwargs[task_input.name] = str_kwargs[task_name]
15
+ # Update dummy shared context for next input default value
16
+ task_input.update_shared_context(
17
+ dummmy_shared_ctx, str_value=str_kwargs[task_name]
18
+ )
19
+ elif arg_index < len(str_args) and task_input.allow_positional_parsing:
20
+ task_str_kwargs[task_name] = str_args[arg_index]
21
+ # Update dummy shared context for next input default value
22
+ task_input.update_shared_context(
23
+ dummmy_shared_ctx, str_value=task_str_kwargs[task_name]
24
+ )
22
25
  arg_index += 1
23
26
  else:
24
27
  if cli_mode and task_input.always_prompt:
25
- str_value = task_input.prompt_cli_str(shared_ctx)
28
+ str_value = task_input.prompt_cli_str(dummmy_shared_ctx)
26
29
  else:
27
- str_value = task_input.get_default_str(shared_ctx)
28
- run_kwargs[task_input.name] = str_value
29
- # Update shared context for next input default value
30
- task_input.update_shared_context(shared_ctx, run_kwargs[task_input.name])
31
- return run_kwargs
30
+ str_value = task_input.get_default_str(dummmy_shared_ctx)
31
+ task_str_kwargs[task_name] = str_value
32
+ # Update dummy shared context for next input default value
33
+ task_input.update_shared_context(
34
+ dummmy_shared_ctx, str_value=task_str_kwargs[task_name]
35
+ )
36
+ return task_str_kwargs
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
3
3
 
4
4
  from zrb.config.web_auth_config import WebAuthConfig
5
5
  from zrb.group.any_group import AnyGroup
6
- from zrb.runner.common_util import get_run_kwargs
6
+ from zrb.runner.common_util import get_task_str_kwargs
7
7
  from zrb.runner.web_util.user import get_user_from_request
8
8
  from zrb.task.any_task import AnyTask
9
9
  from zrb.util.group import NodeNotFoundError, extract_node_from_args
@@ -39,9 +39,9 @@ def serve_task_input_api(
39
39
  if isinstance(task, AnyTask):
40
40
  if not user.can_access_task(task):
41
41
  return JSONResponse(content={"detail": "Forbidden"}, status_code=403)
42
- query_dict = json.loads(query)
43
- run_kwargs = get_run_kwargs(
44
- task=task, args=[], kwargs=query_dict, cli_mode=False
42
+ str_kwargs = json.loads(query)
43
+ task_str_kwargs = get_task_str_kwargs(
44
+ task=task, str_args=[], str_kwargs=str_kwargs, cli_mode=False
45
45
  )
46
- return run_kwargs
46
+ return task_str_kwargs
47
47
  return JSONResponse(content={"detail": "Not found"}, status_code=404)
@@ -19,7 +19,7 @@ def get_user_by_credentials(
19
19
 
20
20
  async def get_user_from_request(
21
21
  web_auth_config: WebAuthConfig, request: "Request"
22
- ) -> User | None:
22
+ ) -> User:
23
23
  from fastapi.security import OAuth2PasswordBearer
24
24
 
25
25
  if not web_auth_config.enable_auth:
@@ -45,7 +45,11 @@ def _get_user_from_cookie(
45
45
  return None
46
46
 
47
47
 
48
- def _get_user_from_token(web_auth_config: WebAuthConfig, token: str) -> User | None:
48
+ def _get_user_from_token(
49
+ web_auth_config: WebAuthConfig, token: str | None
50
+ ) -> User | None:
51
+ if token is None:
52
+ return None
49
53
  try:
50
54
  from jose import jwt
51
55
 
@@ -54,7 +58,7 @@ def _get_user_from_token(web_auth_config: WebAuthConfig, token: str) -> User | N
54
58
  web_auth_config.secret_key,
55
59
  options={"require_sub": True, "require_exp": True},
56
60
  )
57
- username: str = payload.get("sub")
61
+ username: str | None = payload.get("sub")
58
62
  if username is None:
59
63
  return None
60
64
  user = web_auth_config.find_user_by_username(username)
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations # Enables forward references
2
2
 
3
+ import asyncio
3
4
  from abc import ABC, abstractmethod
4
5
  from typing import TYPE_CHECKING, Any, Coroutine, TypeVar
5
6
 
@@ -14,9 +15,6 @@ if TYPE_CHECKING:
14
15
  from zrb.task.any_task import AnyTask
15
16
 
16
17
 
17
- TAnySession = TypeVar("TAnySession", bound="AnySession")
18
-
19
-
20
18
  class AnySession(ABC):
21
19
  """Abstract base class for managing task execution and context in a session.
22
20
 
@@ -62,12 +60,13 @@ class AnySession(ABC):
62
60
 
63
61
  @property
64
62
  @abstractmethod
65
- def parent(self) -> TAnySession | None:
63
+ def parent(self) -> "AnySession | None":
66
64
  """Parent session"""
67
65
  pass
68
66
 
67
+ @property
69
68
  @abstractmethod
70
- def task_path(self) -> str:
69
+ def task_path(self) -> list[str]:
71
70
  """Main task's path"""
72
71
  pass
73
72
 
@@ -105,7 +104,9 @@ class AnySession(ABC):
105
104
  pass
106
105
 
107
106
  @abstractmethod
108
- def defer_monitoring(self, task: "AnyTask", coro: Coroutine):
107
+ def defer_monitoring(
108
+ self, task: "AnyTask", coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
109
+ ):
109
110
  """Defers the execution of a task's monitoring coroutine for later processing.
110
111
 
111
112
  Args:
@@ -115,7 +116,9 @@ class AnySession(ABC):
115
116
  pass
116
117
 
117
118
  @abstractmethod
118
- def defer_action(self, task: "AnyTask", coro: Coroutine):
119
+ def defer_action(
120
+ self, task: "AnyTask", coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
121
+ ):
119
122
  """Defers the execution of a task's coroutine for later processing.
120
123
 
121
124
  Args:
@@ -125,7 +128,7 @@ class AnySession(ABC):
125
128
  pass
126
129
 
127
130
  @abstractmethod
128
- def defer_coro(self, coro: Coroutine):
131
+ def defer_coro(self, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]):
129
132
  """Defers the execution of a coroutine for later processing.
130
133
 
131
134
  Args:
@@ -185,7 +188,7 @@ class AnySession(ABC):
185
188
  pass
186
189
 
187
190
  @abstractmethod
188
- def is_allowed_to_run(self, task: "AnyTask"):
191
+ def is_allowed_to_run(self, task: "AnyTask") -> bool:
189
192
  """Determines if the specified task is allowed to run based on its current state.
190
193
 
191
194
  Args:
zrb/session/session.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import asyncio
2
4
  from typing import TYPE_CHECKING, Any, Coroutine
3
5
 
@@ -48,10 +50,10 @@ class Session(AnySession):
48
50
  self._context: dict[AnyTask, Context] = {}
49
51
  self._shared_ctx = shared_ctx
50
52
  self._shared_ctx.set_session(self)
51
- self._parent = parent
52
- self._action_coros: dict[AnyTask, asyncio.Task] = {}
53
- self._monitoring_coros: dict[AnyTask, asyncio.Task] = {}
54
- self._coros: list[asyncio.Task] = []
53
+ self._parent: AnySession | None = parent
54
+ self._action_coros: dict[AnyTask, asyncio.Task[Any]] = {}
55
+ self._monitoring_coros: dict[AnyTask, asyncio.Task[Any]] = {}
56
+ self._coros: list[asyncio.Task[Any]] = []
55
57
  self._colors = [
56
58
  GREEN,
57
59
  YELLOW,
@@ -114,11 +116,13 @@ class Session(AnySession):
114
116
  return self._parent
115
117
 
116
118
  @property
117
- def task_path(self) -> str:
119
+ def task_path(self) -> list[str]:
118
120
  return self._main_task_path
119
121
 
120
122
  @property
121
123
  def final_result(self) -> Any:
124
+ if self._main_task is None:
125
+ return None
122
126
  xcom: Xcom = self.shared_ctx.xcom[self._main_task.name]
123
127
  try:
124
128
  return xcom.peek()
@@ -134,7 +138,11 @@ class Session(AnySession):
134
138
  def set_main_task(self, main_task: AnyTask):
135
139
  self.register_task(main_task)
136
140
  self._main_task = main_task
137
- main_task_path = get_node_path(self._root_group, main_task)
141
+ main_task_path = (
142
+ None
143
+ if self._root_group is None
144
+ else get_node_path(self._root_group, main_task)
145
+ )
138
146
  self._main_task_path = [] if main_task_path is None else main_task_path
139
147
 
140
148
  def as_state_log(self) -> "SessionStateLog":
@@ -171,7 +179,7 @@ class Session(AnySession):
171
179
  return SessionStateLog(
172
180
  name=self.name,
173
181
  start_time=log_start_time,
174
- main_task_name=self._main_task.name,
182
+ main_task_name="" if self._main_task is None else self._main_task.name,
175
183
  path=self.task_path,
176
184
  final_result=(
177
185
  remove_style(f"{self.final_result}")
@@ -188,16 +196,29 @@ class Session(AnySession):
188
196
  self._register_single_task(task)
189
197
  return self._context[task]
190
198
 
191
- def defer_monitoring(self, task: AnyTask, coro: Coroutine):
199
+ def defer_monitoring(
200
+ self, task: AnyTask, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
201
+ ):
192
202
  self._register_single_task(task)
193
- self._monitoring_coros[task] = coro
203
+ if isinstance(coro, asyncio.Task):
204
+ self._monitoring_coros[task] = coro
205
+ else:
206
+ self._monitoring_coros[task] = asyncio.create_task(coro)
194
207
 
195
- def defer_action(self, task: AnyTask, coro: Coroutine):
208
+ def defer_action(
209
+ self, task: AnyTask, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
210
+ ):
196
211
  self._register_single_task(task)
197
- self._action_coros[task] = coro
212
+ if isinstance(coro, asyncio.Task):
213
+ self._action_coros[task] = coro
214
+ else:
215
+ self._action_coros[task] = asyncio.create_task(coro)
198
216
 
199
- def defer_coro(self, coro: Coroutine):
200
- self._coros.append(coro)
217
+ def defer_coro(self, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]):
218
+ if isinstance(coro, asyncio.Task):
219
+ self._coros.append(coro)
220
+ else:
221
+ self._coros.append(asyncio.create_task(coro))
201
222
  self._coros = [
202
223
  existing_coro for existing_coro in self._coros if not existing_coro.done()
203
224
  ]
@@ -246,15 +267,15 @@ class Session(AnySession):
246
267
 
247
268
  def get_next_tasks(self, task: AnyTask) -> list[AnyTask]:
248
269
  self._register_single_task(task)
249
- return self._downstreams.get(task)
270
+ return self._downstreams.get(task, [])
250
271
 
251
272
  def get_task_status(self, task: AnyTask) -> TaskStatus:
252
273
  self._register_single_task(task)
253
274
  return self._task_status[task]
254
275
 
255
276
  def _register_single_task(self, task: AnyTask):
256
- if task.name not in self._shared_ctx._xcom:
257
- self._shared_ctx._xcom[task.name] = Xcom([])
277
+ if task.name not in self._shared_ctx.xcom:
278
+ self._shared_ctx.xcom[task.name] = Xcom([])
258
279
  if task not in self._context:
259
280
  self._context[task] = Context(
260
281
  shared_ctx=self._shared_ctx,
@@ -278,7 +299,7 @@ class Session(AnySession):
278
299
  self._color_index = 0
279
300
  return chosen
280
301
 
281
- def _get_icon(self, task: AnyTask) -> int:
302
+ def _get_icon(self, task: AnyTask) -> str:
282
303
  if task.icon is not None:
283
304
  return task.icon
284
305
  chosen = self._icons[self._icon_index]
zrb/task/any_task.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations # Enables forward references
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING, Any
4
+ from typing import TYPE_CHECKING, Any, Callable
5
5
 
6
6
  from zrb.env.any_env import AnyEnv
7
7
  from zrb.input.any_input import AnyInput
@@ -36,6 +36,14 @@ class AnyTask(ABC):
36
36
  the actual implementation for these abstract members.
37
37
  """
38
38
 
39
+ @abstractmethod
40
+ def __rshift__(self, other: "AnyTask | list[AnyTask]") -> "AnyTask | list[AnyTask]":
41
+ pass
42
+
43
+ @abstractmethod
44
+ def __lshift__(self, other: "AnyTask | list[AnyTask]") -> "AnyTask":
45
+ pass
46
+
39
47
  @property
40
48
  @abstractmethod
41
49
  def name(self) -> str:
@@ -148,13 +156,17 @@ class AnyTask(ABC):
148
156
 
149
157
  @abstractmethod
150
158
  def run(
151
- self, session: "AnySession | None" = None, str_kwargs: dict[str, str] = {}
159
+ self,
160
+ session: "AnySession | None" = None,
161
+ str_kwargs: dict[str, str] | None = None,
162
+ kwargs: dict[str, Any] | None = None,
152
163
  ) -> Any:
153
164
  """Runs the task synchronously.
154
165
 
155
166
  Args:
156
167
  session (AnySession): The shared session.
157
168
  str_kwargs(dict[str, str]): The input string values.
169
+ kwargs(dict[str, Any]): The input values.
158
170
 
159
171
  Returns:
160
172
  Any: The result of the task execution.
@@ -163,13 +175,17 @@ class AnyTask(ABC):
163
175
 
164
176
  @abstractmethod
165
177
  async def async_run(
166
- self, session: "AnySession | None" = None, str_kwargs: dict[str, str] = {}
178
+ self,
179
+ session: "AnySession | None" = None,
180
+ str_kwargs: dict[str, str] | None = None,
181
+ kwargs: dict[str, Any] | None = None,
167
182
  ) -> Any:
168
183
  """Runs the task asynchronously.
169
184
 
170
185
  Args:
171
186
  session (AnySession): The shared session.
172
187
  str_kwargs(dict[str, str]): The input string values.
188
+ kwargs(dict[str, Any]): The input values.
173
189
 
174
190
  Returns:
175
191
  Any: The result of the task execution.
@@ -203,3 +219,8 @@ class AnyTask(ABC):
203
219
  session (AnySession): The shared session.
204
220
  """
205
221
  pass
222
+
223
+ @abstractmethod
224
+ def to_function(self) -> Callable[..., Any]:
225
+ """Turn a task into a function"""
226
+ pass