zrb 1.13.1__py3-none-any.whl → 1.21.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. zrb/__init__.py +2 -6
  2. zrb/attr/type.py +8 -8
  3. zrb/builtin/__init__.py +2 -0
  4. zrb/builtin/group.py +31 -15
  5. zrb/builtin/http.py +7 -8
  6. zrb/builtin/llm/attachment.py +40 -0
  7. zrb/builtin/llm/chat_session.py +130 -144
  8. zrb/builtin/llm/chat_session_cmd.py +226 -0
  9. zrb/builtin/llm/chat_trigger.py +73 -0
  10. zrb/builtin/llm/history.py +4 -4
  11. zrb/builtin/llm/llm_ask.py +218 -110
  12. zrb/builtin/llm/tool/api.py +74 -62
  13. zrb/builtin/llm/tool/cli.py +35 -16
  14. zrb/builtin/llm/tool/code.py +49 -47
  15. zrb/builtin/llm/tool/file.py +262 -251
  16. zrb/builtin/llm/tool/note.py +84 -0
  17. zrb/builtin/llm/tool/rag.py +25 -18
  18. zrb/builtin/llm/tool/sub_agent.py +29 -22
  19. zrb/builtin/llm/tool/web.py +135 -143
  20. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
  21. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
  22. zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
  23. zrb/builtin/searxng/config/settings.yml +5671 -0
  24. zrb/builtin/searxng/start.py +21 -0
  25. zrb/builtin/setup/latex/ubuntu.py +1 -0
  26. zrb/builtin/setup/ubuntu.py +1 -1
  27. zrb/builtin/shell/autocomplete/bash.py +4 -3
  28. zrb/builtin/shell/autocomplete/zsh.py +4 -3
  29. zrb/config/config.py +255 -78
  30. zrb/config/default_prompt/file_extractor_system_prompt.md +109 -9
  31. zrb/config/default_prompt/interactive_system_prompt.md +24 -30
  32. zrb/config/default_prompt/persona.md +1 -1
  33. zrb/config/default_prompt/repo_extractor_system_prompt.md +31 -31
  34. zrb/config/default_prompt/repo_summarizer_system_prompt.md +27 -8
  35. zrb/config/default_prompt/summarization_prompt.md +8 -13
  36. zrb/config/default_prompt/system_prompt.md +36 -30
  37. zrb/config/llm_config.py +129 -24
  38. zrb/config/llm_context/config.py +127 -90
  39. zrb/config/llm_context/config_parser.py +1 -7
  40. zrb/config/llm_context/workflow.py +81 -0
  41. zrb/config/llm_rate_limitter.py +89 -45
  42. zrb/context/any_shared_context.py +7 -1
  43. zrb/context/context.py +8 -2
  44. zrb/context/shared_context.py +6 -8
  45. zrb/group/any_group.py +12 -5
  46. zrb/group/group.py +67 -3
  47. zrb/input/any_input.py +5 -1
  48. zrb/input/base_input.py +18 -6
  49. zrb/input/text_input.py +7 -24
  50. zrb/runner/cli.py +21 -20
  51. zrb/runner/common_util.py +24 -19
  52. zrb/runner/web_route/task_input_api_route.py +5 -5
  53. zrb/runner/web_route/task_session_api_route.py +1 -4
  54. zrb/runner/web_util/user.py +7 -3
  55. zrb/session/any_session.py +12 -6
  56. zrb/session/session.py +39 -18
  57. zrb/task/any_task.py +24 -3
  58. zrb/task/base/context.py +17 -9
  59. zrb/task/base/execution.py +15 -8
  60. zrb/task/base/lifecycle.py +8 -4
  61. zrb/task/base/monitoring.py +12 -7
  62. zrb/task/base_task.py +69 -5
  63. zrb/task/base_trigger.py +12 -5
  64. zrb/task/llm/agent.py +138 -52
  65. zrb/task/llm/config.py +45 -13
  66. zrb/task/llm/conversation_history.py +76 -6
  67. zrb/task/llm/conversation_history_model.py +0 -168
  68. zrb/task/llm/default_workflow/coding/workflow.md +41 -0
  69. zrb/task/llm/default_workflow/copywriting/workflow.md +68 -0
  70. zrb/task/llm/default_workflow/git/workflow.md +118 -0
  71. zrb/task/llm/default_workflow/golang/workflow.md +128 -0
  72. zrb/task/llm/default_workflow/html-css/workflow.md +135 -0
  73. zrb/task/llm/default_workflow/java/workflow.md +146 -0
  74. zrb/task/llm/default_workflow/javascript/workflow.md +158 -0
  75. zrb/task/llm/default_workflow/python/workflow.md +160 -0
  76. zrb/task/llm/default_workflow/researching/workflow.md +153 -0
  77. zrb/task/llm/default_workflow/rust/workflow.md +162 -0
  78. zrb/task/llm/default_workflow/shell/workflow.md +299 -0
  79. zrb/task/llm/file_replacement.py +206 -0
  80. zrb/task/llm/file_tool_model.py +57 -0
  81. zrb/task/llm/history_summarization.py +22 -35
  82. zrb/task/llm/history_summarization_tool.py +24 -0
  83. zrb/task/llm/print_node.py +182 -63
  84. zrb/task/llm/prompt.py +213 -153
  85. zrb/task/llm/tool_wrapper.py +210 -53
  86. zrb/task/llm/workflow.py +76 -0
  87. zrb/task/llm_task.py +98 -47
  88. zrb/task/make_task.py +2 -3
  89. zrb/task/rsync_task.py +25 -10
  90. zrb/task/scheduler.py +4 -4
  91. zrb/util/attr.py +50 -40
  92. zrb/util/cli/markdown.py +12 -0
  93. zrb/util/cli/text.py +30 -0
  94. zrb/util/file.py +27 -11
  95. zrb/util/{llm/prompt.py → markdown.py} +2 -3
  96. zrb/util/string/conversion.py +1 -1
  97. zrb/util/truncate.py +23 -0
  98. zrb/util/yaml.py +204 -0
  99. {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/METADATA +40 -20
  100. {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/RECORD +102 -79
  101. {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/WHEEL +1 -1
  102. zrb/task/llm/default_workflow/coding.md +0 -24
  103. zrb/task/llm/default_workflow/copywriting.md +0 -17
  104. zrb/task/llm/default_workflow/researching.md +0 -18
  105. {zrb-1.13.1.dist-info → zrb-1.21.17.dist-info}/entry_points.txt +0 -0
@@ -1,29 +1,12 @@
1
1
  import asyncio
2
+ import json
2
3
  import time
3
4
  from collections import deque
4
- from typing import Callable
5
-
6
- import tiktoken
5
+ from typing import Any, Callable
7
6
 
8
7
  from zrb.config.config import CFG
9
8
 
10
9
 
11
- def _estimate_token(text: str) -> int:
12
- """
13
- Estimates the number of tokens in a given text.
14
- Tries to use the 'gpt-4o' model's tokenizer for an accurate count.
15
- If the tokenizer is unavailable (e.g., due to network issues),
16
- it falls back to a heuristic of 4 characters per token.
17
- """
18
- try:
19
- # Primary method: Use tiktoken for an accurate count
20
- enc = tiktoken.encoding_for_model("gpt-4o")
21
- return len(enc.encode(text))
22
- except Exception:
23
- # Fallback method: Heuristic (4 characters per token)
24
- return len(text) // 4
25
-
26
-
27
10
  class LLMRateLimiter:
28
11
  """
29
12
  Helper class to enforce LLM API rate limits and throttling.
@@ -35,14 +18,18 @@ class LLMRateLimiter:
35
18
  max_requests_per_minute: int | None = None,
36
19
  max_tokens_per_minute: int | None = None,
37
20
  max_tokens_per_request: int | None = None,
21
+ max_tokens_per_tool_call_result: int | None = None,
38
22
  throttle_sleep: float | None = None,
39
- token_counter_fn: Callable[[str], int] | None = None,
23
+ use_tiktoken: bool | None = None,
24
+ tiktoken_encoding_name: str | None = None,
40
25
  ):
41
26
  self._max_requests_per_minute = max_requests_per_minute
42
27
  self._max_tokens_per_minute = max_tokens_per_minute
43
28
  self._max_tokens_per_request = max_tokens_per_request
29
+ self._max_tokens_per_tool_call_result = max_tokens_per_tool_call_result
44
30
  self._throttle_sleep = throttle_sleep
45
- self._token_counter_fn = token_counter_fn
31
+ self._use_tiktoken = use_tiktoken
32
+ self._tiktoken_encoding_name = tiktoken_encoding_name
46
33
  self.request_times = deque()
47
34
  self.token_times = deque()
48
35
 
@@ -64,6 +51,12 @@ class LLMRateLimiter:
64
51
  return self._max_tokens_per_request
65
52
  return CFG.LLM_MAX_TOKENS_PER_REQUEST
66
53
 
54
+ @property
55
+ def max_tokens_per_tool_call_result(self) -> int:
56
+ if self._max_tokens_per_tool_call_result is not None:
57
+ return self._max_tokens_per_tool_call_result
58
+ return CFG.LLM_MAX_TOKENS_PER_TOOL_CALL_RESULT
59
+
67
60
  @property
68
61
  def throttle_sleep(self) -> float:
69
62
  if self._throttle_sleep is not None:
@@ -71,10 +64,16 @@ class LLMRateLimiter:
71
64
  return CFG.LLM_THROTTLE_SLEEP
72
65
 
73
66
  @property
74
- def count_token(self) -> Callable[[str], int]:
75
- if self._token_counter_fn is not None:
76
- return self._token_counter_fn
77
- return _estimate_token
67
+ def use_tiktoken(self) -> bool:
68
+ if self._use_tiktoken is not None:
69
+ return self._use_tiktoken
70
+ return CFG.USE_TIKTOKEN
71
+
72
+ @property
73
+ def tiktoken_encoding_name(self) -> str:
74
+ if self._tiktoken_encoding_name is not None:
75
+ return self._tiktoken_encoding_name
76
+ return CFG.TIKTOKEN_ENCODING_NAME
78
77
 
79
78
  def set_max_requests_per_minute(self, value: int):
80
79
  self._max_requests_per_minute = value
@@ -85,29 +84,56 @@ class LLMRateLimiter:
85
84
  def set_max_tokens_per_request(self, value: int):
86
85
  self._max_tokens_per_request = value
87
86
 
87
+ def set_max_tokens_per_tool_call_result(self, value: int):
88
+ self._max_tokens_per_tool_call_result = value
89
+
88
90
  def set_throttle_sleep(self, value: float):
89
91
  self._throttle_sleep = value
90
92
 
91
- def set_token_counter_fn(self, fn: Callable[[str], int]):
92
- self._token_counter_fn = fn
93
-
94
- def clip_prompt(self, prompt: str, limit: int) -> str:
95
- token_count = self.count_token(prompt)
96
- if token_count <= limit:
97
- return prompt
98
- while token_count > limit:
99
- prompt_parts = prompt.split(" ")
100
- last_part_index = len(prompt_parts) - 2
101
- clipped_prompt = " ".join(prompt_parts[:last_part_index])
102
- clipped_prompt += "(Content clipped...)"
103
- token_count = self.count_token(clipped_prompt)
104
- if token_count < limit:
105
- return clipped_prompt
106
- return prompt[:limit]
107
-
108
- async def throttle(self, prompt: str):
93
+ def count_token(self, prompt: Any) -> int:
94
+ str_prompt = self._prompt_to_str(prompt)
95
+ if not self.use_tiktoken:
96
+ return self._fallback_count_token(str_prompt)
97
+ try:
98
+ import tiktoken
99
+
100
+ enc = tiktoken.get_encoding(self.tiktoken_encoding_name)
101
+ return len(enc.encode(str_prompt))
102
+ except Exception:
103
+ return self._fallback_count_token(str_prompt)
104
+
105
+ def _fallback_count_token(self, str_prompt: str) -> int:
106
+ return len(str_prompt) // 4
107
+
108
+ def clip_prompt(self, prompt: Any, limit: int) -> str:
109
+ str_prompt = self._prompt_to_str(prompt)
110
+ if not self.use_tiktoken:
111
+ return self._fallback_clip_prompt(str_prompt, limit)
112
+ try:
113
+ import tiktoken
114
+
115
+ enc = tiktoken.get_encoding(self.tiktoken_encoding_name)
116
+ tokens = enc.encode(str_prompt)
117
+ if len(tokens) <= limit:
118
+ return str_prompt
119
+ truncated = tokens[: limit - 3]
120
+ clipped_text = enc.decode(truncated)
121
+ return clipped_text + "..."
122
+ except Exception:
123
+ return self._fallback_clip_prompt(str_prompt, limit)
124
+
125
+ def _fallback_clip_prompt(self, str_prompt: str, limit: int) -> str:
126
+ char_limit = limit * 4 if limit * 4 <= 10 else limit * 4 - 10
127
+ return str_prompt[:char_limit] + "..."
128
+
129
+ async def throttle(
130
+ self,
131
+ prompt: Any,
132
+ throttle_notif_callback: Callable | None = None,
133
+ ):
109
134
  now = time.time()
110
- tokens = self.count_token(prompt)
135
+ str_prompt = self._prompt_to_str(prompt)
136
+ tokens = self.count_token(str_prompt)
111
137
  # Clean up old entries
112
138
  while self.request_times and now - self.request_times[0] > 60:
113
139
  self.request_times.popleft()
@@ -116,13 +142,25 @@ class LLMRateLimiter:
116
142
  # Check per-request token limit
117
143
  if tokens > self.max_tokens_per_request:
118
144
  raise ValueError(
119
- f"Request exceeds max_tokens_per_request ({self.max_tokens_per_request})."
145
+ (
146
+ "Request exceeds max_tokens_per_request "
147
+ "({tokens} > {self.max_tokens_per_request})."
148
+ )
149
+ )
150
+ if tokens > self.max_tokens_per_minute:
151
+ raise ValueError(
152
+ (
153
+ "Request exceeds max_tokens_per_minute "
154
+ "({tokens} > {self.max_tokens_per_minute})."
155
+ )
120
156
  )
121
157
  # Wait if over per-minute request or token limit
122
158
  while (
123
159
  len(self.request_times) >= self.max_requests_per_minute
124
160
  or sum(t for _, t in self.token_times) + tokens > self.max_tokens_per_minute
125
161
  ):
162
+ if throttle_notif_callback is not None:
163
+ throttle_notif_callback()
126
164
  await asyncio.sleep(self.throttle_sleep)
127
165
  now = time.time()
128
166
  while self.request_times and now - self.request_times[0] > 60:
@@ -133,5 +171,11 @@ class LLMRateLimiter:
133
171
  self.request_times.append(now)
134
172
  self.token_times.append((now, tokens))
135
173
 
174
+ def _prompt_to_str(self, prompt: Any) -> str:
175
+ try:
176
+ return json.dumps(prompt)
177
+ except Exception:
178
+ return f"{prompt}"
179
+
136
180
 
137
181
  llm_rate_limitter = LLMRateLimiter()
@@ -29,26 +29,32 @@ class AnySharedContext(ABC):
29
29
  pass
30
30
 
31
31
  @property
32
+ @abstractmethod
32
33
  def input(self) -> DotDict:
33
34
  pass
34
35
 
35
36
  @property
37
+ @abstractmethod
36
38
  def env(self) -> DotDict:
37
39
  pass
38
40
 
39
41
  @property
42
+ @abstractmethod
40
43
  def args(self) -> list[Any]:
41
44
  pass
42
45
 
43
46
  @property
44
- def xcom(self) -> DotDict[str, Xcom]:
47
+ @abstractmethod
48
+ def xcom(self) -> DotDict:
45
49
  pass
46
50
 
47
51
  @property
52
+ @abstractmethod
48
53
  def shared_log(self) -> list[str]:
49
54
  pass
50
55
 
51
56
  @property
57
+ @abstractmethod
52
58
  def session(self) -> any_session.AnySession | None:
53
59
  pass
54
60
 
zrb/context/context.py CHANGED
@@ -63,7 +63,7 @@ class Context(AnyContext):
63
63
 
64
64
  @property
65
65
  def session(self) -> AnySession | None:
66
- return self._shared_ctx._session
66
+ return self._shared_ctx.session
67
67
 
68
68
  def update_task_env(self, task_env: dict[str, str]):
69
69
  self._env.update(task_env)
@@ -119,7 +119,13 @@ class Context(AnyContext):
119
119
  return
120
120
  color = self._color
121
121
  icon = self._icon
122
- max_name_length = max(len(name) + len(icon) for name in self.session.task_names)
122
+ # Handle case where session is None (e.g., in tests)
123
+ if self.session is None:
124
+ max_name_length = len(self._task_name) + len(icon)
125
+ else:
126
+ max_name_length = max(
127
+ len(name) + len(icon) for name in self.session.task_names
128
+ )
123
129
  styled_task_name = f"{icon} {self._task_name}"
124
130
  padded_styled_task_name = styled_task_name.rjust(max_name_length + 1)
125
131
  if self._attempt == 0:
@@ -27,6 +27,7 @@ class SharedContext(AnySharedContext):
27
27
  env: dict[str, str] = {},
28
28
  xcom: dict[str, Xcom] = {},
29
29
  logging_level: int | None = None,
30
+ is_web_mode: bool = False,
30
31
  ):
31
32
  self.__logging_level = logging_level
32
33
  self._input = DotDict(input)
@@ -35,18 +36,15 @@ class SharedContext(AnySharedContext):
35
36
  self._xcom = DotDict(xcom)
36
37
  self._session: AnySession | None = None
37
38
  self._log = []
39
+ self._is_web_mode = is_web_mode
38
40
 
39
41
  def __repr__(self):
40
42
  class_name = self.__class__.__name__
41
- input = self._input
42
- args = self._args
43
- env = self._env
44
- xcom = self._xcom
45
- return f"<{class_name} input={input} args={args} xcom={xcom} env={env}>"
43
+ return f"<{class_name}>"
46
44
 
47
45
  @property
48
46
  def is_web_mode(self) -> bool:
49
- return self.env.get("_ZRB_IS_WEB_MODE", "0") == "1"
47
+ return self._is_web_mode
50
48
 
51
49
  @property
52
50
  def is_tty(self) -> bool:
@@ -68,7 +66,7 @@ class SharedContext(AnySharedContext):
68
66
  return self._args
69
67
 
70
68
  @property
71
- def xcom(self) -> DotDict[str, Xcom]:
69
+ def xcom(self) -> DotDict:
72
70
  return self._xcom
73
71
 
74
72
  @property
@@ -83,7 +81,7 @@ class SharedContext(AnySharedContext):
83
81
  self._log.append(message)
84
82
  session = self.session
85
83
  if session is not None:
86
- session_parent: AnySession = session.parent
84
+ session_parent: AnySession | None = session.parent
87
85
  if session_parent is not None:
88
86
  session_parent.shared_ctx.append_to_shared_log(message)
89
87
 
zrb/group/any_group.py CHANGED
@@ -1,5 +1,4 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Optional, Union
3
2
 
4
3
  from zrb.task.any_task import AnyTask
5
4
 
@@ -31,16 +30,24 @@ class AnyGroup(ABC):
31
30
 
32
31
  @property
33
32
  @abstractmethod
34
- def subgroups(self) -> dict[str, "AnyGroup"]:
33
+ def subgroups(self) -> "dict[str, AnyGroup]":
35
34
  """Group subgroups"""
36
35
  pass
37
36
 
38
37
  @abstractmethod
39
- def add_group(self, group: Union["AnyGroup", str]) -> "AnyGroup":
38
+ def add_group(self, group: "AnyGroup", alias: str | None = None) -> "AnyGroup":
40
39
  pass
41
40
 
42
41
  @abstractmethod
43
- def add_task(self, task: AnyTask, alias: str | None = None) -> AnyTask:
42
+ def add_task(self, task: "AnyTask", alias: str | None = None) -> "AnyTask":
43
+ pass
44
+
45
+ @abstractmethod
46
+ def remove_group(self, group: "AnyGroup | str"):
47
+ pass
48
+
49
+ @abstractmethod
50
+ def remove_task(self, task: "AnyTask | str"):
44
51
  pass
45
52
 
46
53
  @abstractmethod
@@ -48,5 +55,5 @@ class AnyGroup(ABC):
48
55
  pass
49
56
 
50
57
  @abstractmethod
51
- def get_group_by_alias(self, name: str) -> Optional["AnyGroup"]:
58
+ def get_group_by_alias(self, alias: str) -> "AnyGroup | None":
52
59
  pass
zrb/group/group.py CHANGED
@@ -33,15 +33,15 @@ class Group(AnyGroup):
33
33
  def subgroups(self) -> dict[str, AnyGroup]:
34
34
  names = list(self._groups.keys())
35
35
  names.sort()
36
- return {name: self._groups.get(name) for name in names}
36
+ return {name: self._groups[name] for name in names}
37
37
 
38
38
  @property
39
39
  def subtasks(self) -> dict[str, AnyTask]:
40
40
  alias = list(self._tasks.keys())
41
41
  alias.sort()
42
- return {name: self._tasks.get(name) for name in alias}
42
+ return {name: self._tasks[name] for name in alias}
43
43
 
44
- def add_group(self, group: AnyGroup | str, alias: str | None = None) -> AnyGroup:
44
+ def add_group(self, group: AnyGroup, alias: str | None = None) -> AnyGroup:
45
45
  real_group = Group(group) if isinstance(group, str) else group
46
46
  alias = alias if alias is not None else real_group.name
47
47
  self._groups[alias] = real_group
@@ -52,6 +52,70 @@ class Group(AnyGroup):
52
52
  self._tasks[alias] = task
53
53
  return task
54
54
 
55
+ def remove_group(self, group: "AnyGroup | str"):
56
+ original_groups_len = len(self._groups)
57
+ if isinstance(group, AnyGroup):
58
+ new_groups = {
59
+ alias: existing_group
60
+ for alias, existing_group in self._groups.items()
61
+ if group != existing_group
62
+ }
63
+ if len(new_groups) == original_groups_len:
64
+ raise ValueError(f"Cannot remove group {group} from {self}")
65
+ self._groups = new_groups
66
+ return
67
+ # group is string, try to remove by alias
68
+ new_groups = {
69
+ alias: existing_group
70
+ for alias, existing_group in self._groups.items()
71
+ if alias != group
72
+ }
73
+ if len(new_groups) < original_groups_len:
74
+ self._groups = new_groups
75
+ return
76
+ # if alias removal didn't work, try to remove by name
77
+ new_groups = {
78
+ alias: existing_group
79
+ for alias, existing_group in self._groups.items()
80
+ if existing_group.name != group
81
+ }
82
+ if len(new_groups) < original_groups_len:
83
+ self._groups = new_groups
84
+ return
85
+ raise ValueError(f"Cannot remove group {group} from {self}")
86
+
87
+ def remove_task(self, task: "AnyTask | str"):
88
+ original_tasks_len = len(self._tasks)
89
+ if isinstance(task, AnyTask):
90
+ new_tasks = {
91
+ alias: existing_task
92
+ for alias, existing_task in self._tasks.items()
93
+ if task != existing_task
94
+ }
95
+ if len(new_tasks) == original_tasks_len:
96
+ raise ValueError(f"Cannot remove task {task} from {self}")
97
+ self._tasks = new_tasks
98
+ return
99
+ # task is string, try to remove by alias
100
+ new_tasks = {
101
+ alias: existing_task
102
+ for alias, existing_task in self._tasks.items()
103
+ if alias != task
104
+ }
105
+ if len(new_tasks) < original_tasks_len:
106
+ self._tasks = new_tasks
107
+ return
108
+ # if alias removal didn't work, try to remove by name
109
+ new_tasks = {
110
+ alias: existing_task
111
+ for alias, existing_task in self._tasks.items()
112
+ if existing_task.name != task
113
+ }
114
+ if len(new_tasks) < original_tasks_len:
115
+ self._tasks = new_tasks
116
+ return
117
+ raise ValueError(f"Cannot remove task {task} from {self}")
118
+
55
119
  def get_task_by_alias(self, alias: str) -> AnyTask | None:
56
120
  return self._tasks.get(alias)
57
121
 
zrb/input/any_input.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
+ from typing import Any
2
3
 
3
4
  from zrb.context.any_shared_context import AnySharedContext
4
5
 
@@ -35,7 +36,10 @@ class AnyInput(ABC):
35
36
 
36
37
  @abstractmethod
37
38
  def update_shared_context(
38
- self, shared_ctx: AnySharedContext, str_value: str | None = None
39
+ self,
40
+ shared_ctx: AnySharedContext,
41
+ str_value: str | None = None,
42
+ value: Any = None,
39
43
  ):
40
44
  pass
41
45
 
zrb/input/base_input.py CHANGED
@@ -58,11 +58,15 @@ class BaseInput(AnyInput):
58
58
  return f'<input name="{name}" placeholder="{description}" value="{default}" />'
59
59
 
60
60
  def update_shared_context(
61
- self, shared_ctx: AnySharedContext, str_value: str | None = None
61
+ self,
62
+ shared_ctx: AnySharedContext,
63
+ str_value: str | None = None,
64
+ value: Any = None,
62
65
  ):
63
- if str_value is None:
64
- str_value = self.get_default_str(shared_ctx)
65
- value = self._parse_str_value(str_value)
66
+ if value is None:
67
+ if str_value is None:
68
+ str_value = self.get_default_str(shared_ctx)
69
+ value = self._parse_str_value(str_value)
66
70
  if self.name in shared_ctx.input:
67
71
  raise ValueError(f"Input already defined in the context: {self.name}")
68
72
  shared_ctx.input[self.name] = value
@@ -91,12 +95,20 @@ class BaseInput(AnyInput):
91
95
  default_str = self.get_default_str(shared_ctx)
92
96
  if default_str != "":
93
97
  prompt_message = f"{prompt_message} [{default_str}]"
94
- print(f"{prompt_message}: ", end="")
95
- value = input()
98
+ value = self._read_line(shared_ctx, prompt_message)
96
99
  if value.strip() == "":
97
100
  value = default_str
98
101
  return value
99
102
 
103
+ def _read_line(self, shared_ctx: AnySharedContext, prompt_message: str) -> str:
104
+ if not shared_ctx.is_tty:
105
+ print(f"{prompt_message}: ", end="")
106
+ return input()
107
+ from prompt_toolkit import PromptSession
108
+
109
+ reader = PromptSession()
110
+ return reader.prompt(f"{prompt_message}: ")
111
+
100
112
  def get_default_str(self, shared_ctx: AnySharedContext) -> str:
101
113
  """Get default value as str"""
102
114
  default_value = get_attr(
zrb/input/text_input.py CHANGED
@@ -1,12 +1,9 @@
1
- import os
2
- import subprocess
3
- import tempfile
4
1
  from collections.abc import Callable
5
2
 
6
3
  from zrb.config.config import CFG
7
4
  from zrb.context.any_shared_context import AnySharedContext
8
5
  from zrb.input.base_input import BaseInput
9
- from zrb.util.file import read_file
6
+ from zrb.util.cli.text import edit_text
10
7
 
11
8
 
12
9
  class TextInput(BaseInput):
@@ -85,24 +82,10 @@ class TextInput(BaseInput):
85
82
  comment_prompt_message = (
86
83
  f"{self.comment_start}{prompt_message}{self.comment_end}"
87
84
  )
88
- comment_prompt_message_eol = f"{comment_prompt_message}\n"
89
85
  default_value = self.get_default_str(shared_ctx)
90
- with tempfile.NamedTemporaryFile(
91
- delete=False, suffix=self._extension
92
- ) as temp_file:
93
- temp_file_name = temp_file.name
94
- temp_file.write(comment_prompt_message_eol.encode())
95
- # Pre-fill with default content
96
- if default_value:
97
- temp_file.write(default_value.encode())
98
- temp_file.flush()
99
- subprocess.call([self.editor_cmd, temp_file_name])
100
- # Read the edited content
101
- edited_content = read_file(temp_file_name)
102
- parts = [
103
- text.strip() for text in edited_content.split(comment_prompt_message, 1)
104
- ]
105
- edited_content = "\n".join(parts).lstrip()
106
- os.remove(temp_file_name)
107
- print(f"{prompt_message}: {edited_content}")
108
- return edited_content
86
+ return edit_text(
87
+ prompt_message=comment_prompt_message,
88
+ value=default_value,
89
+ editor=self.editor_cmd,
90
+ extension=self._extension,
91
+ )
zrb/runner/cli.py CHANGED
@@ -7,7 +7,7 @@ from zrb.context.any_context import AnyContext
7
7
  from zrb.context.shared_context import SharedContext
8
8
  from zrb.group.any_group import AnyGroup
9
9
  from zrb.group.group import Group
10
- from zrb.runner.common_util import get_run_kwargs
10
+ from zrb.runner.common_util import get_task_str_kwargs
11
11
  from zrb.session.session import Session
12
12
  from zrb.session_state_logger.session_state_logger_factory import session_state_logger
13
13
  from zrb.task.any_task import AnyTask
@@ -38,23 +38,25 @@ class Cli(Group):
38
38
  def banner(self) -> str:
39
39
  return CFG.BANNER
40
40
 
41
- def run(self, args: list[str] = []):
42
- kwargs, args = self._extract_kwargs_from_args(args)
43
- node, node_path, args = extract_node_from_args(self, args)
41
+ def run(self, str_args: list[str] = []):
42
+ str_kwargs, str_args = self._extract_kwargs_from_args(str_args)
43
+ node, node_path, str_args = extract_node_from_args(self, str_args)
44
44
  if isinstance(node, AnyGroup):
45
45
  self._show_group_info(node)
46
46
  return
47
- if "h" in kwargs or "help" in kwargs:
47
+ if "h" in str_kwargs or "help" in str_kwargs:
48
48
  self._show_task_info(node)
49
49
  return
50
- run_kwargs = get_run_kwargs(task=node, args=args, kwargs=kwargs, cli_mode=True)
50
+ task_str_kwargs = get_task_str_kwargs(
51
+ task=node, str_args=str_args, str_kwargs=str_kwargs, cli_mode=True
52
+ )
51
53
  try:
52
- result = self._run_task(node, args, run_kwargs)
54
+ result = self._run_task(node, str_args, task_str_kwargs)
53
55
  if result is not None:
54
56
  print(result)
55
57
  return result
56
58
  finally:
57
- run_command = self._get_run_command(node_path, run_kwargs)
59
+ run_command = self._get_run_command(node_path, task_str_kwargs)
58
60
  self._print_run_command(run_command)
59
61
 
60
62
  def _print_run_command(self, run_command: str):
@@ -64,11 +66,14 @@ class Cli(Group):
64
66
  file=sys.stderr,
65
67
  )
66
68
 
67
- def _get_run_command(self, node_path: list[str], run_kwargs: dict[str, str]) -> str:
69
+ def _get_run_command(
70
+ self, node_path: list[str], task_str_kwargs: dict[str, str]
71
+ ) -> str:
68
72
  parts = [self.name] + node_path
69
- if len(run_kwargs) > 0:
73
+ if len(task_str_kwargs) > 0:
70
74
  parts += [
71
- self._get_run_command_param(key, val) for key, val in run_kwargs.items()
75
+ self._get_run_command_param(key, val)
76
+ for key, val in task_str_kwargs.items()
72
77
  ]
73
78
  return " ".join(parts)
74
79
 
@@ -81,13 +86,9 @@ class Cli(Group):
81
86
  self, task: AnyTask, args: list[str], run_kwargs: dict[str, str]
82
87
  ) -> tuple[Any]:
83
88
  shared_ctx = SharedContext(args=args)
84
- for task_input in task.inputs:
85
- if task_input.name in run_kwargs:
86
- task_input.update_shared_context(
87
- shared_ctx, run_kwargs[task_input.name]
88
- )
89
- continue
90
- return task.run(Session(shared_ctx=shared_ctx, root_group=self))
89
+ return task.run(
90
+ Session(shared_ctx=shared_ctx, root_group=self), str_kwargs=run_kwargs
91
+ )
91
92
 
92
93
  def _show_task_info(self, task: AnyTask):
93
94
  description = task.description
@@ -150,11 +151,11 @@ class Cli(Group):
150
151
  kwargs[key] = args[i + 1]
151
152
  i += 1 # Skip the next argument as it's a value
152
153
  else:
153
- kwargs[key] = True
154
+ kwargs[key] = "true"
154
155
  elif arg.startswith("-"):
155
156
  # Handle short flags like -t or -n
156
157
  key = arg[1:]
157
- kwargs[key] = True
158
+ kwargs[key] = "true"
158
159
  else:
159
160
  # Anything else is considered a positional argument
160
161
  residual_args.append(arg)