zrb 1.13.1__py3-none-any.whl → 1.21.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/__init__.py +2 -6
- zrb/attr/type.py +8 -8
- zrb/builtin/__init__.py +2 -0
- zrb/builtin/group.py +31 -15
- zrb/builtin/http.py +7 -8
- zrb/builtin/llm/attachment.py +40 -0
- zrb/builtin/llm/chat_session.py +130 -144
- zrb/builtin/llm/chat_session_cmd.py +226 -0
- zrb/builtin/llm/chat_trigger.py +73 -0
- zrb/builtin/llm/history.py +4 -4
- zrb/builtin/llm/llm_ask.py +218 -110
- zrb/builtin/llm/tool/api.py +74 -62
- zrb/builtin/llm/tool/cli.py +35 -16
- zrb/builtin/llm/tool/code.py +49 -47
- zrb/builtin/llm/tool/file.py +262 -251
- zrb/builtin/llm/tool/note.py +84 -0
- zrb/builtin/llm/tool/rag.py +25 -18
- zrb/builtin/llm/tool/sub_agent.py +29 -22
- zrb/builtin/llm/tool/web.py +135 -143
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
- zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
- zrb/builtin/searxng/config/settings.yml +5671 -0
- zrb/builtin/searxng/start.py +21 -0
- zrb/builtin/setup/latex/ubuntu.py +1 -0
- zrb/builtin/setup/ubuntu.py +1 -1
- zrb/builtin/shell/autocomplete/bash.py +4 -3
- zrb/builtin/shell/autocomplete/zsh.py +4 -3
- zrb/config/config.py +255 -78
- zrb/config/default_prompt/file_extractor_system_prompt.md +109 -9
- zrb/config/default_prompt/interactive_system_prompt.md +24 -30
- zrb/config/default_prompt/persona.md +1 -1
- zrb/config/default_prompt/repo_extractor_system_prompt.md +31 -31
- zrb/config/default_prompt/repo_summarizer_system_prompt.md +27 -8
- zrb/config/default_prompt/summarization_prompt.md +8 -13
- zrb/config/default_prompt/system_prompt.md +36 -30
- zrb/config/llm_config.py +129 -24
- zrb/config/llm_context/config.py +127 -90
- zrb/config/llm_context/config_parser.py +1 -7
- zrb/config/llm_context/workflow.py +81 -0
- zrb/config/llm_rate_limitter.py +89 -45
- zrb/context/any_shared_context.py +7 -1
- zrb/context/context.py +8 -2
- zrb/context/shared_context.py +6 -8
- zrb/group/any_group.py +12 -5
- zrb/group/group.py +67 -3
- zrb/input/any_input.py +5 -1
- zrb/input/base_input.py +18 -6
- zrb/input/text_input.py +7 -24
- zrb/runner/cli.py +21 -20
- zrb/runner/common_util.py +24 -19
- zrb/runner/web_route/task_input_api_route.py +5 -5
- zrb/runner/web_route/task_session_api_route.py +1 -4
- zrb/runner/web_util/user.py +7 -3
- zrb/session/any_session.py +12 -6
- zrb/session/session.py +39 -18
- zrb/task/any_task.py +24 -3
- zrb/task/base/context.py +17 -9
- zrb/task/base/execution.py +15 -8
- zrb/task/base/lifecycle.py +8 -4
- zrb/task/base/monitoring.py +12 -7
- zrb/task/base_task.py +69 -5
- zrb/task/base_trigger.py +12 -5
- zrb/task/llm/agent.py +138 -52
- zrb/task/llm/config.py +45 -13
- zrb/task/llm/conversation_history.py +76 -6
- zrb/task/llm/conversation_history_model.py +0 -168
- zrb/task/llm/default_workflow/coding/workflow.md +41 -0
- zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
- zrb/task/llm/default_workflow/git/workflow.md +118 -0
- zrb/task/llm/default_workflow/golang/workflow.md +128 -0
- zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
- zrb/task/llm/default_workflow/java/workflow.md +146 -0
- zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
- zrb/task/llm/default_workflow/python/workflow.md +160 -0
- zrb/task/llm/default_workflow/researching/workflow.md +153 -0
- zrb/task/llm/default_workflow/rust/workflow.md +162 -0
- zrb/task/llm/default_workflow/shell/workflow.md +299 -0
- zrb/task/llm/file_replacement.py +206 -0
- zrb/task/llm/file_tool_model.py +57 -0
- zrb/task/llm/history_summarization.py +22 -35
- zrb/task/llm/history_summarization_tool.py +24 -0
- zrb/task/llm/print_node.py +182 -63
- zrb/task/llm/prompt.py +213 -153
- zrb/task/llm/tool_wrapper.py +210 -53
- zrb/task/llm/workflow.py +76 -0
- zrb/task/llm_task.py +98 -47
- zrb/task/make_task.py +2 -3
- zrb/task/rsync_task.py +25 -10
- zrb/task/scheduler.py +4 -4
- zrb/util/attr.py +50 -40
- zrb/util/cli/markdown.py +12 -0
- zrb/util/cli/text.py +30 -0
- zrb/util/file.py +27 -11
- zrb/util/{llm/prompt.py → markdown.py} +2 -3
- zrb/util/string/conversion.py +1 -1
- zrb/util/truncate.py +23 -0
- zrb/util/yaml.py +204 -0
- {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/METADATA +40 -20
- {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/RECORD +102 -79
- {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/WHEEL +1 -1
- zrb/task/llm/default_workflow/coding.md +0 -24
- zrb/task/llm/default_workflow/copywriting.md +0 -17
- zrb/task/llm/default_workflow/researching.md +0 -18
- {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/entry_points.txt +0 -0
zrb/config/llm_rate_limitter.py
CHANGED
|
@@ -1,29 +1,12 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import json
|
|
2
3
|
import time
|
|
3
4
|
from collections import deque
|
|
4
|
-
from typing import Callable
|
|
5
|
-
|
|
6
|
-
import tiktoken
|
|
5
|
+
from typing import Any, Callable
|
|
7
6
|
|
|
8
7
|
from zrb.config.config import CFG
|
|
9
8
|
|
|
10
9
|
|
|
11
|
-
def _estimate_token(text: str) -> int:
|
|
12
|
-
"""
|
|
13
|
-
Estimates the number of tokens in a given text.
|
|
14
|
-
Tries to use the 'gpt-4o' model's tokenizer for an accurate count.
|
|
15
|
-
If the tokenizer is unavailable (e.g., due to network issues),
|
|
16
|
-
it falls back to a heuristic of 4 characters per token.
|
|
17
|
-
"""
|
|
18
|
-
try:
|
|
19
|
-
# Primary method: Use tiktoken for an accurate count
|
|
20
|
-
enc = tiktoken.encoding_for_model("gpt-4o")
|
|
21
|
-
return len(enc.encode(text))
|
|
22
|
-
except Exception:
|
|
23
|
-
# Fallback method: Heuristic (4 characters per token)
|
|
24
|
-
return len(text) // 4
|
|
25
|
-
|
|
26
|
-
|
|
27
10
|
class LLMRateLimiter:
|
|
28
11
|
"""
|
|
29
12
|
Helper class to enforce LLM API rate limits and throttling.
|
|
@@ -35,14 +18,18 @@ class LLMRateLimiter:
|
|
|
35
18
|
max_requests_per_minute: int | None = None,
|
|
36
19
|
max_tokens_per_minute: int | None = None,
|
|
37
20
|
max_tokens_per_request: int | None = None,
|
|
21
|
+
max_tokens_per_tool_call_result: int | None = None,
|
|
38
22
|
throttle_sleep: float | None = None,
|
|
39
|
-
|
|
23
|
+
use_tiktoken: bool | None = None,
|
|
24
|
+
tiktoken_encoding_name: str | None = None,
|
|
40
25
|
):
|
|
41
26
|
self._max_requests_per_minute = max_requests_per_minute
|
|
42
27
|
self._max_tokens_per_minute = max_tokens_per_minute
|
|
43
28
|
self._max_tokens_per_request = max_tokens_per_request
|
|
29
|
+
self._max_tokens_per_tool_call_result = max_tokens_per_tool_call_result
|
|
44
30
|
self._throttle_sleep = throttle_sleep
|
|
45
|
-
self.
|
|
31
|
+
self._use_tiktoken = use_tiktoken
|
|
32
|
+
self._tiktoken_encoding_name = tiktoken_encoding_name
|
|
46
33
|
self.request_times = deque()
|
|
47
34
|
self.token_times = deque()
|
|
48
35
|
|
|
@@ -64,6 +51,12 @@ class LLMRateLimiter:
|
|
|
64
51
|
return self._max_tokens_per_request
|
|
65
52
|
return CFG.LLM_MAX_TOKENS_PER_REQUEST
|
|
66
53
|
|
|
54
|
+
@property
|
|
55
|
+
def max_tokens_per_tool_call_result(self) -> int:
|
|
56
|
+
if self._max_tokens_per_tool_call_result is not None:
|
|
57
|
+
return self._max_tokens_per_tool_call_result
|
|
58
|
+
return CFG.LLM_MAX_TOKENS_PER_TOOL_CALL_RESULT
|
|
59
|
+
|
|
67
60
|
@property
|
|
68
61
|
def throttle_sleep(self) -> float:
|
|
69
62
|
if self._throttle_sleep is not None:
|
|
@@ -71,10 +64,16 @@ class LLMRateLimiter:
|
|
|
71
64
|
return CFG.LLM_THROTTLE_SLEEP
|
|
72
65
|
|
|
73
66
|
@property
|
|
74
|
-
def
|
|
75
|
-
if self.
|
|
76
|
-
return self.
|
|
77
|
-
return
|
|
67
|
+
def use_tiktoken(self) -> bool:
|
|
68
|
+
if self._use_tiktoken is not None:
|
|
69
|
+
return self._use_tiktoken
|
|
70
|
+
return CFG.USE_TIKTOKEN
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def tiktoken_encoding_name(self) -> str:
|
|
74
|
+
if self._tiktoken_encoding_name is not None:
|
|
75
|
+
return self._tiktoken_encoding_name
|
|
76
|
+
return CFG.TIKTOKEN_ENCODING_NAME
|
|
78
77
|
|
|
79
78
|
def set_max_requests_per_minute(self, value: int):
|
|
80
79
|
self._max_requests_per_minute = value
|
|
@@ -85,29 +84,56 @@ class LLMRateLimiter:
|
|
|
85
84
|
def set_max_tokens_per_request(self, value: int):
|
|
86
85
|
self._max_tokens_per_request = value
|
|
87
86
|
|
|
87
|
+
def set_max_tokens_per_tool_call_result(self, value: int):
|
|
88
|
+
self._max_tokens_per_tool_call_result = value
|
|
89
|
+
|
|
88
90
|
def set_throttle_sleep(self, value: float):
|
|
89
91
|
self._throttle_sleep = value
|
|
90
92
|
|
|
91
|
-
def
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
93
|
+
def count_token(self, prompt: Any) -> int:
|
|
94
|
+
str_prompt = self._prompt_to_str(prompt)
|
|
95
|
+
if not self.use_tiktoken:
|
|
96
|
+
return self._fallback_count_token(str_prompt)
|
|
97
|
+
try:
|
|
98
|
+
import tiktoken
|
|
99
|
+
|
|
100
|
+
enc = tiktoken.get_encoding(self.tiktoken_encoding_name)
|
|
101
|
+
return len(enc.encode(str_prompt))
|
|
102
|
+
except Exception:
|
|
103
|
+
return self._fallback_count_token(str_prompt)
|
|
104
|
+
|
|
105
|
+
def _fallback_count_token(self, str_prompt: str) -> int:
|
|
106
|
+
return len(str_prompt) // 4
|
|
107
|
+
|
|
108
|
+
def clip_prompt(self, prompt: Any, limit: int) -> str:
|
|
109
|
+
str_prompt = self._prompt_to_str(prompt)
|
|
110
|
+
if not self.use_tiktoken:
|
|
111
|
+
return self._fallback_clip_prompt(str_prompt, limit)
|
|
112
|
+
try:
|
|
113
|
+
import tiktoken
|
|
114
|
+
|
|
115
|
+
enc = tiktoken.get_encoding(self.tiktoken_encoding_name)
|
|
116
|
+
tokens = enc.encode(str_prompt)
|
|
117
|
+
if len(tokens) <= limit:
|
|
118
|
+
return str_prompt
|
|
119
|
+
truncated = tokens[: limit - 3]
|
|
120
|
+
clipped_text = enc.decode(truncated)
|
|
121
|
+
return clipped_text + "..."
|
|
122
|
+
except Exception:
|
|
123
|
+
return self._fallback_clip_prompt(str_prompt, limit)
|
|
124
|
+
|
|
125
|
+
def _fallback_clip_prompt(self, str_prompt: str, limit: int) -> str:
|
|
126
|
+
char_limit = limit * 4 if limit * 4 <= 10 else limit * 4 - 10
|
|
127
|
+
return str_prompt[:char_limit] + "..."
|
|
128
|
+
|
|
129
|
+
async def throttle(
|
|
130
|
+
self,
|
|
131
|
+
prompt: Any,
|
|
132
|
+
throttle_notif_callback: Callable | None = None,
|
|
133
|
+
):
|
|
109
134
|
now = time.time()
|
|
110
|
-
|
|
135
|
+
str_prompt = self._prompt_to_str(prompt)
|
|
136
|
+
tokens = self.count_token(str_prompt)
|
|
111
137
|
# Clean up old entries
|
|
112
138
|
while self.request_times and now - self.request_times[0] > 60:
|
|
113
139
|
self.request_times.popleft()
|
|
@@ -116,13 +142,25 @@ class LLMRateLimiter:
|
|
|
116
142
|
# Check per-request token limit
|
|
117
143
|
if tokens > self.max_tokens_per_request:
|
|
118
144
|
raise ValueError(
|
|
119
|
-
|
|
145
|
+
(
|
|
146
|
+
"Request exceeds max_tokens_per_request "
|
|
147
|
+
"({tokens} > {self.max_tokens_per_request})."
|
|
148
|
+
)
|
|
149
|
+
)
|
|
150
|
+
if tokens > self.max_tokens_per_minute:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
(
|
|
153
|
+
"Request exceeds max_tokens_per_minute "
|
|
154
|
+
"({tokens} > {self.max_tokens_per_minute})."
|
|
155
|
+
)
|
|
120
156
|
)
|
|
121
157
|
# Wait if over per-minute request or token limit
|
|
122
158
|
while (
|
|
123
159
|
len(self.request_times) >= self.max_requests_per_minute
|
|
124
160
|
or sum(t for _, t in self.token_times) + tokens > self.max_tokens_per_minute
|
|
125
161
|
):
|
|
162
|
+
if throttle_notif_callback is not None:
|
|
163
|
+
throttle_notif_callback()
|
|
126
164
|
await asyncio.sleep(self.throttle_sleep)
|
|
127
165
|
now = time.time()
|
|
128
166
|
while self.request_times and now - self.request_times[0] > 60:
|
|
@@ -133,5 +171,11 @@ class LLMRateLimiter:
|
|
|
133
171
|
self.request_times.append(now)
|
|
134
172
|
self.token_times.append((now, tokens))
|
|
135
173
|
|
|
174
|
+
def _prompt_to_str(self, prompt: Any) -> str:
|
|
175
|
+
try:
|
|
176
|
+
return json.dumps(prompt)
|
|
177
|
+
except Exception:
|
|
178
|
+
return f"{prompt}"
|
|
179
|
+
|
|
136
180
|
|
|
137
181
|
llm_rate_limitter = LLMRateLimiter()
|
|
@@ -29,26 +29,32 @@ class AnySharedContext(ABC):
|
|
|
29
29
|
pass
|
|
30
30
|
|
|
31
31
|
@property
|
|
32
|
+
@abstractmethod
|
|
32
33
|
def input(self) -> DotDict:
|
|
33
34
|
pass
|
|
34
35
|
|
|
35
36
|
@property
|
|
37
|
+
@abstractmethod
|
|
36
38
|
def env(self) -> DotDict:
|
|
37
39
|
pass
|
|
38
40
|
|
|
39
41
|
@property
|
|
42
|
+
@abstractmethod
|
|
40
43
|
def args(self) -> list[Any]:
|
|
41
44
|
pass
|
|
42
45
|
|
|
43
46
|
@property
|
|
44
|
-
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def xcom(self) -> DotDict:
|
|
45
49
|
pass
|
|
46
50
|
|
|
47
51
|
@property
|
|
52
|
+
@abstractmethod
|
|
48
53
|
def shared_log(self) -> list[str]:
|
|
49
54
|
pass
|
|
50
55
|
|
|
51
56
|
@property
|
|
57
|
+
@abstractmethod
|
|
52
58
|
def session(self) -> any_session.AnySession | None:
|
|
53
59
|
pass
|
|
54
60
|
|
zrb/context/context.py
CHANGED
|
@@ -63,7 +63,7 @@ class Context(AnyContext):
|
|
|
63
63
|
|
|
64
64
|
@property
|
|
65
65
|
def session(self) -> AnySession | None:
|
|
66
|
-
return self._shared_ctx.
|
|
66
|
+
return self._shared_ctx.session
|
|
67
67
|
|
|
68
68
|
def update_task_env(self, task_env: dict[str, str]):
|
|
69
69
|
self._env.update(task_env)
|
|
@@ -119,7 +119,13 @@ class Context(AnyContext):
|
|
|
119
119
|
return
|
|
120
120
|
color = self._color
|
|
121
121
|
icon = self._icon
|
|
122
|
-
|
|
122
|
+
# Handle case where session is None (e.g., in tests)
|
|
123
|
+
if self.session is None:
|
|
124
|
+
max_name_length = len(self._task_name) + len(icon)
|
|
125
|
+
else:
|
|
126
|
+
max_name_length = max(
|
|
127
|
+
len(name) + len(icon) for name in self.session.task_names
|
|
128
|
+
)
|
|
123
129
|
styled_task_name = f"{icon} {self._task_name}"
|
|
124
130
|
padded_styled_task_name = styled_task_name.rjust(max_name_length + 1)
|
|
125
131
|
if self._attempt == 0:
|
zrb/context/shared_context.py
CHANGED
|
@@ -27,6 +27,7 @@ class SharedContext(AnySharedContext):
|
|
|
27
27
|
env: dict[str, str] = {},
|
|
28
28
|
xcom: dict[str, Xcom] = {},
|
|
29
29
|
logging_level: int | None = None,
|
|
30
|
+
is_web_mode: bool = False,
|
|
30
31
|
):
|
|
31
32
|
self.__logging_level = logging_level
|
|
32
33
|
self._input = DotDict(input)
|
|
@@ -35,18 +36,15 @@ class SharedContext(AnySharedContext):
|
|
|
35
36
|
self._xcom = DotDict(xcom)
|
|
36
37
|
self._session: AnySession | None = None
|
|
37
38
|
self._log = []
|
|
39
|
+
self._is_web_mode = is_web_mode
|
|
38
40
|
|
|
39
41
|
def __repr__(self):
|
|
40
42
|
class_name = self.__class__.__name__
|
|
41
|
-
|
|
42
|
-
args = self._args
|
|
43
|
-
env = self._env
|
|
44
|
-
xcom = self._xcom
|
|
45
|
-
return f"<{class_name} input={input} args={args} xcom={xcom} env={env}>"
|
|
43
|
+
return f"<{class_name}>"
|
|
46
44
|
|
|
47
45
|
@property
|
|
48
46
|
def is_web_mode(self) -> bool:
|
|
49
|
-
return self.
|
|
47
|
+
return self._is_web_mode
|
|
50
48
|
|
|
51
49
|
@property
|
|
52
50
|
def is_tty(self) -> bool:
|
|
@@ -68,7 +66,7 @@ class SharedContext(AnySharedContext):
|
|
|
68
66
|
return self._args
|
|
69
67
|
|
|
70
68
|
@property
|
|
71
|
-
def xcom(self) -> DotDict
|
|
69
|
+
def xcom(self) -> DotDict:
|
|
72
70
|
return self._xcom
|
|
73
71
|
|
|
74
72
|
@property
|
|
@@ -83,7 +81,7 @@ class SharedContext(AnySharedContext):
|
|
|
83
81
|
self._log.append(message)
|
|
84
82
|
session = self.session
|
|
85
83
|
if session is not None:
|
|
86
|
-
session_parent: AnySession = session.parent
|
|
84
|
+
session_parent: AnySession | None = session.parent
|
|
87
85
|
if session_parent is not None:
|
|
88
86
|
session_parent.shared_ctx.append_to_shared_log(message)
|
|
89
87
|
|
zrb/group/any_group.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Optional, Union
|
|
3
2
|
|
|
4
3
|
from zrb.task.any_task import AnyTask
|
|
5
4
|
|
|
@@ -31,16 +30,24 @@ class AnyGroup(ABC):
|
|
|
31
30
|
|
|
32
31
|
@property
|
|
33
32
|
@abstractmethod
|
|
34
|
-
def subgroups(self) -> dict[str,
|
|
33
|
+
def subgroups(self) -> "dict[str, AnyGroup]":
|
|
35
34
|
"""Group subgroups"""
|
|
36
35
|
pass
|
|
37
36
|
|
|
38
37
|
@abstractmethod
|
|
39
|
-
def add_group(self, group:
|
|
38
|
+
def add_group(self, group: "AnyGroup", alias: str | None = None) -> "AnyGroup":
|
|
40
39
|
pass
|
|
41
40
|
|
|
42
41
|
@abstractmethod
|
|
43
|
-
def add_task(self, task: AnyTask, alias: str | None = None) -> AnyTask:
|
|
42
|
+
def add_task(self, task: "AnyTask", alias: str | None = None) -> "AnyTask":
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def remove_group(self, group: "AnyGroup | str"):
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def remove_task(self, task: "AnyTask | str"):
|
|
44
51
|
pass
|
|
45
52
|
|
|
46
53
|
@abstractmethod
|
|
@@ -48,5 +55,5 @@ class AnyGroup(ABC):
|
|
|
48
55
|
pass
|
|
49
56
|
|
|
50
57
|
@abstractmethod
|
|
51
|
-
def get_group_by_alias(self,
|
|
58
|
+
def get_group_by_alias(self, alias: str) -> "AnyGroup | None":
|
|
52
59
|
pass
|
zrb/group/group.py
CHANGED
|
@@ -33,15 +33,15 @@ class Group(AnyGroup):
|
|
|
33
33
|
def subgroups(self) -> dict[str, AnyGroup]:
|
|
34
34
|
names = list(self._groups.keys())
|
|
35
35
|
names.sort()
|
|
36
|
-
return {name: self._groups
|
|
36
|
+
return {name: self._groups[name] for name in names}
|
|
37
37
|
|
|
38
38
|
@property
|
|
39
39
|
def subtasks(self) -> dict[str, AnyTask]:
|
|
40
40
|
alias = list(self._tasks.keys())
|
|
41
41
|
alias.sort()
|
|
42
|
-
return {name: self._tasks
|
|
42
|
+
return {name: self._tasks[name] for name in alias}
|
|
43
43
|
|
|
44
|
-
def add_group(self, group: AnyGroup
|
|
44
|
+
def add_group(self, group: AnyGroup, alias: str | None = None) -> AnyGroup:
|
|
45
45
|
real_group = Group(group) if isinstance(group, str) else group
|
|
46
46
|
alias = alias if alias is not None else real_group.name
|
|
47
47
|
self._groups[alias] = real_group
|
|
@@ -52,6 +52,70 @@ class Group(AnyGroup):
|
|
|
52
52
|
self._tasks[alias] = task
|
|
53
53
|
return task
|
|
54
54
|
|
|
55
|
+
def remove_group(self, group: "AnyGroup | str"):
|
|
56
|
+
original_groups_len = len(self._groups)
|
|
57
|
+
if isinstance(group, AnyGroup):
|
|
58
|
+
new_groups = {
|
|
59
|
+
alias: existing_group
|
|
60
|
+
for alias, existing_group in self._groups.items()
|
|
61
|
+
if group != existing_group
|
|
62
|
+
}
|
|
63
|
+
if len(new_groups) == original_groups_len:
|
|
64
|
+
raise ValueError(f"Cannot remove group {group} from {self}")
|
|
65
|
+
self._groups = new_groups
|
|
66
|
+
return
|
|
67
|
+
# group is string, try to remove by alias
|
|
68
|
+
new_groups = {
|
|
69
|
+
alias: existing_group
|
|
70
|
+
for alias, existing_group in self._groups.items()
|
|
71
|
+
if alias != group
|
|
72
|
+
}
|
|
73
|
+
if len(new_groups) < original_groups_len:
|
|
74
|
+
self._groups = new_groups
|
|
75
|
+
return
|
|
76
|
+
# if alias removal didn't work, try to remove by name
|
|
77
|
+
new_groups = {
|
|
78
|
+
alias: existing_group
|
|
79
|
+
for alias, existing_group in self._groups.items()
|
|
80
|
+
if existing_group.name != group
|
|
81
|
+
}
|
|
82
|
+
if len(new_groups) < original_groups_len:
|
|
83
|
+
self._groups = new_groups
|
|
84
|
+
return
|
|
85
|
+
raise ValueError(f"Cannot remove group {group} from {self}")
|
|
86
|
+
|
|
87
|
+
def remove_task(self, task: "AnyTask | str"):
|
|
88
|
+
original_tasks_len = len(self._tasks)
|
|
89
|
+
if isinstance(task, AnyTask):
|
|
90
|
+
new_tasks = {
|
|
91
|
+
alias: existing_task
|
|
92
|
+
for alias, existing_task in self._tasks.items()
|
|
93
|
+
if task != existing_task
|
|
94
|
+
}
|
|
95
|
+
if len(new_tasks) == original_tasks_len:
|
|
96
|
+
raise ValueError(f"Cannot remove task {task} from {self}")
|
|
97
|
+
self._tasks = new_tasks
|
|
98
|
+
return
|
|
99
|
+
# task is string, try to remove by alias
|
|
100
|
+
new_tasks = {
|
|
101
|
+
alias: existing_task
|
|
102
|
+
for alias, existing_task in self._tasks.items()
|
|
103
|
+
if alias != task
|
|
104
|
+
}
|
|
105
|
+
if len(new_tasks) < original_tasks_len:
|
|
106
|
+
self._tasks = new_tasks
|
|
107
|
+
return
|
|
108
|
+
# if alias removal didn't work, try to remove by name
|
|
109
|
+
new_tasks = {
|
|
110
|
+
alias: existing_task
|
|
111
|
+
for alias, existing_task in self._tasks.items()
|
|
112
|
+
if existing_task.name != task
|
|
113
|
+
}
|
|
114
|
+
if len(new_tasks) < original_tasks_len:
|
|
115
|
+
self._tasks = new_tasks
|
|
116
|
+
return
|
|
117
|
+
raise ValueError(f"Cannot remove task {task} from {self}")
|
|
118
|
+
|
|
55
119
|
def get_task_by_alias(self, alias: str) -> AnyTask | None:
|
|
56
120
|
return self._tasks.get(alias)
|
|
57
121
|
|
zrb/input/any_input.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
from zrb.context.any_shared_context import AnySharedContext
|
|
4
5
|
|
|
@@ -35,7 +36,10 @@ class AnyInput(ABC):
|
|
|
35
36
|
|
|
36
37
|
@abstractmethod
|
|
37
38
|
def update_shared_context(
|
|
38
|
-
self,
|
|
39
|
+
self,
|
|
40
|
+
shared_ctx: AnySharedContext,
|
|
41
|
+
str_value: str | None = None,
|
|
42
|
+
value: Any = None,
|
|
39
43
|
):
|
|
40
44
|
pass
|
|
41
45
|
|
zrb/input/base_input.py
CHANGED
|
@@ -58,11 +58,15 @@ class BaseInput(AnyInput):
|
|
|
58
58
|
return f'<input name="{name}" placeholder="{description}" value="{default}" />'
|
|
59
59
|
|
|
60
60
|
def update_shared_context(
|
|
61
|
-
self,
|
|
61
|
+
self,
|
|
62
|
+
shared_ctx: AnySharedContext,
|
|
63
|
+
str_value: str | None = None,
|
|
64
|
+
value: Any = None,
|
|
62
65
|
):
|
|
63
|
-
if
|
|
64
|
-
str_value
|
|
65
|
-
|
|
66
|
+
if value is None:
|
|
67
|
+
if str_value is None:
|
|
68
|
+
str_value = self.get_default_str(shared_ctx)
|
|
69
|
+
value = self._parse_str_value(str_value)
|
|
66
70
|
if self.name in shared_ctx.input:
|
|
67
71
|
raise ValueError(f"Input already defined in the context: {self.name}")
|
|
68
72
|
shared_ctx.input[self.name] = value
|
|
@@ -91,12 +95,20 @@ class BaseInput(AnyInput):
|
|
|
91
95
|
default_str = self.get_default_str(shared_ctx)
|
|
92
96
|
if default_str != "":
|
|
93
97
|
prompt_message = f"{prompt_message} [{default_str}]"
|
|
94
|
-
|
|
95
|
-
value = input()
|
|
98
|
+
value = self._read_line(shared_ctx, prompt_message)
|
|
96
99
|
if value.strip() == "":
|
|
97
100
|
value = default_str
|
|
98
101
|
return value
|
|
99
102
|
|
|
103
|
+
def _read_line(self, shared_ctx: AnySharedContext, prompt_message: str) -> str:
|
|
104
|
+
if not shared_ctx.is_tty:
|
|
105
|
+
print(f"{prompt_message}: ", end="")
|
|
106
|
+
return input()
|
|
107
|
+
from prompt_toolkit import PromptSession
|
|
108
|
+
|
|
109
|
+
reader = PromptSession()
|
|
110
|
+
return reader.prompt(f"{prompt_message}: ")
|
|
111
|
+
|
|
100
112
|
def get_default_str(self, shared_ctx: AnySharedContext) -> str:
|
|
101
113
|
"""Get default value as str"""
|
|
102
114
|
default_value = get_attr(
|
zrb/input/text_input.py
CHANGED
|
@@ -1,12 +1,9 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import subprocess
|
|
3
|
-
import tempfile
|
|
4
1
|
from collections.abc import Callable
|
|
5
2
|
|
|
6
3
|
from zrb.config.config import CFG
|
|
7
4
|
from zrb.context.any_shared_context import AnySharedContext
|
|
8
5
|
from zrb.input.base_input import BaseInput
|
|
9
|
-
from zrb.util.
|
|
6
|
+
from zrb.util.cli.text import edit_text
|
|
10
7
|
|
|
11
8
|
|
|
12
9
|
class TextInput(BaseInput):
|
|
@@ -85,24 +82,10 @@ class TextInput(BaseInput):
|
|
|
85
82
|
comment_prompt_message = (
|
|
86
83
|
f"{self.comment_start}{prompt_message}{self.comment_end}"
|
|
87
84
|
)
|
|
88
|
-
comment_prompt_message_eol = f"{comment_prompt_message}\n"
|
|
89
85
|
default_value = self.get_default_str(shared_ctx)
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
if default_value:
|
|
97
|
-
temp_file.write(default_value.encode())
|
|
98
|
-
temp_file.flush()
|
|
99
|
-
subprocess.call([self.editor_cmd, temp_file_name])
|
|
100
|
-
# Read the edited content
|
|
101
|
-
edited_content = read_file(temp_file_name)
|
|
102
|
-
parts = [
|
|
103
|
-
text.strip() for text in edited_content.split(comment_prompt_message, 1)
|
|
104
|
-
]
|
|
105
|
-
edited_content = "\n".join(parts).lstrip()
|
|
106
|
-
os.remove(temp_file_name)
|
|
107
|
-
print(f"{prompt_message}: {edited_content}")
|
|
108
|
-
return edited_content
|
|
86
|
+
return edit_text(
|
|
87
|
+
prompt_message=comment_prompt_message,
|
|
88
|
+
value=default_value,
|
|
89
|
+
editor=self.editor_cmd,
|
|
90
|
+
extension=self._extension,
|
|
91
|
+
)
|
zrb/runner/cli.py
CHANGED
|
@@ -7,7 +7,7 @@ from zrb.context.any_context import AnyContext
|
|
|
7
7
|
from zrb.context.shared_context import SharedContext
|
|
8
8
|
from zrb.group.any_group import AnyGroup
|
|
9
9
|
from zrb.group.group import Group
|
|
10
|
-
from zrb.runner.common_util import
|
|
10
|
+
from zrb.runner.common_util import get_task_str_kwargs
|
|
11
11
|
from zrb.session.session import Session
|
|
12
12
|
from zrb.session_state_logger.session_state_logger_factory import session_state_logger
|
|
13
13
|
from zrb.task.any_task import AnyTask
|
|
@@ -38,23 +38,25 @@ class Cli(Group):
|
|
|
38
38
|
def banner(self) -> str:
|
|
39
39
|
return CFG.BANNER
|
|
40
40
|
|
|
41
|
-
def run(self,
|
|
42
|
-
|
|
43
|
-
node, node_path,
|
|
41
|
+
def run(self, str_args: list[str] = []):
|
|
42
|
+
str_kwargs, str_args = self._extract_kwargs_from_args(str_args)
|
|
43
|
+
node, node_path, str_args = extract_node_from_args(self, str_args)
|
|
44
44
|
if isinstance(node, AnyGroup):
|
|
45
45
|
self._show_group_info(node)
|
|
46
46
|
return
|
|
47
|
-
if "h" in
|
|
47
|
+
if "h" in str_kwargs or "help" in str_kwargs:
|
|
48
48
|
self._show_task_info(node)
|
|
49
49
|
return
|
|
50
|
-
|
|
50
|
+
task_str_kwargs = get_task_str_kwargs(
|
|
51
|
+
task=node, str_args=str_args, str_kwargs=str_kwargs, cli_mode=True
|
|
52
|
+
)
|
|
51
53
|
try:
|
|
52
|
-
result = self._run_task(node,
|
|
54
|
+
result = self._run_task(node, str_args, task_str_kwargs)
|
|
53
55
|
if result is not None:
|
|
54
56
|
print(result)
|
|
55
57
|
return result
|
|
56
58
|
finally:
|
|
57
|
-
run_command = self._get_run_command(node_path,
|
|
59
|
+
run_command = self._get_run_command(node_path, task_str_kwargs)
|
|
58
60
|
self._print_run_command(run_command)
|
|
59
61
|
|
|
60
62
|
def _print_run_command(self, run_command: str):
|
|
@@ -64,11 +66,14 @@ class Cli(Group):
|
|
|
64
66
|
file=sys.stderr,
|
|
65
67
|
)
|
|
66
68
|
|
|
67
|
-
def _get_run_command(
|
|
69
|
+
def _get_run_command(
|
|
70
|
+
self, node_path: list[str], task_str_kwargs: dict[str, str]
|
|
71
|
+
) -> str:
|
|
68
72
|
parts = [self.name] + node_path
|
|
69
|
-
if len(
|
|
73
|
+
if len(task_str_kwargs) > 0:
|
|
70
74
|
parts += [
|
|
71
|
-
self._get_run_command_param(key, val)
|
|
75
|
+
self._get_run_command_param(key, val)
|
|
76
|
+
for key, val in task_str_kwargs.items()
|
|
72
77
|
]
|
|
73
78
|
return " ".join(parts)
|
|
74
79
|
|
|
@@ -81,13 +86,9 @@ class Cli(Group):
|
|
|
81
86
|
self, task: AnyTask, args: list[str], run_kwargs: dict[str, str]
|
|
82
87
|
) -> tuple[Any]:
|
|
83
88
|
shared_ctx = SharedContext(args=args)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
shared_ctx, run_kwargs[task_input.name]
|
|
88
|
-
)
|
|
89
|
-
continue
|
|
90
|
-
return task.run(Session(shared_ctx=shared_ctx, root_group=self))
|
|
89
|
+
return task.run(
|
|
90
|
+
Session(shared_ctx=shared_ctx, root_group=self), str_kwargs=run_kwargs
|
|
91
|
+
)
|
|
91
92
|
|
|
92
93
|
def _show_task_info(self, task: AnyTask):
|
|
93
94
|
description = task.description
|
|
@@ -150,11 +151,11 @@ class Cli(Group):
|
|
|
150
151
|
kwargs[key] = args[i + 1]
|
|
151
152
|
i += 1 # Skip the next argument as it's a value
|
|
152
153
|
else:
|
|
153
|
-
kwargs[key] =
|
|
154
|
+
kwargs[key] = "true"
|
|
154
155
|
elif arg.startswith("-"):
|
|
155
156
|
# Handle short flags like -t or -n
|
|
156
157
|
key = arg[1:]
|
|
157
|
-
kwargs[key] =
|
|
158
|
+
kwargs[key] = "true"
|
|
158
159
|
else:
|
|
159
160
|
# Anything else is considered a positional argument
|
|
160
161
|
residual_args.append(arg)
|