zrb 1.15.3__py3-none-any.whl → 1.21.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of zrb might be problematic. Click here for more details.
- zrb/__init__.py +2 -6
- zrb/attr/type.py +10 -7
- zrb/builtin/__init__.py +2 -0
- zrb/builtin/git.py +12 -1
- zrb/builtin/group.py +31 -15
- zrb/builtin/llm/attachment.py +40 -0
- zrb/builtin/llm/chat_completion.py +274 -0
- zrb/builtin/llm/chat_session.py +126 -167
- zrb/builtin/llm/chat_session_cmd.py +288 -0
- zrb/builtin/llm/chat_trigger.py +79 -0
- zrb/builtin/llm/history.py +4 -4
- zrb/builtin/llm/llm_ask.py +217 -135
- zrb/builtin/llm/tool/api.py +74 -70
- zrb/builtin/llm/tool/cli.py +35 -21
- zrb/builtin/llm/tool/code.py +55 -73
- zrb/builtin/llm/tool/file.py +278 -344
- zrb/builtin/llm/tool/note.py +84 -0
- zrb/builtin/llm/tool/rag.py +27 -34
- zrb/builtin/llm/tool/sub_agent.py +54 -41
- zrb/builtin/llm/tool/web.py +74 -98
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
- zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
- zrb/builtin/searxng/config/settings.yml +5671 -0
- zrb/builtin/searxng/start.py +21 -0
- zrb/builtin/shell/autocomplete/bash.py +4 -3
- zrb/builtin/shell/autocomplete/zsh.py +4 -3
- zrb/config/config.py +202 -27
- zrb/config/default_prompt/file_extractor_system_prompt.md +109 -9
- zrb/config/default_prompt/interactive_system_prompt.md +24 -30
- zrb/config/default_prompt/persona.md +1 -1
- zrb/config/default_prompt/repo_extractor_system_prompt.md +31 -31
- zrb/config/default_prompt/repo_summarizer_system_prompt.md +27 -8
- zrb/config/default_prompt/summarization_prompt.md +57 -16
- zrb/config/default_prompt/system_prompt.md +36 -30
- zrb/config/llm_config.py +119 -23
- zrb/config/llm_context/config.py +127 -90
- zrb/config/llm_context/config_parser.py +1 -7
- zrb/config/llm_context/workflow.py +81 -0
- zrb/config/llm_rate_limitter.py +100 -47
- zrb/context/any_shared_context.py +7 -1
- zrb/context/context.py +8 -2
- zrb/context/shared_context.py +3 -7
- zrb/group/any_group.py +3 -3
- zrb/group/group.py +3 -3
- zrb/input/any_input.py +5 -1
- zrb/input/base_input.py +18 -6
- zrb/input/option_input.py +13 -1
- zrb/input/text_input.py +7 -24
- zrb/runner/cli.py +21 -20
- zrb/runner/common_util.py +24 -19
- zrb/runner/web_route/task_input_api_route.py +5 -5
- zrb/runner/web_util/user.py +7 -3
- zrb/session/any_session.py +12 -6
- zrb/session/session.py +39 -18
- zrb/task/any_task.py +24 -3
- zrb/task/base/context.py +17 -9
- zrb/task/base/execution.py +15 -8
- zrb/task/base/lifecycle.py +8 -4
- zrb/task/base/monitoring.py +12 -7
- zrb/task/base_task.py +69 -5
- zrb/task/base_trigger.py +12 -5
- zrb/task/llm/agent.py +128 -167
- zrb/task/llm/agent_runner.py +152 -0
- zrb/task/llm/config.py +39 -20
- zrb/task/llm/conversation_history.py +110 -29
- zrb/task/llm/conversation_history_model.py +4 -179
- zrb/task/llm/default_workflow/coding/workflow.md +41 -0
- zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
- zrb/task/llm/default_workflow/git/workflow.md +118 -0
- zrb/task/llm/default_workflow/golang/workflow.md +128 -0
- zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
- zrb/task/llm/default_workflow/java/workflow.md +146 -0
- zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
- zrb/task/llm/default_workflow/python/workflow.md +160 -0
- zrb/task/llm/default_workflow/researching/workflow.md +153 -0
- zrb/task/llm/default_workflow/rust/workflow.md +162 -0
- zrb/task/llm/default_workflow/shell/workflow.md +299 -0
- zrb/task/llm/file_replacement.py +206 -0
- zrb/task/llm/file_tool_model.py +57 -0
- zrb/task/llm/history_processor.py +206 -0
- zrb/task/llm/history_summarization.py +2 -193
- zrb/task/llm/print_node.py +184 -64
- zrb/task/llm/prompt.py +175 -179
- zrb/task/llm/subagent_conversation_history.py +41 -0
- zrb/task/llm/tool_wrapper.py +226 -85
- zrb/task/llm/workflow.py +76 -0
- zrb/task/llm_task.py +109 -71
- zrb/task/make_task.py +2 -3
- zrb/task/rsync_task.py +25 -10
- zrb/task/scheduler.py +4 -4
- zrb/util/attr.py +54 -39
- zrb/util/cli/markdown.py +12 -0
- zrb/util/cli/text.py +30 -0
- zrb/util/file.py +12 -3
- zrb/util/git.py +2 -2
- zrb/util/{llm/prompt.py → markdown.py} +2 -3
- zrb/util/string/conversion.py +1 -1
- zrb/util/truncate.py +23 -0
- zrb/util/yaml.py +204 -0
- zrb/xcom/xcom.py +10 -0
- {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/METADATA +38 -18
- {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/RECORD +105 -79
- {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/WHEEL +1 -1
- zrb/task/llm/default_workflow/coding.md +0 -24
- zrb/task/llm/default_workflow/copywriting.md +0 -17
- zrb/task/llm/default_workflow/researching.md +0 -18
- {zrb-1.15.3.dist-info → zrb-1.21.29.dist-info}/entry_points.txt +0 -0
zrb/group/group.py
CHANGED
|
@@ -33,15 +33,15 @@ class Group(AnyGroup):
|
|
|
33
33
|
def subgroups(self) -> dict[str, AnyGroup]:
|
|
34
34
|
names = list(self._groups.keys())
|
|
35
35
|
names.sort()
|
|
36
|
-
return {name: self._groups
|
|
36
|
+
return {name: self._groups[name] for name in names}
|
|
37
37
|
|
|
38
38
|
@property
|
|
39
39
|
def subtasks(self) -> dict[str, AnyTask]:
|
|
40
40
|
alias = list(self._tasks.keys())
|
|
41
41
|
alias.sort()
|
|
42
|
-
return {name: self._tasks
|
|
42
|
+
return {name: self._tasks[name] for name in alias}
|
|
43
43
|
|
|
44
|
-
def add_group(self, group: AnyGroup
|
|
44
|
+
def add_group(self, group: AnyGroup, alias: str | None = None) -> AnyGroup:
|
|
45
45
|
real_group = Group(group) if isinstance(group, str) else group
|
|
46
46
|
alias = alias if alias is not None else real_group.name
|
|
47
47
|
self._groups[alias] = real_group
|
zrb/input/any_input.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
from zrb.context.any_shared_context import AnySharedContext
|
|
4
5
|
|
|
@@ -35,7 +36,10 @@ class AnyInput(ABC):
|
|
|
35
36
|
|
|
36
37
|
@abstractmethod
|
|
37
38
|
def update_shared_context(
|
|
38
|
-
self,
|
|
39
|
+
self,
|
|
40
|
+
shared_ctx: AnySharedContext,
|
|
41
|
+
str_value: str | None = None,
|
|
42
|
+
value: Any = None,
|
|
39
43
|
):
|
|
40
44
|
pass
|
|
41
45
|
|
zrb/input/base_input.py
CHANGED
|
@@ -58,11 +58,15 @@ class BaseInput(AnyInput):
|
|
|
58
58
|
return f'<input name="{name}" placeholder="{description}" value="{default}" />'
|
|
59
59
|
|
|
60
60
|
def update_shared_context(
|
|
61
|
-
self,
|
|
61
|
+
self,
|
|
62
|
+
shared_ctx: AnySharedContext,
|
|
63
|
+
str_value: str | None = None,
|
|
64
|
+
value: Any = None,
|
|
62
65
|
):
|
|
63
|
-
if
|
|
64
|
-
str_value
|
|
65
|
-
|
|
66
|
+
if value is None:
|
|
67
|
+
if str_value is None:
|
|
68
|
+
str_value = self.get_default_str(shared_ctx)
|
|
69
|
+
value = self._parse_str_value(str_value)
|
|
66
70
|
if self.name in shared_ctx.input:
|
|
67
71
|
raise ValueError(f"Input already defined in the context: {self.name}")
|
|
68
72
|
shared_ctx.input[self.name] = value
|
|
@@ -91,12 +95,20 @@ class BaseInput(AnyInput):
|
|
|
91
95
|
default_str = self.get_default_str(shared_ctx)
|
|
92
96
|
if default_str != "":
|
|
93
97
|
prompt_message = f"{prompt_message} [{default_str}]"
|
|
94
|
-
|
|
95
|
-
value = input()
|
|
98
|
+
value = self._read_line(shared_ctx, prompt_message)
|
|
96
99
|
if value.strip() == "":
|
|
97
100
|
value = default_str
|
|
98
101
|
return value
|
|
99
102
|
|
|
103
|
+
def _read_line(self, shared_ctx: AnySharedContext, prompt_message: str) -> str:
|
|
104
|
+
if not shared_ctx.is_tty:
|
|
105
|
+
print(f"{prompt_message}: ", end="")
|
|
106
|
+
return input()
|
|
107
|
+
from prompt_toolkit import PromptSession
|
|
108
|
+
|
|
109
|
+
reader = PromptSession()
|
|
110
|
+
return reader.prompt(f"{prompt_message}: ")
|
|
111
|
+
|
|
100
112
|
def get_default_str(self, shared_ctx: AnySharedContext) -> str:
|
|
101
113
|
"""Get default value as str"""
|
|
102
114
|
default_value = get_attr(
|
zrb/input/option_input.py
CHANGED
|
@@ -47,9 +47,21 @@ class OptionInput(BaseInput):
|
|
|
47
47
|
option_str = ", ".join(options)
|
|
48
48
|
if default_value != "":
|
|
49
49
|
prompt_message = f"{prompt_message} ({option_str}) [{default_value}]"
|
|
50
|
-
value =
|
|
50
|
+
value = self._get_value_from_user_input(shared_ctx, prompt_message, options)
|
|
51
51
|
if value.strip() != "" and value.strip() not in options:
|
|
52
52
|
value = self._prompt_cli_str(shared_ctx)
|
|
53
53
|
if value.strip() == "":
|
|
54
54
|
value = default_value
|
|
55
55
|
return value
|
|
56
|
+
|
|
57
|
+
def _get_value_from_user_input(
|
|
58
|
+
self, shared_ctx: AnySharedContext, prompt_message: str, options: list[str]
|
|
59
|
+
) -> str:
|
|
60
|
+
from prompt_toolkit import PromptSession
|
|
61
|
+
from prompt_toolkit.completion import WordCompleter
|
|
62
|
+
|
|
63
|
+
if shared_ctx.is_tty:
|
|
64
|
+
reader = PromptSession()
|
|
65
|
+
option_completer = WordCompleter(options, ignore_case=True)
|
|
66
|
+
return reader.prompt(f"{prompt_message}: ", completer=option_completer)
|
|
67
|
+
return input(f"{prompt_message}: ")
|
zrb/input/text_input.py
CHANGED
|
@@ -1,12 +1,9 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import subprocess
|
|
3
|
-
import tempfile
|
|
4
1
|
from collections.abc import Callable
|
|
5
2
|
|
|
6
3
|
from zrb.config.config import CFG
|
|
7
4
|
from zrb.context.any_shared_context import AnySharedContext
|
|
8
5
|
from zrb.input.base_input import BaseInput
|
|
9
|
-
from zrb.util.
|
|
6
|
+
from zrb.util.cli.text import edit_text
|
|
10
7
|
|
|
11
8
|
|
|
12
9
|
class TextInput(BaseInput):
|
|
@@ -85,24 +82,10 @@ class TextInput(BaseInput):
|
|
|
85
82
|
comment_prompt_message = (
|
|
86
83
|
f"{self.comment_start}{prompt_message}{self.comment_end}"
|
|
87
84
|
)
|
|
88
|
-
comment_prompt_message_eol = f"{comment_prompt_message}\n"
|
|
89
85
|
default_value = self.get_default_str(shared_ctx)
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
if default_value:
|
|
97
|
-
temp_file.write(default_value.encode())
|
|
98
|
-
temp_file.flush()
|
|
99
|
-
subprocess.call([self.editor_cmd, temp_file_name])
|
|
100
|
-
# Read the edited content
|
|
101
|
-
edited_content = read_file(temp_file_name)
|
|
102
|
-
parts = [
|
|
103
|
-
text.strip() for text in edited_content.split(comment_prompt_message, 1)
|
|
104
|
-
]
|
|
105
|
-
edited_content = "\n".join(parts).lstrip()
|
|
106
|
-
os.remove(temp_file_name)
|
|
107
|
-
print(f"{prompt_message}: {edited_content}")
|
|
108
|
-
return edited_content
|
|
86
|
+
return edit_text(
|
|
87
|
+
prompt_message=comment_prompt_message,
|
|
88
|
+
value=default_value,
|
|
89
|
+
editor=self.editor_cmd,
|
|
90
|
+
extension=self._extension,
|
|
91
|
+
)
|
zrb/runner/cli.py
CHANGED
|
@@ -7,7 +7,7 @@ from zrb.context.any_context import AnyContext
|
|
|
7
7
|
from zrb.context.shared_context import SharedContext
|
|
8
8
|
from zrb.group.any_group import AnyGroup
|
|
9
9
|
from zrb.group.group import Group
|
|
10
|
-
from zrb.runner.common_util import
|
|
10
|
+
from zrb.runner.common_util import get_task_str_kwargs
|
|
11
11
|
from zrb.session.session import Session
|
|
12
12
|
from zrb.session_state_logger.session_state_logger_factory import session_state_logger
|
|
13
13
|
from zrb.task.any_task import AnyTask
|
|
@@ -38,23 +38,25 @@ class Cli(Group):
|
|
|
38
38
|
def banner(self) -> str:
|
|
39
39
|
return CFG.BANNER
|
|
40
40
|
|
|
41
|
-
def run(self,
|
|
42
|
-
|
|
43
|
-
node, node_path,
|
|
41
|
+
def run(self, str_args: list[str] = []):
|
|
42
|
+
str_kwargs, str_args = self._extract_kwargs_from_args(str_args)
|
|
43
|
+
node, node_path, str_args = extract_node_from_args(self, str_args)
|
|
44
44
|
if isinstance(node, AnyGroup):
|
|
45
45
|
self._show_group_info(node)
|
|
46
46
|
return
|
|
47
|
-
if "h" in
|
|
47
|
+
if "h" in str_kwargs or "help" in str_kwargs:
|
|
48
48
|
self._show_task_info(node)
|
|
49
49
|
return
|
|
50
|
-
|
|
50
|
+
task_str_kwargs = get_task_str_kwargs(
|
|
51
|
+
task=node, str_args=str_args, str_kwargs=str_kwargs, cli_mode=True
|
|
52
|
+
)
|
|
51
53
|
try:
|
|
52
|
-
result = self._run_task(node,
|
|
54
|
+
result = self._run_task(node, str_args, task_str_kwargs)
|
|
53
55
|
if result is not None:
|
|
54
56
|
print(result)
|
|
55
57
|
return result
|
|
56
58
|
finally:
|
|
57
|
-
run_command = self._get_run_command(node_path,
|
|
59
|
+
run_command = self._get_run_command(node_path, task_str_kwargs)
|
|
58
60
|
self._print_run_command(run_command)
|
|
59
61
|
|
|
60
62
|
def _print_run_command(self, run_command: str):
|
|
@@ -64,11 +66,14 @@ class Cli(Group):
|
|
|
64
66
|
file=sys.stderr,
|
|
65
67
|
)
|
|
66
68
|
|
|
67
|
-
def _get_run_command(
|
|
69
|
+
def _get_run_command(
|
|
70
|
+
self, node_path: list[str], task_str_kwargs: dict[str, str]
|
|
71
|
+
) -> str:
|
|
68
72
|
parts = [self.name] + node_path
|
|
69
|
-
if len(
|
|
73
|
+
if len(task_str_kwargs) > 0:
|
|
70
74
|
parts += [
|
|
71
|
-
self._get_run_command_param(key, val)
|
|
75
|
+
self._get_run_command_param(key, val)
|
|
76
|
+
for key, val in task_str_kwargs.items()
|
|
72
77
|
]
|
|
73
78
|
return " ".join(parts)
|
|
74
79
|
|
|
@@ -81,13 +86,9 @@ class Cli(Group):
|
|
|
81
86
|
self, task: AnyTask, args: list[str], run_kwargs: dict[str, str]
|
|
82
87
|
) -> tuple[Any]:
|
|
83
88
|
shared_ctx = SharedContext(args=args)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
shared_ctx, run_kwargs[task_input.name]
|
|
88
|
-
)
|
|
89
|
-
continue
|
|
90
|
-
return task.run(Session(shared_ctx=shared_ctx, root_group=self))
|
|
89
|
+
return task.run(
|
|
90
|
+
Session(shared_ctx=shared_ctx, root_group=self), str_kwargs=run_kwargs
|
|
91
|
+
)
|
|
91
92
|
|
|
92
93
|
def _show_task_info(self, task: AnyTask):
|
|
93
94
|
description = task.description
|
|
@@ -150,11 +151,11 @@ class Cli(Group):
|
|
|
150
151
|
kwargs[key] = args[i + 1]
|
|
151
152
|
i += 1 # Skip the next argument as it's a value
|
|
152
153
|
else:
|
|
153
|
-
kwargs[key] =
|
|
154
|
+
kwargs[key] = "true"
|
|
154
155
|
elif arg.startswith("-"):
|
|
155
156
|
# Handle short flags like -t or -n
|
|
156
157
|
key = arg[1:]
|
|
157
|
-
kwargs[key] =
|
|
158
|
+
kwargs[key] = "true"
|
|
158
159
|
else:
|
|
159
160
|
# Anything else is considered a positional argument
|
|
160
161
|
residual_args.append(arg)
|
zrb/runner/common_util.py
CHANGED
|
@@ -1,31 +1,36 @@
|
|
|
1
|
-
from typing import Any
|
|
2
|
-
|
|
3
1
|
from zrb.context.shared_context import SharedContext
|
|
4
2
|
from zrb.task.any_task import AnyTask
|
|
5
3
|
|
|
6
4
|
|
|
7
|
-
def
|
|
8
|
-
task: AnyTask,
|
|
5
|
+
def get_task_str_kwargs(
|
|
6
|
+
task: AnyTask, str_args: list[str], str_kwargs: dict[str, str], cli_mode: bool
|
|
9
7
|
) -> dict[str, str]:
|
|
10
8
|
arg_index = 0
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
shared_ctx = SharedContext(args=args)
|
|
9
|
+
dummmy_shared_ctx = SharedContext()
|
|
10
|
+
task_str_kwargs = {}
|
|
14
11
|
for task_input in task.inputs:
|
|
12
|
+
task_name = task_input.name
|
|
15
13
|
if task_input.name in str_kwargs:
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
14
|
+
task_str_kwargs[task_input.name] = str_kwargs[task_name]
|
|
15
|
+
# Update dummy shared context for next input default value
|
|
16
|
+
task_input.update_shared_context(
|
|
17
|
+
dummmy_shared_ctx, str_value=str_kwargs[task_name]
|
|
18
|
+
)
|
|
19
|
+
elif arg_index < len(str_args) and task_input.allow_positional_parsing:
|
|
20
|
+
task_str_kwargs[task_name] = str_args[arg_index]
|
|
21
|
+
# Update dummy shared context for next input default value
|
|
22
|
+
task_input.update_shared_context(
|
|
23
|
+
dummmy_shared_ctx, str_value=task_str_kwargs[task_name]
|
|
24
|
+
)
|
|
22
25
|
arg_index += 1
|
|
23
26
|
else:
|
|
24
27
|
if cli_mode and task_input.always_prompt:
|
|
25
|
-
str_value = task_input.prompt_cli_str(
|
|
28
|
+
str_value = task_input.prompt_cli_str(dummmy_shared_ctx)
|
|
26
29
|
else:
|
|
27
|
-
str_value = task_input.get_default_str(
|
|
28
|
-
|
|
29
|
-
# Update shared context for next input default value
|
|
30
|
-
task_input.update_shared_context(
|
|
31
|
-
|
|
30
|
+
str_value = task_input.get_default_str(dummmy_shared_ctx)
|
|
31
|
+
task_str_kwargs[task_name] = str_value
|
|
32
|
+
# Update dummy shared context for next input default value
|
|
33
|
+
task_input.update_shared_context(
|
|
34
|
+
dummmy_shared_ctx, str_value=task_str_kwargs[task_name]
|
|
35
|
+
)
|
|
36
|
+
return task_str_kwargs
|
|
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
|
|
|
3
3
|
|
|
4
4
|
from zrb.config.web_auth_config import WebAuthConfig
|
|
5
5
|
from zrb.group.any_group import AnyGroup
|
|
6
|
-
from zrb.runner.common_util import
|
|
6
|
+
from zrb.runner.common_util import get_task_str_kwargs
|
|
7
7
|
from zrb.runner.web_util.user import get_user_from_request
|
|
8
8
|
from zrb.task.any_task import AnyTask
|
|
9
9
|
from zrb.util.group import NodeNotFoundError, extract_node_from_args
|
|
@@ -39,9 +39,9 @@ def serve_task_input_api(
|
|
|
39
39
|
if isinstance(task, AnyTask):
|
|
40
40
|
if not user.can_access_task(task):
|
|
41
41
|
return JSONResponse(content={"detail": "Forbidden"}, status_code=403)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
task=task,
|
|
42
|
+
str_kwargs = json.loads(query)
|
|
43
|
+
task_str_kwargs = get_task_str_kwargs(
|
|
44
|
+
task=task, str_args=[], str_kwargs=str_kwargs, cli_mode=False
|
|
45
45
|
)
|
|
46
|
-
return
|
|
46
|
+
return task_str_kwargs
|
|
47
47
|
return JSONResponse(content={"detail": "Not found"}, status_code=404)
|
zrb/runner/web_util/user.py
CHANGED
|
@@ -19,7 +19,7 @@ def get_user_by_credentials(
|
|
|
19
19
|
|
|
20
20
|
async def get_user_from_request(
|
|
21
21
|
web_auth_config: WebAuthConfig, request: "Request"
|
|
22
|
-
) -> User
|
|
22
|
+
) -> User:
|
|
23
23
|
from fastapi.security import OAuth2PasswordBearer
|
|
24
24
|
|
|
25
25
|
if not web_auth_config.enable_auth:
|
|
@@ -45,7 +45,11 @@ def _get_user_from_cookie(
|
|
|
45
45
|
return None
|
|
46
46
|
|
|
47
47
|
|
|
48
|
-
def _get_user_from_token(
|
|
48
|
+
def _get_user_from_token(
|
|
49
|
+
web_auth_config: WebAuthConfig, token: str | None
|
|
50
|
+
) -> User | None:
|
|
51
|
+
if token is None:
|
|
52
|
+
return None
|
|
49
53
|
try:
|
|
50
54
|
from jose import jwt
|
|
51
55
|
|
|
@@ -54,7 +58,7 @@ def _get_user_from_token(web_auth_config: WebAuthConfig, token: str) -> User | N
|
|
|
54
58
|
web_auth_config.secret_key,
|
|
55
59
|
options={"require_sub": True, "require_exp": True},
|
|
56
60
|
)
|
|
57
|
-
username: str = payload.get("sub")
|
|
61
|
+
username: str | None = payload.get("sub")
|
|
58
62
|
if username is None:
|
|
59
63
|
return None
|
|
60
64
|
user = web_auth_config.find_user_by_username(username)
|
zrb/session/any_session.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations # Enables forward references
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
from abc import ABC, abstractmethod
|
|
4
5
|
from typing import TYPE_CHECKING, Any, Coroutine, TypeVar
|
|
5
6
|
|
|
@@ -62,12 +63,13 @@ class AnySession(ABC):
|
|
|
62
63
|
|
|
63
64
|
@property
|
|
64
65
|
@abstractmethod
|
|
65
|
-
def parent(self) ->
|
|
66
|
+
def parent(self) -> "AnySession | None":
|
|
66
67
|
"""Parent session"""
|
|
67
68
|
pass
|
|
68
69
|
|
|
70
|
+
@property
|
|
69
71
|
@abstractmethod
|
|
70
|
-
def task_path(self) -> str:
|
|
72
|
+
def task_path(self) -> list[str]:
|
|
71
73
|
"""Main task's path"""
|
|
72
74
|
pass
|
|
73
75
|
|
|
@@ -105,7 +107,9 @@ class AnySession(ABC):
|
|
|
105
107
|
pass
|
|
106
108
|
|
|
107
109
|
@abstractmethod
|
|
108
|
-
def defer_monitoring(
|
|
110
|
+
def defer_monitoring(
|
|
111
|
+
self, task: "AnyTask", coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
|
|
112
|
+
):
|
|
109
113
|
"""Defers the execution of a task's monitoring coroutine for later processing.
|
|
110
114
|
|
|
111
115
|
Args:
|
|
@@ -115,7 +119,9 @@ class AnySession(ABC):
|
|
|
115
119
|
pass
|
|
116
120
|
|
|
117
121
|
@abstractmethod
|
|
118
|
-
def defer_action(
|
|
122
|
+
def defer_action(
|
|
123
|
+
self, task: "AnyTask", coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
|
|
124
|
+
):
|
|
119
125
|
"""Defers the execution of a task's coroutine for later processing.
|
|
120
126
|
|
|
121
127
|
Args:
|
|
@@ -125,7 +131,7 @@ class AnySession(ABC):
|
|
|
125
131
|
pass
|
|
126
132
|
|
|
127
133
|
@abstractmethod
|
|
128
|
-
def defer_coro(self, coro: Coroutine):
|
|
134
|
+
def defer_coro(self, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]):
|
|
129
135
|
"""Defers the execution of a coroutine for later processing.
|
|
130
136
|
|
|
131
137
|
Args:
|
|
@@ -185,7 +191,7 @@ class AnySession(ABC):
|
|
|
185
191
|
pass
|
|
186
192
|
|
|
187
193
|
@abstractmethod
|
|
188
|
-
def is_allowed_to_run(self, task: "AnyTask"):
|
|
194
|
+
def is_allowed_to_run(self, task: "AnyTask") -> bool:
|
|
189
195
|
"""Determines if the specified task is allowed to run based on its current state.
|
|
190
196
|
|
|
191
197
|
Args:
|
zrb/session/session.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import asyncio
|
|
2
4
|
from typing import TYPE_CHECKING, Any, Coroutine
|
|
3
5
|
|
|
4
6
|
from zrb.context.any_shared_context import AnySharedContext
|
|
5
7
|
from zrb.context.context import AnyContext, Context
|
|
6
8
|
from zrb.group.any_group import AnyGroup
|
|
7
|
-
from zrb.session.any_session import AnySession
|
|
9
|
+
from zrb.session.any_session import AnySession, TAnySession
|
|
8
10
|
from zrb.session_state_logger.any_session_state_logger import AnySessionStateLogger
|
|
9
11
|
from zrb.session_state_logger.session_state_logger_factory import session_state_logger
|
|
10
12
|
from zrb.task.any_task import AnyTask
|
|
@@ -48,10 +50,10 @@ class Session(AnySession):
|
|
|
48
50
|
self._context: dict[AnyTask, Context] = {}
|
|
49
51
|
self._shared_ctx = shared_ctx
|
|
50
52
|
self._shared_ctx.set_session(self)
|
|
51
|
-
self._parent = parent
|
|
52
|
-
self._action_coros: dict[AnyTask, asyncio.Task] = {}
|
|
53
|
-
self._monitoring_coros: dict[AnyTask, asyncio.Task] = {}
|
|
54
|
-
self._coros: list[asyncio.Task] = []
|
|
53
|
+
self._parent: AnySession | None = parent
|
|
54
|
+
self._action_coros: dict[AnyTask, asyncio.Task[Any]] = {}
|
|
55
|
+
self._monitoring_coros: dict[AnyTask, asyncio.Task[Any]] = {}
|
|
56
|
+
self._coros: list[asyncio.Task[Any]] = []
|
|
55
57
|
self._colors = [
|
|
56
58
|
GREEN,
|
|
57
59
|
YELLOW,
|
|
@@ -114,11 +116,13 @@ class Session(AnySession):
|
|
|
114
116
|
return self._parent
|
|
115
117
|
|
|
116
118
|
@property
|
|
117
|
-
def task_path(self) -> str:
|
|
119
|
+
def task_path(self) -> list[str]:
|
|
118
120
|
return self._main_task_path
|
|
119
121
|
|
|
120
122
|
@property
|
|
121
123
|
def final_result(self) -> Any:
|
|
124
|
+
if self._main_task is None:
|
|
125
|
+
return None
|
|
122
126
|
xcom: Xcom = self.shared_ctx.xcom[self._main_task.name]
|
|
123
127
|
try:
|
|
124
128
|
return xcom.peek()
|
|
@@ -134,7 +138,11 @@ class Session(AnySession):
|
|
|
134
138
|
def set_main_task(self, main_task: AnyTask):
|
|
135
139
|
self.register_task(main_task)
|
|
136
140
|
self._main_task = main_task
|
|
137
|
-
main_task_path =
|
|
141
|
+
main_task_path = (
|
|
142
|
+
None
|
|
143
|
+
if self._root_group is None
|
|
144
|
+
else get_node_path(self._root_group, main_task)
|
|
145
|
+
)
|
|
138
146
|
self._main_task_path = [] if main_task_path is None else main_task_path
|
|
139
147
|
|
|
140
148
|
def as_state_log(self) -> "SessionStateLog":
|
|
@@ -171,7 +179,7 @@ class Session(AnySession):
|
|
|
171
179
|
return SessionStateLog(
|
|
172
180
|
name=self.name,
|
|
173
181
|
start_time=log_start_time,
|
|
174
|
-
main_task_name=self._main_task.name,
|
|
182
|
+
main_task_name="" if self._main_task is None else self._main_task.name,
|
|
175
183
|
path=self.task_path,
|
|
176
184
|
final_result=(
|
|
177
185
|
remove_style(f"{self.final_result}")
|
|
@@ -188,16 +196,29 @@ class Session(AnySession):
|
|
|
188
196
|
self._register_single_task(task)
|
|
189
197
|
return self._context[task]
|
|
190
198
|
|
|
191
|
-
def defer_monitoring(
|
|
199
|
+
def defer_monitoring(
|
|
200
|
+
self, task: AnyTask, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
|
|
201
|
+
):
|
|
192
202
|
self._register_single_task(task)
|
|
193
|
-
|
|
203
|
+
if isinstance(coro, asyncio.Task):
|
|
204
|
+
self._monitoring_coros[task] = coro
|
|
205
|
+
else:
|
|
206
|
+
self._monitoring_coros[task] = asyncio.create_task(coro)
|
|
194
207
|
|
|
195
|
-
def defer_action(
|
|
208
|
+
def defer_action(
|
|
209
|
+
self, task: AnyTask, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]
|
|
210
|
+
):
|
|
196
211
|
self._register_single_task(task)
|
|
197
|
-
|
|
212
|
+
if isinstance(coro, asyncio.Task):
|
|
213
|
+
self._action_coros[task] = coro
|
|
214
|
+
else:
|
|
215
|
+
self._action_coros[task] = asyncio.create_task(coro)
|
|
198
216
|
|
|
199
|
-
def defer_coro(self, coro: Coroutine):
|
|
200
|
-
|
|
217
|
+
def defer_coro(self, coro: Coroutine[Any, Any, Any] | asyncio.Task[Any]):
|
|
218
|
+
if isinstance(coro, asyncio.Task):
|
|
219
|
+
self._coros.append(coro)
|
|
220
|
+
else:
|
|
221
|
+
self._coros.append(asyncio.create_task(coro))
|
|
201
222
|
self._coros = [
|
|
202
223
|
existing_coro for existing_coro in self._coros if not existing_coro.done()
|
|
203
224
|
]
|
|
@@ -246,15 +267,15 @@ class Session(AnySession):
|
|
|
246
267
|
|
|
247
268
|
def get_next_tasks(self, task: AnyTask) -> list[AnyTask]:
|
|
248
269
|
self._register_single_task(task)
|
|
249
|
-
return self._downstreams.get(task)
|
|
270
|
+
return self._downstreams.get(task, [])
|
|
250
271
|
|
|
251
272
|
def get_task_status(self, task: AnyTask) -> TaskStatus:
|
|
252
273
|
self._register_single_task(task)
|
|
253
274
|
return self._task_status[task]
|
|
254
275
|
|
|
255
276
|
def _register_single_task(self, task: AnyTask):
|
|
256
|
-
if task.name not in self._shared_ctx.
|
|
257
|
-
self._shared_ctx.
|
|
277
|
+
if task.name not in self._shared_ctx.xcom:
|
|
278
|
+
self._shared_ctx.xcom[task.name] = Xcom([])
|
|
258
279
|
if task not in self._context:
|
|
259
280
|
self._context[task] = Context(
|
|
260
281
|
shared_ctx=self._shared_ctx,
|
|
@@ -278,7 +299,7 @@ class Session(AnySession):
|
|
|
278
299
|
self._color_index = 0
|
|
279
300
|
return chosen
|
|
280
301
|
|
|
281
|
-
def _get_icon(self, task: AnyTask) ->
|
|
302
|
+
def _get_icon(self, task: AnyTask) -> str:
|
|
282
303
|
if task.icon is not None:
|
|
283
304
|
return task.icon
|
|
284
305
|
chosen = self._icons[self._icon_index]
|
zrb/task/any_task.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations # Enables forward references
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import TYPE_CHECKING, Any
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
5
5
|
|
|
6
6
|
from zrb.env.any_env import AnyEnv
|
|
7
7
|
from zrb.input.any_input import AnyInput
|
|
@@ -36,6 +36,14 @@ class AnyTask(ABC):
|
|
|
36
36
|
the actual implementation for these abstract members.
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def __rshift__(self, other: "AnyTask | list[AnyTask]") -> "AnyTask | list[AnyTask]":
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def __lshift__(self, other: "AnyTask | list[AnyTask]") -> "AnyTask":
|
|
45
|
+
pass
|
|
46
|
+
|
|
39
47
|
@property
|
|
40
48
|
@abstractmethod
|
|
41
49
|
def name(self) -> str:
|
|
@@ -148,13 +156,17 @@ class AnyTask(ABC):
|
|
|
148
156
|
|
|
149
157
|
@abstractmethod
|
|
150
158
|
def run(
|
|
151
|
-
self,
|
|
159
|
+
self,
|
|
160
|
+
session: "AnySession | None" = None,
|
|
161
|
+
str_kwargs: dict[str, str] | None = None,
|
|
162
|
+
kwargs: dict[str, Any] | None = None,
|
|
152
163
|
) -> Any:
|
|
153
164
|
"""Runs the task synchronously.
|
|
154
165
|
|
|
155
166
|
Args:
|
|
156
167
|
session (AnySession): The shared session.
|
|
157
168
|
str_kwargs(dict[str, str]): The input string values.
|
|
169
|
+
kwargs(dict[str, Any]): The input values.
|
|
158
170
|
|
|
159
171
|
Returns:
|
|
160
172
|
Any: The result of the task execution.
|
|
@@ -163,13 +175,17 @@ class AnyTask(ABC):
|
|
|
163
175
|
|
|
164
176
|
@abstractmethod
|
|
165
177
|
async def async_run(
|
|
166
|
-
self,
|
|
178
|
+
self,
|
|
179
|
+
session: "AnySession | None" = None,
|
|
180
|
+
str_kwargs: dict[str, str] | None = None,
|
|
181
|
+
kwargs: dict[str, Any] | None = None,
|
|
167
182
|
) -> Any:
|
|
168
183
|
"""Runs the task asynchronously.
|
|
169
184
|
|
|
170
185
|
Args:
|
|
171
186
|
session (AnySession): The shared session.
|
|
172
187
|
str_kwargs(dict[str, str]): The input string values.
|
|
188
|
+
kwargs(dict[str, Any]): The input values.
|
|
173
189
|
|
|
174
190
|
Returns:
|
|
175
191
|
Any: The result of the task execution.
|
|
@@ -203,3 +219,8 @@ class AnyTask(ABC):
|
|
|
203
219
|
session (AnySession): The shared session.
|
|
204
220
|
"""
|
|
205
221
|
pass
|
|
222
|
+
|
|
223
|
+
@abstractmethod
|
|
224
|
+
def to_function(self) -> Callable[..., Any]:
|
|
225
|
+
"""Turn a task into a function"""
|
|
226
|
+
pass
|
zrb/task/base/context.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import TYPE_CHECKING
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
3
|
|
|
4
4
|
from zrb.context.any_context import AnyContext
|
|
5
5
|
from zrb.context.any_shared_context import AnySharedContext
|
|
@@ -26,25 +26,33 @@ def build_task_context(task: AnyTask, session: AnySession) -> AnyContext:
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def fill_shared_context_inputs(
|
|
29
|
-
|
|
29
|
+
shared_ctx: AnySharedContext,
|
|
30
|
+
task: AnyTask,
|
|
31
|
+
str_kwargs: dict[str, str] | None = None,
|
|
32
|
+
kwargs: dict[str, Any] | None = None,
|
|
30
33
|
):
|
|
31
34
|
"""
|
|
32
|
-
Populates the shared context with input values provided via
|
|
35
|
+
Populates the shared context with input values provided via str_kwargs.
|
|
33
36
|
"""
|
|
37
|
+
str_kwarg_dict = str_kwargs if str_kwargs is not None else {}
|
|
38
|
+
kwarg_dict = kwargs if kwargs is not None else {}
|
|
34
39
|
for task_input in task.inputs:
|
|
35
|
-
if task_input.name not in
|
|
36
|
-
|
|
37
|
-
|
|
40
|
+
if task_input.name not in shared_ctx.input:
|
|
41
|
+
task_input.update_shared_context(
|
|
42
|
+
shared_ctx,
|
|
43
|
+
value=kwarg_dict.get(task_input.name, None),
|
|
44
|
+
str_value=str_kwarg_dict.get(task_input.name, None),
|
|
45
|
+
)
|
|
38
46
|
|
|
39
47
|
|
|
40
|
-
def fill_shared_context_envs(
|
|
48
|
+
def fill_shared_context_envs(shared_ctx: AnySharedContext):
|
|
41
49
|
"""
|
|
42
50
|
Injects OS environment variables into the shared context if they don't already exist.
|
|
43
51
|
"""
|
|
44
52
|
os_env_map = {
|
|
45
|
-
key: val for key, val in os.environ.items() if key not in
|
|
53
|
+
key: val for key, val in os.environ.items() if key not in shared_ctx.env
|
|
46
54
|
}
|
|
47
|
-
|
|
55
|
+
shared_ctx.env.update(os_env_map)
|
|
48
56
|
|
|
49
57
|
|
|
50
58
|
def combine_inputs(
|