zrb 1.0.0a1__py3-none-any.whl → 1.0.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. zrb/__init__.py +48 -39
  2. zrb/__main__.py +3 -3
  3. zrb/attr/type.py +2 -1
  4. zrb/builtin/__init__.py +40 -2
  5. zrb/builtin/base64.py +32 -0
  6. zrb/builtin/git.py +156 -0
  7. zrb/builtin/git_subtree.py +88 -0
  8. zrb/builtin/group.py +34 -0
  9. zrb/builtin/llm.py +31 -0
  10. zrb/builtin/md5.py +34 -0
  11. zrb/builtin/project/__init__.py +0 -0
  12. zrb/builtin/project/add/__init__.py +0 -0
  13. zrb/builtin/project/add/fastapp.py +72 -0
  14. zrb/builtin/project/add/fastapp_template/.gitignore +4 -0
  15. zrb/builtin/project/add/fastapp_template/README.md +7 -0
  16. zrb/builtin/project/add/fastapp_template/__init__.py +0 -0
  17. zrb/builtin/project/add/fastapp_template/_zrb/config.py +17 -0
  18. zrb/builtin/project/add/fastapp_template/_zrb/group.py +16 -0
  19. zrb/builtin/project/add/fastapp_template/_zrb/helper.py +97 -0
  20. zrb/builtin/project/add/fastapp_template/_zrb/main.py +132 -0
  21. zrb/builtin/project/add/fastapp_template/_zrb/venv_task.py +22 -0
  22. zrb/builtin/project/add/fastapp_template/common/__init__.py +0 -0
  23. zrb/builtin/project/add/fastapp_template/common/app.py +18 -0
  24. zrb/builtin/project/add/fastapp_template/common/db_engine.py +5 -0
  25. zrb/builtin/project/add/fastapp_template/common/db_repository.py +134 -0
  26. zrb/builtin/project/add/fastapp_template/common/error.py +8 -0
  27. zrb/builtin/project/add/fastapp_template/common/schema.py +5 -0
  28. zrb/builtin/project/add/fastapp_template/common/usecase.py +232 -0
  29. zrb/builtin/project/add/fastapp_template/config.py +29 -0
  30. zrb/builtin/project/add/fastapp_template/main.py +7 -0
  31. zrb/builtin/project/add/fastapp_template/migrate.py +3 -0
  32. zrb/builtin/project/add/fastapp_template/module/__init__.py +0 -0
  33. zrb/builtin/project/add/fastapp_template/module/auth/alembic.ini +117 -0
  34. zrb/builtin/project/add/fastapp_template/module/auth/client/api_client.py +7 -0
  35. zrb/builtin/project/add/fastapp_template/module/auth/client/base_client.py +27 -0
  36. zrb/builtin/project/add/fastapp_template/module/auth/client/direct_client.py +6 -0
  37. zrb/builtin/project/add/fastapp_template/module/auth/client/factory.py +9 -0
  38. zrb/builtin/project/add/fastapp_template/module/auth/migration/README +1 -0
  39. zrb/builtin/project/add/fastapp_template/module/auth/migration/env.py +108 -0
  40. zrb/builtin/project/add/fastapp_template/module/auth/migration/script.py.mako +26 -0
  41. zrb/builtin/project/add/fastapp_template/module/auth/migration/versions/3093c7336477_add_user_table.py +37 -0
  42. zrb/builtin/project/add/fastapp_template/module/auth/migration_metadata.py +6 -0
  43. zrb/builtin/project/add/fastapp_template/module/auth/route.py +22 -0
  44. zrb/builtin/project/add/fastapp_template/module/auth/service/__init__.py +0 -0
  45. zrb/builtin/project/add/fastapp_template/module/auth/service/user/__init__.py +0 -0
  46. zrb/builtin/project/add/fastapp_template/module/auth/service/user/repository/__init__.py +0 -0
  47. zrb/builtin/project/add/fastapp_template/module/auth/service/user/repository/db_repository.py +39 -0
  48. zrb/builtin/project/add/fastapp_template/module/auth/service/user/repository/factory.py +13 -0
  49. zrb/builtin/project/add/fastapp_template/module/auth/service/user/repository/repository.py +34 -0
  50. zrb/builtin/project/add/fastapp_template/module/auth/service/user/usecase.py +45 -0
  51. zrb/builtin/project/add/fastapp_template/module/gateway/alembic.ini +117 -0
  52. zrb/builtin/project/add/fastapp_template/module/gateway/migration/README +1 -0
  53. zrb/builtin/project/add/fastapp_template/module/gateway/migration/env.py +108 -0
  54. zrb/builtin/project/add/fastapp_template/module/gateway/migration/script.py.mako +26 -0
  55. zrb/builtin/project/add/fastapp_template/module/gateway/migration/versions/.gitkeep +0 -0
  56. zrb/builtin/project/add/fastapp_template/module/gateway/migration_metadata.py +3 -0
  57. zrb/builtin/project/add/fastapp_template/module/gateway/route.py +27 -0
  58. zrb/builtin/project/add/fastapp_template/requirements.txt +6 -0
  59. zrb/builtin/project/add/fastapp_template/schema/__init__.py +0 -0
  60. zrb/builtin/project/add/fastapp_template/schema/role.py +31 -0
  61. zrb/builtin/project/add/fastapp_template/schema/user.py +31 -0
  62. zrb/builtin/project/add/fastapp_template/template.env +2 -0
  63. zrb/builtin/project/create/__init__.py +0 -0
  64. zrb/builtin/project/create/create.py +41 -0
  65. zrb/builtin/project/create/project-template/README.md +3 -0
  66. zrb/builtin/project/create/project-template/zrb_init.py +7 -0
  67. zrb/builtin/python.py +11 -0
  68. zrb/builtin/shell/__init__.py +0 -5
  69. zrb/builtin/shell/autocomplete/__init__.py +0 -9
  70. zrb/builtin/shell/autocomplete/bash.py +5 -6
  71. zrb/builtin/shell/autocomplete/subcmd.py +7 -8
  72. zrb/builtin/shell/autocomplete/zsh.py +5 -6
  73. zrb/builtin/todo.py +186 -0
  74. zrb/callback/any_callback.py +1 -1
  75. zrb/callback/callback.py +5 -5
  76. zrb/cmd/cmd_val.py +2 -2
  77. zrb/config.py +4 -1
  78. zrb/content_transformer/any_content_transformer.py +1 -1
  79. zrb/content_transformer/content_transformer.py +2 -2
  80. zrb/context/any_context.py +5 -1
  81. zrb/context/any_shared_context.py +3 -3
  82. zrb/context/context.py +15 -9
  83. zrb/context/shared_context.py +9 -8
  84. zrb/env/__init__.py +0 -3
  85. zrb/env/any_env.py +2 -2
  86. zrb/env/env.py +4 -5
  87. zrb/env/env_file.py +4 -4
  88. zrb/env/env_map.py +4 -4
  89. zrb/group/__init__.py +0 -3
  90. zrb/group/any_group.py +3 -3
  91. zrb/group/group.py +7 -6
  92. zrb/input/any_input.py +1 -1
  93. zrb/input/base_input.py +4 -4
  94. zrb/input/bool_input.py +5 -5
  95. zrb/input/float_input.py +3 -3
  96. zrb/input/int_input.py +3 -3
  97. zrb/input/option_input.py +51 -0
  98. zrb/input/password_input.py +2 -2
  99. zrb/input/str_input.py +1 -1
  100. zrb/input/text_input.py +12 -10
  101. zrb/runner/cli.py +79 -44
  102. zrb/runner/web_app/group_info_ui/controller.py +7 -8
  103. zrb/runner/web_app/group_info_ui/view.html +2 -2
  104. zrb/runner/web_app/home_page/controller.py +7 -6
  105. zrb/runner/web_app/home_page/view.html +2 -2
  106. zrb/runner/web_app/task_ui/controller.py +13 -13
  107. zrb/runner/web_app/task_ui/partial/common-util.js +37 -0
  108. zrb/runner/web_app/task_ui/partial/main.js +9 -2
  109. zrb/runner/web_app/task_ui/partial/show-existing-session.js +20 -5
  110. zrb/runner/web_app/task_ui/partial/visualize-history.js +1 -41
  111. zrb/runner/web_app/task_ui/view.html +4 -2
  112. zrb/runner/web_server.py +137 -211
  113. zrb/runner/web_util.py +5 -35
  114. zrb/session/any_session.py +13 -7
  115. zrb/session/session.py +80 -41
  116. zrb/session_state_log/session_state_log.py +7 -5
  117. zrb/session_state_logger/any_session_state_logger.py +1 -1
  118. zrb/session_state_logger/default_session_state_logger.py +2 -2
  119. zrb/session_state_logger/file_session_state_logger.py +19 -27
  120. zrb/task/any_task.py +8 -3
  121. zrb/task/base_task.py +47 -33
  122. zrb/task/base_trigger.py +11 -12
  123. zrb/task/cmd_task.py +55 -43
  124. zrb/task/http_check.py +8 -8
  125. zrb/task/llm_task.py +160 -0
  126. zrb/task/make_task.py +9 -9
  127. zrb/task/rsync_task.py +7 -7
  128. zrb/task/scaffolder.py +14 -11
  129. zrb/task/scheduler.py +6 -7
  130. zrb/task/task.py +1 -1
  131. zrb/task/tcp_check.py +8 -8
  132. zrb/util/attr.py +19 -3
  133. zrb/util/cli/style.py +71 -2
  134. zrb/util/cli/subcommand.py +2 -2
  135. zrb/util/codemod/__init__.py +0 -0
  136. zrb/util/codemod/add_code_to_class.py +35 -0
  137. zrb/util/codemod/add_code_to_function.py +36 -0
  138. zrb/util/codemod/add_code_to_method.py +55 -0
  139. zrb/util/codemod/add_key_to_dict.py +51 -0
  140. zrb/util/codemod/add_param_to_function_call.py +39 -0
  141. zrb/util/codemod/add_property_to_class.py +55 -0
  142. zrb/util/git.py +156 -0
  143. zrb/util/git_subtree.py +94 -0
  144. zrb/util/group.py +2 -2
  145. zrb/util/llm/tool.py +63 -0
  146. zrb/util/string/conversion.py +7 -0
  147. zrb/util/todo.py +135 -0
  148. {zrb-1.0.0a1.dist-info → zrb-1.0.0a3.dist-info}/METADATA +11 -7
  149. zrb-1.0.0a3.dist-info/RECORD +194 -0
  150. zrb/builtin/shell/_group.py +0 -9
  151. zrb/builtin/shell/autocomplete/_group.py +0 -6
  152. zrb/runner/web_app/any_request_handler.py +0 -24
  153. zrb/runner/web_server.bak.py +0 -208
  154. zrb-1.0.0a1.dist-info/RECORD +0 -120
  155. {zrb-1.0.0a1.dist-info → zrb-1.0.0a3.dist-info}/WHEEL +0 -0
  156. {zrb-1.0.0a1.dist-info → zrb-1.0.0a3.dist-info}/entry_points.txt +0 -0
@@ -1,9 +1,8 @@
1
1
  import datetime
2
- import json
3
2
  import os
4
3
 
5
- from ..session_state_log.session_state_log import SessionStateLog, SessionStateLogList
6
- from .any_session_state_logger import AnySessionStateLogger
4
+ from zrb.session_state_log.session_state_log import SessionStateLog, SessionStateLogList
5
+ from zrb.session_state_logger.any_session_state_logger import AnySessionStateLogger
7
6
 
8
7
 
9
8
  class FileSessionStateLogger(AnySessionStateLogger):
@@ -12,21 +11,24 @@ class FileSessionStateLogger(AnySessionStateLogger):
12
11
  self._session_log_dir = session_log_dir
13
12
 
14
13
  def write(self, session_log: SessionStateLog):
15
- session_file_path = self._get_session_file_path(session_log["name"])
14
+ session_file_path = self._get_session_file_path(session_log.name)
15
+ session_dir_path = os.path.dirname(session_file_path)
16
+ if not os.path.isdir(session_dir_path):
17
+ os.makedirs(session_dir_path, exist_ok=True)
16
18
  with open(session_file_path, "w") as f:
17
- f.write(json.dumps(session_log))
18
- start_time = self._get_start_time(session_log)
19
- if start_time is None:
19
+ f.write(session_log.model_dump_json())
20
+ start_time = session_log.start_time
21
+ if start_time == "":
20
22
  return
21
- timeline_dir_path = self._get_timeline_dir_path(session_log, start_time)
23
+ timeline_dir_path = self._get_timeline_dir_path(session_log)
22
24
  os.makedirs(timeline_dir_path, exist_ok=True)
23
- with open(os.path.join(timeline_dir_path, session_log["name"]), "w"):
25
+ with open(os.path.join(timeline_dir_path, session_log.name), "w"):
24
26
  pass
25
27
 
26
28
  def read(self, session_name: str) -> SessionStateLog:
27
29
  session_file_path = self._get_session_file_path(session_name)
28
30
  with open(session_file_path, "r") as f:
29
- return json.loads(f.read())
31
+ return SessionStateLog.model_validate_json(f.read())
30
32
 
31
33
  def list(
32
34
  self,
@@ -59,21 +61,20 @@ class FileSessionStateLogger(AnySessionStateLogger):
59
61
  paginated_sessions = matching_sessions[start_index:end_index]
60
62
  # Extract session logs from the sorted list of tuples
61
63
  data = [session_log for _, session_log in paginated_sessions]
62
- return {"total": total, "data": data}
64
+ return SessionStateLogList(total=total, data=data)
63
65
 
64
66
  def _get_session_file_path(self, session_name: str) -> str:
65
67
  return os.path.join(self._session_log_dir, f"{session_name}.json")
66
68
 
67
- def _get_timeline_dir_path(
68
- self, session_log: SessionStateLog, start_time: datetime.datetime
69
- ) -> str:
69
+ def _get_timeline_dir_path(self, session_log: SessionStateLog) -> str:
70
+ start_time = self._get_start_time(session_log)
70
71
  year = start_time.year
71
72
  month = start_time.month
72
73
  day = start_time.day
73
74
  hour = start_time.hour
74
75
  minute = start_time.minute
75
76
  second = start_time.second
76
- paths = session_log["path"] + [
77
+ paths = session_log.path + [
77
78
  f"{year}",
78
79
  f"{month}",
79
80
  f"{day}",
@@ -84,15 +85,6 @@ class FileSessionStateLogger(AnySessionStateLogger):
84
85
  return os.path.join(self._session_log_dir, "_timeline", *paths)
85
86
 
86
87
  def _get_start_time(self, session_log: SessionStateLog) -> datetime.datetime:
87
- result: datetime.datetime | None = None
88
- for task_status in session_log["task_status"].values():
89
- histories = task_status["history"]
90
- if len(histories) == 0:
91
- continue
92
- first_history = histories[0]
93
- first_time = datetime.datetime.strptime(
94
- first_history["time"], "%Y-%m-%d %H:%M:%S.%f"
95
- )
96
- if result is None or first_time < result:
97
- result = first_time
98
- return result
88
+ return datetime.datetime.strptime(
89
+ session_log.start_time, "%Y-%m-%d %H:%M:%S.%f"
90
+ )
zrb/task/any_task.py CHANGED
@@ -3,11 +3,12 @@ from __future__ import annotations # Enables forward references
3
3
  from abc import ABC, abstractmethod
4
4
  from typing import TYPE_CHECKING, Any
5
5
 
6
- from ..env.any_env import AnyEnv
7
- from ..input.any_input import AnyInput
6
+ from zrb.env.any_env import AnyEnv
7
+ from zrb.input.any_input import AnyInput
8
8
 
9
9
  if TYPE_CHECKING:
10
- from ..session import session
10
+ from zrb.context import any_context
11
+ from zrb.session import session
11
12
 
12
13
 
13
14
  class AnyTask(ABC):
@@ -89,6 +90,10 @@ class AnyTask(ABC):
89
90
  """
90
91
  pass
91
92
 
93
+ @abstractmethod
94
+ def get_ctx(self, session: session.AnySession) -> any_context.AnyContext:
95
+ pass
96
+
92
97
  @abstractmethod
93
98
  def run(
94
99
  self, session: session.AnySession | None = None, str_kwargs: dict[str, str] = {}
zrb/task/base_task.py CHANGED
@@ -3,17 +3,17 @@ import os
3
3
  from collections.abc import Callable
4
4
  from typing import Any
5
5
 
6
- from ..attr.type import BoolAttr, fstring
7
- from ..context.any_context import AnyContext
8
- from ..context.shared_context import AnySharedContext, SharedContext
9
- from ..env.any_env import AnyEnv
10
- from ..input.any_input import AnyInput
11
- from ..session.any_session import AnySession
12
- from ..session.session import Session
13
- from ..util.attr import get_bool_attr
14
- from ..util.run import run_async
15
- from ..xcom.xcom import Xcom
16
- from .any_task import AnyTask
6
+ from zrb.attr.type import BoolAttr, fstring
7
+ from zrb.context.any_context import AnyContext
8
+ from zrb.context.shared_context import AnySharedContext, SharedContext
9
+ from zrb.env.any_env import AnyEnv
10
+ from zrb.input.any_input import AnyInput
11
+ from zrb.session.any_session import AnySession
12
+ from zrb.session.session import Session
13
+ from zrb.task.any_task import AnyTask
14
+ from zrb.util.attr import get_bool_attr
15
+ from zrb.util.run import run_async
16
+ from zrb.xcom.xcom import Xcom
17
17
 
18
18
 
19
19
  class BaseTask(AnyTask):
@@ -24,8 +24,8 @@ class BaseTask(AnyTask):
24
24
  icon: str | None = None,
25
25
  description: str | None = None,
26
26
  cli_only: bool = False,
27
- input: list[AnyInput] | AnyInput | None = None,
28
- env: list[AnyEnv] | AnyEnv | None = None,
27
+ input: list[AnyInput | None] | AnyInput | None = None,
28
+ env: list[AnyEnv | None] | AnyEnv | None = None,
29
29
  action: fstring | Callable[[AnyContext], Any] | None = None,
30
30
  execute_condition: BoolAttr = True,
31
31
  retries: int = 2,
@@ -63,16 +63,22 @@ class BaseTask(AnyTask):
63
63
  return f"<{self.__class__.__name__} name={self._name}>"
64
64
 
65
65
  def __rshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask:
66
- if isinstance(other, AnyTask):
67
- other.append_upstreams(self)
68
- elif isinstance(other, list):
69
- for task in other:
70
- task.append_upstreams(self)
71
- return other
66
+ try:
67
+ if isinstance(other, AnyTask):
68
+ other.append_upstreams(self)
69
+ elif isinstance(other, list):
70
+ for task in other:
71
+ task.append_upstreams(self)
72
+ return other
73
+ except Exception as e:
74
+ raise ValueError(f"Invalid operation {self} >> {other}: {e}")
72
75
 
73
76
  def __lshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask:
74
- self.append_upstreams(other)
75
- return self
77
+ try:
78
+ self.append_upstreams(other)
79
+ return self
80
+ except Exception as e:
81
+ raise ValueError(f"Invalid operation {self} << {other}: {e}")
76
82
 
77
83
  @property
78
84
  def name(self) -> str:
@@ -103,7 +109,7 @@ class BaseTask(AnyTask):
103
109
  envs.append(self._envs)
104
110
  elif self._envs is not None:
105
111
  envs += self._envs
106
- return envs
112
+ return [env for env in envs if env is not None]
107
113
 
108
114
  @property
109
115
  def inputs(self) -> list[AnyInput]:
@@ -112,7 +118,7 @@ class BaseTask(AnyTask):
112
118
  self.__combine_inputs(inputs, upstream.inputs)
113
119
  if self._inputs is not None:
114
120
  self.__combine_inputs(inputs, self._inputs)
115
- return inputs
121
+ return [task_input for task_input in inputs if task_input is not None]
116
122
 
117
123
  def __combine_inputs(
118
124
  self, inputs: list[AnyInput], other_inputs: list[AnyInput] | AnyInput
@@ -120,7 +126,11 @@ class BaseTask(AnyTask):
120
126
  input_names = [task_input.name for task_input in inputs]
121
127
  if isinstance(other_inputs, AnyInput):
122
128
  other_inputs = [other_inputs]
129
+ elif other_inputs is None:
130
+ other_inputs = []
123
131
  for task_input in other_inputs:
132
+ if task_input is None:
133
+ continue
124
134
  if task_input.name not in input_names:
125
135
  inputs.append(task_input)
126
136
 
@@ -163,6 +173,13 @@ class BaseTask(AnyTask):
163
173
  if upstream not in self._upstreams:
164
174
  self._upstreams.append(upstream)
165
175
 
176
+ def get_ctx(self, session: AnySession) -> AnyContext:
177
+ ctx = session.get_ctx(self)
178
+ # Enhance session ctx with current task env
179
+ for env in self.envs:
180
+ env.update_context(ctx)
181
+ return ctx
182
+
166
183
  def run(
167
184
  self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
168
185
  ) -> Any:
@@ -194,9 +211,6 @@ class BaseTask(AnyTask):
194
211
  if key not in shared_context._env
195
212
  }
196
213
  shared_context._env.update(os_env_map)
197
- # Inject environment from task's envs
198
- for env in self.envs:
199
- env.update_shared_context(shared_context)
200
214
 
201
215
  async def exec_root_tasks(self, session: AnySession):
202
216
  session.set_main_task(self)
@@ -219,12 +233,12 @@ class BaseTask(AnyTask):
219
233
  except IndexError:
220
234
  return None
221
235
  except asyncio.CancelledError:
222
- ctx = session.get_ctx(self)
236
+ ctx = self.get_ctx(session)
223
237
  ctx.log_info("Session terminated")
224
238
  finally:
225
239
  session.terminate()
226
240
  session.state_logger.write(session.as_state_log())
227
- ctx = session.get_ctx(self)
241
+ ctx = self.get_ctx(session)
228
242
  ctx.log_debug(session)
229
243
 
230
244
  async def _log_session_state(self, session: AnySession):
@@ -247,7 +261,7 @@ class BaseTask(AnyTask):
247
261
  return await asyncio.gather(*next_coros)
248
262
 
249
263
  async def exec(self, session: AnySession):
250
- ctx = session.get_ctx(self)
264
+ ctx = self.get_ctx(session)
251
265
  if not session.is_allowed_to_run(self):
252
266
  # Task is not allowed to run, skip it for now.
253
267
  # This will be triggered later
@@ -262,11 +276,11 @@ class BaseTask(AnyTask):
262
276
  await run_async(self.__exec_action_until_ready(session))
263
277
 
264
278
  def __get_execute_condition(self, session: Session) -> bool:
265
- ctx = session.get_ctx(self)
279
+ ctx = self.get_ctx(session)
266
280
  return get_bool_attr(ctx, self._execute_condition, True, auto_render=True)
267
281
 
268
282
  async def __exec_action_until_ready(self, session: AnySession):
269
- ctx = session.get_ctx(self)
283
+ ctx = self.get_ctx(session)
270
284
  readiness_checks = self.readiness_checks
271
285
  if len(readiness_checks) == 0:
272
286
  ctx.log_info("No readiness checks")
@@ -301,7 +315,7 @@ class BaseTask(AnyTask):
301
315
  async def __exec_monitoring(self, session: AnySession, action_coro: asyncio.Task):
302
316
  readiness_checks = self.readiness_checks
303
317
  failure_count = 0
304
- ctx = session.get_ctx(self)
318
+ ctx = self.get_ctx(session)
305
319
  while not session.is_terminated:
306
320
  await asyncio.sleep(self._readiness_check_period)
307
321
  if failure_count < self._readiness_failure_threshold:
@@ -343,7 +357,7 @@ class BaseTask(AnyTask):
343
357
  ctx.log_info("Continue monitoring")
344
358
 
345
359
  async def __exec_action_and_retry(self, session: AnySession) -> Any:
346
- ctx = session.get_ctx(self)
360
+ ctx = self.get_ctx(session)
347
361
  max_attempt = self._retries + 1
348
362
  ctx.set_max_attempt(max_attempt)
349
363
  for attempt in range(max_attempt):
zrb/task/base_trigger.py CHANGED
@@ -2,21 +2,20 @@ import asyncio
2
2
  from collections.abc import Callable
3
3
  from typing import Any
4
4
 
5
+ from zrb.attr.type import fstring
6
+ from zrb.callback.any_callback import AnyCallback
5
7
  from zrb.context.any_context import AnyContext
6
8
  from zrb.context.any_shared_context import AnySharedContext
9
+ from zrb.context.shared_context import SharedContext
10
+ from zrb.dot_dict.dot_dict import DotDict
7
11
  from zrb.env.any_env import AnyEnv
8
12
  from zrb.input.any_input import AnyInput
13
+ from zrb.session.any_session import AnySession
14
+ from zrb.session.session import Session
9
15
  from zrb.task.any_task import AnyTask
10
-
11
- from ..attr.type import fstring
12
- from ..callback.any_callback import AnyCallback
13
- from ..context.shared_context import SharedContext
14
- from ..dot_dict.dot_dict import DotDict
15
- from ..session.any_session import AnySession
16
- from ..session.session import Session
17
- from ..util.cli.style import CYAN
18
- from ..xcom.xcom import Xcom
19
- from .base_task import BaseTask
16
+ from zrb.task.base_task import BaseTask
17
+ from zrb.util.cli.style import CYAN
18
+ from zrb.xcom.xcom import Xcom
20
19
 
21
20
 
22
21
  class BaseTrigger(BaseTask):
@@ -28,8 +27,8 @@ class BaseTrigger(BaseTask):
28
27
  icon: str | None = None,
29
28
  description: str | None = None,
30
29
  cli_only: bool = False,
31
- input: list[AnyInput] | AnyInput | None = None,
32
- env: list[AnyEnv] | AnyEnv | None = None,
30
+ input: list[AnyInput | None] | AnyInput | None = None,
31
+ env: list[AnyEnv | None] | AnyEnv | None = None,
33
32
  action: fstring | Callable[[AnyContext], Any] | None = None,
34
33
  execute_condition: bool | str | Callable[[AnySharedContext], bool] = True,
35
34
  queue_name: fstring | None = None,
zrb/task/cmd_task.py CHANGED
@@ -2,17 +2,17 @@ import asyncio
2
2
  import os
3
3
  import sys
4
4
 
5
- from ..attr.type import BoolAttr, IntAttr, StrAttr
6
- from ..cmd.cmd_result import CmdResult
7
- from ..cmd.cmd_val import AnyCmdVal, CmdVal, SingleCmdVal
8
- from ..config import DEFAULT_SHELL
9
- from ..context.any_context import AnyContext
10
- from ..env.any_env import AnyEnv
11
- from ..input.any_input import AnyInput
12
- from ..util.attr import get_int_attr, get_str_attr
13
- from ..util.cmd.remote import get_remote_cmd_script
14
- from .any_task import AnyTask
15
- from .base_task import BaseTask
5
+ from zrb.attr.type import BoolAttr, IntAttr, StrAttr
6
+ from zrb.cmd.cmd_result import CmdResult
7
+ from zrb.cmd.cmd_val import AnyCmdVal, CmdVal, SingleCmdVal
8
+ from zrb.config import DEFAULT_SHELL
9
+ from zrb.context.any_context import AnyContext
10
+ from zrb.env.any_env import AnyEnv
11
+ from zrb.input.any_input import AnyInput
12
+ from zrb.task.any_task import AnyTask
13
+ from zrb.task.base_task import BaseTask
14
+ from zrb.util.attr import get_int_attr, get_str_attr
15
+ from zrb.util.cmd.remote import get_remote_cmd_script
16
16
 
17
17
 
18
18
  class CmdTask(BaseTask):
@@ -23,8 +23,8 @@ class CmdTask(BaseTask):
23
23
  icon: str | None = None,
24
24
  description: str | None = None,
25
25
  cli_only: bool = False,
26
- input: list[AnyInput] | AnyInput | None = None,
27
- env: list[AnyEnv] | AnyEnv | None = None,
26
+ input: list[AnyInput | None] | AnyInput | None = None,
27
+ env: list[AnyEnv | None] | AnyEnv | None = None,
28
28
  shell: StrAttr | None = None,
29
29
  auto_render_shell: bool = True,
30
30
  shell_flag: StrAttr | None = None,
@@ -105,44 +105,56 @@ class CmdTask(BaseTask):
105
105
  Returns:
106
106
  Any: The result of the action execution.
107
107
  """
108
+ ctx.log_info("Running script")
108
109
  cmd_script = self._get_cmd_script(ctx)
109
- cwd = self._get_cwd(ctx)
110
+ ctx.log_debug(f"Script: {self.__get_multiline_repr(cmd_script)}")
110
111
  shell = self._get_shell(ctx)
111
- shell_flag = self._get_shell_flag(ctx)
112
- ctx.log_info("Running script")
113
112
  ctx.log_debug(f"Shell: {shell}")
114
- ctx.log_debug(f"Script: {self.__get_multiline_repr(cmd_script)}")
113
+ shell_flag = self._get_shell_flag(ctx)
114
+ cwd = self._get_cwd(ctx)
115
115
  ctx.log_debug(f"Working directory: {cwd}")
116
- cmd_process = await asyncio.create_subprocess_exec(
117
- shell,
118
- shell_flag,
119
- cmd_script,
120
- cwd=cwd,
121
- stdin=sys.stdin if sys.stdin.isatty() else None,
122
- stdout=asyncio.subprocess.PIPE,
123
- stderr=asyncio.subprocess.PIPE,
124
- env=self.__get_env_map(ctx),
125
- bufsize=0,
126
- )
127
- stdout_task = asyncio.create_task(
128
- self.__read_stream(cmd_process.stdout, ctx.print, self._max_output_line)
129
- )
130
- stderr_task = asyncio.create_task(
131
- self.__read_stream(cmd_process.stderr, ctx.print, self._max_error_line)
132
- )
133
- # Wait for process to complete and gather stdout/stderr
134
- return_code = await cmd_process.wait()
135
- stdout = await stdout_task
136
- stderr = await stderr_task
137
- # Check for errors
138
- if return_code != 0:
139
- ctx.log_error(f"Exit status: {return_code}")
140
- raise Exception(f"Process {self._name} exited ({return_code}): {stderr}")
141
- return CmdResult(stdout, stderr)
116
+ env_map = self.__get_env_map(ctx)
117
+ ctx.log_debug(f"Environment map: {env_map}")
118
+ cmd_process = None
119
+ try:
120
+ cmd_process = await asyncio.create_subprocess_exec(
121
+ shell,
122
+ shell_flag,
123
+ cmd_script,
124
+ cwd=cwd,
125
+ stdin=sys.stdin if sys.stdin.isatty() else None,
126
+ stdout=asyncio.subprocess.PIPE,
127
+ stderr=asyncio.subprocess.PIPE,
128
+ env=env_map,
129
+ bufsize=0,
130
+ )
131
+ stdout_task = asyncio.create_task(
132
+ self.__read_stream(cmd_process.stdout, ctx.print, self._max_output_line)
133
+ )
134
+ stderr_task = asyncio.create_task(
135
+ self.__read_stream(cmd_process.stderr, ctx.print, self._max_error_line)
136
+ )
137
+ # Wait for process to complete and gather stdout/stderr
138
+ return_code = await cmd_process.wait()
139
+ stdout = await stdout_task
140
+ stderr = await stderr_task
141
+ # Check for errors
142
+ if return_code != 0:
143
+ ctx.log_error(f"Exit status: {return_code}")
144
+ raise Exception(
145
+ f"Process {self._name} exited ({return_code}): {stderr}"
146
+ )
147
+ ctx.log_info(f"Exit status: {return_code}")
148
+ return CmdResult(stdout, stderr)
149
+ finally:
150
+ if cmd_process is not None and cmd_process.returncode is None:
151
+ cmd_process.terminate()
142
152
 
143
153
  def __get_env_map(self, ctx: AnyContext) -> dict[str, str]:
144
154
  envs = {key: val for key, val in ctx.env.items()}
145
155
  envs["_ZRB_SSH_PASSWORD"] = self._get_remote_password(ctx)
156
+ envs["PYTHONBUFFERED"] = "1"
157
+ return envs
146
158
 
147
159
  async def __read_stream(self, stream, log_method, max_lines):
148
160
  lines = []
zrb/task/http_check.py CHANGED
@@ -3,14 +3,14 @@ from collections.abc import Callable
3
3
 
4
4
  import requests
5
5
 
6
- from ..attr.type import StrAttr
7
- from ..context.any_context import AnyContext
8
- from ..context.context import Context
9
- from ..env.any_env import AnyEnv
10
- from ..input.any_input import AnyInput
11
- from ..util.attr import get_str_attr
12
- from .any_task import AnyTask
13
- from .base_task import BaseTask
6
+ from zrb.attr.type import StrAttr
7
+ from zrb.context.any_context import AnyContext
8
+ from zrb.context.context import Context
9
+ from zrb.env.any_env import AnyEnv
10
+ from zrb.input.any_input import AnyInput
11
+ from zrb.task.any_task import AnyTask
12
+ from zrb.task.base_task import BaseTask
13
+ from zrb.util.attr import get_str_attr
14
14
 
15
15
 
16
16
  class HttpCheck(BaseTask):
zrb/task/llm_task.py ADDED
@@ -0,0 +1,160 @@
1
+ import json
2
+ from collections.abc import Callable
3
+ from typing import Any
4
+
5
+ from zrb.attr.type import StrAttr
6
+ from zrb.config import LLM_MODEL, LLM_SYSTEM_PROMPT
7
+ from zrb.context.any_context import AnyContext
8
+ from zrb.context.any_shared_context import AnySharedContext
9
+ from zrb.env.any_env import AnyEnv
10
+ from zrb.input.any_input import AnyInput
11
+ from zrb.task.any_task import AnyTask
12
+ from zrb.task.base_task import BaseTask
13
+ from zrb.util.attr import get_str_attr
14
+ from zrb.util.llm.tool import callable_to_tool_schema
15
+
16
+ DictList = list[dict[str, Any]]
17
+
18
+
19
+ def scratchpad(thought: str) -> str:
20
+ """Use this tool to note your thought and planning"""
21
+ return thought
22
+
23
+
24
+ class LLMTask(BaseTask):
25
+
26
+ def __init__(
27
+ self,
28
+ name: str,
29
+ color: int | None = None,
30
+ icon: str | None = None,
31
+ description: str | None = None,
32
+ cli_only: bool = False,
33
+ input: list[AnyInput | None] | AnyInput | None = None,
34
+ env: list[AnyEnv | None] | AnyEnv | None = None,
35
+ model: StrAttr | None = LLM_MODEL,
36
+ system_prompt: StrAttr | None = LLM_SYSTEM_PROMPT,
37
+ message: StrAttr | None = None,
38
+ tools: (
39
+ dict[str, Callable] | Callable[[AnySharedContext], dict[str, Callable]]
40
+ ) = {},
41
+ tool_schema: DictList | Callable[[AnySharedContext], DictList] = [],
42
+ history: DictList | Callable[[AnySharedContext], DictList] = [],
43
+ execute_condition: bool | str | Callable[[AnySharedContext], bool] = True,
44
+ retries: int = 2,
45
+ retry_period: float = 0,
46
+ readiness_check: list[AnyTask] | AnyTask | None = None,
47
+ readiness_check_delay: float = 0.5,
48
+ readiness_check_period: float = 5,
49
+ readiness_failure_threshold: int = 1,
50
+ readiness_timeout: int = 60,
51
+ monitor_readiness: bool = False,
52
+ upstream: list[AnyTask] | AnyTask | None = None,
53
+ fallback: list[AnyTask] | AnyTask | None = None,
54
+ ):
55
+ super().__init__(
56
+ name=name,
57
+ color=color,
58
+ icon=icon,
59
+ description=description,
60
+ cli_only=cli_only,
61
+ input=input,
62
+ env=env,
63
+ execute_condition=execute_condition,
64
+ retries=retries,
65
+ retry_period=retry_period,
66
+ readiness_check=readiness_check,
67
+ readiness_check_delay=readiness_check_delay,
68
+ readiness_check_period=readiness_check_period,
69
+ readiness_failure_threshold=readiness_failure_threshold,
70
+ readiness_timeout=readiness_timeout,
71
+ monitor_readiness=monitor_readiness,
72
+ upstream=upstream,
73
+ fallback=fallback,
74
+ )
75
+ self._model = model
76
+ self._system_prompt = system_prompt
77
+ self._message = message
78
+ self._tools = tools
79
+ self._tool_schema = tool_schema
80
+ self._history = history
81
+
82
+ async def _exec_action(self, ctx: AnyContext) -> Any:
83
+ from litellm import completion
84
+
85
+ system_prompt = self._get_system_prompt(ctx)
86
+ ctx.log_debug("SYSTEM PROMPT", system_prompt)
87
+ history = self._get_history(ctx)
88
+ ctx.log_debug("HISTORY PROMPT", history)
89
+ user_message = self._get_message(ctx)
90
+ ctx.log_debug("USER MESSAGE", user_message)
91
+ messages = (
92
+ [{"role": "system", "content": system_prompt}]
93
+ + history
94
+ + [{"role": "user", "content": user_message}]
95
+ )
96
+ available_tools = self._get_tools(ctx)
97
+ available_tools["scratchpad"] = scratchpad
98
+ tool_schema = self._get_tool_schema(ctx)
99
+ for tool_name, tool in available_tools.items():
100
+ matched_tool_schema = [
101
+ schema
102
+ for schema in tool_schema
103
+ if "function" in schema
104
+ and "name" in schema["function"]
105
+ and schema["function"]["name"] == tool_name
106
+ ]
107
+ if len(matched_tool_schema) == 0:
108
+ tool_schema.append(callable_to_tool_schema(tool))
109
+ ctx.log_debug("TOOL SCHEMA", tool_schema)
110
+ while True:
111
+ response = completion(
112
+ model=self._get_model(ctx), messages=messages, tools=tool_schema
113
+ )
114
+ response_message = response.choices[0].message
115
+ ctx.print(response_message)
116
+ messages.append(response_message)
117
+ tool_calls = response_message.tool_calls
118
+ if tool_calls:
119
+ # noqa Reference: https://docs.litellm.ai/docs/completion/function_call#full-code---parallel-function-calling-with-gpt-35-turbo-1106
120
+ for tool_call in tool_calls:
121
+ function_name = tool_call.function.name
122
+ function_to_call = available_tools[function_name]
123
+ function_kwargs = json.loads(tool_call.function.arguments)
124
+ function_response = function_to_call(**function_kwargs)
125
+ tool_call_message = {
126
+ "tool_call_id": tool_call.id,
127
+ "role": "tool",
128
+ "name": function_name,
129
+ "content": function_response,
130
+ }
131
+ ctx.print(tool_call_message)
132
+ messages.append(tool_call_message)
133
+ continue
134
+ return response_message.content
135
+
136
+ def _get_model(self, ctx: AnyContext) -> str:
137
+ return get_str_attr(ctx, self._model, "ollama_chat/llama3.1", auto_render=True)
138
+
139
+ def _get_system_prompt(self, ctx: AnyContext) -> str:
140
+ return get_str_attr(
141
+ ctx, self._system_prompt, "You are a helpful assistant", auto_render=True
142
+ )
143
+
144
+ def _get_message(self, ctx: AnyContext) -> str:
145
+ return get_str_attr(ctx, self._message, "How are you?", auto_render=True)
146
+
147
+ def _get_tools(self, ctx: AnyContext) -> dict[str, Callable]:
148
+ if callable(self._tools):
149
+ return self._tools(ctx)
150
+ return self._tools
151
+
152
+ def _get_tool_schema(self, ctx: AnyContext) -> DictList:
153
+ if callable(self._tool_schema):
154
+ return self._tool_schema(ctx)
155
+ return self._tool_schema
156
+
157
+ def _get_history(self, ctx: AnyContext) -> DictList:
158
+ if callable(self._history):
159
+ return self._history(ctx)
160
+ return self._history