zrb 1.13.1__py3-none-any.whl → 1.21.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/__init__.py +2 -6
- zrb/attr/type.py +8 -8
- zrb/builtin/__init__.py +2 -0
- zrb/builtin/group.py +31 -15
- zrb/builtin/http.py +7 -8
- zrb/builtin/llm/attachment.py +40 -0
- zrb/builtin/llm/chat_session.py +130 -144
- zrb/builtin/llm/chat_session_cmd.py +226 -0
- zrb/builtin/llm/chat_trigger.py +73 -0
- zrb/builtin/llm/history.py +4 -4
- zrb/builtin/llm/llm_ask.py +218 -110
- zrb/builtin/llm/tool/api.py +74 -62
- zrb/builtin/llm/tool/cli.py +35 -16
- zrb/builtin/llm/tool/code.py +49 -47
- zrb/builtin/llm/tool/file.py +262 -251
- zrb/builtin/llm/tool/note.py +84 -0
- zrb/builtin/llm/tool/rag.py +25 -18
- zrb/builtin/llm/tool/sub_agent.py +29 -22
- zrb/builtin/llm/tool/web.py +135 -143
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
- zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
- zrb/builtin/searxng/config/settings.yml +5671 -0
- zrb/builtin/searxng/start.py +21 -0
- zrb/builtin/setup/latex/ubuntu.py +1 -0
- zrb/builtin/setup/ubuntu.py +1 -1
- zrb/builtin/shell/autocomplete/bash.py +4 -3
- zrb/builtin/shell/autocomplete/zsh.py +4 -3
- zrb/config/config.py +255 -78
- zrb/config/default_prompt/file_extractor_system_prompt.md +109 -9
- zrb/config/default_prompt/interactive_system_prompt.md +24 -30
- zrb/config/default_prompt/persona.md +1 -1
- zrb/config/default_prompt/repo_extractor_system_prompt.md +31 -31
- zrb/config/default_prompt/repo_summarizer_system_prompt.md +27 -8
- zrb/config/default_prompt/summarization_prompt.md +8 -13
- zrb/config/default_prompt/system_prompt.md +36 -30
- zrb/config/llm_config.py +129 -24
- zrb/config/llm_context/config.py +127 -90
- zrb/config/llm_context/config_parser.py +1 -7
- zrb/config/llm_context/workflow.py +81 -0
- zrb/config/llm_rate_limitter.py +89 -45
- zrb/context/any_shared_context.py +7 -1
- zrb/context/context.py +8 -2
- zrb/context/shared_context.py +6 -8
- zrb/group/any_group.py +12 -5
- zrb/group/group.py +67 -3
- zrb/input/any_input.py +5 -1
- zrb/input/base_input.py +18 -6
- zrb/input/text_input.py +7 -24
- zrb/runner/cli.py +21 -20
- zrb/runner/common_util.py +24 -19
- zrb/runner/web_route/task_input_api_route.py +5 -5
- zrb/runner/web_route/task_session_api_route.py +1 -4
- zrb/runner/web_util/user.py +7 -3
- zrb/session/any_session.py +12 -6
- zrb/session/session.py +39 -18
- zrb/task/any_task.py +24 -3
- zrb/task/base/context.py +17 -9
- zrb/task/base/execution.py +15 -8
- zrb/task/base/lifecycle.py +8 -4
- zrb/task/base/monitoring.py +12 -7
- zrb/task/base_task.py +69 -5
- zrb/task/base_trigger.py +12 -5
- zrb/task/llm/agent.py +138 -52
- zrb/task/llm/config.py +45 -13
- zrb/task/llm/conversation_history.py +76 -6
- zrb/task/llm/conversation_history_model.py +0 -168
- zrb/task/llm/default_workflow/coding/workflow.md +41 -0
- zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
- zrb/task/llm/default_workflow/git/workflow.md +118 -0
- zrb/task/llm/default_workflow/golang/workflow.md +128 -0
- zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
- zrb/task/llm/default_workflow/java/workflow.md +146 -0
- zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
- zrb/task/llm/default_workflow/python/workflow.md +160 -0
- zrb/task/llm/default_workflow/researching/workflow.md +153 -0
- zrb/task/llm/default_workflow/rust/workflow.md +162 -0
- zrb/task/llm/default_workflow/shell/workflow.md +299 -0
- zrb/task/llm/file_replacement.py +206 -0
- zrb/task/llm/file_tool_model.py +57 -0
- zrb/task/llm/history_summarization.py +22 -35
- zrb/task/llm/history_summarization_tool.py +24 -0
- zrb/task/llm/print_node.py +182 -63
- zrb/task/llm/prompt.py +213 -153
- zrb/task/llm/tool_wrapper.py +210 -53
- zrb/task/llm/workflow.py +76 -0
- zrb/task/llm_task.py +98 -47
- zrb/task/make_task.py +2 -3
- zrb/task/rsync_task.py +25 -10
- zrb/task/scheduler.py +4 -4
- zrb/util/attr.py +50 -40
- zrb/util/cli/markdown.py +12 -0
- zrb/util/cli/text.py +30 -0
- zrb/util/file.py +27 -11
- zrb/util/{llm/prompt.py → markdown.py} +2 -3
- zrb/util/string/conversion.py +1 -1
- zrb/util/truncate.py +23 -0
- zrb/util/yaml.py +204 -0
- {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/METADATA +40 -20
- {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/RECORD +102 -79
- {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/WHEEL +1 -1
- zrb/task/llm/default_workflow/coding.md +0 -24
- zrb/task/llm/default_workflow/copywriting.md +0 -17
- zrb/task/llm/default_workflow/researching.md +0 -18
- {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/entry_points.txt +0 -0
zrb/runner/common_util.py
CHANGED
|
@@ -1,31 +1,36 @@
|
|
|
1
|
-
from typing import Any
|
|
2
|
-
|
|
3
1
|
from zrb.context.shared_context import SharedContext
|
|
4
2
|
from zrb.task.any_task import AnyTask
|
|
5
3
|
|
|
6
4
|
|
|
7
|
-
def
|
|
8
|
-
task: AnyTask,
|
|
5
|
+
def get_task_str_kwargs(
|
|
6
|
+
task: AnyTask, str_args: list[str], str_kwargs: dict[str, str], cli_mode: bool
|
|
9
7
|
) -> dict[str, str]:
|
|
10
8
|
arg_index = 0
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
shared_ctx = SharedContext(args=args)
|
|
9
|
+
dummmy_shared_ctx = SharedContext()
|
|
10
|
+
task_str_kwargs = {}
|
|
14
11
|
for task_input in task.inputs:
|
|
12
|
+
task_name = task_input.name
|
|
15
13
|
if task_input.name in str_kwargs:
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
14
|
+
task_str_kwargs[task_input.name] = str_kwargs[task_name]
|
|
15
|
+
# Update dummy shared context for next input default value
|
|
16
|
+
task_input.update_shared_context(
|
|
17
|
+
dummmy_shared_ctx, str_value=str_kwargs[task_name]
|
|
18
|
+
)
|
|
19
|
+
elif arg_index < len(str_args) and task_input.allow_positional_parsing:
|
|
20
|
+
task_str_kwargs[task_name] = str_args[arg_index]
|
|
21
|
+
# Update dummy shared context for next input default value
|
|
22
|
+
task_input.update_shared_context(
|
|
23
|
+
dummmy_shared_ctx, str_value=task_str_kwargs[task_name]
|
|
24
|
+
)
|
|
22
25
|
arg_index += 1
|
|
23
26
|
else:
|
|
24
27
|
if cli_mode and task_input.always_prompt:
|
|
25
|
-
str_value = task_input.prompt_cli_str(
|
|
28
|
+
str_value = task_input.prompt_cli_str(dummmy_shared_ctx)
|
|
26
29
|
else:
|
|
27
|
-
str_value = task_input.get_default_str(
|
|
28
|
-
|
|
29
|
-
# Update shared context for next input default value
|
|
30
|
-
task_input.update_shared_context(
|
|
31
|
-
|
|
30
|
+
str_value = task_input.get_default_str(dummmy_shared_ctx)
|
|
31
|
+
task_str_kwargs[task_name] = str_value
|
|
32
|
+
# Update dummy shared context for next input default value
|
|
33
|
+
task_input.update_shared_context(
|
|
34
|
+
dummmy_shared_ctx, str_value=task_str_kwargs[task_name]
|
|
35
|
+
)
|
|
36
|
+
return task_str_kwargs
|
|
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
|
|
|
3
3
|
|
|
4
4
|
from zrb.config.web_auth_config import WebAuthConfig
|
|
5
5
|
from zrb.group.any_group import AnyGroup
|
|
6
|
-
from zrb.runner.common_util import
|
|
6
|
+
from zrb.runner.common_util import get_task_str_kwargs
|
|
7
7
|
from zrb.runner.web_util.user import get_user_from_request
|
|
8
8
|
from zrb.task.any_task import AnyTask
|
|
9
9
|
from zrb.util.group import NodeNotFoundError, extract_node_from_args
|
|
@@ -39,9 +39,9 @@ def serve_task_input_api(
|
|
|
39
39
|
if isinstance(task, AnyTask):
|
|
40
40
|
if not user.can_access_task(task):
|
|
41
41
|
return JSONResponse(content={"detail": "Forbidden"}, status_code=403)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
task=task,
|
|
42
|
+
str_kwargs = json.loads(query)
|
|
43
|
+
task_str_kwargs = get_task_str_kwargs(
|
|
44
|
+
task=task, str_args=[], str_kwargs=str_kwargs, cli_mode=False
|
|
45
45
|
)
|
|
46
|
-
return
|
|
46
|
+
return task_str_kwargs
|
|
47
47
|
return JSONResponse(content={"detail": "Not found"}, status_code=404)
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import os
|
|
3
2
|
from datetime import datetime, timedelta
|
|
4
3
|
from typing import TYPE_CHECKING, Any
|
|
5
4
|
|
|
@@ -57,9 +56,7 @@ def serve_task_session_api(
|
|
|
57
56
|
return JSONResponse(content={"detail": "Forbidden"}, status_code=403)
|
|
58
57
|
session_name = residual_args[0] if residual_args else None
|
|
59
58
|
if not session_name:
|
|
60
|
-
shared_ctx = SharedContext(
|
|
61
|
-
env={**dict(os.environ), "_ZRB_IS_WEB_MODE": "1"}
|
|
62
|
-
)
|
|
59
|
+
shared_ctx = SharedContext(is_web_mode=True)
|
|
63
60
|
session = Session(shared_ctx=shared_ctx, root_group=root_group)
|
|
64
61
|
coro = asyncio.create_task(task.async_run(session, str_kwargs=inputs))
|
|
65
62
|
coroutines.append(coro)
|
zrb/runner/web_util/user.py
CHANGED
|
@@ -19,7 +19,7 @@ def get_user_by_credentials(
|
|
|
19
19
|
|
|
20
20
|
async def get_user_from_request(
|
|
21
21
|
web_auth_config: WebAuthConfig, request: "Request"
|
|
22
|
-
) -> User
|
|
22
|
+
) -> User:
|
|
23
23
|
from fastapi.security import OAuth2PasswordBearer
|
|
24
24
|
|
|
25
25
|
if not web_auth_config.enable_auth:
|
|
@@ -45,7 +45,11 @@ def _get_user_from_cookie(
|
|
|
45
45
|
return None
|
|
46
46
|
|
|
47
47
|
|
|
48
|
-
def _get_user_from_token(
|
|
48
|
+
def _get_user_from_token(
|
|
49
|
+
web_auth_config: WebAuthConfig, token: str | None
|
|
50
|
+
) -> User | None:
|
|
51
|
+
if token is None:
|
|
52
|
+
return None
|
|
49
53
|
try:
|
|
50
54
|
from jose import jwt
|
|
51
55
|
|
|
@@ -54,7 +58,7 @@ def _get_user_from_token(web_auth_config: WebAuthConfig, token: str) -> User | N
|
|
|
54
58
|
web_auth_config.secret_key,
|
|
55
59
|
options={"require_sub": True, "require_exp": True},
|
|
56
60
|
)
|
|
57
|
-
username: str = payload.get("sub")
|
|
61
|
+
username: str | None = payload.get("sub")
|
|
58
62
|
if username is None:
|
|
59
63
|
return None
|
|
60
64
|
user = web_auth_config.find_user_by_username(username)
|
zrb/session/any_session.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations # Enables forward references
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
from abc import ABC, abstractmethod
|
|
4
5
|
from typing import TYPE_CHECKING, Any, Coroutine, TypeVar
|
|
5
6
|
|
|
@@ -62,12 +63,13 @@ class AnySession(ABC):
|
|
|
62
63
|
|
|
63
64
|
@property
|
|
64
65
|
@abstractmethod
|
|
65
|
-
def parent(self) ->
|
|
66
|
+
def parent(self) -> "AnySession | None":
|
|
66
67
|
"""Parent session"""
|
|
67
68
|
pass
|
|
68
69
|
|
|
70
|
+
@property
|
|
69
71
|
@abstractmethod
|
|
70
|
-
def task_path(self) -> str:
|
|
72
|
+
def task_path(self) -> list[str]:
|
|
71
73
|
"""Main task's path"""
|
|
72
74
|
pass
|
|
73
75
|
|
|
@@ -105,7 +107,9 @@ class AnySession(ABC):
|
|
|
105
107
|
pass
|
|
106
108
|
|
|
107
109
|
@abstractmethod
|
|
108
|
-
def defer_monitoring(
|
|
110
|
+
def defer_monitoring(
|
|
111
|
+
self, task: "AnyTask", coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
|
|
112
|
+
):
|
|
109
113
|
"""Defers the execution of a task's monitoring coroutine for later processing.
|
|
110
114
|
|
|
111
115
|
Args:
|
|
@@ -115,7 +119,9 @@ class AnySession(ABC):
|
|
|
115
119
|
pass
|
|
116
120
|
|
|
117
121
|
@abstractmethod
|
|
118
|
-
def defer_action(
|
|
122
|
+
def defer_action(
|
|
123
|
+
self, task: "AnyTask", coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
|
|
124
|
+
):
|
|
119
125
|
"""Defers the execution of a task's coroutine for later processing.
|
|
120
126
|
|
|
121
127
|
Args:
|
|
@@ -125,7 +131,7 @@ class AnySession(ABC):
|
|
|
125
131
|
pass
|
|
126
132
|
|
|
127
133
|
@abstractmethod
|
|
128
|
-
def defer_coro(self, coro: Coroutine):
|
|
134
|
+
def defer_coro(self, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]):
|
|
129
135
|
"""Defers the execution of a coroutine for later processing.
|
|
130
136
|
|
|
131
137
|
Args:
|
|
@@ -185,7 +191,7 @@ class AnySession(ABC):
|
|
|
185
191
|
pass
|
|
186
192
|
|
|
187
193
|
@abstractmethod
|
|
188
|
-
def is_allowed_to_run(self, task: "AnyTask"):
|
|
194
|
+
def is_allowed_to_run(self, task: "AnyTask") -> bool:
|
|
189
195
|
"""Determines if the specified task is allowed to run based on its current state.
|
|
190
196
|
|
|
191
197
|
Args:
|
zrb/session/session.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import asyncio
|
|
2
4
|
from typing import TYPE_CHECKING, Any, Coroutine
|
|
3
5
|
|
|
4
6
|
from zrb.context.any_shared_context import AnySharedContext
|
|
5
7
|
from zrb.context.context import AnyContext, Context
|
|
6
8
|
from zrb.group.any_group import AnyGroup
|
|
7
|
-
from zrb.session.any_session import AnySession
|
|
9
|
+
from zrb.session.any_session import AnySession, TAnySession
|
|
8
10
|
from zrb.session_state_logger.any_session_state_logger import AnySessionStateLogger
|
|
9
11
|
from zrb.session_state_logger.session_state_logger_factory import session_state_logger
|
|
10
12
|
from zrb.task.any_task import AnyTask
|
|
@@ -48,10 +50,10 @@ class Session(AnySession):
|
|
|
48
50
|
self._context: dict[AnyTask, Context] = {}
|
|
49
51
|
self._shared_ctx = shared_ctx
|
|
50
52
|
self._shared_ctx.set_session(self)
|
|
51
|
-
self._parent = parent
|
|
52
|
-
self._action_coros: dict[AnyTask, asyncio.Task] = {}
|
|
53
|
-
self._monitoring_coros: dict[AnyTask, asyncio.Task] = {}
|
|
54
|
-
self._coros: list[asyncio.Task] = []
|
|
53
|
+
self._parent: AnySession | None = parent
|
|
54
|
+
self._action_coros: dict[AnyTask, asyncio.Task[Any]] = {}
|
|
55
|
+
self._monitoring_coros: dict[AnyTask, asyncio.Task[Any]] = {}
|
|
56
|
+
self._coros: list[asyncio.Task[Any]] = []
|
|
55
57
|
self._colors = [
|
|
56
58
|
GREEN,
|
|
57
59
|
YELLOW,
|
|
@@ -114,11 +116,13 @@ class Session(AnySession):
|
|
|
114
116
|
return self._parent
|
|
115
117
|
|
|
116
118
|
@property
|
|
117
|
-
def task_path(self) -> str:
|
|
119
|
+
def task_path(self) -> list[str]:
|
|
118
120
|
return self._main_task_path
|
|
119
121
|
|
|
120
122
|
@property
|
|
121
123
|
def final_result(self) -> Any:
|
|
124
|
+
if self._main_task is None:
|
|
125
|
+
return None
|
|
122
126
|
xcom: Xcom = self.shared_ctx.xcom[self._main_task.name]
|
|
123
127
|
try:
|
|
124
128
|
return xcom.peek()
|
|
@@ -134,7 +138,11 @@ class Session(AnySession):
|
|
|
134
138
|
def set_main_task(self, main_task: AnyTask):
|
|
135
139
|
self.register_task(main_task)
|
|
136
140
|
self._main_task = main_task
|
|
137
|
-
main_task_path =
|
|
141
|
+
main_task_path = (
|
|
142
|
+
None
|
|
143
|
+
if self._root_group is None
|
|
144
|
+
else get_node_path(self._root_group, main_task)
|
|
145
|
+
)
|
|
138
146
|
self._main_task_path = [] if main_task_path is None else main_task_path
|
|
139
147
|
|
|
140
148
|
def as_state_log(self) -> "SessionStateLog":
|
|
@@ -171,7 +179,7 @@ class Session(AnySession):
|
|
|
171
179
|
return SessionStateLog(
|
|
172
180
|
name=self.name,
|
|
173
181
|
start_time=log_start_time,
|
|
174
|
-
main_task_name=self._main_task.name,
|
|
182
|
+
main_task_name="" if self._main_task is None else self._main_task.name,
|
|
175
183
|
path=self.task_path,
|
|
176
184
|
final_result=(
|
|
177
185
|
remove_style(f"{self.final_result}")
|
|
@@ -188,16 +196,29 @@ class Session(AnySession):
|
|
|
188
196
|
self._register_single_task(task)
|
|
189
197
|
return self._context[task]
|
|
190
198
|
|
|
191
|
-
def defer_monitoring(
|
|
199
|
+
def defer_monitoring(
|
|
200
|
+
self, task: AnyTask, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
|
|
201
|
+
):
|
|
192
202
|
self._register_single_task(task)
|
|
193
|
-
|
|
203
|
+
if isinstance(coro, asyncio.Task):
|
|
204
|
+
self._monitoring_coros[task] = coro
|
|
205
|
+
else:
|
|
206
|
+
self._monitoring_coros[task] = asyncio.create_task(coro)
|
|
194
207
|
|
|
195
|
-
def defer_action(
|
|
208
|
+
def defer_action(
|
|
209
|
+
self, task: AnyTask, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
|
|
210
|
+
):
|
|
196
211
|
self._register_single_task(task)
|
|
197
|
-
|
|
212
|
+
if isinstance(coro, asyncio.Task):
|
|
213
|
+
self._action_coros[task] = coro
|
|
214
|
+
else:
|
|
215
|
+
self._action_coros[task] = asyncio.create_task(coro)
|
|
198
216
|
|
|
199
|
-
def defer_coro(self, coro: Coroutine):
|
|
200
|
-
|
|
217
|
+
def defer_coro(self, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]):
|
|
218
|
+
if isinstance(coro, asyncio.Task):
|
|
219
|
+
self._coros.append(coro)
|
|
220
|
+
else:
|
|
221
|
+
self._coros.append(asyncio.create_task(coro))
|
|
201
222
|
self._coros = [
|
|
202
223
|
existing_coro for existing_coro in self._coros if not existing_coro.done()
|
|
203
224
|
]
|
|
@@ -246,15 +267,15 @@ class Session(AnySession):
|
|
|
246
267
|
|
|
247
268
|
def get_next_tasks(self, task: AnyTask) -> list[AnyTask]:
|
|
248
269
|
self._register_single_task(task)
|
|
249
|
-
return self._downstreams.get(task)
|
|
270
|
+
return self._downstreams.get(task, [])
|
|
250
271
|
|
|
251
272
|
def get_task_status(self, task: AnyTask) -> TaskStatus:
|
|
252
273
|
self._register_single_task(task)
|
|
253
274
|
return self._task_status[task]
|
|
254
275
|
|
|
255
276
|
def _register_single_task(self, task: AnyTask):
|
|
256
|
-
if task.name not in self._shared_ctx.
|
|
257
|
-
self._shared_ctx.
|
|
277
|
+
if task.name not in self._shared_ctx.xcom:
|
|
278
|
+
self._shared_ctx.xcom[task.name] = Xcom([])
|
|
258
279
|
if task not in self._context:
|
|
259
280
|
self._context[task] = Context(
|
|
260
281
|
shared_ctx=self._shared_ctx,
|
|
@@ -278,7 +299,7 @@ class Session(AnySession):
|
|
|
278
299
|
self._color_index = 0
|
|
279
300
|
return chosen
|
|
280
301
|
|
|
281
|
-
def _get_icon(self, task: AnyTask) ->
|
|
302
|
+
def _get_icon(self, task: AnyTask) -> str:
|
|
282
303
|
if task.icon is not None:
|
|
283
304
|
return task.icon
|
|
284
305
|
chosen = self._icons[self._icon_index]
|
zrb/task/any_task.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations # Enables forward references
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import TYPE_CHECKING, Any
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
5
5
|
|
|
6
6
|
from zrb.env.any_env import AnyEnv
|
|
7
7
|
from zrb.input.any_input import AnyInput
|
|
@@ -36,6 +36,14 @@ class AnyTask(ABC):
|
|
|
36
36
|
the actual implementation for these abstract members.
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def __rshift__(self, other: "AnyTask | list[AnyTask]") -> "AnyTask | list[AnyTask]":
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def __lshift__(self, other: "AnyTask | list[AnyTask]") -> "AnyTask":
|
|
45
|
+
pass
|
|
46
|
+
|
|
39
47
|
@property
|
|
40
48
|
@abstractmethod
|
|
41
49
|
def name(self) -> str:
|
|
@@ -148,13 +156,17 @@ class AnyTask(ABC):
|
|
|
148
156
|
|
|
149
157
|
@abstractmethod
|
|
150
158
|
def run(
|
|
151
|
-
self,
|
|
159
|
+
self,
|
|
160
|
+
session: "AnySession | None" = None,
|
|
161
|
+
str_kwargs: dict[str, str] | None = None,
|
|
162
|
+
kwargs: dict[str, Any] | None = None,
|
|
152
163
|
) -> Any:
|
|
153
164
|
"""Runs the task synchronously.
|
|
154
165
|
|
|
155
166
|
Args:
|
|
156
167
|
session (AnySession): The shared session.
|
|
157
168
|
str_kwargs(dict[str, str]): The input string values.
|
|
169
|
+
kwargs(dict[str, Any]): The input values.
|
|
158
170
|
|
|
159
171
|
Returns:
|
|
160
172
|
Any: The result of the task execution.
|
|
@@ -163,13 +175,17 @@ class AnyTask(ABC):
|
|
|
163
175
|
|
|
164
176
|
@abstractmethod
|
|
165
177
|
async def async_run(
|
|
166
|
-
self,
|
|
178
|
+
self,
|
|
179
|
+
session: "AnySession | None" = None,
|
|
180
|
+
str_kwargs: dict[str, str] | None = None,
|
|
181
|
+
kwargs: dict[str, Any] | None = None,
|
|
167
182
|
) -> Any:
|
|
168
183
|
"""Runs the task asynchronously.
|
|
169
184
|
|
|
170
185
|
Args:
|
|
171
186
|
session (AnySession): The shared session.
|
|
172
187
|
str_kwargs(dict[str, str]): The input string values.
|
|
188
|
+
kwargs(dict[str, Any]): The input values.
|
|
173
189
|
|
|
174
190
|
Returns:
|
|
175
191
|
Any: The result of the task execution.
|
|
@@ -203,3 +219,8 @@ class AnyTask(ABC):
|
|
|
203
219
|
session (AnySession): The shared session.
|
|
204
220
|
"""
|
|
205
221
|
pass
|
|
222
|
+
|
|
223
|
+
@abstractmethod
|
|
224
|
+
def to_function(self) -> Callable[..., Any]:
|
|
225
|
+
"""Turn a task into a function"""
|
|
226
|
+
pass
|
zrb/task/base/context.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import TYPE_CHECKING
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
3
|
|
|
4
4
|
from zrb.context.any_context import AnyContext
|
|
5
5
|
from zrb.context.any_shared_context import AnySharedContext
|
|
@@ -26,25 +26,33 @@ def build_task_context(task: AnyTask, session: AnySession) -> AnyContext:
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def fill_shared_context_inputs(
|
|
29
|
-
|
|
29
|
+
shared_ctx: AnySharedContext,
|
|
30
|
+
task: AnyTask,
|
|
31
|
+
str_kwargs: dict[str, str] | None = None,
|
|
32
|
+
kwargs: dict[str, Any] | None = None,
|
|
30
33
|
):
|
|
31
34
|
"""
|
|
32
|
-
Populates the shared context with input values provided via
|
|
35
|
+
Populates the shared context with input values provided via str_kwargs.
|
|
33
36
|
"""
|
|
37
|
+
str_kwarg_dict = str_kwargs if str_kwargs is not None else {}
|
|
38
|
+
kwarg_dict = kwargs if kwargs is not None else {}
|
|
34
39
|
for task_input in task.inputs:
|
|
35
|
-
if task_input.name not in
|
|
36
|
-
|
|
37
|
-
|
|
40
|
+
if task_input.name not in shared_ctx.input:
|
|
41
|
+
task_input.update_shared_context(
|
|
42
|
+
shared_ctx,
|
|
43
|
+
value=kwarg_dict.get(task_input.name, None),
|
|
44
|
+
str_value=str_kwarg_dict.get(task_input.name, None),
|
|
45
|
+
)
|
|
38
46
|
|
|
39
47
|
|
|
40
|
-
def fill_shared_context_envs(
|
|
48
|
+
def fill_shared_context_envs(shared_ctx: AnySharedContext):
|
|
41
49
|
"""
|
|
42
50
|
Injects OS environment variables into the shared context if they don't already exist.
|
|
43
51
|
"""
|
|
44
52
|
os_env_map = {
|
|
45
|
-
key: val for key, val in os.environ.items() if key not in
|
|
53
|
+
key: val for key, val in os.environ.items() if key not in shared_ctx.env
|
|
46
54
|
}
|
|
47
|
-
|
|
55
|
+
shared_ctx.env.update(os_env_map)
|
|
48
56
|
|
|
49
57
|
|
|
50
58
|
def combine_inputs(
|
zrb/task/base/execution.py
CHANGED
|
@@ -53,7 +53,9 @@ def check_execute_condition(task: "BaseTask", session: AnySession) -> bool:
|
|
|
53
53
|
Evaluates the task's execute_condition attribute.
|
|
54
54
|
"""
|
|
55
55
|
ctx = task.get_ctx(session)
|
|
56
|
-
execute_condition_attr =
|
|
56
|
+
execute_condition_attr = (
|
|
57
|
+
task._execute_condition if task._execute_condition is not None else True
|
|
58
|
+
)
|
|
57
59
|
return get_bool_attr(ctx, execute_condition_attr, True, auto_render=True)
|
|
58
60
|
|
|
59
61
|
|
|
@@ -63,8 +65,12 @@ async def execute_action_until_ready(task: "BaseTask", session: AnySession):
|
|
|
63
65
|
"""
|
|
64
66
|
ctx = task.get_ctx(session)
|
|
65
67
|
readiness_checks = task.readiness_checks
|
|
66
|
-
readiness_check_delay =
|
|
67
|
-
|
|
68
|
+
readiness_check_delay = (
|
|
69
|
+
task._readiness_check_delay if task._readiness_check_delay is not None else 0.5
|
|
70
|
+
)
|
|
71
|
+
monitor_readiness = (
|
|
72
|
+
task._monitor_readiness if task._monitor_readiness is not None else False
|
|
73
|
+
)
|
|
68
74
|
|
|
69
75
|
if not readiness_checks: # Simplified check for empty list
|
|
70
76
|
ctx.log_info("No readiness checks")
|
|
@@ -140,8 +146,8 @@ async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> An
|
|
|
140
146
|
handling success (triggering successors) and failure (triggering fallbacks).
|
|
141
147
|
"""
|
|
142
148
|
ctx = task.get_ctx(session)
|
|
143
|
-
retries =
|
|
144
|
-
retry_period =
|
|
149
|
+
retries = task._retries if task._retries is not None else 2
|
|
150
|
+
retry_period = task._retry_period if task._retry_period is not None else 0
|
|
145
151
|
max_attempt = retries + 1
|
|
146
152
|
ctx.set_max_attempt(max_attempt)
|
|
147
153
|
|
|
@@ -163,8 +169,9 @@ async def execute_action_with_retry(task: "BaseTask", session: AnySession) -> An
|
|
|
163
169
|
session.get_task_status(task).mark_as_completed()
|
|
164
170
|
|
|
165
171
|
# Store result in XCom
|
|
166
|
-
task_xcom: Xcom = ctx.xcom.get(task.name)
|
|
167
|
-
task_xcom
|
|
172
|
+
task_xcom: Xcom | None = ctx.xcom.get(task.name)
|
|
173
|
+
if task_xcom is not None:
|
|
174
|
+
task_xcom.push(result)
|
|
168
175
|
|
|
169
176
|
# Skip fallbacks and execute successors on success
|
|
170
177
|
skip_fallbacks(task, session)
|
|
@@ -201,7 +208,7 @@ async def run_default_action(task: "BaseTask", ctx: AnyContext) -> Any:
|
|
|
201
208
|
This is the default implementation called by BaseTask._exec_action.
|
|
202
209
|
Subclasses like LLMTask override _exec_action with their own logic.
|
|
203
210
|
"""
|
|
204
|
-
action =
|
|
211
|
+
action = task._action
|
|
205
212
|
if action is None:
|
|
206
213
|
ctx.log_debug("No action defined for this task.")
|
|
207
214
|
return None
|
zrb/task/base/lifecycle.py
CHANGED
|
@@ -12,7 +12,8 @@ from zrb.util.run import run_async
|
|
|
12
12
|
async def run_and_cleanup(
|
|
13
13
|
task: AnyTask,
|
|
14
14
|
session: AnySession | None = None,
|
|
15
|
-
str_kwargs: dict[str, str] =
|
|
15
|
+
str_kwargs: dict[str, str] | None = None,
|
|
16
|
+
kwargs: dict[str, Any] | None = None,
|
|
16
17
|
) -> Any:
|
|
17
18
|
"""
|
|
18
19
|
Wrapper for async_run that ensures session termination and cleanup of
|
|
@@ -23,7 +24,9 @@ async def run_and_cleanup(
|
|
|
23
24
|
session = Session(shared_ctx=SharedContext())
|
|
24
25
|
|
|
25
26
|
# Create the main task execution coroutine
|
|
26
|
-
main_task_coro = asyncio.create_task(
|
|
27
|
+
main_task_coro = asyncio.create_task(
|
|
28
|
+
run_task_async(task, session, str_kwargs, kwargs)
|
|
29
|
+
)
|
|
27
30
|
|
|
28
31
|
try:
|
|
29
32
|
result = await main_task_coro
|
|
@@ -67,7 +70,8 @@ async def run_and_cleanup(
|
|
|
67
70
|
async def run_task_async(
|
|
68
71
|
task: AnyTask,
|
|
69
72
|
session: AnySession | None = None,
|
|
70
|
-
str_kwargs: dict[str, str] =
|
|
73
|
+
str_kwargs: dict[str, str] | None = None,
|
|
74
|
+
kwargs: dict[str, Any] | None = None,
|
|
71
75
|
) -> Any:
|
|
72
76
|
"""
|
|
73
77
|
Asynchronous entry point for running a task (`task.async_run()`).
|
|
@@ -77,7 +81,7 @@ async def run_task_async(
|
|
|
77
81
|
session = Session(shared_ctx=SharedContext())
|
|
78
82
|
|
|
79
83
|
# Populate shared context with inputs and environment variables
|
|
80
|
-
fill_shared_context_inputs(
|
|
84
|
+
fill_shared_context_inputs(session.shared_ctx, task, str_kwargs, kwargs)
|
|
81
85
|
fill_shared_context_envs(session.shared_ctx) # Inject OS env vars
|
|
82
86
|
|
|
83
87
|
# Start the execution chain from the root tasks
|
zrb/task/base/monitoring.py
CHANGED
|
@@ -17,9 +17,13 @@ async def monitor_task_readiness(
|
|
|
17
17
|
"""
|
|
18
18
|
ctx = task.get_ctx(session)
|
|
19
19
|
readiness_checks = task.readiness_checks
|
|
20
|
-
readiness_check_period =
|
|
21
|
-
|
|
22
|
-
|
|
20
|
+
readiness_check_period = (
|
|
21
|
+
task._readiness_check_period if task._readiness_check_period else 5.0
|
|
22
|
+
)
|
|
23
|
+
readiness_failure_threshold = (
|
|
24
|
+
task._readiness_failure_threshold if task._readiness_failure_threshold else 1
|
|
25
|
+
)
|
|
26
|
+
readiness_timeout = task._readiness_timeout if task._readiness_timeout else 60
|
|
23
27
|
|
|
24
28
|
if not readiness_checks:
|
|
25
29
|
ctx.log_debug("No readiness checks defined, monitoring is not applicable.")
|
|
@@ -41,8 +45,9 @@ async def monitor_task_readiness(
|
|
|
41
45
|
session.get_task_status(check).reset_history()
|
|
42
46
|
session.get_task_status(check).reset()
|
|
43
47
|
# Clear previous XCom data for the check task if needed
|
|
44
|
-
check_xcom: Xcom = ctx.xcom.get(check.name)
|
|
45
|
-
check_xcom
|
|
48
|
+
check_xcom: Xcom | None = ctx.xcom.get(check.name)
|
|
49
|
+
if check_xcom is not None:
|
|
50
|
+
check_xcom.clear()
|
|
46
51
|
|
|
47
52
|
readiness_check_coros = [
|
|
48
53
|
run_async(check.exec_chain(session)) for check in readiness_checks
|
|
@@ -77,7 +82,7 @@ async def monitor_task_readiness(
|
|
|
77
82
|
)
|
|
78
83
|
# Ensure check tasks are marked as failed on timeout
|
|
79
84
|
for check in readiness_checks:
|
|
80
|
-
if not session.get_task_status(check).
|
|
85
|
+
if not session.get_task_status(check).is_ready:
|
|
81
86
|
session.get_task_status(check).mark_as_failed()
|
|
82
87
|
|
|
83
88
|
except (asyncio.CancelledError, KeyboardInterrupt):
|
|
@@ -92,7 +97,7 @@ async def monitor_task_readiness(
|
|
|
92
97
|
)
|
|
93
98
|
# Mark checks as failed
|
|
94
99
|
for check in readiness_checks:
|
|
95
|
-
if not session.get_task_status(check).
|
|
100
|
+
if not session.get_task_status(check).is_ready:
|
|
96
101
|
session.get_task_status(check).mark_as_failed()
|
|
97
102
|
|
|
98
103
|
# If failure threshold is reached
|