zrb 1.0.0a1__py3-none-any.whl → 1.0.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/__init__.py +48 -39
- zrb/__main__.py +3 -3
- zrb/attr/type.py +2 -1
- zrb/builtin/__init__.py +40 -2
- zrb/builtin/base64.py +32 -0
- zrb/builtin/git.py +156 -0
- zrb/builtin/git_subtree.py +88 -0
- zrb/builtin/group.py +34 -0
- zrb/builtin/llm.py +31 -0
- zrb/builtin/md5.py +34 -0
- zrb/builtin/project/__init__.py +0 -0
- zrb/builtin/project/add/__init__.py +0 -0
- zrb/builtin/project/add/fastapp.py +72 -0
- zrb/builtin/project/add/fastapp_template/.gitignore +4 -0
- zrb/builtin/project/add/fastapp_template/README.md +7 -0
- zrb/builtin/project/add/fastapp_template/__init__.py +0 -0
- zrb/builtin/project/add/fastapp_template/_zrb/config.py +17 -0
- zrb/builtin/project/add/fastapp_template/_zrb/group.py +16 -0
- zrb/builtin/project/add/fastapp_template/_zrb/helper.py +97 -0
- zrb/builtin/project/add/fastapp_template/_zrb/main.py +132 -0
- zrb/builtin/project/add/fastapp_template/_zrb/venv_task.py +22 -0
- zrb/builtin/project/add/fastapp_template/common/__init__.py +0 -0
- zrb/builtin/project/add/fastapp_template/common/app.py +18 -0
- zrb/builtin/project/add/fastapp_template/common/db_engine.py +5 -0
- zrb/builtin/project/add/fastapp_template/common/db_repository.py +134 -0
- zrb/builtin/project/add/fastapp_template/common/error.py +8 -0
- zrb/builtin/project/add/fastapp_template/common/schema.py +5 -0
- zrb/builtin/project/add/fastapp_template/common/usecase.py +232 -0
- zrb/builtin/project/add/fastapp_template/config.py +29 -0
- zrb/builtin/project/add/fastapp_template/main.py +7 -0
- zrb/builtin/project/add/fastapp_template/migrate.py +3 -0
- zrb/builtin/project/add/fastapp_template/module/__init__.py +0 -0
- zrb/builtin/project/add/fastapp_template/module/auth/alembic.ini +117 -0
- zrb/builtin/project/add/fastapp_template/module/auth/client/api_client.py +7 -0
- zrb/builtin/project/add/fastapp_template/module/auth/client/base_client.py +27 -0
- zrb/builtin/project/add/fastapp_template/module/auth/client/direct_client.py +6 -0
- zrb/builtin/project/add/fastapp_template/module/auth/client/factory.py +9 -0
- zrb/builtin/project/add/fastapp_template/module/auth/migration/README +1 -0
- zrb/builtin/project/add/fastapp_template/module/auth/migration/env.py +108 -0
- zrb/builtin/project/add/fastapp_template/module/auth/migration/script.py.mako +26 -0
- zrb/builtin/project/add/fastapp_template/module/auth/migration/versions/3093c7336477_add_user_table.py +37 -0
- zrb/builtin/project/add/fastapp_template/module/auth/migration_metadata.py +6 -0
- zrb/builtin/project/add/fastapp_template/module/auth/route.py +22 -0
- zrb/builtin/project/add/fastapp_template/module/auth/service/__init__.py +0 -0
- zrb/builtin/project/add/fastapp_template/module/auth/service/user/__init__.py +0 -0
- zrb/builtin/project/add/fastapp_template/module/auth/service/user/repository/__init__.py +0 -0
- zrb/builtin/project/add/fastapp_template/module/auth/service/user/repository/db_repository.py +39 -0
- zrb/builtin/project/add/fastapp_template/module/auth/service/user/repository/factory.py +13 -0
- zrb/builtin/project/add/fastapp_template/module/auth/service/user/repository/repository.py +34 -0
- zrb/builtin/project/add/fastapp_template/module/auth/service/user/usecase.py +45 -0
- zrb/builtin/project/add/fastapp_template/module/gateway/alembic.ini +117 -0
- zrb/builtin/project/add/fastapp_template/module/gateway/migration/README +1 -0
- zrb/builtin/project/add/fastapp_template/module/gateway/migration/env.py +108 -0
- zrb/builtin/project/add/fastapp_template/module/gateway/migration/script.py.mako +26 -0
- zrb/builtin/project/add/fastapp_template/module/gateway/migration/versions/.gitkeep +0 -0
- zrb/builtin/project/add/fastapp_template/module/gateway/migration_metadata.py +3 -0
- zrb/builtin/project/add/fastapp_template/module/gateway/route.py +27 -0
- zrb/builtin/project/add/fastapp_template/requirements.txt +6 -0
- zrb/builtin/project/add/fastapp_template/schema/__init__.py +0 -0
- zrb/builtin/project/add/fastapp_template/schema/role.py +31 -0
- zrb/builtin/project/add/fastapp_template/schema/user.py +31 -0
- zrb/builtin/project/add/fastapp_template/template.env +2 -0
- zrb/builtin/project/create/__init__.py +0 -0
- zrb/builtin/project/create/create.py +41 -0
- zrb/builtin/project/create/project-template/README.md +3 -0
- zrb/builtin/project/create/project-template/zrb_init.py +7 -0
- zrb/builtin/python.py +11 -0
- zrb/builtin/shell/__init__.py +0 -5
- zrb/builtin/shell/autocomplete/__init__.py +0 -9
- zrb/builtin/shell/autocomplete/bash.py +5 -6
- zrb/builtin/shell/autocomplete/subcmd.py +7 -8
- zrb/builtin/shell/autocomplete/zsh.py +5 -6
- zrb/builtin/todo.py +186 -0
- zrb/callback/any_callback.py +1 -1
- zrb/callback/callback.py +5 -5
- zrb/cmd/cmd_val.py +2 -2
- zrb/config.py +4 -1
- zrb/content_transformer/any_content_transformer.py +1 -1
- zrb/content_transformer/content_transformer.py +2 -2
- zrb/context/any_context.py +5 -1
- zrb/context/any_shared_context.py +3 -3
- zrb/context/context.py +15 -9
- zrb/context/shared_context.py +9 -8
- zrb/env/__init__.py +0 -3
- zrb/env/any_env.py +2 -2
- zrb/env/env.py +4 -5
- zrb/env/env_file.py +4 -4
- zrb/env/env_map.py +4 -4
- zrb/group/__init__.py +0 -3
- zrb/group/any_group.py +3 -3
- zrb/group/group.py +7 -6
- zrb/input/any_input.py +1 -1
- zrb/input/base_input.py +4 -4
- zrb/input/bool_input.py +5 -5
- zrb/input/float_input.py +3 -3
- zrb/input/int_input.py +3 -3
- zrb/input/option_input.py +51 -0
- zrb/input/password_input.py +2 -2
- zrb/input/str_input.py +1 -1
- zrb/input/text_input.py +12 -10
- zrb/runner/cli.py +79 -44
- zrb/runner/web_app/group_info_ui/controller.py +7 -8
- zrb/runner/web_app/group_info_ui/view.html +2 -2
- zrb/runner/web_app/home_page/controller.py +7 -6
- zrb/runner/web_app/home_page/view.html +2 -2
- zrb/runner/web_app/task_ui/controller.py +13 -13
- zrb/runner/web_app/task_ui/partial/common-util.js +37 -0
- zrb/runner/web_app/task_ui/partial/main.js +9 -2
- zrb/runner/web_app/task_ui/partial/show-existing-session.js +20 -5
- zrb/runner/web_app/task_ui/partial/visualize-history.js +1 -41
- zrb/runner/web_app/task_ui/view.html +4 -2
- zrb/runner/web_server.py +137 -211
- zrb/runner/web_util.py +5 -35
- zrb/session/any_session.py +13 -7
- zrb/session/session.py +80 -41
- zrb/session_state_log/session_state_log.py +7 -5
- zrb/session_state_logger/any_session_state_logger.py +1 -1
- zrb/session_state_logger/default_session_state_logger.py +2 -2
- zrb/session_state_logger/file_session_state_logger.py +19 -27
- zrb/task/any_task.py +8 -3
- zrb/task/base_task.py +47 -33
- zrb/task/base_trigger.py +11 -12
- zrb/task/cmd_task.py +55 -43
- zrb/task/http_check.py +8 -8
- zrb/task/llm_task.py +160 -0
- zrb/task/make_task.py +9 -9
- zrb/task/rsync_task.py +7 -7
- zrb/task/scaffolder.py +14 -11
- zrb/task/scheduler.py +6 -7
- zrb/task/task.py +1 -1
- zrb/task/tcp_check.py +8 -8
- zrb/util/attr.py +19 -3
- zrb/util/cli/style.py +71 -2
- zrb/util/cli/subcommand.py +2 -2
- zrb/util/codemod/__init__.py +0 -0
- zrb/util/codemod/add_code_to_class.py +35 -0
- zrb/util/codemod/add_code_to_function.py +36 -0
- zrb/util/codemod/add_code_to_method.py +55 -0
- zrb/util/codemod/add_key_to_dict.py +51 -0
- zrb/util/codemod/add_param_to_function_call.py +39 -0
- zrb/util/codemod/add_property_to_class.py +55 -0
- zrb/util/git.py +156 -0
- zrb/util/git_subtree.py +94 -0
- zrb/util/group.py +2 -2
- zrb/util/llm/tool.py +63 -0
- zrb/util/string/conversion.py +7 -0
- zrb/util/todo.py +135 -0
- {zrb-1.0.0a1.dist-info → zrb-1.0.0a3.dist-info}/METADATA +11 -7
- zrb-1.0.0a3.dist-info/RECORD +194 -0
- zrb/builtin/shell/_group.py +0 -9
- zrb/builtin/shell/autocomplete/_group.py +0 -6
- zrb/runner/web_app/any_request_handler.py +0 -24
- zrb/runner/web_server.bak.py +0 -208
- zrb-1.0.0a1.dist-info/RECORD +0 -120
- {zrb-1.0.0a1.dist-info → zrb-1.0.0a3.dist-info}/WHEEL +0 -0
- {zrb-1.0.0a1.dist-info → zrb-1.0.0a3.dist-info}/entry_points.txt +0 -0
@@ -1,9 +1,8 @@
|
|
1
1
|
import datetime
|
2
|
-
import json
|
3
2
|
import os
|
4
3
|
|
5
|
-
from
|
6
|
-
from .any_session_state_logger import AnySessionStateLogger
|
4
|
+
from zrb.session_state_log.session_state_log import SessionStateLog, SessionStateLogList
|
5
|
+
from zrb.session_state_logger.any_session_state_logger import AnySessionStateLogger
|
7
6
|
|
8
7
|
|
9
8
|
class FileSessionStateLogger(AnySessionStateLogger):
|
@@ -12,21 +11,24 @@ class FileSessionStateLogger(AnySessionStateLogger):
|
|
12
11
|
self._session_log_dir = session_log_dir
|
13
12
|
|
14
13
|
def write(self, session_log: SessionStateLog):
|
15
|
-
session_file_path = self._get_session_file_path(session_log
|
14
|
+
session_file_path = self._get_session_file_path(session_log.name)
|
15
|
+
session_dir_path = os.path.dirname(session_file_path)
|
16
|
+
if not os.path.isdir(session_dir_path):
|
17
|
+
os.makedirs(session_dir_path, exist_ok=True)
|
16
18
|
with open(session_file_path, "w") as f:
|
17
|
-
f.write(
|
18
|
-
start_time =
|
19
|
-
if start_time
|
19
|
+
f.write(session_log.model_dump_json())
|
20
|
+
start_time = session_log.start_time
|
21
|
+
if start_time == "":
|
20
22
|
return
|
21
|
-
timeline_dir_path = self._get_timeline_dir_path(session_log
|
23
|
+
timeline_dir_path = self._get_timeline_dir_path(session_log)
|
22
24
|
os.makedirs(timeline_dir_path, exist_ok=True)
|
23
|
-
with open(os.path.join(timeline_dir_path, session_log
|
25
|
+
with open(os.path.join(timeline_dir_path, session_log.name), "w"):
|
24
26
|
pass
|
25
27
|
|
26
28
|
def read(self, session_name: str) -> SessionStateLog:
|
27
29
|
session_file_path = self._get_session_file_path(session_name)
|
28
30
|
with open(session_file_path, "r") as f:
|
29
|
-
return
|
31
|
+
return SessionStateLog.model_validate_json(f.read())
|
30
32
|
|
31
33
|
def list(
|
32
34
|
self,
|
@@ -59,21 +61,20 @@ class FileSessionStateLogger(AnySessionStateLogger):
|
|
59
61
|
paginated_sessions = matching_sessions[start_index:end_index]
|
60
62
|
# Extract session logs from the sorted list of tuples
|
61
63
|
data = [session_log for _, session_log in paginated_sessions]
|
62
|
-
return
|
64
|
+
return SessionStateLogList(total=total, data=data)
|
63
65
|
|
64
66
|
def _get_session_file_path(self, session_name: str) -> str:
|
65
67
|
return os.path.join(self._session_log_dir, f"{session_name}.json")
|
66
68
|
|
67
|
-
def _get_timeline_dir_path(
|
68
|
-
|
69
|
-
) -> str:
|
69
|
+
def _get_timeline_dir_path(self, session_log: SessionStateLog) -> str:
|
70
|
+
start_time = self._get_start_time(session_log)
|
70
71
|
year = start_time.year
|
71
72
|
month = start_time.month
|
72
73
|
day = start_time.day
|
73
74
|
hour = start_time.hour
|
74
75
|
minute = start_time.minute
|
75
76
|
second = start_time.second
|
76
|
-
paths = session_log
|
77
|
+
paths = session_log.path + [
|
77
78
|
f"{year}",
|
78
79
|
f"{month}",
|
79
80
|
f"{day}",
|
@@ -84,15 +85,6 @@ class FileSessionStateLogger(AnySessionStateLogger):
|
|
84
85
|
return os.path.join(self._session_log_dir, "_timeline", *paths)
|
85
86
|
|
86
87
|
def _get_start_time(self, session_log: SessionStateLog) -> datetime.datetime:
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
if len(histories) == 0:
|
91
|
-
continue
|
92
|
-
first_history = histories[0]
|
93
|
-
first_time = datetime.datetime.strptime(
|
94
|
-
first_history["time"], "%Y-%m-%d %H:%M:%S.%f"
|
95
|
-
)
|
96
|
-
if result is None or first_time < result:
|
97
|
-
result = first_time
|
98
|
-
return result
|
88
|
+
return datetime.datetime.strptime(
|
89
|
+
session_log.start_time, "%Y-%m-%d %H:%M:%S.%f"
|
90
|
+
)
|
zrb/task/any_task.py
CHANGED
@@ -3,11 +3,12 @@ from __future__ import annotations # Enables forward references
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from typing import TYPE_CHECKING, Any
|
5
5
|
|
6
|
-
from
|
7
|
-
from
|
6
|
+
from zrb.env.any_env import AnyEnv
|
7
|
+
from zrb.input.any_input import AnyInput
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
|
-
from
|
10
|
+
from zrb.context import any_context
|
11
|
+
from zrb.session import session
|
11
12
|
|
12
13
|
|
13
14
|
class AnyTask(ABC):
|
@@ -89,6 +90,10 @@ class AnyTask(ABC):
|
|
89
90
|
"""
|
90
91
|
pass
|
91
92
|
|
93
|
+
@abstractmethod
|
94
|
+
def get_ctx(self, session: session.AnySession) -> any_context.AnyContext:
|
95
|
+
pass
|
96
|
+
|
92
97
|
@abstractmethod
|
93
98
|
def run(
|
94
99
|
self, session: session.AnySession | None = None, str_kwargs: dict[str, str] = {}
|
zrb/task/base_task.py
CHANGED
@@ -3,17 +3,17 @@ import os
|
|
3
3
|
from collections.abc import Callable
|
4
4
|
from typing import Any
|
5
5
|
|
6
|
-
from
|
7
|
-
from
|
8
|
-
from
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from
|
15
|
-
from
|
16
|
-
from .
|
6
|
+
from zrb.attr.type import BoolAttr, fstring
|
7
|
+
from zrb.context.any_context import AnyContext
|
8
|
+
from zrb.context.shared_context import AnySharedContext, SharedContext
|
9
|
+
from zrb.env.any_env import AnyEnv
|
10
|
+
from zrb.input.any_input import AnyInput
|
11
|
+
from zrb.session.any_session import AnySession
|
12
|
+
from zrb.session.session import Session
|
13
|
+
from zrb.task.any_task import AnyTask
|
14
|
+
from zrb.util.attr import get_bool_attr
|
15
|
+
from zrb.util.run import run_async
|
16
|
+
from zrb.xcom.xcom import Xcom
|
17
17
|
|
18
18
|
|
19
19
|
class BaseTask(AnyTask):
|
@@ -24,8 +24,8 @@ class BaseTask(AnyTask):
|
|
24
24
|
icon: str | None = None,
|
25
25
|
description: str | None = None,
|
26
26
|
cli_only: bool = False,
|
27
|
-
input: list[AnyInput] | AnyInput | None = None,
|
28
|
-
env: list[AnyEnv] | AnyEnv | None = None,
|
27
|
+
input: list[AnyInput | None] | AnyInput | None = None,
|
28
|
+
env: list[AnyEnv | None] | AnyEnv | None = None,
|
29
29
|
action: fstring | Callable[[AnyContext], Any] | None = None,
|
30
30
|
execute_condition: BoolAttr = True,
|
31
31
|
retries: int = 2,
|
@@ -63,16 +63,22 @@ class BaseTask(AnyTask):
|
|
63
63
|
return f"<{self.__class__.__name__} name={self._name}>"
|
64
64
|
|
65
65
|
def __rshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask:
|
66
|
-
|
67
|
-
other
|
68
|
-
|
69
|
-
|
70
|
-
task
|
71
|
-
|
66
|
+
try:
|
67
|
+
if isinstance(other, AnyTask):
|
68
|
+
other.append_upstreams(self)
|
69
|
+
elif isinstance(other, list):
|
70
|
+
for task in other:
|
71
|
+
task.append_upstreams(self)
|
72
|
+
return other
|
73
|
+
except Exception as e:
|
74
|
+
raise ValueError(f"Invalid operation {self} >> {other}: {e}")
|
72
75
|
|
73
76
|
def __lshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask:
|
74
|
-
|
75
|
-
|
77
|
+
try:
|
78
|
+
self.append_upstreams(other)
|
79
|
+
return self
|
80
|
+
except Exception as e:
|
81
|
+
raise ValueError(f"Invalid operation {self} << {other}: {e}")
|
76
82
|
|
77
83
|
@property
|
78
84
|
def name(self) -> str:
|
@@ -103,7 +109,7 @@ class BaseTask(AnyTask):
|
|
103
109
|
envs.append(self._envs)
|
104
110
|
elif self._envs is not None:
|
105
111
|
envs += self._envs
|
106
|
-
return envs
|
112
|
+
return [env for env in envs if env is not None]
|
107
113
|
|
108
114
|
@property
|
109
115
|
def inputs(self) -> list[AnyInput]:
|
@@ -112,7 +118,7 @@ class BaseTask(AnyTask):
|
|
112
118
|
self.__combine_inputs(inputs, upstream.inputs)
|
113
119
|
if self._inputs is not None:
|
114
120
|
self.__combine_inputs(inputs, self._inputs)
|
115
|
-
return inputs
|
121
|
+
return [task_input for task_input in inputs if task_input is not None]
|
116
122
|
|
117
123
|
def __combine_inputs(
|
118
124
|
self, inputs: list[AnyInput], other_inputs: list[AnyInput] | AnyInput
|
@@ -120,7 +126,11 @@ class BaseTask(AnyTask):
|
|
120
126
|
input_names = [task_input.name for task_input in inputs]
|
121
127
|
if isinstance(other_inputs, AnyInput):
|
122
128
|
other_inputs = [other_inputs]
|
129
|
+
elif other_inputs is None:
|
130
|
+
other_inputs = []
|
123
131
|
for task_input in other_inputs:
|
132
|
+
if task_input is None:
|
133
|
+
continue
|
124
134
|
if task_input.name not in input_names:
|
125
135
|
inputs.append(task_input)
|
126
136
|
|
@@ -163,6 +173,13 @@ class BaseTask(AnyTask):
|
|
163
173
|
if upstream not in self._upstreams:
|
164
174
|
self._upstreams.append(upstream)
|
165
175
|
|
176
|
+
def get_ctx(self, session: AnySession) -> AnyContext:
|
177
|
+
ctx = session.get_ctx(self)
|
178
|
+
# Enhance session ctx with current task env
|
179
|
+
for env in self.envs:
|
180
|
+
env.update_context(ctx)
|
181
|
+
return ctx
|
182
|
+
|
166
183
|
def run(
|
167
184
|
self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
|
168
185
|
) -> Any:
|
@@ -194,9 +211,6 @@ class BaseTask(AnyTask):
|
|
194
211
|
if key not in shared_context._env
|
195
212
|
}
|
196
213
|
shared_context._env.update(os_env_map)
|
197
|
-
# Inject environment from task's envs
|
198
|
-
for env in self.envs:
|
199
|
-
env.update_shared_context(shared_context)
|
200
214
|
|
201
215
|
async def exec_root_tasks(self, session: AnySession):
|
202
216
|
session.set_main_task(self)
|
@@ -219,12 +233,12 @@ class BaseTask(AnyTask):
|
|
219
233
|
except IndexError:
|
220
234
|
return None
|
221
235
|
except asyncio.CancelledError:
|
222
|
-
ctx =
|
236
|
+
ctx = self.get_ctx(session)
|
223
237
|
ctx.log_info("Session terminated")
|
224
238
|
finally:
|
225
239
|
session.terminate()
|
226
240
|
session.state_logger.write(session.as_state_log())
|
227
|
-
ctx =
|
241
|
+
ctx = self.get_ctx(session)
|
228
242
|
ctx.log_debug(session)
|
229
243
|
|
230
244
|
async def _log_session_state(self, session: AnySession):
|
@@ -247,7 +261,7 @@ class BaseTask(AnyTask):
|
|
247
261
|
return await asyncio.gather(*next_coros)
|
248
262
|
|
249
263
|
async def exec(self, session: AnySession):
|
250
|
-
ctx =
|
264
|
+
ctx = self.get_ctx(session)
|
251
265
|
if not session.is_allowed_to_run(self):
|
252
266
|
# Task is not allowed to run, skip it for now.
|
253
267
|
# This will be triggered later
|
@@ -262,11 +276,11 @@ class BaseTask(AnyTask):
|
|
262
276
|
await run_async(self.__exec_action_until_ready(session))
|
263
277
|
|
264
278
|
def __get_execute_condition(self, session: Session) -> bool:
|
265
|
-
ctx =
|
279
|
+
ctx = self.get_ctx(session)
|
266
280
|
return get_bool_attr(ctx, self._execute_condition, True, auto_render=True)
|
267
281
|
|
268
282
|
async def __exec_action_until_ready(self, session: AnySession):
|
269
|
-
ctx =
|
283
|
+
ctx = self.get_ctx(session)
|
270
284
|
readiness_checks = self.readiness_checks
|
271
285
|
if len(readiness_checks) == 0:
|
272
286
|
ctx.log_info("No readiness checks")
|
@@ -301,7 +315,7 @@ class BaseTask(AnyTask):
|
|
301
315
|
async def __exec_monitoring(self, session: AnySession, action_coro: asyncio.Task):
|
302
316
|
readiness_checks = self.readiness_checks
|
303
317
|
failure_count = 0
|
304
|
-
ctx =
|
318
|
+
ctx = self.get_ctx(session)
|
305
319
|
while not session.is_terminated:
|
306
320
|
await asyncio.sleep(self._readiness_check_period)
|
307
321
|
if failure_count < self._readiness_failure_threshold:
|
@@ -343,7 +357,7 @@ class BaseTask(AnyTask):
|
|
343
357
|
ctx.log_info("Continue monitoring")
|
344
358
|
|
345
359
|
async def __exec_action_and_retry(self, session: AnySession) -> Any:
|
346
|
-
ctx =
|
360
|
+
ctx = self.get_ctx(session)
|
347
361
|
max_attempt = self._retries + 1
|
348
362
|
ctx.set_max_attempt(max_attempt)
|
349
363
|
for attempt in range(max_attempt):
|
zrb/task/base_trigger.py
CHANGED
@@ -2,21 +2,20 @@ import asyncio
|
|
2
2
|
from collections.abc import Callable
|
3
3
|
from typing import Any
|
4
4
|
|
5
|
+
from zrb.attr.type import fstring
|
6
|
+
from zrb.callback.any_callback import AnyCallback
|
5
7
|
from zrb.context.any_context import AnyContext
|
6
8
|
from zrb.context.any_shared_context import AnySharedContext
|
9
|
+
from zrb.context.shared_context import SharedContext
|
10
|
+
from zrb.dot_dict.dot_dict import DotDict
|
7
11
|
from zrb.env.any_env import AnyEnv
|
8
12
|
from zrb.input.any_input import AnyInput
|
13
|
+
from zrb.session.any_session import AnySession
|
14
|
+
from zrb.session.session import Session
|
9
15
|
from zrb.task.any_task import AnyTask
|
10
|
-
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from ..context.shared_context import SharedContext
|
14
|
-
from ..dot_dict.dot_dict import DotDict
|
15
|
-
from ..session.any_session import AnySession
|
16
|
-
from ..session.session import Session
|
17
|
-
from ..util.cli.style import CYAN
|
18
|
-
from ..xcom.xcom import Xcom
|
19
|
-
from .base_task import BaseTask
|
16
|
+
from zrb.task.base_task import BaseTask
|
17
|
+
from zrb.util.cli.style import CYAN
|
18
|
+
from zrb.xcom.xcom import Xcom
|
20
19
|
|
21
20
|
|
22
21
|
class BaseTrigger(BaseTask):
|
@@ -28,8 +27,8 @@ class BaseTrigger(BaseTask):
|
|
28
27
|
icon: str | None = None,
|
29
28
|
description: str | None = None,
|
30
29
|
cli_only: bool = False,
|
31
|
-
input: list[AnyInput] | AnyInput | None = None,
|
32
|
-
env: list[AnyEnv] | AnyEnv | None = None,
|
30
|
+
input: list[AnyInput | None] | AnyInput | None = None,
|
31
|
+
env: list[AnyEnv | None] | AnyEnv | None = None,
|
33
32
|
action: fstring | Callable[[AnyContext], Any] | None = None,
|
34
33
|
execute_condition: bool | str | Callable[[AnySharedContext], bool] = True,
|
35
34
|
queue_name: fstring | None = None,
|
zrb/task/cmd_task.py
CHANGED
@@ -2,17 +2,17 @@ import asyncio
|
|
2
2
|
import os
|
3
3
|
import sys
|
4
4
|
|
5
|
-
from
|
6
|
-
from
|
7
|
-
from
|
8
|
-
from
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from .
|
15
|
-
from .
|
5
|
+
from zrb.attr.type import BoolAttr, IntAttr, StrAttr
|
6
|
+
from zrb.cmd.cmd_result import CmdResult
|
7
|
+
from zrb.cmd.cmd_val import AnyCmdVal, CmdVal, SingleCmdVal
|
8
|
+
from zrb.config import DEFAULT_SHELL
|
9
|
+
from zrb.context.any_context import AnyContext
|
10
|
+
from zrb.env.any_env import AnyEnv
|
11
|
+
from zrb.input.any_input import AnyInput
|
12
|
+
from zrb.task.any_task import AnyTask
|
13
|
+
from zrb.task.base_task import BaseTask
|
14
|
+
from zrb.util.attr import get_int_attr, get_str_attr
|
15
|
+
from zrb.util.cmd.remote import get_remote_cmd_script
|
16
16
|
|
17
17
|
|
18
18
|
class CmdTask(BaseTask):
|
@@ -23,8 +23,8 @@ class CmdTask(BaseTask):
|
|
23
23
|
icon: str | None = None,
|
24
24
|
description: str | None = None,
|
25
25
|
cli_only: bool = False,
|
26
|
-
input: list[AnyInput] | AnyInput | None = None,
|
27
|
-
env: list[AnyEnv] | AnyEnv | None = None,
|
26
|
+
input: list[AnyInput | None] | AnyInput | None = None,
|
27
|
+
env: list[AnyEnv | None] | AnyEnv | None = None,
|
28
28
|
shell: StrAttr | None = None,
|
29
29
|
auto_render_shell: bool = True,
|
30
30
|
shell_flag: StrAttr | None = None,
|
@@ -105,44 +105,56 @@ class CmdTask(BaseTask):
|
|
105
105
|
Returns:
|
106
106
|
Any: The result of the action execution.
|
107
107
|
"""
|
108
|
+
ctx.log_info("Running script")
|
108
109
|
cmd_script = self._get_cmd_script(ctx)
|
109
|
-
|
110
|
+
ctx.log_debug(f"Script: {self.__get_multiline_repr(cmd_script)}")
|
110
111
|
shell = self._get_shell(ctx)
|
111
|
-
shell_flag = self._get_shell_flag(ctx)
|
112
|
-
ctx.log_info("Running script")
|
113
112
|
ctx.log_debug(f"Shell: {shell}")
|
114
|
-
|
113
|
+
shell_flag = self._get_shell_flag(ctx)
|
114
|
+
cwd = self._get_cwd(ctx)
|
115
115
|
ctx.log_debug(f"Working directory: {cwd}")
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
116
|
+
env_map = self.__get_env_map(ctx)
|
117
|
+
ctx.log_debug(f"Environment map: {env_map}")
|
118
|
+
cmd_process = None
|
119
|
+
try:
|
120
|
+
cmd_process = await asyncio.create_subprocess_exec(
|
121
|
+
shell,
|
122
|
+
shell_flag,
|
123
|
+
cmd_script,
|
124
|
+
cwd=cwd,
|
125
|
+
stdin=sys.stdin if sys.stdin.isatty() else None,
|
126
|
+
stdout=asyncio.subprocess.PIPE,
|
127
|
+
stderr=asyncio.subprocess.PIPE,
|
128
|
+
env=env_map,
|
129
|
+
bufsize=0,
|
130
|
+
)
|
131
|
+
stdout_task = asyncio.create_task(
|
132
|
+
self.__read_stream(cmd_process.stdout, ctx.print, self._max_output_line)
|
133
|
+
)
|
134
|
+
stderr_task = asyncio.create_task(
|
135
|
+
self.__read_stream(cmd_process.stderr, ctx.print, self._max_error_line)
|
136
|
+
)
|
137
|
+
# Wait for process to complete and gather stdout/stderr
|
138
|
+
return_code = await cmd_process.wait()
|
139
|
+
stdout = await stdout_task
|
140
|
+
stderr = await stderr_task
|
141
|
+
# Check for errors
|
142
|
+
if return_code != 0:
|
143
|
+
ctx.log_error(f"Exit status: {return_code}")
|
144
|
+
raise Exception(
|
145
|
+
f"Process {self._name} exited ({return_code}): {stderr}"
|
146
|
+
)
|
147
|
+
ctx.log_info(f"Exit status: {return_code}")
|
148
|
+
return CmdResult(stdout, stderr)
|
149
|
+
finally:
|
150
|
+
if cmd_process is not None and cmd_process.returncode is None:
|
151
|
+
cmd_process.terminate()
|
142
152
|
|
143
153
|
def __get_env_map(self, ctx: AnyContext) -> dict[str, str]:
|
144
154
|
envs = {key: val for key, val in ctx.env.items()}
|
145
155
|
envs["_ZRB_SSH_PASSWORD"] = self._get_remote_password(ctx)
|
156
|
+
envs["PYTHONBUFFERED"] = "1"
|
157
|
+
return envs
|
146
158
|
|
147
159
|
async def __read_stream(self, stream, log_method, max_lines):
|
148
160
|
lines = []
|
zrb/task/http_check.py
CHANGED
@@ -3,14 +3,14 @@ from collections.abc import Callable
|
|
3
3
|
|
4
4
|
import requests
|
5
5
|
|
6
|
-
from
|
7
|
-
from
|
8
|
-
from
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from .
|
13
|
-
from .
|
6
|
+
from zrb.attr.type import StrAttr
|
7
|
+
from zrb.context.any_context import AnyContext
|
8
|
+
from zrb.context.context import Context
|
9
|
+
from zrb.env.any_env import AnyEnv
|
10
|
+
from zrb.input.any_input import AnyInput
|
11
|
+
from zrb.task.any_task import AnyTask
|
12
|
+
from zrb.task.base_task import BaseTask
|
13
|
+
from zrb.util.attr import get_str_attr
|
14
14
|
|
15
15
|
|
16
16
|
class HttpCheck(BaseTask):
|
zrb/task/llm_task.py
ADDED
@@ -0,0 +1,160 @@
|
|
1
|
+
import json
|
2
|
+
from collections.abc import Callable
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from zrb.attr.type import StrAttr
|
6
|
+
from zrb.config import LLM_MODEL, LLM_SYSTEM_PROMPT
|
7
|
+
from zrb.context.any_context import AnyContext
|
8
|
+
from zrb.context.any_shared_context import AnySharedContext
|
9
|
+
from zrb.env.any_env import AnyEnv
|
10
|
+
from zrb.input.any_input import AnyInput
|
11
|
+
from zrb.task.any_task import AnyTask
|
12
|
+
from zrb.task.base_task import BaseTask
|
13
|
+
from zrb.util.attr import get_str_attr
|
14
|
+
from zrb.util.llm.tool import callable_to_tool_schema
|
15
|
+
|
16
|
+
DictList = list[dict[str, Any]]
|
17
|
+
|
18
|
+
|
19
|
+
def scratchpad(thought: str) -> str:
|
20
|
+
"""Use this tool to note your thought and planning"""
|
21
|
+
return thought
|
22
|
+
|
23
|
+
|
24
|
+
class LLMTask(BaseTask):
|
25
|
+
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
name: str,
|
29
|
+
color: int | None = None,
|
30
|
+
icon: str | None = None,
|
31
|
+
description: str | None = None,
|
32
|
+
cli_only: bool = False,
|
33
|
+
input: list[AnyInput | None] | AnyInput | None = None,
|
34
|
+
env: list[AnyEnv | None] | AnyEnv | None = None,
|
35
|
+
model: StrAttr | None = LLM_MODEL,
|
36
|
+
system_prompt: StrAttr | None = LLM_SYSTEM_PROMPT,
|
37
|
+
message: StrAttr | None = None,
|
38
|
+
tools: (
|
39
|
+
dict[str, Callable] | Callable[[AnySharedContext], dict[str, Callable]]
|
40
|
+
) = {},
|
41
|
+
tool_schema: DictList | Callable[[AnySharedContext], DictList] = [],
|
42
|
+
history: DictList | Callable[[AnySharedContext], DictList] = [],
|
43
|
+
execute_condition: bool | str | Callable[[AnySharedContext], bool] = True,
|
44
|
+
retries: int = 2,
|
45
|
+
retry_period: float = 0,
|
46
|
+
readiness_check: list[AnyTask] | AnyTask | None = None,
|
47
|
+
readiness_check_delay: float = 0.5,
|
48
|
+
readiness_check_period: float = 5,
|
49
|
+
readiness_failure_threshold: int = 1,
|
50
|
+
readiness_timeout: int = 60,
|
51
|
+
monitor_readiness: bool = False,
|
52
|
+
upstream: list[AnyTask] | AnyTask | None = None,
|
53
|
+
fallback: list[AnyTask] | AnyTask | None = None,
|
54
|
+
):
|
55
|
+
super().__init__(
|
56
|
+
name=name,
|
57
|
+
color=color,
|
58
|
+
icon=icon,
|
59
|
+
description=description,
|
60
|
+
cli_only=cli_only,
|
61
|
+
input=input,
|
62
|
+
env=env,
|
63
|
+
execute_condition=execute_condition,
|
64
|
+
retries=retries,
|
65
|
+
retry_period=retry_period,
|
66
|
+
readiness_check=readiness_check,
|
67
|
+
readiness_check_delay=readiness_check_delay,
|
68
|
+
readiness_check_period=readiness_check_period,
|
69
|
+
readiness_failure_threshold=readiness_failure_threshold,
|
70
|
+
readiness_timeout=readiness_timeout,
|
71
|
+
monitor_readiness=monitor_readiness,
|
72
|
+
upstream=upstream,
|
73
|
+
fallback=fallback,
|
74
|
+
)
|
75
|
+
self._model = model
|
76
|
+
self._system_prompt = system_prompt
|
77
|
+
self._message = message
|
78
|
+
self._tools = tools
|
79
|
+
self._tool_schema = tool_schema
|
80
|
+
self._history = history
|
81
|
+
|
82
|
+
async def _exec_action(self, ctx: AnyContext) -> Any:
|
83
|
+
from litellm import completion
|
84
|
+
|
85
|
+
system_prompt = self._get_system_prompt(ctx)
|
86
|
+
ctx.log_debug("SYSTEM PROMPT", system_prompt)
|
87
|
+
history = self._get_history(ctx)
|
88
|
+
ctx.log_debug("HISTORY PROMPT", history)
|
89
|
+
user_message = self._get_message(ctx)
|
90
|
+
ctx.log_debug("USER MESSAGE", user_message)
|
91
|
+
messages = (
|
92
|
+
[{"role": "system", "content": system_prompt}]
|
93
|
+
+ history
|
94
|
+
+ [{"role": "user", "content": user_message}]
|
95
|
+
)
|
96
|
+
available_tools = self._get_tools(ctx)
|
97
|
+
available_tools["scratchpad"] = scratchpad
|
98
|
+
tool_schema = self._get_tool_schema(ctx)
|
99
|
+
for tool_name, tool in available_tools.items():
|
100
|
+
matched_tool_schema = [
|
101
|
+
schema
|
102
|
+
for schema in tool_schema
|
103
|
+
if "function" in schema
|
104
|
+
and "name" in schema["function"]
|
105
|
+
and schema["function"]["name"] == tool_name
|
106
|
+
]
|
107
|
+
if len(matched_tool_schema) == 0:
|
108
|
+
tool_schema.append(callable_to_tool_schema(tool))
|
109
|
+
ctx.log_debug("TOOL SCHEMA", tool_schema)
|
110
|
+
while True:
|
111
|
+
response = completion(
|
112
|
+
model=self._get_model(ctx), messages=messages, tools=tool_schema
|
113
|
+
)
|
114
|
+
response_message = response.choices[0].message
|
115
|
+
ctx.print(response_message)
|
116
|
+
messages.append(response_message)
|
117
|
+
tool_calls = response_message.tool_calls
|
118
|
+
if tool_calls:
|
119
|
+
# noqa Reference: https://docs.litellm.ai/docs/completion/function_call#full-code---parallel-function-calling-with-gpt-35-turbo-1106
|
120
|
+
for tool_call in tool_calls:
|
121
|
+
function_name = tool_call.function.name
|
122
|
+
function_to_call = available_tools[function_name]
|
123
|
+
function_kwargs = json.loads(tool_call.function.arguments)
|
124
|
+
function_response = function_to_call(**function_kwargs)
|
125
|
+
tool_call_message = {
|
126
|
+
"tool_call_id": tool_call.id,
|
127
|
+
"role": "tool",
|
128
|
+
"name": function_name,
|
129
|
+
"content": function_response,
|
130
|
+
}
|
131
|
+
ctx.print(tool_call_message)
|
132
|
+
messages.append(tool_call_message)
|
133
|
+
continue
|
134
|
+
return response_message.content
|
135
|
+
|
136
|
+
def _get_model(self, ctx: AnyContext) -> str:
|
137
|
+
return get_str_attr(ctx, self._model, "ollama_chat/llama3.1", auto_render=True)
|
138
|
+
|
139
|
+
def _get_system_prompt(self, ctx: AnyContext) -> str:
|
140
|
+
return get_str_attr(
|
141
|
+
ctx, self._system_prompt, "You are a helpful assistant", auto_render=True
|
142
|
+
)
|
143
|
+
|
144
|
+
def _get_message(self, ctx: AnyContext) -> str:
|
145
|
+
return get_str_attr(ctx, self._message, "How are you?", auto_render=True)
|
146
|
+
|
147
|
+
def _get_tools(self, ctx: AnyContext) -> dict[str, Callable]:
|
148
|
+
if callable(self._tools):
|
149
|
+
return self._tools(ctx)
|
150
|
+
return self._tools
|
151
|
+
|
152
|
+
def _get_tool_schema(self, ctx: AnyContext) -> DictList:
|
153
|
+
if callable(self._tool_schema):
|
154
|
+
return self._tool_schema(ctx)
|
155
|
+
return self._tool_schema
|
156
|
+
|
157
|
+
def _get_history(self, ctx: AnyContext) -> DictList:
|
158
|
+
if callable(self._history):
|
159
|
+
return self._history(ctx)
|
160
|
+
return self._history
|