zrb 1.21.29__py3-none-any.whl → 2.0.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of zrb might be problematic. Click here for more details.

Files changed (192) hide show
  1. zrb/__init__.py +118 -129
  2. zrb/builtin/__init__.py +54 -2
  3. zrb/builtin/llm/chat.py +147 -0
  4. zrb/callback/callback.py +8 -1
  5. zrb/cmd/cmd_result.py +2 -1
  6. zrb/config/config.py +491 -280
  7. zrb/config/helper.py +84 -0
  8. zrb/config/web_auth_config.py +50 -35
  9. zrb/context/any_shared_context.py +13 -2
  10. zrb/context/context.py +31 -3
  11. zrb/context/print_fn.py +13 -0
  12. zrb/context/shared_context.py +14 -1
  13. zrb/input/option_input.py +30 -2
  14. zrb/llm/agent/__init__.py +9 -0
  15. zrb/llm/agent/agent.py +215 -0
  16. zrb/llm/agent/summarizer.py +20 -0
  17. zrb/llm/app/__init__.py +10 -0
  18. zrb/llm/app/completion.py +281 -0
  19. zrb/llm/app/confirmation/allow_tool.py +66 -0
  20. zrb/llm/app/confirmation/handler.py +178 -0
  21. zrb/llm/app/confirmation/replace_confirmation.py +77 -0
  22. zrb/llm/app/keybinding.py +34 -0
  23. zrb/llm/app/layout.py +117 -0
  24. zrb/llm/app/lexer.py +155 -0
  25. zrb/llm/app/redirection.py +28 -0
  26. zrb/llm/app/style.py +16 -0
  27. zrb/llm/app/ui.py +733 -0
  28. zrb/llm/config/__init__.py +4 -0
  29. zrb/llm/config/config.py +122 -0
  30. zrb/llm/config/limiter.py +247 -0
  31. zrb/llm/history_manager/__init__.py +4 -0
  32. zrb/llm/history_manager/any_history_manager.py +23 -0
  33. zrb/llm/history_manager/file_history_manager.py +91 -0
  34. zrb/llm/history_processor/summarizer.py +108 -0
  35. zrb/llm/note/__init__.py +3 -0
  36. zrb/llm/note/manager.py +122 -0
  37. zrb/llm/prompt/__init__.py +29 -0
  38. zrb/llm/prompt/claude_compatibility.py +92 -0
  39. zrb/llm/prompt/compose.py +55 -0
  40. zrb/llm/prompt/default.py +51 -0
  41. zrb/llm/prompt/markdown/mandate.md +23 -0
  42. zrb/llm/prompt/markdown/persona.md +3 -0
  43. zrb/llm/prompt/markdown/summarizer.md +21 -0
  44. zrb/llm/prompt/note.py +41 -0
  45. zrb/llm/prompt/system_context.py +46 -0
  46. zrb/llm/prompt/zrb.py +41 -0
  47. zrb/llm/skill/__init__.py +3 -0
  48. zrb/llm/skill/manager.py +86 -0
  49. zrb/llm/task/__init__.py +4 -0
  50. zrb/llm/task/llm_chat_task.py +316 -0
  51. zrb/llm/task/llm_task.py +245 -0
  52. zrb/llm/tool/__init__.py +39 -0
  53. zrb/llm/tool/bash.py +75 -0
  54. zrb/llm/tool/code.py +266 -0
  55. zrb/llm/tool/file.py +419 -0
  56. zrb/llm/tool/note.py +70 -0
  57. zrb/{builtin/llm → llm}/tool/rag.py +8 -5
  58. zrb/llm/tool/search/brave.py +53 -0
  59. zrb/llm/tool/search/searxng.py +47 -0
  60. zrb/llm/tool/search/serpapi.py +47 -0
  61. zrb/llm/tool/skill.py +19 -0
  62. zrb/llm/tool/sub_agent.py +70 -0
  63. zrb/llm/tool/web.py +97 -0
  64. zrb/llm/tool/zrb_task.py +66 -0
  65. zrb/llm/util/attachment.py +101 -0
  66. zrb/llm/util/prompt.py +104 -0
  67. zrb/llm/util/stream_response.py +178 -0
  68. zrb/session/any_session.py +0 -3
  69. zrb/session/session.py +1 -1
  70. zrb/task/base/context.py +25 -13
  71. zrb/task/base/execution.py +52 -47
  72. zrb/task/base/lifecycle.py +7 -4
  73. zrb/task/base_task.py +48 -49
  74. zrb/task/base_trigger.py +4 -1
  75. zrb/task/cmd_task.py +6 -0
  76. zrb/task/http_check.py +11 -5
  77. zrb/task/make_task.py +3 -0
  78. zrb/task/rsync_task.py +5 -0
  79. zrb/task/scaffolder.py +7 -4
  80. zrb/task/scheduler.py +3 -0
  81. zrb/task/tcp_check.py +6 -4
  82. zrb/util/ascii_art/art/bee.txt +17 -0
  83. zrb/util/ascii_art/art/cat.txt +9 -0
  84. zrb/util/ascii_art/art/ghost.txt +16 -0
  85. zrb/util/ascii_art/art/panda.txt +17 -0
  86. zrb/util/ascii_art/art/rose.txt +14 -0
  87. zrb/util/ascii_art/art/unicorn.txt +15 -0
  88. zrb/util/ascii_art/banner.py +92 -0
  89. zrb/util/cli/markdown.py +22 -2
  90. zrb/util/cmd/command.py +33 -10
  91. zrb/util/file.py +51 -32
  92. zrb/util/match.py +78 -0
  93. zrb/util/run.py +3 -3
  94. {zrb-1.21.29.dist-info → zrb-2.0.0a4.dist-info}/METADATA +9 -15
  95. {zrb-1.21.29.dist-info → zrb-2.0.0a4.dist-info}/RECORD +100 -128
  96. zrb/attr/__init__.py +0 -0
  97. zrb/builtin/llm/attachment.py +0 -40
  98. zrb/builtin/llm/chat_completion.py +0 -274
  99. zrb/builtin/llm/chat_session.py +0 -270
  100. zrb/builtin/llm/chat_session_cmd.py +0 -288
  101. zrb/builtin/llm/chat_trigger.py +0 -79
  102. zrb/builtin/llm/history.py +0 -71
  103. zrb/builtin/llm/input.py +0 -27
  104. zrb/builtin/llm/llm_ask.py +0 -269
  105. zrb/builtin/llm/previous-session.js +0 -21
  106. zrb/builtin/llm/tool/__init__.py +0 -0
  107. zrb/builtin/llm/tool/api.py +0 -75
  108. zrb/builtin/llm/tool/cli.py +0 -52
  109. zrb/builtin/llm/tool/code.py +0 -236
  110. zrb/builtin/llm/tool/file.py +0 -560
  111. zrb/builtin/llm/tool/note.py +0 -84
  112. zrb/builtin/llm/tool/sub_agent.py +0 -150
  113. zrb/builtin/llm/tool/web.py +0 -171
  114. zrb/builtin/project/__init__.py +0 -0
  115. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/__init__.py +0 -0
  116. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/my_module/service/__init__.py +0 -0
  117. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/__init__.py +0 -0
  118. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/__init__.py +0 -0
  119. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/__init__.py +0 -0
  120. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/__init__.py +0 -0
  121. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/__init__.py +0 -0
  122. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/__init__.py +0 -0
  123. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/__init__.py +0 -0
  124. zrb/builtin/project/create/__init__.py +0 -0
  125. zrb/builtin/shell/__init__.py +0 -0
  126. zrb/builtin/shell/autocomplete/__init__.py +0 -0
  127. zrb/callback/__init__.py +0 -0
  128. zrb/cmd/__init__.py +0 -0
  129. zrb/config/default_prompt/interactive_system_prompt.md +0 -29
  130. zrb/config/default_prompt/persona.md +0 -1
  131. zrb/config/default_prompt/summarization_prompt.md +0 -57
  132. zrb/config/default_prompt/system_prompt.md +0 -38
  133. zrb/config/llm_config.py +0 -339
  134. zrb/config/llm_context/config.py +0 -166
  135. zrb/config/llm_context/config_parser.py +0 -40
  136. zrb/config/llm_context/workflow.py +0 -81
  137. zrb/config/llm_rate_limitter.py +0 -190
  138. zrb/content_transformer/__init__.py +0 -0
  139. zrb/context/__init__.py +0 -0
  140. zrb/dot_dict/__init__.py +0 -0
  141. zrb/env/__init__.py +0 -0
  142. zrb/group/__init__.py +0 -0
  143. zrb/input/__init__.py +0 -0
  144. zrb/runner/__init__.py +0 -0
  145. zrb/runner/web_route/__init__.py +0 -0
  146. zrb/runner/web_route/home_page/__init__.py +0 -0
  147. zrb/session/__init__.py +0 -0
  148. zrb/session_state_log/__init__.py +0 -0
  149. zrb/session_state_logger/__init__.py +0 -0
  150. zrb/task/__init__.py +0 -0
  151. zrb/task/base/__init__.py +0 -0
  152. zrb/task/llm/__init__.py +0 -0
  153. zrb/task/llm/agent.py +0 -204
  154. zrb/task/llm/agent_runner.py +0 -152
  155. zrb/task/llm/config.py +0 -122
  156. zrb/task/llm/conversation_history.py +0 -209
  157. zrb/task/llm/conversation_history_model.py +0 -67
  158. zrb/task/llm/default_workflow/coding/workflow.md +0 -41
  159. zrb/task/llm/default_workflow/copywriting/workflow.md +0 -68
  160. zrb/task/llm/default_workflow/git/workflow.md +0 -118
  161. zrb/task/llm/default_workflow/golang/workflow.md +0 -128
  162. zrb/task/llm/default_workflow/html-css/workflow.md +0 -135
  163. zrb/task/llm/default_workflow/java/workflow.md +0 -146
  164. zrb/task/llm/default_workflow/javascript/workflow.md +0 -158
  165. zrb/task/llm/default_workflow/python/workflow.md +0 -160
  166. zrb/task/llm/default_workflow/researching/workflow.md +0 -153
  167. zrb/task/llm/default_workflow/rust/workflow.md +0 -162
  168. zrb/task/llm/default_workflow/shell/workflow.md +0 -299
  169. zrb/task/llm/error.py +0 -95
  170. zrb/task/llm/file_replacement.py +0 -206
  171. zrb/task/llm/file_tool_model.py +0 -57
  172. zrb/task/llm/history_processor.py +0 -206
  173. zrb/task/llm/history_summarization.py +0 -25
  174. zrb/task/llm/print_node.py +0 -221
  175. zrb/task/llm/prompt.py +0 -321
  176. zrb/task/llm/subagent_conversation_history.py +0 -41
  177. zrb/task/llm/tool_wrapper.py +0 -361
  178. zrb/task/llm/typing.py +0 -3
  179. zrb/task/llm/workflow.py +0 -76
  180. zrb/task/llm_task.py +0 -379
  181. zrb/task_status/__init__.py +0 -0
  182. zrb/util/__init__.py +0 -0
  183. zrb/util/cli/__init__.py +0 -0
  184. zrb/util/cmd/__init__.py +0 -0
  185. zrb/util/codemod/__init__.py +0 -0
  186. zrb/util/string/__init__.py +0 -0
  187. zrb/xcom/__init__.py +0 -0
  188. /zrb/{config/default_prompt/file_extractor_system_prompt.md → llm/prompt/markdown/file_extractor.md} +0 -0
  189. /zrb/{config/default_prompt/repo_extractor_system_prompt.md → llm/prompt/markdown/repo_extractor.md} +0 -0
  190. /zrb/{config/default_prompt/repo_summarizer_system_prompt.md → llm/prompt/markdown/repo_summarizer.md} +0 -0
  191. {zrb-1.21.29.dist-info → zrb-2.0.0a4.dist-info}/WHEEL +0 -0
  192. {zrb-1.21.29.dist-info → zrb-2.0.0a4.dist-info}/entry_points.txt +0 -0
zrb/llm/util/prompt.py ADDED
@@ -0,0 +1,104 @@
1
+ import os
2
+ import re
3
+
4
+ from zrb.util.file import list_files, read_file
5
+
6
+
7
+ def expand_prompt(prompt: str) -> str:
8
+ """
9
+ Expands @reference patterns in the prompt into a Reference + Appendix style.
10
+ Example: "Check @main.py" -> "Check main.py (see Appendix)\n...[Appendix with content]..."
11
+ """
12
+ if not prompt:
13
+ return prompt
14
+
15
+ # Regex to capture @path.
16
+ # Matches @ followed by typical path chars.
17
+ # We'll allow alphanumeric, _, -, ., /, \, and ~ (home dir).
18
+ pattern = re.compile(r"@(?P<path>[\w~\-\./\\]+)")
19
+
20
+ matches = list(pattern.finditer(prompt))
21
+ if not matches:
22
+ return prompt
23
+
24
+ appendix_entries: list[str] = []
25
+ # We construct the new string by slicing.
26
+ last_idx = 0
27
+ parts = []
28
+
29
+ for match in matches:
30
+ # Add text before match
31
+ parts.append(prompt[last_idx : match.start()])
32
+
33
+ path_ref = match.group("path")
34
+ original_token = match.group(0)
35
+
36
+ # Check existence
37
+ expanded_path = os.path.expanduser(path_ref)
38
+ abs_path = os.path.abspath(expanded_path)
39
+
40
+ content = ""
41
+ header = ""
42
+ is_valid_ref = False
43
+
44
+ if os.path.isfile(abs_path):
45
+ try:
46
+ content = read_file(abs_path)
47
+ header = f"File Content: `{path_ref}`"
48
+ is_valid_ref = True
49
+ except Exception:
50
+ pass
51
+ elif os.path.isdir(abs_path):
52
+ try:
53
+ # Use list_files for directory structure
54
+ file_list = list_files(abs_path, depth=2)
55
+ content = "\n".join(file_list)
56
+ if not content:
57
+ content = "(Empty directory)"
58
+ header = f"Directory Listing: `{path_ref}`"
59
+ is_valid_ref = True
60
+ except Exception:
61
+ pass
62
+
63
+ if not is_valid_ref:
64
+ # Fallback: leave original token if unreadable or not found
65
+ parts.append(original_token)
66
+ last_idx = match.end()
67
+ continue
68
+
69
+ # If we successfully got content
70
+ parts.append(f"`{path_ref}` (see Appendix)")
71
+
72
+ # Add to appendix with strict instructions
73
+ entry_lines = [
74
+ f"### {header}",
75
+ f"> **SYSTEM NOTE:** The content of `{path_ref}` is provided below.",
76
+ "> **DO NOT** use tools like `read_file` or `list_files` to read this path again.",
77
+ "> Use the content provided here directly.\n",
78
+ "```",
79
+ f"{content}",
80
+ "```",
81
+ ]
82
+ appendix_entries.append("\n".join(entry_lines))
83
+
84
+ last_idx = match.end()
85
+
86
+ # Add remaining text
87
+ parts.append(prompt[last_idx:])
88
+
89
+ new_prompt = "".join(parts)
90
+
91
+ if appendix_entries:
92
+ sep = "=" * 20
93
+ header_lines = [
94
+ f"\n\n{sep}APPENDIX: PRE-LOADED CONTEXT {sep}",
95
+ "⚠️ **SYSTEM INSTRUCTION**: The user has attached the following content.",
96
+ "You MUST use this provided content for your analysis.",
97
+ "**DO NOT** consume resources by calling `read_file` or `list_files`",
98
+ "or `run_shell_command` to read these specific paths again.\n",
99
+ ]
100
+ appendix_section = "\n".join(header_lines)
101
+ appendix_section += "\n\n".join(appendix_entries)
102
+ new_prompt += appendix_section
103
+
104
+ return new_prompt
@@ -0,0 +1,178 @@
1
+ import json
2
+ from typing import TYPE_CHECKING, Any, Callable
3
+
4
+ from zrb.context.any_context import AnyContext
5
+ from zrb.util.cli.style import stylize_faint
6
+
7
+ if TYPE_CHECKING:
8
+ from pydantic_ai import AgentStreamEvent
9
+
10
+
11
+ def create_event_handler(
12
+ print_event: Callable[..., None],
13
+ indent_level: int = 1,
14
+ show_tool_call_detail: bool = False,
15
+ show_tool_result: bool = False,
16
+ ):
17
+ from pydantic_ai import (
18
+ AgentRunResultEvent,
19
+ FinalResultEvent,
20
+ FunctionToolCallEvent,
21
+ FunctionToolResultEvent,
22
+ PartDeltaEvent,
23
+ PartStartEvent,
24
+ TextPartDelta,
25
+ ThinkingPartDelta,
26
+ ToolCallPartDelta,
27
+ )
28
+
29
+ indentation = indent_level * 2 * " "
30
+ progress_char_list = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
31
+ progress_char_index = 0
32
+ was_tool_call_delta = False
33
+ event_prefix = indentation
34
+
35
+ def fprint(content: str, preserve_leading_newline: bool = False):
36
+ if preserve_leading_newline and content.startswith("\n"):
37
+ return print_event("\n" + content[1:].replace("\n", f"\n{indentation} "))
38
+ return print_event(content.replace("\n", f"\n{indentation} "))
39
+
40
+ async def handle_event(event: "AgentStreamEvent"):
41
+ from pydantic_ai import ToolCallPart
42
+
43
+ nonlocal progress_char_index, was_tool_call_delta, event_prefix
44
+ if isinstance(event, PartStartEvent):
45
+ # Skip ToolCallPart start, we handle it in Deltas/CallEvent
46
+ if isinstance(event.part, ToolCallPart):
47
+ return
48
+ content = _get_event_part_content(event)
49
+ # Use preserve_leading_newline=True because event_prefix contains the correctly indented newline
50
+ fprint(f"{event_prefix}🧠 {content}", preserve_leading_newline=True)
51
+ was_tool_call_delta = False
52
+ elif isinstance(event, PartDeltaEvent):
53
+ if isinstance(event.delta, TextPartDelta):
54
+ # Standard fprint for deltas to ensure wrapping indentation
55
+ fprint(f"{event.delta.content_delta}")
56
+ was_tool_call_delta = False
57
+ elif isinstance(event.delta, ThinkingPartDelta):
58
+ fprint(f"{event.delta.content_delta}")
59
+ was_tool_call_delta = False
60
+ elif isinstance(event.delta, ToolCallPartDelta):
61
+ if show_tool_call_detail:
62
+ fprint(f"{event.delta.args_delta}")
63
+ else:
64
+ progress_char = progress_char_list[progress_char_index]
65
+ if not was_tool_call_delta:
66
+ # Print newline for tool param spinner
67
+ fprint("\n")
68
+
69
+ # Split \r to avoid UI._append_to_output stripping the ANSI start code along with the line
70
+ print_event("\r")
71
+ print_event(
72
+ f"{indentation}🔄 Prepare tool parameters {progress_char}"
73
+ )
74
+ progress_char_index += 1
75
+ if progress_char_index >= len(progress_char_list):
76
+ progress_char_index = 0
77
+ was_tool_call_delta = True
78
+ elif isinstance(event, FunctionToolCallEvent):
79
+ args = _get_truncated_event_part_args(event)
80
+ # Use preserve_leading_newline=True for the block header
81
+ fprint(
82
+ f"{event_prefix}🧰 {event.part.tool_call_id} | {event.part.tool_name} {args}",
83
+ preserve_leading_newline=True,
84
+ )
85
+ was_tool_call_delta = False
86
+ elif isinstance(event, FunctionToolResultEvent):
87
+ if show_tool_result:
88
+ fprint(
89
+ f"{event_prefix}🔠 {event.tool_call_id} | Return {event.result.content}",
90
+ preserve_leading_newline=True,
91
+ )
92
+ else:
93
+ fprint(
94
+ f"{event_prefix}🔠 {event.tool_call_id} Executed",
95
+ preserve_leading_newline=True,
96
+ )
97
+ was_tool_call_delta = False
98
+ elif isinstance(event, AgentRunResultEvent):
99
+ usage = event.result.usage()
100
+ usage_msg = " ".join(
101
+ [
102
+ "💸",
103
+ f"(Requests: {usage.requests} |",
104
+ f"Tool Calls: {usage.tool_calls} |",
105
+ f"Total: {usage.total_tokens})",
106
+ f"Input: {usage.input_tokens} |",
107
+ f"Audio Input: {usage.input_audio_tokens} |",
108
+ f"Output: {usage.output_tokens} |",
109
+ f"Audio Output: {usage.output_audio_tokens} |",
110
+ f"Cache Read: {usage.cache_read_tokens} |",
111
+ f"Cache Write: {usage.cache_write_tokens} |",
112
+ f"Details: {usage.details}",
113
+ ]
114
+ )
115
+ fprint(
116
+ f"{event_prefix}{stylize_faint(usage_msg)}\n",
117
+ preserve_leading_newline=True,
118
+ )
119
+ was_tool_call_delta = False
120
+ elif isinstance(event, FinalResultEvent):
121
+ was_tool_call_delta = False
122
+ event_prefix = f"\n{indentation}"
123
+
124
+ return handle_event
125
+
126
+
127
+ def create_faint_printer(ctx: AnyContext):
128
+ def faint_print(*values: object):
129
+ message = stylize_faint(" ".join([f"{value}" for value in values]))
130
+ ctx.print(message, end="", plain=True)
131
+
132
+ return faint_print
133
+
134
+
135
+ def _get_truncated_event_part_args(event: "AgentStreamEvent") -> Any:
136
+ # Handle empty arguments across different providers
137
+ if not hasattr(event, "part"):
138
+ return {}
139
+ part = getattr(event, "part")
140
+ if not hasattr(part, "args"):
141
+ return {}
142
+ args = getattr(part, "args")
143
+ if args == "" or args is None:
144
+ return {}
145
+ if isinstance(args, str):
146
+ # Some providers might send "null" or "{}" as a string
147
+ if args.strip() in ["null", "{}"]:
148
+ return {}
149
+ try:
150
+ obj = json.loads(args)
151
+ if isinstance(obj, dict):
152
+ return _truncate_kwargs(obj)
153
+ except json.JSONDecodeError:
154
+ pass
155
+ # Handle dummy property if present (from our schema sanitization)
156
+ if isinstance(args, dict):
157
+ return _truncate_kwargs(args)
158
+ return args
159
+
160
+
161
+ def _truncate_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
162
+ return {key: _truncate_arg(val) for key, val in kwargs.items()}
163
+
164
+
165
+ def _truncate_arg(arg: str, length: int = 30) -> str:
166
+ if isinstance(arg, str) and len(arg) > length:
167
+ return f"{arg[:length-4]} ..."
168
+ return arg
169
+
170
+
171
+ def _get_event_part_content(event: "AgentStreamEvent") -> str:
172
+ if not hasattr(event, "part"):
173
+ return ""
174
+ part = getattr(event, "part")
175
+ if hasattr(part, "content"):
176
+ return getattr(part, "content")
177
+ # For parts without content (like ToolCallPart, though we skip it now), return empty or simple repr
178
+ return ""
@@ -15,9 +15,6 @@ if TYPE_CHECKING:
15
15
  from zrb.task.any_task import AnyTask
16
16
 
17
17
 
18
- TAnySession = TypeVar("TAnySession", bound="AnySession")
19
-
20
-
21
18
  class AnySession(ABC):
22
19
  """Abstract base class for managing task execution and context in a session.
23
20
 
zrb/session/session.py CHANGED
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Coroutine
6
6
  from zrb.context.any_shared_context import AnySharedContext
7
7
  from zrb.context.context import AnyContext, Context
8
8
  from zrb.group.any_group import AnyGroup
9
- from zrb.session.any_session import AnySession, TAnySession
9
+ from zrb.session.any_session import AnySession
10
10
  from zrb.session_state_logger.any_session_state_logger import AnySessionStateLogger
11
11
  from zrb.session_state_logger.session_state_logger_factory import session_state_logger
12
12
  from zrb.task.any_task import AnyTask
zrb/task/base/context.py CHANGED
@@ -79,24 +79,36 @@ def combine_inputs(
79
79
  input_names.append(task_input.name) # Update names list
80
80
 
81
81
 
82
+ def combine_envs(
83
+ existing_envs: list[AnyEnv],
84
+ new_envs: list[AnyEnv | None] | AnyEnv | None,
85
+ ):
86
+ """
87
+ Combines new envs into an existing list.
88
+ Modifies the existing_envs list in place.
89
+ """
90
+ if isinstance(new_envs, AnyEnv):
91
+ existing_envs.append(new_envs)
92
+ elif new_envs is None:
93
+ pass
94
+ else:
95
+ # new_envs is a list
96
+ for env in new_envs:
97
+ if env is not None:
98
+ existing_envs.append(env)
99
+
100
+
82
101
  def get_combined_envs(task: "BaseTask") -> list[AnyEnv]:
83
102
  """
84
103
  Aggregates environment variables from the task and its upstreams.
85
104
  """
86
- envs = []
105
+ envs: list[AnyEnv] = []
87
106
  for upstream in task.upstreams:
88
- envs.extend(upstream.envs) # Use extend for list concatenation
89
-
90
- # Access _envs directly as task is BaseTask
91
- task_envs: list[AnyEnv | None] | AnyEnv | None = task._envs
92
- if isinstance(task_envs, AnyEnv):
93
- envs.append(task_envs)
94
- elif isinstance(task_envs, list):
95
- # Filter out None while extending
96
- envs.extend(env for env in task_envs if env is not None)
97
-
98
- # Filter out None values efficiently from the combined list
99
- return [env for env in envs if env is not None]
107
+ combine_envs(envs, upstream.envs)
108
+
109
+ combine_envs(envs, task._envs)
110
+
111
+ return envs
100
112
 
101
113
 
102
114
  def get_combined_inputs(task: "BaseTask") -> list[AnyInput]:
@@ -88,56 +88,61 @@ async def execute_action_until_ready(task: "BaseTask", session: AnySession):
88
88
  run_async(execute_action_with_retry(task, session))
89
89
  )
90
90
 
91
- await asyncio.sleep(readiness_check_delay)
92
-
93
- readiness_check_coros = [
94
- run_async(check.exec_chain(session)) for check in readiness_checks
95
- ]
96
-
97
- # Wait primarily for readiness checks to complete
98
- ctx.log_info("Waiting for readiness checks")
99
- readiness_passed = False
100
91
  try:
101
- # Gather results, but primarily interested in completion/errors
102
- await asyncio.gather(*readiness_check_coros)
103
- # Check if all readiness tasks actually completed successfully
104
- all_readiness_completed = all(
105
- session.get_task_status(check).is_completed for check in readiness_checks
106
- )
107
- if all_readiness_completed:
108
- ctx.log_info("Readiness checks completed successfully")
109
- readiness_passed = True
110
- # Mark task as ready only if checks passed and action didn't fail during checks
111
- if not session.get_task_status(task).is_failed:
112
- ctx.log_info("Marked as ready")
113
- session.get_task_status(task).mark_as_ready()
114
- else:
115
- ctx.log_warning(
116
- "One or more readiness checks did not complete successfully."
117
- )
92
+ await asyncio.sleep(readiness_check_delay)
118
93
 
119
- except Exception as e:
120
- ctx.log_error(f"Readiness check failed with exception: {e}")
121
- # If readiness checks fail with an exception, the task is not ready.
122
- # The action_coro might still be running or have failed.
123
- # execute_action_with_retry handles marking the main task status.
124
-
125
- # Defer the main action coroutine; it will be awaited later if needed
126
- session.defer_action(task, action_coro)
127
-
128
- # Start monitoring only if readiness passed and monitoring is enabled
129
- if readiness_passed and monitor_readiness:
130
- # Import dynamically to avoid circular dependency if monitoring imports execution
131
- from zrb.task.base.monitoring import monitor_task_readiness
94
+ readiness_check_coros = [
95
+ run_async(check.exec_chain(session)) for check in readiness_checks
96
+ ]
132
97
 
133
- monitor_coro = asyncio.create_task(
134
- run_async(monitor_task_readiness(task, session, action_coro))
135
- )
136
- session.defer_monitoring(task, monitor_coro)
98
+ # Wait primarily for readiness checks to complete
99
+ ctx.log_info("Waiting for readiness checks")
100
+ readiness_passed = False
101
+ try:
102
+ # Gather results, but primarily interested in completion/errors
103
+ await asyncio.gather(*readiness_check_coros)
104
+ # Check if all readiness tasks actually completed successfully
105
+ all_readiness_completed = all(
106
+ session.get_task_status(check).is_completed
107
+ for check in readiness_checks
108
+ )
109
+ if all_readiness_completed:
110
+ ctx.log_info("Readiness checks completed successfully")
111
+ readiness_passed = True
112
+ # Mark task as ready only if checks passed and action didn't fail during checks
113
+ if not session.get_task_status(task).is_failed:
114
+ ctx.log_info("Marked as ready")
115
+ session.get_task_status(task).mark_as_ready()
116
+ else:
117
+ ctx.log_warning(
118
+ "One or more readiness checks did not complete successfully."
119
+ )
120
+
121
+ except Exception as e:
122
+ ctx.log_error(f"Readiness check failed with exception: {e}")
123
+ # If readiness checks fail with an exception, the task is not ready.
124
+ # The action_coro might still be running or have failed.
125
+ # execute_action_with_retry handles marking the main task status.
126
+
127
+ # Defer the main action coroutine; it will be awaited later if needed
128
+ session.defer_action(task, action_coro)
129
+
130
+ # Start monitoring only if readiness passed and monitoring is enabled
131
+ if readiness_passed and monitor_readiness:
132
+ # Import dynamically to avoid circular dependency if monitoring imports execution
133
+ from zrb.task.base.monitoring import monitor_task_readiness
134
+
135
+ monitor_coro = asyncio.create_task(
136
+ run_async(monitor_task_readiness(task, session, action_coro))
137
+ )
138
+ session.defer_monitoring(task, monitor_coro)
137
139
 
138
- # The result here is primarily about readiness check completion.
139
- # The actual task result is handled by the deferred action_coro.
140
- return None
140
+ # The result here is primarily about readiness check completion.
141
+ # The actual task result is handled by the deferred action_coro.
142
+ return None
143
+ except (asyncio.CancelledError, KeyboardInterrupt, GeneratorExit):
144
+ action_coro.cancel()
145
+ raise
141
146
 
142
147
 
143
148
  async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> Any:
@@ -178,7 +183,7 @@ async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> An
178
183
  await run_async(execute_successors(task, session))
179
184
  return result
180
185
 
181
- except (asyncio.CancelledError, KeyboardInterrupt):
186
+ except (asyncio.CancelledError, KeyboardInterrupt, GeneratorExit):
182
187
  ctx.log_warning("Task cancelled or interrupted")
183
188
  session.get_task_status(task).mark_as_failed() # Mark as failed on cancel
184
189
  # Do not trigger fallbacks/successors on cancellation
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  from typing import Any
3
3
 
4
+ from zrb.context.print_fn import PrintFn
4
5
  from zrb.context.shared_context import SharedContext
5
6
  from zrb.session.any_session import AnySession
6
7
  from zrb.session.session import Session
@@ -12,6 +13,7 @@ from zrb.util.run import run_async
12
13
  async def run_and_cleanup(
13
14
  task: AnyTask,
14
15
  session: AnySession | None = None,
16
+ print_fn: PrintFn | None = None,
15
17
  str_kwargs: dict[str, str] | None = None,
16
18
  kwargs: dict[str, Any] | None = None,
17
19
  ) -> Any:
@@ -21,11 +23,11 @@ async def run_and_cleanup(
21
23
  """
22
24
  # Ensure a session exists
23
25
  if session is None:
24
- session = Session(shared_ctx=SharedContext())
26
+ session = Session(shared_ctx=SharedContext(print_fn=print_fn))
25
27
 
26
28
  # Create the main task execution coroutine
27
29
  main_task_coro = asyncio.create_task(
28
- run_task_async(task, session, str_kwargs, kwargs)
30
+ run_task_async(task, session, print_fn, str_kwargs, kwargs)
29
31
  )
30
32
 
31
33
  try:
@@ -70,6 +72,7 @@ async def run_and_cleanup(
70
72
  async def run_task_async(
71
73
  task: AnyTask,
72
74
  session: AnySession | None = None,
75
+ print_fn: PrintFn | None = None,
73
76
  str_kwargs: dict[str, str] | None = None,
74
77
  kwargs: dict[str, Any] | None = None,
75
78
  ) -> Any:
@@ -78,7 +81,7 @@ async def run_task_async(
78
81
  Sets up the session and initiates the root task execution chain.
79
82
  """
80
83
  if session is None:
81
- session = Session(shared_ctx=SharedContext())
84
+ session = Session(shared_ctx=SharedContext(print_fn=print_fn))
82
85
 
83
86
  # Populate shared context with inputs and environment variables
84
87
  fill_shared_context_inputs(session.shared_ctx, task, str_kwargs, kwargs)
@@ -176,7 +179,7 @@ async def log_session_state(task: AnyTask, session: AnySession):
176
179
  try:
177
180
  while not session.is_terminated:
178
181
  session.state_logger.write(session.as_state_log())
179
- await asyncio.sleep(0.1) # Log interval
182
+ await asyncio.sleep(0) # Log interval
180
183
  # Log one final time after termination signal
181
184
  session.state_logger.write(session.as_state_log())
182
185
  except asyncio.CancelledError: