zrb 1.8.10__py3-none-any.whl → 1.21.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of zrb might be problematic. Click here for more details.

Files changed (147) hide show
  1. zrb/__init__.py +126 -113
  2. zrb/__main__.py +1 -1
  3. zrb/attr/type.py +10 -7
  4. zrb/builtin/__init__.py +2 -50
  5. zrb/builtin/git.py +12 -1
  6. zrb/builtin/group.py +31 -15
  7. zrb/builtin/http.py +7 -8
  8. zrb/builtin/llm/attachment.py +40 -0
  9. zrb/builtin/llm/chat_completion.py +274 -0
  10. zrb/builtin/llm/chat_session.py +152 -85
  11. zrb/builtin/llm/chat_session_cmd.py +288 -0
  12. zrb/builtin/llm/chat_trigger.py +79 -0
  13. zrb/builtin/llm/history.py +7 -9
  14. zrb/builtin/llm/llm_ask.py +221 -98
  15. zrb/builtin/llm/tool/api.py +74 -52
  16. zrb/builtin/llm/tool/cli.py +46 -17
  17. zrb/builtin/llm/tool/code.py +71 -90
  18. zrb/builtin/llm/tool/file.py +301 -241
  19. zrb/builtin/llm/tool/note.py +84 -0
  20. zrb/builtin/llm/tool/rag.py +38 -8
  21. zrb/builtin/llm/tool/sub_agent.py +67 -50
  22. zrb/builtin/llm/tool/web.py +146 -122
  23. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
  24. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
  25. zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
  26. zrb/builtin/searxng/config/settings.yml +5671 -0
  27. zrb/builtin/searxng/start.py +21 -0
  28. zrb/builtin/setup/latex/ubuntu.py +1 -0
  29. zrb/builtin/setup/ubuntu.py +1 -1
  30. zrb/builtin/shell/autocomplete/bash.py +4 -3
  31. zrb/builtin/shell/autocomplete/zsh.py +4 -3
  32. zrb/builtin/todo.py +13 -2
  33. zrb/config/config.py +614 -0
  34. zrb/config/default_prompt/file_extractor_system_prompt.md +112 -0
  35. zrb/config/default_prompt/interactive_system_prompt.md +29 -0
  36. zrb/config/default_prompt/persona.md +1 -0
  37. zrb/config/default_prompt/repo_extractor_system_prompt.md +112 -0
  38. zrb/config/default_prompt/repo_summarizer_system_prompt.md +29 -0
  39. zrb/config/default_prompt/summarization_prompt.md +57 -0
  40. zrb/config/default_prompt/system_prompt.md +38 -0
  41. zrb/config/llm_config.py +339 -0
  42. zrb/config/llm_context/config.py +166 -0
  43. zrb/config/llm_context/config_parser.py +40 -0
  44. zrb/config/llm_context/workflow.py +81 -0
  45. zrb/config/llm_rate_limitter.py +190 -0
  46. zrb/{runner → config}/web_auth_config.py +17 -22
  47. zrb/context/any_shared_context.py +17 -1
  48. zrb/context/context.py +16 -2
  49. zrb/context/shared_context.py +18 -8
  50. zrb/group/any_group.py +12 -5
  51. zrb/group/group.py +67 -3
  52. zrb/input/any_input.py +5 -1
  53. zrb/input/base_input.py +18 -6
  54. zrb/input/option_input.py +13 -1
  55. zrb/input/text_input.py +8 -25
  56. zrb/runner/cli.py +25 -23
  57. zrb/runner/common_util.py +24 -19
  58. zrb/runner/web_app.py +3 -3
  59. zrb/runner/web_route/docs_route.py +1 -1
  60. zrb/runner/web_route/error_page/serve_default_404.py +1 -1
  61. zrb/runner/web_route/error_page/show_error_page.py +1 -1
  62. zrb/runner/web_route/home_page/home_page_route.py +2 -2
  63. zrb/runner/web_route/login_api_route.py +1 -1
  64. zrb/runner/web_route/login_page/login_page_route.py +2 -2
  65. zrb/runner/web_route/logout_api_route.py +1 -1
  66. zrb/runner/web_route/logout_page/logout_page_route.py +2 -2
  67. zrb/runner/web_route/node_page/group/show_group_page.py +1 -1
  68. zrb/runner/web_route/node_page/node_page_route.py +1 -1
  69. zrb/runner/web_route/node_page/task/show_task_page.py +1 -1
  70. zrb/runner/web_route/refresh_token_api_route.py +1 -1
  71. zrb/runner/web_route/static/static_route.py +1 -1
  72. zrb/runner/web_route/task_input_api_route.py +6 -6
  73. zrb/runner/web_route/task_session_api_route.py +20 -12
  74. zrb/runner/web_util/cookie.py +1 -1
  75. zrb/runner/web_util/token.py +1 -1
  76. zrb/runner/web_util/user.py +8 -4
  77. zrb/session/any_session.py +24 -17
  78. zrb/session/session.py +50 -25
  79. zrb/session_state_logger/any_session_state_logger.py +9 -4
  80. zrb/session_state_logger/file_session_state_logger.py +16 -6
  81. zrb/session_state_logger/session_state_logger_factory.py +1 -1
  82. zrb/task/any_task.py +30 -9
  83. zrb/task/base/context.py +17 -9
  84. zrb/task/base/execution.py +15 -8
  85. zrb/task/base/lifecycle.py +8 -4
  86. zrb/task/base/monitoring.py +12 -7
  87. zrb/task/base_task.py +69 -5
  88. zrb/task/base_trigger.py +12 -5
  89. zrb/task/cmd_task.py +1 -1
  90. zrb/task/llm/agent.py +154 -161
  91. zrb/task/llm/agent_runner.py +152 -0
  92. zrb/task/llm/config.py +47 -18
  93. zrb/task/llm/conversation_history.py +209 -0
  94. zrb/task/llm/conversation_history_model.py +67 -0
  95. zrb/task/llm/default_workflow/coding/workflow.md +41 -0
  96. zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
  97. zrb/task/llm/default_workflow/git/workflow.md +118 -0
  98. zrb/task/llm/default_workflow/golang/workflow.md +128 -0
  99. zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
  100. zrb/task/llm/default_workflow/java/workflow.md +146 -0
  101. zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
  102. zrb/task/llm/default_workflow/python/workflow.md +160 -0
  103. zrb/task/llm/default_workflow/researching/workflow.md +153 -0
  104. zrb/task/llm/default_workflow/rust/workflow.md +162 -0
  105. zrb/task/llm/default_workflow/shell/workflow.md +299 -0
  106. zrb/task/llm/error.py +24 -10
  107. zrb/task/llm/file_replacement.py +206 -0
  108. zrb/task/llm/file_tool_model.py +57 -0
  109. zrb/task/llm/history_processor.py +206 -0
  110. zrb/task/llm/history_summarization.py +11 -166
  111. zrb/task/llm/print_node.py +193 -69
  112. zrb/task/llm/prompt.py +242 -45
  113. zrb/task/llm/subagent_conversation_history.py +41 -0
  114. zrb/task/llm/tool_wrapper.py +260 -57
  115. zrb/task/llm/workflow.py +76 -0
  116. zrb/task/llm_task.py +182 -171
  117. zrb/task/make_task.py +2 -3
  118. zrb/task/rsync_task.py +26 -11
  119. zrb/task/scheduler.py +4 -4
  120. zrb/util/attr.py +54 -39
  121. zrb/util/callable.py +23 -0
  122. zrb/util/cli/markdown.py +12 -0
  123. zrb/util/cli/text.py +30 -0
  124. zrb/util/file.py +29 -11
  125. zrb/util/git.py +8 -11
  126. zrb/util/git_diff_model.py +10 -0
  127. zrb/util/git_subtree.py +9 -14
  128. zrb/util/git_subtree_model.py +32 -0
  129. zrb/util/init_path.py +1 -1
  130. zrb/util/markdown.py +62 -0
  131. zrb/util/string/conversion.py +2 -2
  132. zrb/util/todo.py +17 -50
  133. zrb/util/todo_model.py +46 -0
  134. zrb/util/truncate.py +23 -0
  135. zrb/util/yaml.py +204 -0
  136. zrb/xcom/xcom.py +10 -0
  137. zrb-1.21.29.dist-info/METADATA +270 -0
  138. {zrb-1.8.10.dist-info → zrb-1.21.29.dist-info}/RECORD +140 -98
  139. {zrb-1.8.10.dist-info → zrb-1.21.29.dist-info}/WHEEL +1 -1
  140. zrb/config.py +0 -335
  141. zrb/llm_config.py +0 -411
  142. zrb/llm_rate_limitter.py +0 -125
  143. zrb/task/llm/context.py +0 -102
  144. zrb/task/llm/context_enrichment.py +0 -199
  145. zrb/task/llm/history.py +0 -211
  146. zrb-1.8.10.dist-info/METADATA +0 -264
  147. {zrb-1.8.10.dist-info → zrb-1.21.29.dist-info}/entry_points.txt +0 -0
@@ -1,17 +1,19 @@
1
1
  from __future__ import annotations # Enables forward references
2
2
 
3
+ import asyncio
3
4
  from abc import ABC, abstractmethod
4
5
  from typing import TYPE_CHECKING, Any, Coroutine, TypeVar
5
6
 
6
7
  from zrb.context.any_context import AnyContext
7
8
  from zrb.group.any_group import AnyGroup
8
- from zrb.session_state_log.session_state_log import SessionStateLog
9
9
  from zrb.session_state_logger.any_session_state_logger import AnySessionStateLogger
10
10
  from zrb.task_status.task_status import TaskStatus
11
11
 
12
12
  if TYPE_CHECKING:
13
- from zrb.context import any_shared_context
14
- from zrb.task import any_task
13
+ from zrb.context.any_shared_context import AnySharedContext
14
+ from zrb.session_state_log.session_state_log import SessionStateLog
15
+ from zrb.task.any_task import AnyTask
16
+
15
17
 
16
18
  TAnySession = TypeVar("TAnySession", bound="AnySession")
17
19
 
@@ -44,7 +46,7 @@ class AnySession(ABC):
44
46
 
45
47
  @property
46
48
  @abstractmethod
47
- def shared_ctx(self) -> any_shared_context.AnySharedContext:
49
+ def shared_ctx(self) -> "AnySharedContext":
48
50
  """Shared context for this session"""
49
51
  pass
50
52
 
@@ -61,12 +63,13 @@ class AnySession(ABC):
61
63
 
62
64
  @property
63
65
  @abstractmethod
64
- def parent(self) -> TAnySession | None:
66
+ def parent(self) -> "AnySession | None":
65
67
  """Parent session"""
66
68
  pass
67
69
 
70
+ @property
68
71
  @abstractmethod
69
- def task_path(self) -> str:
72
+ def task_path(self) -> list[str]:
70
73
  """Main task's path"""
71
74
  pass
72
75
 
@@ -83,16 +86,16 @@ class AnySession(ABC):
83
86
  pass
84
87
 
85
88
  @abstractmethod
86
- def set_main_task(self, main_task: any_task.AnyTask):
89
+ def set_main_task(self, main_task: "AnyTask"):
87
90
  """Set main task"""
88
91
  pass
89
92
 
90
93
  @abstractmethod
91
- def as_state_log(self) -> SessionStateLog:
94
+ def as_state_log(self) -> "SessionStateLog":
92
95
  pass
93
96
 
94
97
  @abstractmethod
95
- def get_ctx(self, task: any_task.AnyTask) -> AnyContext:
98
+ def get_ctx(self, task: "AnyTask") -> AnyContext:
96
99
  """Retrieves the context for a specific task.
97
100
 
98
101
  Args:
@@ -104,7 +107,9 @@ class AnySession(ABC):
104
107
  pass
105
108
 
106
109
  @abstractmethod
107
- def defer_monitoring(self, task: any_task.AnyTask, coro: Coroutine):
110
+ def defer_monitoring(
111
+ self, task: "AnyTask", coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
112
+ ):
108
113
  """Defers the execution of a task's monitoring coroutine for later processing.
109
114
 
110
115
  Args:
@@ -114,7 +119,9 @@ class AnySession(ABC):
114
119
  pass
115
120
 
116
121
  @abstractmethod
117
- def defer_action(self, task: any_task.AnyTask, coro: Coroutine):
122
+ def defer_action(
123
+ self, task: "AnyTask", coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
124
+ ):
118
125
  """Defers the execution of a task's coroutine for later processing.
119
126
 
120
127
  Args:
@@ -124,7 +131,7 @@ class AnySession(ABC):
124
131
  pass
125
132
 
126
133
  @abstractmethod
127
- def defer_coro(self, coro: Coroutine):
134
+ def defer_coro(self, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]):
128
135
  """Defers the execution of a coroutine for later processing.
129
136
 
130
137
  Args:
@@ -138,7 +145,7 @@ class AnySession(ABC):
138
145
  pass
139
146
 
140
147
  @abstractmethod
141
- def register_task(self, task: any_task.AnyTask):
148
+ def register_task(self, task: "AnyTask"):
142
149
  """Registers a new task in the session.
143
150
 
144
151
  Args:
@@ -147,7 +154,7 @@ class AnySession(ABC):
147
154
  pass
148
155
 
149
156
  @abstractmethod
150
- def get_root_tasks(self, task: any_task.AnyTask) -> list[any_task.AnyTask]:
157
+ def get_root_tasks(self, task: "AnyTask") -> list["AnyTask"]:
151
158
  """Retrieves the list of root tasks that should be executed first
152
159
  to run the given task.
153
160
 
@@ -160,7 +167,7 @@ class AnySession(ABC):
160
167
  pass
161
168
 
162
169
  @abstractmethod
163
- def get_next_tasks(self, task: any_task.AnyTask) -> list[any_task.AnyTask]:
170
+ def get_next_tasks(self, task: "AnyTask") -> list["AnyTask"]:
164
171
  """Retrieves the list of tasks that should be executed after the given task.
165
172
 
166
173
  Args:
@@ -172,7 +179,7 @@ class AnySession(ABC):
172
179
  pass
173
180
 
174
181
  @abstractmethod
175
- def get_task_status(self, task: any_task.AnyTask) -> TaskStatus:
182
+ def get_task_status(self, task: "AnyTask") -> TaskStatus:
176
183
  """Get tasks' status.
177
184
 
178
185
  Args:
@@ -184,7 +191,7 @@ class AnySession(ABC):
184
191
  pass
185
192
 
186
193
  @abstractmethod
187
- def is_allowed_to_run(self, task: any_task.AnyTask):
194
+ def is_allowed_to_run(self, task: "AnyTask") -> bool:
188
195
  """Determines if the specified task is allowed to run based on its current state.
189
196
 
190
197
  Args:
zrb/session/session.py CHANGED
@@ -1,15 +1,12 @@
1
+ from __future__ import annotations
2
+
1
3
  import asyncio
2
- from typing import Any, Coroutine
4
+ from typing import TYPE_CHECKING, Any, Coroutine
3
5
 
4
6
  from zrb.context.any_shared_context import AnySharedContext
5
7
  from zrb.context.context import AnyContext, Context
6
8
  from zrb.group.any_group import AnyGroup
7
- from zrb.session.any_session import AnySession
8
- from zrb.session_state_log.session_state_log import (
9
- SessionStateLog,
10
- TaskStatusHistoryStateLog,
11
- TaskStatusStateLog,
12
- )
9
+ from zrb.session.any_session import AnySession, TAnySession
13
10
  from zrb.session_state_logger.any_session_state_logger import AnySessionStateLogger
14
11
  from zrb.session_state_logger.session_state_logger_factory import session_state_logger
15
12
  from zrb.task.any_task import AnyTask
@@ -32,6 +29,9 @@ from zrb.util.group import get_node_path
32
29
  from zrb.util.string.name import get_random_name
33
30
  from zrb.xcom.xcom import Xcom
34
31
 
32
+ if TYPE_CHECKING:
33
+ from zrb.session_state_log.session_state_log import SessionStateLog
34
+
35
35
 
36
36
  class Session(AnySession):
37
37
  def __init__(
@@ -50,10 +50,10 @@ class Session(AnySession):
50
50
  self._context: dict[AnyTask, Context] = {}
51
51
  self._shared_ctx = shared_ctx
52
52
  self._shared_ctx.set_session(self)
53
- self._parent = parent
54
- self._action_coros: dict[AnyTask, asyncio.Task] = {}
55
- self._monitoring_coros: dict[AnyTask, asyncio.Task] = {}
56
- self._coros: list[asyncio.Task] = []
53
+ self._parent: AnySession | None = parent
54
+ self._action_coros: dict[AnyTask, asyncio.Task[Any]] = {}
55
+ self._monitoring_coros: dict[AnyTask, asyncio.Task[Any]] = {}
56
+ self._coros: list[asyncio.Task[Any]] = []
57
57
  self._colors = [
58
58
  GREEN,
59
59
  YELLOW,
@@ -116,11 +116,13 @@ class Session(AnySession):
116
116
  return self._parent
117
117
 
118
118
  @property
119
- def task_path(self) -> str:
119
+ def task_path(self) -> list[str]:
120
120
  return self._main_task_path
121
121
 
122
122
  @property
123
123
  def final_result(self) -> Any:
124
+ if self._main_task is None:
125
+ return None
124
126
  xcom: Xcom = self.shared_ctx.xcom[self._main_task.name]
125
127
  try:
126
128
  return xcom.peek()
@@ -136,10 +138,20 @@ class Session(AnySession):
136
138
  def set_main_task(self, main_task: AnyTask):
137
139
  self.register_task(main_task)
138
140
  self._main_task = main_task
139
- main_task_path = get_node_path(self._root_group, main_task)
141
+ main_task_path = (
142
+ None
143
+ if self._root_group is None
144
+ else get_node_path(self._root_group, main_task)
145
+ )
140
146
  self._main_task_path = [] if main_task_path is None else main_task_path
141
147
 
142
- def as_state_log(self) -> SessionStateLog:
148
+ def as_state_log(self) -> "SessionStateLog":
149
+ from zrb.session_state_log.session_state_log import (
150
+ SessionStateLog,
151
+ TaskStatusHistoryStateLog,
152
+ TaskStatusStateLog,
153
+ )
154
+
143
155
  task_status_log: dict[str, TaskStatusStateLog] = {}
144
156
  log_start_time = ""
145
157
  for task, task_status in self._task_status.items():
@@ -167,7 +179,7 @@ class Session(AnySession):
167
179
  return SessionStateLog(
168
180
  name=self.name,
169
181
  start_time=log_start_time,
170
- main_task_name=self._main_task.name,
182
+ main_task_name="" if self._main_task is None else self._main_task.name,
171
183
  path=self.task_path,
172
184
  final_result=(
173
185
  remove_style(f"{self.final_result}")
@@ -184,16 +196,29 @@ class Session(AnySession):
184
196
  self._register_single_task(task)
185
197
  return self._context[task]
186
198
 
187
- def defer_monitoring(self, task: AnyTask, coro: Coroutine):
199
+ def defer_monitoring(
200
+ self, task: AnyTask, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
201
+ ):
188
202
  self._register_single_task(task)
189
- self._monitoring_coros[task] = coro
203
+ if isinstance(coro, asyncio.Task):
204
+ self._monitoring_coros[task] = coro
205
+ else:
206
+ self._monitoring_coros[task] = asyncio.create_task(coro)
190
207
 
191
- def defer_action(self, task: AnyTask, coro: Coroutine):
208
+ def defer_action(
209
+ self, task: AnyTask, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
210
+ ):
192
211
  self._register_single_task(task)
193
- self._action_coros[task] = coro
212
+ if isinstance(coro, asyncio.Task):
213
+ self._action_coros[task] = coro
214
+ else:
215
+ self._action_coros[task] = asyncio.create_task(coro)
194
216
 
195
- def defer_coro(self, coro: Coroutine):
196
- self._coros.append(coro)
217
+ def defer_coro(self, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]):
218
+ if isinstance(coro, asyncio.Task):
219
+ self._coros.append(coro)
220
+ else:
221
+ self._coros.append(asyncio.create_task(coro))
197
222
  self._coros = [
198
223
  existing_coro for existing_coro in self._coros if not existing_coro.done()
199
224
  ]
@@ -242,15 +267,15 @@ class Session(AnySession):
242
267
 
243
268
  def get_next_tasks(self, task: AnyTask) -> list[AnyTask]:
244
269
  self._register_single_task(task)
245
- return self._downstreams.get(task)
270
+ return self._downstreams.get(task, [])
246
271
 
247
272
  def get_task_status(self, task: AnyTask) -> TaskStatus:
248
273
  self._register_single_task(task)
249
274
  return self._task_status[task]
250
275
 
251
276
  def _register_single_task(self, task: AnyTask):
252
- if task.name not in self._shared_ctx._xcom:
253
- self._shared_ctx._xcom[task.name] = Xcom([])
277
+ if task.name not in self._shared_ctx.xcom:
278
+ self._shared_ctx.xcom[task.name] = Xcom([])
254
279
  if task not in self._context:
255
280
  self._context[task] = Context(
256
281
  shared_ctx=self._shared_ctx,
@@ -274,7 +299,7 @@ class Session(AnySession):
274
299
  self._color_index = 0
275
300
  return chosen
276
301
 
277
- def _get_icon(self, task: AnyTask) -> int:
302
+ def _get_icon(self, task: AnyTask) -> str:
278
303
  if task.icon is not None:
279
304
  return task.icon
280
305
  chosen = self._icons[self._icon_index]
@@ -1,16 +1,21 @@
1
1
  import datetime
2
2
  from abc import ABC, abstractmethod
3
+ from typing import TYPE_CHECKING
3
4
 
4
- from zrb.session_state_log.session_state_log import SessionStateLog, SessionStateLogList
5
+ if TYPE_CHECKING:
6
+ from zrb.session_state_log.session_state_log import (
7
+ SessionStateLog,
8
+ SessionStateLogList,
9
+ )
5
10
 
6
11
 
7
12
  class AnySessionStateLogger(ABC):
8
13
  @abstractmethod
9
- def write(self, session_log: SessionStateLog):
14
+ def write(self, session_log: "SessionStateLog"):
10
15
  pass
11
16
 
12
17
  @abstractmethod
13
- def read(self, session_name: str) -> SessionStateLog:
18
+ def read(self, session_name: str) -> "SessionStateLog":
14
19
  pass
15
20
 
16
21
  @abstractmethod
@@ -21,5 +26,5 @@ class AnySessionStateLogger(ABC):
21
26
  max_start_time: datetime.datetime,
22
27
  page: int = 0,
23
28
  limit: int = 10,
24
- ) -> SessionStateLogList:
29
+ ) -> "SessionStateLogList":
25
30
  pass
@@ -1,16 +1,22 @@
1
1
  import datetime
2
2
  import os
3
+ from typing import TYPE_CHECKING
3
4
 
4
- from zrb.session_state_log.session_state_log import SessionStateLog, SessionStateLogList
5
5
  from zrb.session_state_logger.any_session_state_logger import AnySessionStateLogger
6
6
  from zrb.util.file import read_file, write_file
7
7
 
8
+ if TYPE_CHECKING:
9
+ from zrb.session_state_log.session_state_log import (
10
+ SessionStateLog,
11
+ SessionStateLogList,
12
+ )
13
+
8
14
 
9
15
  class FileSessionStateLogger(AnySessionStateLogger):
10
16
  def __init__(self, session_log_dir: str):
11
17
  self._session_log_dir = session_log_dir
12
18
 
13
- def write(self, session_log: SessionStateLog):
19
+ def write(self, session_log: "SessionStateLog"):
14
20
  session_file_path = self._get_session_file_path(session_log.name)
15
21
  session_dir_path = os.path.dirname(session_file_path)
16
22
  if not os.path.isdir(session_dir_path):
@@ -22,7 +28,9 @@ class FileSessionStateLogger(AnySessionStateLogger):
22
28
  timeline_dir_path = self._get_timeline_dir_path(session_log)
23
29
  write_file(os.path.join(timeline_dir_path, session_log.name), "")
24
30
 
25
- def read(self, session_name: str) -> SessionStateLog:
31
+ def read(self, session_name: str) -> "SessionStateLog":
32
+ from zrb.session_state_log.session_state_log import SessionStateLog
33
+
26
34
  session_file_path = self._get_session_file_path(session_name)
27
35
  return SessionStateLog.model_validate_json(read_file(session_file_path))
28
36
 
@@ -33,7 +41,9 @@ class FileSessionStateLogger(AnySessionStateLogger):
33
41
  max_start_time: datetime.datetime,
34
42
  page: int = 0,
35
43
  limit: int = 10,
36
- ) -> SessionStateLogList:
44
+ ) -> "SessionStateLogList":
45
+ from zrb.session_state_log.session_state_log import SessionStateLogList
46
+
37
47
  matching_sessions = []
38
48
  # Traverse the timeline directory and filter sessions
39
49
  timeline_dir = os.path.join(self._session_log_dir, "_timeline", *task_path)
@@ -62,7 +72,7 @@ class FileSessionStateLogger(AnySessionStateLogger):
62
72
  def _get_session_file_path(self, session_name: str) -> str:
63
73
  return os.path.join(self._session_log_dir, f"{session_name}.json")
64
74
 
65
- def _get_timeline_dir_path(self, session_log: SessionStateLog) -> str:
75
+ def _get_timeline_dir_path(self, session_log: "SessionStateLog") -> str:
66
76
  start_time = self._get_start_time(session_log)
67
77
  year = start_time.year
68
78
  month = start_time.month
@@ -80,7 +90,7 @@ class FileSessionStateLogger(AnySessionStateLogger):
80
90
  ]
81
91
  return os.path.join(self._session_log_dir, "_timeline", *paths)
82
92
 
83
- def _get_start_time(self, session_log: SessionStateLog) -> datetime.datetime:
93
+ def _get_start_time(self, session_log: "SessionStateLog") -> datetime.datetime:
84
94
  return datetime.datetime.strptime(
85
95
  session_log.start_time, "%Y-%m-%d %H:%M:%S.%f"
86
96
  )
@@ -1,4 +1,4 @@
1
- from zrb.config import CFG
1
+ from zrb.config.config import CFG
2
2
  from zrb.session_state_logger.file_session_state_logger import FileSessionStateLogger
3
3
 
4
4
  session_state_logger = FileSessionStateLogger(CFG.SESSION_LOG_DIR)
zrb/task/any_task.py CHANGED
@@ -1,14 +1,14 @@
1
1
  from __future__ import annotations # Enables forward references
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING, Any
4
+ from typing import TYPE_CHECKING, Any, Callable
5
5
 
6
6
  from zrb.env.any_env import AnyEnv
7
7
  from zrb.input.any_input import AnyInput
8
8
 
9
9
  if TYPE_CHECKING:
10
- from zrb.context import any_context
11
- from zrb.session import session
10
+ from zrb.context.any_context import AnyContext
11
+ from zrb.session.any_session import AnySession
12
12
 
13
13
 
14
14
  class AnyTask(ABC):
@@ -36,6 +36,14 @@ class AnyTask(ABC):
36
36
  the actual implementation for these abstract members.
37
37
  """
38
38
 
39
+ @abstractmethod
40
+ def __rshift__(self, other: "AnyTask | list[AnyTask]") -> "AnyTask | list[AnyTask]":
41
+ pass
42
+
43
+ @abstractmethod
44
+ def __lshift__(self, other: "AnyTask | list[AnyTask]") -> "AnyTask":
45
+ pass
46
+
39
47
  @property
40
48
  @abstractmethod
41
49
  def name(self) -> str:
@@ -143,18 +151,22 @@ class AnyTask(ABC):
143
151
  pass
144
152
 
145
153
  @abstractmethod
146
- def get_ctx(self, session: session.AnySession) -> any_context.AnyContext:
154
+ def get_ctx(self, session: "AnySession") -> "AnyContext":
147
155
  pass
148
156
 
149
157
  @abstractmethod
150
158
  def run(
151
- self, session: session.AnySession | None = None, str_kwargs: dict[str, str] = {}
159
+ self,
160
+ session: "AnySession | None" = None,
161
+ str_kwargs: dict[str, str] | None = None,
162
+ kwargs: dict[str, Any] | None = None,
152
163
  ) -> Any:
153
164
  """Runs the task synchronously.
154
165
 
155
166
  Args:
156
167
  session (AnySession): The shared session.
157
168
  str_kwargs(dict[str, str]): The input string values.
169
+ kwargs(dict[str, Any]): The input values.
158
170
 
159
171
  Returns:
160
172
  Any: The result of the task execution.
@@ -163,13 +175,17 @@ class AnyTask(ABC):
163
175
 
164
176
  @abstractmethod
165
177
  async def async_run(
166
- self, session: session.AnySession | None = None, str_kwargs: dict[str, str] = {}
178
+ self,
179
+ session: "AnySession | None" = None,
180
+ str_kwargs: dict[str, str] | None = None,
181
+ kwargs: dict[str, Any] | None = None,
167
182
  ) -> Any:
168
183
  """Runs the task asynchronously.
169
184
 
170
185
  Args:
171
186
  session (AnySession): The shared session.
172
187
  str_kwargs(dict[str, str]): The input string values.
188
+ kwargs(dict[str, Any]): The input values.
173
189
 
174
190
  Returns:
175
191
  Any: The result of the task execution.
@@ -177,7 +193,7 @@ class AnyTask(ABC):
177
193
  pass
178
194
 
179
195
  @abstractmethod
180
- async def exec_root_tasks(self, session: session.AnySession):
196
+ async def exec_root_tasks(self, session: "AnySession"):
181
197
  """Execute the root tasks along with the downstreams until the current task
182
198
  is ready.
183
199
 
@@ -187,7 +203,7 @@ class AnyTask(ABC):
187
203
  pass
188
204
 
189
205
  @abstractmethod
190
- async def exec_chain(self, session: session.AnySession):
206
+ async def exec_chain(self, session: "AnySession"):
191
207
  """Execute the task along with the downstreams.
192
208
 
193
209
  Args:
@@ -196,10 +212,15 @@ class AnyTask(ABC):
196
212
  pass
197
213
 
198
214
  @abstractmethod
199
- async def exec(self, session: session.AnySession):
215
+ async def exec(self, session: "AnySession"):
200
216
  """Execute the task (without upstream or downstream).
201
217
 
202
218
  Args:
203
219
  session (AnySession): The shared session.
204
220
  """
205
221
  pass
222
+
223
+ @abstractmethod
224
+ def to_function(self) -> Callable[..., Any]:
225
+ """Turn a task into a function"""
226
+ pass
zrb/task/base/context.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING
2
+ from typing import TYPE_CHECKING, Any
3
3
 
4
4
  from zrb.context.any_context import AnyContext
5
5
  from zrb.context.any_shared_context import AnySharedContext
@@ -26,25 +26,33 @@ def build_task_context(task: AnyTask, session: AnySession) -> AnyContext:
26
26
 
27
27
 
28
28
  def fill_shared_context_inputs(
29
- task: AnyTask, shared_context: AnySharedContext, str_kwargs: dict[str, str] = {}
29
+ shared_ctx: AnySharedContext,
30
+ task: AnyTask,
31
+ str_kwargs: dict[str, str] | None = None,
32
+ kwargs: dict[str, Any] | None = None,
30
33
  ):
31
34
  """
32
- Populates the shared context with input values provided via kwargs.
35
+ Populates the shared context with input values provided via str_kwargs.
33
36
  """
37
+ str_kwarg_dict = str_kwargs if str_kwargs is not None else {}
38
+ kwarg_dict = kwargs if kwargs is not None else {}
34
39
  for task_input in task.inputs:
35
- if task_input.name not in shared_context.input:
36
- str_value = str_kwargs.get(task_input.name, None)
37
- task_input.update_shared_context(shared_context, str_value)
40
+ if task_input.name not in shared_ctx.input:
41
+ task_input.update_shared_context(
42
+ shared_ctx,
43
+ value=kwarg_dict.get(task_input.name, None),
44
+ str_value=str_kwarg_dict.get(task_input.name, None),
45
+ )
38
46
 
39
47
 
40
- def fill_shared_context_envs(shared_context: AnySharedContext):
48
+ def fill_shared_context_envs(shared_ctx: AnySharedContext):
41
49
  """
42
50
  Injects OS environment variables into the shared context if they don't already exist.
43
51
  """
44
52
  os_env_map = {
45
- key: val for key, val in os.environ.items() if key not in shared_context.env
53
+ key: val for key, val in os.environ.items() if key not in shared_ctx.env
46
54
  }
47
- shared_context.env.update(os_env_map)
55
+ shared_ctx.env.update(os_env_map)
48
56
 
49
57
 
50
58
  def combine_inputs(
@@ -53,7 +53,9 @@ def check_execute_condition(task: "BaseTask", session: AnySession) -> bool:
53
53
  Evaluates the task's execute_condition attribute.
54
54
  """
55
55
  ctx = task.get_ctx(session)
56
- execute_condition_attr = getattr(task, "_execute_condition", True)
56
+ execute_condition_attr = (
57
+ task._execute_condition if task._execute_condition is not None else True
58
+ )
57
59
  return get_bool_attr(ctx, execute_condition_attr, True, auto_render=True)
58
60
 
59
61
 
@@ -63,8 +65,12 @@ async def execute_action_until_ready(task: "BaseTask", session: AnySession):
63
65
  """
64
66
  ctx = task.get_ctx(session)
65
67
  readiness_checks = task.readiness_checks
66
- readiness_check_delay = getattr(task, "_readiness_check_delay", 0.5)
67
- monitor_readiness = getattr(task, "_monitor_readiness", False)
68
+ readiness_check_delay = (
69
+ task._readiness_check_delay if task._readiness_check_delay is not None else 0.5
70
+ )
71
+ monitor_readiness = (
72
+ task._monitor_readiness if task._monitor_readiness is not None else False
73
+ )
68
74
 
69
75
  if not readiness_checks: # Simplified check for empty list
70
76
  ctx.log_info("No readiness checks")
@@ -140,8 +146,8 @@ async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> An
140
146
  handling success (triggering successors) and failure (triggering fallbacks).
141
147
  """
142
148
  ctx = task.get_ctx(session)
143
- retries = getattr(task, "_retries", 2)
144
- retry_period = getattr(task, "_retry_period", 0)
149
+ retries = task._retries if task._retries is not None else 2
150
+ retry_period = task._retry_period if task._retry_period is not None else 0
145
151
  max_attempt = retries + 1
146
152
  ctx.set_max_attempt(max_attempt)
147
153
 
@@ -163,8 +169,9 @@ async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> An
163
169
  session.get_task_status(task).mark_as_completed()
164
170
 
165
171
  # Store result in XCom
166
- task_xcom: Xcom = ctx.xcom.get(task.name)
167
- task_xcom.push(result)
172
+ task_xcom: Xcom | None = ctx.xcom.get(task.name)
173
+ if task_xcom is not None:
174
+ task_xcom.push(result)
168
175
 
169
176
  # Skip fallbacks and execute successors on success
170
177
  skip_fallbacks(task, session)
@@ -201,7 +208,7 @@ async def run_default_action(task: "BaseTask", ctx: AnyContext) -> Any:
201
208
  This is the default implementation called by BaseTask._exec_action.
202
209
  Subclasses like LLMTask override _exec_action with their own logic.
203
210
  """
204
- action = getattr(task, "_action", None)
211
+ action = task._action
205
212
  if action is None:
206
213
  ctx.log_debug("No action defined for this task.")
207
214
  return None
@@ -12,7 +12,8 @@ from zrb.util.run import run_async
12
12
  async def run_and_cleanup(
13
13
  task: AnyTask,
14
14
  session: AnySession | None = None,
15
- str_kwargs: dict[str, str] = {},
15
+ str_kwargs: dict[str, str] | None = None,
16
+ kwargs: dict[str, Any] | None = None,
16
17
  ) -> Any:
17
18
  """
18
19
  Wrapper for async_run that ensures session termination and cleanup of
@@ -23,7 +24,9 @@ async def run_and_cleanup(
23
24
  session = Session(shared_ctx=SharedContext())
24
25
 
25
26
  # Create the main task execution coroutine
26
- main_task_coro = asyncio.create_task(run_task_async(task, session, str_kwargs))
27
+ main_task_coro = asyncio.create_task(
28
+ run_task_async(task, session, str_kwargs, kwargs)
29
+ )
27
30
 
28
31
  try:
29
32
  result = await main_task_coro
@@ -67,7 +70,8 @@ async def run_and_cleanup(
67
70
  async def run_task_async(
68
71
  task: AnyTask,
69
72
  session: AnySession | None = None,
70
- str_kwargs: dict[str, str] = {},
73
+ str_kwargs: dict[str, str] | None = None,
74
+ kwargs: dict[str, Any] | None = None,
71
75
  ) -> Any:
72
76
  """
73
77
  Asynchronous entry point for running a task (`task.async_run()`).
@@ -77,7 +81,7 @@ async def run_task_async(
77
81
  session = Session(shared_ctx=SharedContext())
78
82
 
79
83
  # Populate shared context with inputs and environment variables
80
- fill_shared_context_inputs(task, session.shared_ctx, str_kwargs)
84
+ fill_shared_context_inputs(session.shared_ctx, task, str_kwargs, kwargs)
81
85
  fill_shared_context_envs(session.shared_ctx) # Inject OS env vars
82
86
 
83
87
  # Start the execution chain from the root tasks