zrb 1.15.3__py3-none-any.whl → 1.21.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of zrb might be problematic. Click here for more details.

Files changed (108) hide show
  1. zrb/__init__.py +2 -6
  2. zrb/attr/type.py +10 -7
  3. zrb/builtin/__init__.py +2 -0
  4. zrb/builtin/git.py +12 -1
  5. zrb/builtin/group.py +31 -15
  6. zrb/builtin/llm/attachment.py +40 -0
  7. zrb/builtin/llm/chat_completion.py +274 -0
  8. zrb/builtin/llm/chat_session.py +126 -167
  9. zrb/builtin/llm/chat_session_cmd.py +288 -0
  10. zrb/builtin/llm/chat_trigger.py +79 -0
  11. zrb/builtin/llm/history.py +4 -4
  12. zrb/builtin/llm/llm_ask.py +217 -135
  13. zrb/builtin/llm/tool/api.py +74 -70
  14. zrb/builtin/llm/tool/cli.py +35 -21
  15. zrb/builtin/llm/tool/code.py +55 -73
  16. zrb/builtin/llm/tool/file.py +278 -344
  17. zrb/builtin/llm/tool/note.py +84 -0
  18. zrb/builtin/llm/tool/rag.py +27 -34
  19. zrb/builtin/llm/tool/sub_agent.py +54 -41
  20. zrb/builtin/llm/tool/web.py +74 -98
  21. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
  22. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
  23. zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
  24. zrb/builtin/searxng/config/settings.yml +5671 -0
  25. zrb/builtin/searxng/start.py +21 -0
  26. zrb/builtin/shell/autocomplete/bash.py +4 -3
  27. zrb/builtin/shell/autocomplete/zsh.py +4 -3
  28. zrb/config/config.py +202 -27
  29. zrb/config/default_prompt/file_extractor_system_prompt.md +109 -9
  30. zrb/config/default_prompt/interactive_system_prompt.md +24 -30
  31. zrb/config/default_prompt/persona.md +1 -1
  32. zrb/config/default_prompt/repo_extractor_system_prompt.md +31 -31
  33. zrb/config/default_prompt/repo_summarizer_system_prompt.md +27 -8
  34. zrb/config/default_prompt/summarization_prompt.md +57 -16
  35. zrb/config/default_prompt/system_prompt.md +36 -30
  36. zrb/config/llm_config.py +119 -23
  37. zrb/config/llm_context/config.py +127 -90
  38. zrb/config/llm_context/config_parser.py +1 -7
  39. zrb/config/llm_context/workflow.py +81 -0
  40. zrb/config/llm_rate_limitter.py +100 -47
  41. zrb/context/any_shared_context.py +7 -1
  42. zrb/context/context.py +8 -2
  43. zrb/context/shared_context.py +3 -7
  44. zrb/group/any_group.py +3 -3
  45. zrb/group/group.py +3 -3
  46. zrb/input/any_input.py +5 -1
  47. zrb/input/base_input.py +18 -6
  48. zrb/input/option_input.py +13 -1
  49. zrb/input/text_input.py +7 -24
  50. zrb/runner/cli.py +21 -20
  51. zrb/runner/common_util.py +24 -19
  52. zrb/runner/web_route/task_input_api_route.py +5 -5
  53. zrb/runner/web_util/user.py +7 -3
  54. zrb/session/any_session.py +12 -6
  55. zrb/session/session.py +39 -18
  56. zrb/task/any_task.py +24 -3
  57. zrb/task/base/context.py +17 -9
  58. zrb/task/base/execution.py +15 -8
  59. zrb/task/base/lifecycle.py +8 -4
  60. zrb/task/base/monitoring.py +12 -7
  61. zrb/task/base_task.py +69 -5
  62. zrb/task/base_trigger.py +12 -5
  63. zrb/task/llm/agent.py +128 -167
  64. zrb/task/llm/agent_runner.py +152 -0
  65. zrb/task/llm/config.py +39 -20
  66. zrb/task/llm/conversation_history.py +110 -29
  67. zrb/task/llm/conversation_history_model.py +4 -179
  68. zrb/task/llm/default_workflow/coding/workflow.md +41 -0
  69. zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
  70. zrb/task/llm/default_workflow/git/workflow.md +118 -0
  71. zrb/task/llm/default_workflow/golang/workflow.md +128 -0
  72. zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
  73. zrb/task/llm/default_workflow/java/workflow.md +146 -0
  74. zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
  75. zrb/task/llm/default_workflow/python/workflow.md +160 -0
  76. zrb/task/llm/default_workflow/researching/workflow.md +153 -0
  77. zrb/task/llm/default_workflow/rust/workflow.md +162 -0
  78. zrb/task/llm/default_workflow/shell/workflow.md +299 -0
  79. zrb/task/llm/file_replacement.py +206 -0
  80. zrb/task/llm/file_tool_model.py +57 -0
  81. zrb/task/llm/history_processor.py +206 -0
  82. zrb/task/llm/history_summarization.py +2 -193
  83. zrb/task/llm/print_node.py +184 -64
  84. zrb/task/llm/prompt.py +175 -179
  85. zrb/task/llm/subagent_conversation_history.py +41 -0
  86. zrb/task/llm/tool_wrapper.py +226 -85
  87. zrb/task/llm/workflow.py +76 -0
  88. zrb/task/llm_task.py +109 -71
  89. zrb/task/make_task.py +2 -3
  90. zrb/task/rsync_task.py +25 -10
  91. zrb/task/scheduler.py +4 -4
  92. zrb/util/attr.py +54 -39
  93. zrb/util/cli/markdown.py +12 -0
  94. zrb/util/cli/text.py +30 -0
  95. zrb/util/file.py +12 -3
  96. zrb/util/git.py +2 -2
  97. zrb/util/{llm/prompt.py → markdown.py} +2 -3
  98. zrb/util/string/conversion.py +1 -1
  99. zrb/util/truncate.py +23 -0
  100. zrb/util/yaml.py +204 -0
  101. zrb/xcom/xcom.py +10 -0
  102. {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/METADATA +38 -18
  103. {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/RECORD +105 -79
  104. {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/WHEEL +1 -1
  105. zrb/task/llm/default_workflow/coding.md +0 -24
  106. zrb/task/llm/default_workflow/copywriting.md +0 -17
  107. zrb/task/llm/default_workflow/researching.md +0 -18
  108. {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/entry_points.txt +0 -0
zrb/group/group.py CHANGED
@@ -33,15 +33,15 @@ class Group(AnyGroup):
33
33
  def subgroups(self) -> dict[str, AnyGroup]:
34
34
  names = list(self._groups.keys())
35
35
  names.sort()
36
- return {name: self._groups.get(name) for name in names}
36
+ return {name: self._groups[name] for name in names}
37
37
 
38
38
  @property
39
39
  def subtasks(self) -> dict[str, AnyTask]:
40
40
  alias = list(self._tasks.keys())
41
41
  alias.sort()
42
- return {name: self._tasks.get(name) for name in alias}
42
+ return {name: self._tasks[name] for name in alias}
43
43
 
44
- def add_group(self, group: AnyGroup | str, alias: str | None = None) -> AnyGroup:
44
+ def add_group(self, group: AnyGroup, alias: str | None = None) -> AnyGroup:
45
45
  real_group = Group(group) if isinstance(group, str) else group
46
46
  alias = alias if alias is not None else real_group.name
47
47
  self._groups[alias] = real_group
zrb/input/any_input.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
+ from typing import Any
2
3
 
3
4
  from zrb.context.any_shared_context import AnySharedContext
4
5
 
@@ -35,7 +36,10 @@ class AnyInput(ABC):
35
36
 
36
37
  @abstractmethod
37
38
  def update_shared_context(
38
- self, shared_ctx: AnySharedContext, str_value: str | None = None
39
+ self,
40
+ shared_ctx: AnySharedContext,
41
+ str_value: str | None = None,
42
+ value: Any = None,
39
43
  ):
40
44
  pass
41
45
 
zrb/input/base_input.py CHANGED
@@ -58,11 +58,15 @@ class BaseInput(AnyInput):
58
58
  return f'<input name="{name}" placeholder="{description}" value="{default}" />'
59
59
 
60
60
  def update_shared_context(
61
- self, shared_ctx: AnySharedContext, str_value: str | None = None
61
+ self,
62
+ shared_ctx: AnySharedContext,
63
+ str_value: str | None = None,
64
+ value: Any = None,
62
65
  ):
63
- if str_value is None:
64
- str_value = self.get_default_str(shared_ctx)
65
- value = self._parse_str_value(str_value)
66
+ if value is None:
67
+ if str_value is None:
68
+ str_value = self.get_default_str(shared_ctx)
69
+ value = self._parse_str_value(str_value)
66
70
  if self.name in shared_ctx.input:
67
71
  raise ValueError(f"Input already defined in the context: {self.name}")
68
72
  shared_ctx.input[self.name] = value
@@ -91,12 +95,20 @@ class BaseInput(AnyInput):
91
95
  default_str = self.get_default_str(shared_ctx)
92
96
  if default_str != "":
93
97
  prompt_message = f"{prompt_message} [{default_str}]"
94
- print(f"{prompt_message}: ", end="")
95
- value = input()
98
+ value = self._read_line(shared_ctx, prompt_message)
96
99
  if value.strip() == "":
97
100
  value = default_str
98
101
  return value
99
102
 
103
+ def _read_line(self, shared_ctx: AnySharedContext, prompt_message: str) -> str:
104
+ if not shared_ctx.is_tty:
105
+ print(f"{prompt_message}: ", end="")
106
+ return input()
107
+ from prompt_toolkit import PromptSession
108
+
109
+ reader = PromptSession()
110
+ return reader.prompt(f"{prompt_message}: ")
111
+
100
112
  def get_default_str(self, shared_ctx: AnySharedContext) -> str:
101
113
  """Get default value as str"""
102
114
  default_value = get_attr(
zrb/input/option_input.py CHANGED
@@ -47,9 +47,21 @@ class OptionInput(BaseInput):
47
47
  option_str = ", ".join(options)
48
48
  if default_value != "":
49
49
  prompt_message = f"{prompt_message} ({option_str}) [{default_value}]"
50
- value = input(f"{prompt_message}: ")
50
+ value = self._get_value_from_user_input(shared_ctx, prompt_message, options)
51
51
  if value.strip() != "" and value.strip() not in options:
52
52
  value = self._prompt_cli_str(shared_ctx)
53
53
  if value.strip() == "":
54
54
  value = default_value
55
55
  return value
56
+
57
+ def _get_value_from_user_input(
58
+ self, shared_ctx: AnySharedContext, prompt_message: str, options: list[str]
59
+ ) -> str:
60
+ from prompt_toolkit import PromptSession
61
+ from prompt_toolkit.completion import WordCompleter
62
+
63
+ if shared_ctx.is_tty:
64
+ reader = PromptSession()
65
+ option_completer = WordCompleter(options, ignore_case=True)
66
+ return reader.prompt(f"{prompt_message}: ", completer=option_completer)
67
+ return input(f"{prompt_message}: ")
zrb/input/text_input.py CHANGED
@@ -1,12 +1,9 @@
1
- import os
2
- import subprocess
3
- import tempfile
4
1
  from collections.abc import Callable
5
2
 
6
3
  from zrb.config.config import CFG
7
4
  from zrb.context.any_shared_context import AnySharedContext
8
5
  from zrb.input.base_input import BaseInput
9
- from zrb.util.file import read_file
6
+ from zrb.util.cli.text import edit_text
10
7
 
11
8
 
12
9
  class TextInput(BaseInput):
@@ -85,24 +82,10 @@ class TextInput(BaseInput):
85
82
  comment_prompt_message = (
86
83
  f"{self.comment_start}{prompt_message}{self.comment_end}"
87
84
  )
88
- comment_prompt_message_eol = f"{comment_prompt_message}\n"
89
85
  default_value = self.get_default_str(shared_ctx)
90
- with tempfile.NamedTemporaryFile(
91
- delete=False, suffix=self._extension
92
- ) as temp_file:
93
- temp_file_name = temp_file.name
94
- temp_file.write(comment_prompt_message_eol.encode())
95
- # Pre-fill with default content
96
- if default_value:
97
- temp_file.write(default_value.encode())
98
- temp_file.flush()
99
- subprocess.call([self.editor_cmd, temp_file_name])
100
- # Read the edited content
101
- edited_content = read_file(temp_file_name)
102
- parts = [
103
- text.strip() for text in edited_content.split(comment_prompt_message, 1)
104
- ]
105
- edited_content = "\n".join(parts).lstrip()
106
- os.remove(temp_file_name)
107
- print(f"{prompt_message}: {edited_content}")
108
- return edited_content
86
+ return edit_text(
87
+ prompt_message=comment_prompt_message,
88
+ value=default_value,
89
+ editor=self.editor_cmd,
90
+ extension=self._extension,
91
+ )
zrb/runner/cli.py CHANGED
@@ -7,7 +7,7 @@ from zrb.context.any_context import AnyContext
7
7
  from zrb.context.shared_context import SharedContext
8
8
  from zrb.group.any_group import AnyGroup
9
9
  from zrb.group.group import Group
10
- from zrb.runner.common_util import get_run_kwargs
10
+ from zrb.runner.common_util import get_task_str_kwargs
11
11
  from zrb.session.session import Session
12
12
  from zrb.session_state_logger.session_state_logger_factory import session_state_logger
13
13
  from zrb.task.any_task import AnyTask
@@ -38,23 +38,25 @@ class Cli(Group):
38
38
  def banner(self) -> str:
39
39
  return CFG.BANNER
40
40
 
41
- def run(self, args: list[str] = []):
42
- kwargs, args = self._extract_kwargs_from_args(args)
43
- node, node_path, args = extract_node_from_args(self, args)
41
+ def run(self, str_args: list[str] = []):
42
+ str_kwargs, str_args = self._extract_kwargs_from_args(str_args)
43
+ node, node_path, str_args = extract_node_from_args(self, str_args)
44
44
  if isinstance(node, AnyGroup):
45
45
  self._show_group_info(node)
46
46
  return
47
- if "h" in kwargs or "help" in kwargs:
47
+ if "h" in str_kwargs or "help" in str_kwargs:
48
48
  self._show_task_info(node)
49
49
  return
50
- run_kwargs = get_run_kwargs(task=node, args=args, kwargs=kwargs, cli_mode=True)
50
+ task_str_kwargs = get_task_str_kwargs(
51
+ task=node, str_args=str_args, str_kwargs=str_kwargs, cli_mode=True
52
+ )
51
53
  try:
52
- result = self._run_task(node, args, run_kwargs)
54
+ result = self._run_task(node, str_args, task_str_kwargs)
53
55
  if result is not None:
54
56
  print(result)
55
57
  return result
56
58
  finally:
57
- run_command = self._get_run_command(node_path, run_kwargs)
59
+ run_command = self._get_run_command(node_path, task_str_kwargs)
58
60
  self._print_run_command(run_command)
59
61
 
60
62
  def _print_run_command(self, run_command: str):
@@ -64,11 +66,14 @@ class Cli(Group):
64
66
  file=sys.stderr,
65
67
  )
66
68
 
67
- def _get_run_command(self, node_path: list[str], run_kwargs: dict[str, str]) -> str:
69
+ def _get_run_command(
70
+ self, node_path: list[str], task_str_kwargs: dict[str, str]
71
+ ) -> str:
68
72
  parts = [self.name] + node_path
69
- if len(run_kwargs) > 0:
73
+ if len(task_str_kwargs) > 0:
70
74
  parts += [
71
- self._get_run_command_param(key, val) for key, val in run_kwargs.items()
75
+ self._get_run_command_param(key, val)
76
+ for key, val in task_str_kwargs.items()
72
77
  ]
73
78
  return " ".join(parts)
74
79
 
@@ -81,13 +86,9 @@ class Cli(Group):
81
86
  self, task: AnyTask, args: list[str], run_kwargs: dict[str, str]
82
87
  ) -> tuple[Any]:
83
88
  shared_ctx = SharedContext(args=args)
84
- for task_input in task.inputs:
85
- if task_input.name in run_kwargs:
86
- task_input.update_shared_context(
87
- shared_ctx, run_kwargs[task_input.name]
88
- )
89
- continue
90
- return task.run(Session(shared_ctx=shared_ctx, root_group=self))
89
+ return task.run(
90
+ Session(shared_ctx=shared_ctx, root_group=self), str_kwargs=run_kwargs
91
+ )
91
92
 
92
93
  def _show_task_info(self, task: AnyTask):
93
94
  description = task.description
@@ -150,11 +151,11 @@ class Cli(Group):
150
151
  kwargs[key] = args[i + 1]
151
152
  i += 1 # Skip the next argument as it's a value
152
153
  else:
153
- kwargs[key] = True
154
+ kwargs[key] = "true"
154
155
  elif arg.startswith("-"):
155
156
  # Handle short flags like -t or -n
156
157
  key = arg[1:]
157
- kwargs[key] = True
158
+ kwargs[key] = "true"
158
159
  else:
159
160
  # Anything else is considered a positional argument
160
161
  residual_args.append(arg)
zrb/runner/common_util.py CHANGED
@@ -1,31 +1,36 @@
1
- from typing import Any
2
-
3
1
  from zrb.context.shared_context import SharedContext
4
2
  from zrb.task.any_task import AnyTask
5
3
 
6
4
 
7
- def get_run_kwargs(
8
- task: AnyTask, args: list[str], kwargs: dict[str, str], cli_mode: bool
5
+ def get_task_str_kwargs(
6
+ task: AnyTask, str_args: list[str], str_kwargs: dict[str, str], cli_mode: bool
9
7
  ) -> dict[str, str]:
10
8
  arg_index = 0
11
- str_kwargs = {key: f"{val}" for key, val in kwargs.items()}
12
- run_kwargs = {**str_kwargs}
13
- shared_ctx = SharedContext(args=args)
9
+ dummmy_shared_ctx = SharedContext()
10
+ task_str_kwargs = {}
14
11
  for task_input in task.inputs:
12
+ task_name = task_input.name
15
13
  if task_input.name in str_kwargs:
16
- # Update shared context for next input default value
17
- task_input.update_shared_context(shared_ctx, str_kwargs[task_input.name])
18
- elif arg_index < len(args) and task_input.allow_positional_parsing:
19
- run_kwargs[task_input.name] = args[arg_index]
20
- # Update shared context for next input default value
21
- task_input.update_shared_context(shared_ctx, run_kwargs[task_input.name])
14
+ task_str_kwargs[task_input.name] = str_kwargs[task_name]
15
+ # Update dummy shared context for next input default value
16
+ task_input.update_shared_context(
17
+ dummmy_shared_ctx, str_value=str_kwargs[task_name]
18
+ )
19
+ elif arg_index < len(str_args) and task_input.allow_positional_parsing:
20
+ task_str_kwargs[task_name] = str_args[arg_index]
21
+ # Update dummy shared context for next input default value
22
+ task_input.update_shared_context(
23
+ dummmy_shared_ctx, str_value=task_str_kwargs[task_name]
24
+ )
22
25
  arg_index += 1
23
26
  else:
24
27
  if cli_mode and task_input.always_prompt:
25
- str_value = task_input.prompt_cli_str(shared_ctx)
28
+ str_value = task_input.prompt_cli_str(dummmy_shared_ctx)
26
29
  else:
27
- str_value = task_input.get_default_str(shared_ctx)
28
- run_kwargs[task_input.name] = str_value
29
- # Update shared context for next input default value
30
- task_input.update_shared_context(shared_ctx, run_kwargs[task_input.name])
31
- return run_kwargs
30
+ str_value = task_input.get_default_str(dummmy_shared_ctx)
31
+ task_str_kwargs[task_name] = str_value
32
+ # Update dummy shared context for next input default value
33
+ task_input.update_shared_context(
34
+ dummmy_shared_ctx, str_value=task_str_kwargs[task_name]
35
+ )
36
+ return task_str_kwargs
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
3
3
 
4
4
  from zrb.config.web_auth_config import WebAuthConfig
5
5
  from zrb.group.any_group import AnyGroup
6
- from zrb.runner.common_util import get_run_kwargs
6
+ from zrb.runner.common_util import get_task_str_kwargs
7
7
  from zrb.runner.web_util.user import get_user_from_request
8
8
  from zrb.task.any_task import AnyTask
9
9
  from zrb.util.group import NodeNotFoundError, extract_node_from_args
@@ -39,9 +39,9 @@ def serve_task_input_api(
39
39
  if isinstance(task, AnyTask):
40
40
  if not user.can_access_task(task):
41
41
  return JSONResponse(content={"detail": "Forbidden"}, status_code=403)
42
- query_dict = json.loads(query)
43
- run_kwargs = get_run_kwargs(
44
- task=task, args=[], kwargs=query_dict, cli_mode=False
42
+ str_kwargs = json.loads(query)
43
+ task_str_kwargs = get_task_str_kwargs(
44
+ task=task, str_args=[], str_kwargs=str_kwargs, cli_mode=False
45
45
  )
46
- return run_kwargs
46
+ return task_str_kwargs
47
47
  return JSONResponse(content={"detail": "Not found"}, status_code=404)
@@ -19,7 +19,7 @@ def get_user_by_credentials(
19
19
 
20
20
  async def get_user_from_request(
21
21
  web_auth_config: WebAuthConfig, request: "Request"
22
- ) -> User | None:
22
+ ) -> User:
23
23
  from fastapi.security import OAuth2PasswordBearer
24
24
 
25
25
  if not web_auth_config.enable_auth:
@@ -45,7 +45,11 @@ def _get_user_from_cookie(
45
45
  return None
46
46
 
47
47
 
48
- def _get_user_from_token(web_auth_config: WebAuthConfig, token: str) -> User | None:
48
+ def _get_user_from_token(
49
+ web_auth_config: WebAuthConfig, token: str | None
50
+ ) -> User | None:
51
+ if token is None:
52
+ return None
49
53
  try:
50
54
  from jose import jwt
51
55
 
@@ -54,7 +58,7 @@ def _get_user_from_token(web_auth_config: WebAuthConfig, token: str) -> User | N
54
58
  web_auth_config.secret_key,
55
59
  options={"require_sub": True, "require_exp": True},
56
60
  )
57
- username: str = payload.get("sub")
61
+ username: str | None = payload.get("sub")
58
62
  if username is None:
59
63
  return None
60
64
  user = web_auth_config.find_user_by_username(username)
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations # Enables forward references
2
2
 
3
+ import asyncio
3
4
  from abc import ABC, abstractmethod
4
5
  from typing import TYPE_CHECKING, Any, Coroutine, TypeVar
5
6
 
@@ -62,12 +63,13 @@ class AnySession(ABC):
62
63
 
63
64
  @property
64
65
  @abstractmethod
65
- def parent(self) -> TAnySession | None:
66
+ def parent(self) -> "AnySession | None":
66
67
  """Parent session"""
67
68
  pass
68
69
 
70
+ @property
69
71
  @abstractmethod
70
- def task_path(self) -> str:
72
+ def task_path(self) -> list[str]:
71
73
  """Main task's path"""
72
74
  pass
73
75
 
@@ -105,7 +107,9 @@ class AnySession(ABC):
105
107
  pass
106
108
 
107
109
  @abstractmethod
108
- def defer_monitoring(self, task: "AnyTask", coro: Coroutine):
110
+ def defer_monitoring(
111
+ self, task: "AnyTask", coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
112
+ ):
109
113
  """Defers the execution of a task's monitoring coroutine for later processing.
110
114
 
111
115
  Args:
@@ -115,7 +119,9 @@ class AnySession(ABC):
115
119
  pass
116
120
 
117
121
  @abstractmethod
118
- def defer_action(self, task: "AnyTask", coro: Coroutine):
122
+ def defer_action(
123
+ self, task: "AnyTask", coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
124
+ ):
119
125
  """Defers the execution of a task's coroutine for later processing.
120
126
 
121
127
  Args:
@@ -125,7 +131,7 @@ class AnySession(ABC):
125
131
  pass
126
132
 
127
133
  @abstractmethod
128
- def defer_coro(self, coro: Coroutine):
134
+ def defer_coro(self, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]):
129
135
  """Defers the execution of a coroutine for later processing.
130
136
 
131
137
  Args:
@@ -185,7 +191,7 @@ class AnySession(ABC):
185
191
  pass
186
192
 
187
193
  @abstractmethod
188
- def is_allowed_to_run(self, task: "AnyTask"):
194
+ def is_allowed_to_run(self, task: "AnyTask") -> bool:
189
195
  """Determines if the specified task is allowed to run based on its current state.
190
196
 
191
197
  Args:
zrb/session/session.py CHANGED
@@ -1,10 +1,12 @@
1
+ from __future__ import annotations
2
+
1
3
  import asyncio
2
4
  from typing import TYPE_CHECKING, Any, Coroutine
3
5
 
4
6
  from zrb.context.any_shared_context import AnySharedContext
5
7
  from zrb.context.context import AnyContext, Context
6
8
  from zrb.group.any_group import AnyGroup
7
- from zrb.session.any_session import AnySession
9
+ from zrb.session.any_session import AnySession, TAnySession
8
10
  from zrb.session_state_logger.any_session_state_logger import AnySessionStateLogger
9
11
  from zrb.session_state_logger.session_state_logger_factory import session_state_logger
10
12
  from zrb.task.any_task import AnyTask
@@ -48,10 +50,10 @@ class Session(AnySession):
48
50
  self._context: dict[AnyTask, Context] = {}
49
51
  self._shared_ctx = shared_ctx
50
52
  self._shared_ctx.set_session(self)
51
- self._parent = parent
52
- self._action_coros: dict[AnyTask, asyncio.Task] = {}
53
- self._monitoring_coros: dict[AnyTask, asyncio.Task] = {}
54
- self._coros: list[asyncio.Task] = []
53
+ self._parent: AnySession | None = parent
54
+ self._action_coros: dict[AnyTask, asyncio.Task[Any]] = {}
55
+ self._monitoring_coros: dict[AnyTask, asyncio.Task[Any]] = {}
56
+ self._coros: list[asyncio.Task[Any]] = []
55
57
  self._colors = [
56
58
  GREEN,
57
59
  YELLOW,
@@ -114,11 +116,13 @@ class Session(AnySession):
114
116
  return self._parent
115
117
 
116
118
  @property
117
- def task_path(self) -> str:
119
+ def task_path(self) -> list[str]:
118
120
  return self._main_task_path
119
121
 
120
122
  @property
121
123
  def final_result(self) -> Any:
124
+ if self._main_task is None:
125
+ return None
122
126
  xcom: Xcom = self.shared_ctx.xcom[self._main_task.name]
123
127
  try:
124
128
  return xcom.peek()
@@ -134,7 +138,11 @@ class Session(AnySession):
134
138
  def set_main_task(self, main_task: AnyTask):
135
139
  self.register_task(main_task)
136
140
  self._main_task = main_task
137
- main_task_path = get_node_path(self._root_group, main_task)
141
+ main_task_path = (
142
+ None
143
+ if self._root_group is None
144
+ else get_node_path(self._root_group, main_task)
145
+ )
138
146
  self._main_task_path = [] if main_task_path is None else main_task_path
139
147
 
140
148
  def as_state_log(self) -> "SessionStateLog":
@@ -171,7 +179,7 @@ class Session(AnySession):
171
179
  return SessionStateLog(
172
180
  name=self.name,
173
181
  start_time=log_start_time,
174
- main_task_name=self._main_task.name,
182
+ main_task_name="" if self._main_task is None else self._main_task.name,
175
183
  path=self.task_path,
176
184
  final_result=(
177
185
  remove_style(f"{self.final_result}")
@@ -188,16 +196,29 @@ class Session(AnySession):
188
196
  self._register_single_task(task)
189
197
  return self._context[task]
190
198
 
191
- def defer_monitoring(self, task: AnyTask, coro: Coroutine):
199
+ def defer_monitoring(
200
+ self, task: AnyTask, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
201
+ ):
192
202
  self._register_single_task(task)
193
- self._monitoring_coros[task] = coro
203
+ if isinstance(coro, asyncio.Task):
204
+ self._monitoring_coros[task] = coro
205
+ else:
206
+ self._monitoring_coros[task] = asyncio.create_task(coro)
194
207
 
195
- def defer_action(self, task: AnyTask, coro: Coroutine):
208
+ def defer_action(
209
+ self, task: AnyTask, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
210
+ ):
196
211
  self._register_single_task(task)
197
- self._action_coros[task] = coro
212
+ if isinstance(coro, asyncio.Task):
213
+ self._action_coros[task] = coro
214
+ else:
215
+ self._action_coros[task] = asyncio.create_task(coro)
198
216
 
199
- def defer_coro(self, coro: Coroutine):
200
- self._coros.append(coro)
217
+ def defer_coro(self, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]):
218
+ if isinstance(coro, asyncio.Task):
219
+ self._coros.append(coro)
220
+ else:
221
+ self._coros.append(asyncio.create_task(coro))
201
222
  self._coros = [
202
223
  existing_coro for existing_coro in self._coros if not existing_coro.done()
203
224
  ]
@@ -246,15 +267,15 @@ class Session(AnySession):
246
267
 
247
268
  def get_next_tasks(self, task: AnyTask) -> list[AnyTask]:
248
269
  self._register_single_task(task)
249
- return self._downstreams.get(task)
270
+ return self._downstreams.get(task, [])
250
271
 
251
272
  def get_task_status(self, task: AnyTask) -> TaskStatus:
252
273
  self._register_single_task(task)
253
274
  return self._task_status[task]
254
275
 
255
276
  def _register_single_task(self, task: AnyTask):
256
- if task.name not in self._shared_ctx._xcom:
257
- self._shared_ctx._xcom[task.name] = Xcom([])
277
+ if task.name not in self._shared_ctx.xcom:
278
+ self._shared_ctx.xcom[task.name] = Xcom([])
258
279
  if task not in self._context:
259
280
  self._context[task] = Context(
260
281
  shared_ctx=self._shared_ctx,
@@ -278,7 +299,7 @@ class Session(AnySession):
278
299
  self._color_index = 0
279
300
  return chosen
280
301
 
281
- def _get_icon(self, task: AnyTask) -> int:
302
+ def _get_icon(self, task: AnyTask) -> str:
282
303
  if task.icon is not None:
283
304
  return task.icon
284
305
  chosen = self._icons[self._icon_index]
zrb/task/any_task.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations # Enables forward references
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING, Any
4
+ from typing import TYPE_CHECKING, Any, Callable
5
5
 
6
6
  from zrb.env.any_env import AnyEnv
7
7
  from zrb.input.any_input import AnyInput
@@ -36,6 +36,14 @@ class AnyTask(ABC):
36
36
  the actual implementation for these abstract members.
37
37
  """
38
38
 
39
+ @abstractmethod
40
+ def __rshift__(self, other: "AnyTask | list[AnyTask]") -> "AnyTask | list[AnyTask]":
41
+ pass
42
+
43
+ @abstractmethod
44
+ def __lshift__(self, other: "AnyTask | list[AnyTask]") -> "AnyTask":
45
+ pass
46
+
39
47
  @property
40
48
  @abstractmethod
41
49
  def name(self) -> str:
@@ -148,13 +156,17 @@ class AnyTask(ABC):
148
156
 
149
157
  @abstractmethod
150
158
  def run(
151
- self, session: "AnySession | None" = None, str_kwargs: dict[str, str] = {}
159
+ self,
160
+ session: "AnySession | None" = None,
161
+ str_kwargs: dict[str, str] | None = None,
162
+ kwargs: dict[str, Any] | None = None,
152
163
  ) -> Any:
153
164
  """Runs the task synchronously.
154
165
 
155
166
  Args:
156
167
  session (AnySession): The shared session.
157
168
  str_kwargs(dict[str, str]): The input string values.
169
+ kwargs(dict[str, Any]): The input values.
158
170
 
159
171
  Returns:
160
172
  Any: The result of the task execution.
@@ -163,13 +175,17 @@ class AnyTask(ABC):
163
175
 
164
176
  @abstractmethod
165
177
  async def async_run(
166
- self, session: "AnySession | None" = None, str_kwargs: dict[str, str] = {}
178
+ self,
179
+ session: "AnySession | None" = None,
180
+ str_kwargs: dict[str, str] | None = None,
181
+ kwargs: dict[str, Any] | None = None,
167
182
  ) -> Any:
168
183
  """Runs the task asynchronously.
169
184
 
170
185
  Args:
171
186
  session (AnySession): The shared session.
172
187
  str_kwargs(dict[str, str]): The input string values.
188
+ kwargs(dict[str, Any]): The input values.
173
189
 
174
190
  Returns:
175
191
  Any: The result of the task execution.
@@ -203,3 +219,8 @@ class AnyTask(ABC):
203
219
  session (AnySession): The shared session.
204
220
  """
205
221
  pass
222
+
223
+ @abstractmethod
224
+ def to_function(self) -> Callable[..., Any]:
225
+ """Turn a task into a function"""
226
+ pass
zrb/task/base/context.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import TYPE_CHECKING
2
+ from typing import TYPE_CHECKING, Any
3
3
 
4
4
  from zrb.context.any_context import AnyContext
5
5
  from zrb.context.any_shared_context import AnySharedContext
@@ -26,25 +26,33 @@ def build_task_context(task: AnyTask, session: AnySession) -> AnyContext:
26
26
 
27
27
 
28
28
  def fill_shared_context_inputs(
29
- task: AnyTask, shared_context: AnySharedContext, str_kwargs: dict[str, str] = {}
29
+ shared_ctx: AnySharedContext,
30
+ task: AnyTask,
31
+ str_kwargs: dict[str, str] | None = None,
32
+ kwargs: dict[str, Any] | None = None,
30
33
  ):
31
34
  """
32
- Populates the shared context with input values provided via kwargs.
35
+ Populates the shared context with input values provided via str_kwargs.
33
36
  """
37
+ str_kwarg_dict = str_kwargs if str_kwargs is not None else {}
38
+ kwarg_dict = kwargs if kwargs is not None else {}
34
39
  for task_input in task.inputs:
35
- if task_input.name not in shared_context.input:
36
- str_value = str_kwargs.get(task_input.name, None)
37
- task_input.update_shared_context(shared_context, str_value)
40
+ if task_input.name not in shared_ctx.input:
41
+ task_input.update_shared_context(
42
+ shared_ctx,
43
+ value=kwarg_dict.get(task_input.name, None),
44
+ str_value=str_kwarg_dict.get(task_input.name, None),
45
+ )
38
46
 
39
47
 
40
- def fill_shared_context_envs(shared_context: AnySharedContext):
48
+ def fill_shared_context_envs(shared_ctx: AnySharedContext):
41
49
  """
42
50
  Injects OS environment variables into the shared context if they don't already exist.
43
51
  """
44
52
  os_env_map = {
45
- key: val for key, val in os.environ.items() if key not in shared_context.env
53
+ key: val for key, val in os.environ.items() if key not in shared_ctx.env
46
54
  }
47
- shared_context.env.update(os_env_map)
55
+ shared_ctx.env.update(os_env_map)
48
56
 
49
57
 
50
58
  def combine_inputs(