zrb 1.5.11__py3-none-any.whl → 1.5.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. zrb/builtin/llm/llm_chat.py +1 -1
  2. zrb/builtin/llm/tool/__init__.py +0 -0
  3. zrb/builtin/llm/tool/sub_agent.py +125 -0
  4. zrb/builtin/llm/tool/web.py +0 -2
  5. zrb/config.py +0 -3
  6. zrb/llm_config.py +16 -2
  7. zrb/task/base_task.py +20 -0
  8. zrb/task/llm/agent.py +5 -8
  9. zrb/task/llm/context.py +17 -8
  10. zrb/task/llm/context_enrichment.py +52 -13
  11. zrb/task/llm/history_summarization.py +3 -5
  12. zrb/task/llm/prompt.py +7 -4
  13. zrb/task/llm/tool_wrapper.py +115 -53
  14. zrb/task/llm_task.py +16 -1
  15. zrb/util/attr.py +84 -1
  16. zrb/util/cli/style.py +147 -0
  17. zrb/util/cli/subcommand.py +22 -1
  18. zrb/util/cmd/command.py +18 -0
  19. zrb/util/cmd/remote.py +15 -0
  20. zrb/util/codemod/modification_mode.py +4 -0
  21. zrb/util/codemod/modify_class.py +72 -0
  22. zrb/util/codemod/modify_class_parent.py +68 -0
  23. zrb/util/codemod/modify_class_property.py +67 -0
  24. zrb/util/codemod/modify_dict.py +62 -0
  25. zrb/util/codemod/modify_function.py +75 -3
  26. zrb/util/codemod/modify_function_call.py +72 -0
  27. zrb/util/codemod/modify_method.py +77 -0
  28. zrb/util/codemod/modify_module.py +10 -0
  29. zrb/util/cron.py +37 -3
  30. zrb/util/file.py +32 -0
  31. zrb/util/git.py +113 -0
  32. zrb/util/git_subtree.py +58 -0
  33. zrb/util/group.py +64 -2
  34. zrb/util/load.py +29 -0
  35. zrb/util/run.py +9 -0
  36. zrb/util/string/conversion.py +86 -0
  37. zrb/util/string/format.py +20 -0
  38. zrb/util/string/name.py +12 -0
  39. zrb/util/todo.py +165 -4
  40. {zrb-1.5.11.dist-info → zrb-1.5.12.dist-info}/METADATA +3 -3
  41. {zrb-1.5.11.dist-info → zrb-1.5.12.dist-info}/RECORD +43 -41
  42. {zrb-1.5.11.dist-info → zrb-1.5.12.dist-info}/WHEEL +0 -0
  43. {zrb-1.5.11.dist-info → zrb-1.5.12.dist-info}/entry_points.txt +0 -0
@@ -6,24 +6,85 @@ from zrb.util.codemod.modification_mode import APPEND, PREPEND, REPLACE
6
6
  def replace_method_code(
7
7
  original_code: str, class_name: str, method_name: str, new_code: str
8
8
  ) -> str:
9
+ """
10
+ Replace the entire code body of a specified method within a class.
11
+
12
+ Args:
13
+ original_code (str): The original Python code as a string.
14
+ class_name (str): The name of the class containing the method.
15
+ method_name (str): The name of the method to modify.
16
+ new_code (str): The new code body for the method as a string.
17
+
18
+ Returns:
19
+ str: The modified Python code as a string.
20
+
21
+ Raises:
22
+ ValueError: If the specified class or method is not found in the code.
23
+ """
9
24
  return _modify_method(original_code, class_name, method_name, new_code, REPLACE)
10
25
 
11
26
 
12
27
  def prepend_code_to_method(
13
28
  original_code: str, class_name: str, method_name: str, new_code: str
14
29
  ) -> str:
30
+ """
31
+ Prepend code to the body of a specified method within a class.
32
+
33
+ Args:
34
+ original_code (str): The original Python code as a string.
35
+ class_name (str): The name of the class containing the method.
36
+ method_name (str): The name of the method to modify.
37
+ new_code (str): The code to prepend as a string.
38
+
39
+ Returns:
40
+ str: The modified Python code as a string.
41
+
42
+ Raises:
43
+ ValueError: If the specified class or method is not found in the code.
44
+ """
15
45
  return _modify_method(original_code, class_name, method_name, new_code, PREPEND)
16
46
 
17
47
 
18
48
  def append_code_to_method(
19
49
  original_code: str, class_name: str, method_name: str, new_code: str
20
50
  ) -> str:
51
+ """
52
+ Append code to the body of a specified method within a class.
53
+
54
+ Args:
55
+ original_code (str): The original Python code as a string.
56
+ class_name (str): The name of the class containing the method.
57
+ method_name (str): The name of the method to modify.
58
+ new_code (str): The code to append as a string.
59
+
60
+ Returns:
61
+ str: The modified Python code as a string.
62
+
63
+ Raises:
64
+ ValueError: If the specified class or method is not found in the code.
65
+ """
21
66
  return _modify_method(original_code, class_name, method_name, new_code, APPEND)
22
67
 
23
68
 
24
69
  def _modify_method(
25
70
  original_code: str, class_name: str, method_name: str, new_code: str, mode: int
26
71
  ) -> str:
72
+ """
73
+ Modify the code body of a specified method within a class.
74
+
75
+ Args:
76
+ original_code (str): The original Python code as a string.
77
+ class_name (str): The name of the class containing the method.
78
+ method_name (str): The name of the method to modify.
79
+ new_code (str): The code to add/replace as a string.
80
+ mode (int): The modification mode (PREPEND, APPEND, or REPLACE).
81
+
82
+ Returns:
83
+ str: The modified Python code as a string.
84
+
85
+ Raises:
86
+ ValueError: If the specified class or method is not found in the code.
87
+ """
27
88
  # Parse the original code into a module
28
89
  module = cst.parse_module(original_code)
29
90
  # Initialize the transformer with the necessary information
@@ -40,7 +101,20 @@ def _modify_method(
40
101
 
41
102
 
42
103
  class _MethodModifier(cst.CSTTransformer):
104
+ """
105
+ A LibCST transformer to modify the code body of a method within a ClassDef node.
106
+ """
107
+
43
108
  def __init__(self, class_name: str, method_name: str, new_code: str, mode: int):
109
+ """
110
+ Initialize the transformer.
111
+
112
+ Args:
113
+ class_name (str): The name of the target class.
114
+ method_name (str): The name of the target method.
115
+ new_code (str): The new code body as a string.
116
+ mode (int): The modification mode (PREPEND, APPEND, or REPLACE).
117
+ """
44
118
  self.class_name = class_name
45
119
  self.method_name = method_name
46
120
  # Use parse_module to handle multiple statements
@@ -52,6 +126,9 @@ class _MethodModifier(cst.CSTTransformer):
52
126
  def leave_ClassDef(
53
127
  self, original_node: cst.ClassDef, updated_node: cst.ClassDef
54
128
  ) -> cst.ClassDef:
129
+ """
130
+ Called when leaving a ClassDef node. Modifies the body of the target method.
131
+ """
55
132
  # Check if the class matches the target class
56
133
  if original_node.name.value == self.class_name:
57
134
  self.class_found = True
@@ -1,4 +1,14 @@
1
1
  def prepend_code_to_module(original_code: str, new_code: str) -> str:
2
+ """
3
+ Prepend code to a module after the last import statement.
4
+
5
+ Args:
6
+ original_code (str): The original Python code as a string.
7
+ new_code (str): The code to prepend as a string.
8
+
9
+ Returns:
10
+ str: The modified Python code as a string.
11
+ """
2
12
  lines = original_code.splitlines()
3
13
  last_import_index = -1
4
14
  for i, line in enumerate(lines):
zrb/util/cron.py CHANGED
@@ -2,7 +2,20 @@ import datetime
2
2
 
3
3
 
4
4
  def parse_cron_field(field: str, min_value: int, max_value: int):
5
- """Parse a cron field with support for wildcards (*), ranges, lists, and steps."""
5
+ """
6
+ Parse a cron field string into a set of integer values.
7
+
8
+ Supports wildcards (*), ranges (e.g., 1-5), lists (e.g., 1,3,5),
9
+ and step values (e.g., */5, 1-10/2).
10
+
11
+ Args:
12
+ field (str): The cron field string (e.g., "*", "1-5", "*/10").
13
+ min_value (int): The minimum allowed value for the field.
14
+ max_value (int): The maximum allowed value for the field.
15
+
16
+ Returns:
17
+ set[int]: A set of integer values represented by the cron field.
18
+ """
6
19
  values = set()
7
20
  if field == "*":
8
21
  return set(range(min_value, max_value + 1))
@@ -34,7 +47,16 @@ def parse_cron_field(field: str, min_value: int, max_value: int):
34
47
 
35
48
 
36
49
  def handle_special_cron_patterns(pattern: str, dt: datetime.datetime):
37
- """Handle special cron patterns like @yearly, @monthly, etc."""
50
+ """
51
+ Check if a datetime object matches a special cron pattern.
52
+
53
+ Args:
54
+ pattern (str): The special cron pattern (e.g., "@yearly", "@monthly").
55
+ dt (datetime.datetime): The datetime object to check.
56
+
57
+ Returns:
58
+ bool: True if the datetime matches the pattern, False otherwise.
59
+ """
38
60
  if pattern == "@yearly" or pattern == "@annually":
39
61
  return dt.month == 1 and dt.day == 1 and dt.hour == 0 and dt.minute == 0
40
62
  elif pattern == "@monthly":
@@ -53,7 +75,19 @@ def handle_special_cron_patterns(pattern: str, dt: datetime.datetime):
53
75
 
54
76
 
55
77
  def match_cron(cron_pattern: str, dt: datetime.datetime):
56
- """Check if a datetime object matches a cron pattern, including special cases."""
78
+ """
79
+ Check if a datetime object matches a cron pattern.
80
+
81
+ Supports standard cron format (minute hour day month day_of_week)
82
+ and special patterns (e.g., "@yearly", "@monthly").
83
+
84
+ Args:
85
+ cron_pattern (str): The cron pattern string.
86
+ dt (datetime.datetime): The datetime object to check.
87
+
88
+ Returns:
89
+ bool: True if the datetime matches the cron pattern, False otherwise.
90
+ """
57
91
  # Handle special cron patterns
58
92
  if cron_pattern.startswith("@"):
59
93
  return handle_special_cron_patterns(cron_pattern, dt)
zrb/util/file.py CHANGED
@@ -3,6 +3,15 @@ import re
3
3
 
4
4
 
5
5
  def read_file(file_path: str, replace_map: dict[str, str] = {}) -> str:
6
+ """Reads a file and optionally replaces content based on a map.
7
+
8
+ Args:
9
+ file_path: The path to the file.
10
+ replace_map: A dictionary of strings to replace.
11
+
12
+ Returns:
13
+ The content of the file with replacements applied.
14
+ """
6
15
  with open(
7
16
  os.path.abspath(os.path.expanduser(file_path)), "r", encoding="utf-8"
8
17
  ) as f:
@@ -15,6 +24,15 @@ def read_file(file_path: str, replace_map: dict[str, str] = {}) -> str:
15
24
  def read_file_with_line_numbers(
16
25
  file_path: str, replace_map: dict[str, str] = {}
17
26
  ) -> str:
27
+ """Reads a file and returns content with line numbers.
28
+
29
+ Args:
30
+ file_path: The path to the file.
31
+ replace_map: A dictionary of strings to replace.
32
+
33
+ Returns:
34
+ The content of the file with line numbers and replacements applied.
35
+ """
18
36
  content = read_file(file_path, replace_map)
19
37
  lines = content.splitlines()
20
38
  numbered_lines = [f"{i + 1} | {line}" for i, line in enumerate(lines)]
@@ -22,10 +40,24 @@ def read_file_with_line_numbers(
22
40
 
23
41
 
24
42
  def read_dir(dir_path: str) -> list[str]:
43
+ """Reads a directory and returns a list of file names.
44
+
45
+ Args:
46
+ dir_path: The path to the directory.
47
+
48
+ Returns:
49
+ A list of file names in the directory.
50
+ """
25
51
  return [f for f in os.listdir(os.path.abspath(os.path.expanduser(dir_path)))]
26
52
 
27
53
 
28
54
  def write_file(file_path: str, content: str | list[str]):
55
+ """Writes content to a file.
56
+
57
+ Args:
58
+ file_path: The path to the file.
59
+ content: The content to write, either a string or a list of strings.
60
+ """
29
61
  if isinstance(content, list):
30
62
  content = "\n".join([line for line in content if line is not None])
31
63
  dir_path = os.path.dirname(file_path)
zrb/util/git.py CHANGED
@@ -19,6 +19,21 @@ async def get_diff(
19
19
  current_commit: str,
20
20
  print_method: Callable[..., Any] = print,
21
21
  ) -> DiffResult:
22
+ """
23
+ Get the difference between two commits in a Git repository.
24
+
25
+ Args:
26
+ repo_dir (str): The path to the Git repository.
27
+ source_commit (str): The source commit hash or reference.
28
+ current_commit (str): The current commit hash or reference.
29
+ print_method (Callable[..., Any]): Method to print command output.
30
+
31
+ Returns:
32
+ DiffResult: An object containing lists of created, removed, and updated files.
33
+
34
+ Raises:
35
+ Exception: If the git command returns a non-zero exit code.
36
+ """
22
37
  cmd_result, exit_code = await run_command(
23
38
  cmd=["git", "diff", source_commit, current_commit],
24
39
  cwd=repo_dir,
@@ -55,6 +70,18 @@ async def get_diff(
55
70
 
56
71
 
57
72
  async def get_repo_dir(print_method: Callable[..., Any] = print) -> str:
73
+ """
74
+ Get the top-level directory of the Git repository.
75
+
76
+ Args:
77
+ print_method (Callable[..., Any]): Method to print command output.
78
+
79
+ Returns:
80
+ str: The absolute path to the repository's top-level directory.
81
+
82
+ Raises:
83
+ Exception: If the git command returns a non-zero exit code.
84
+ """
58
85
  cmd_result, exit_code = await run_command(
59
86
  cmd=["git", "rev-parse", "--show-toplevel"],
60
87
  print_method=print_method,
@@ -67,6 +94,19 @@ async def get_repo_dir(print_method: Callable[..., Any] = print) -> str:
67
94
  async def get_current_branch(
68
95
  repo_dir: str, print_method: Callable[..., Any] = print
69
96
  ) -> str:
97
+ """
98
+ Get the current branch name of the Git repository.
99
+
100
+ Args:
101
+ repo_dir (str): The path to the Git repository.
102
+ print_method (Callable[..., Any]): Method to print command output.
103
+
104
+ Returns:
105
+ str: The name of the current branch.
106
+
107
+ Raises:
108
+ Exception: If the git command returns a non-zero exit code.
109
+ """
70
110
  cmd_result, exit_code = await run_command(
71
111
  cmd=["git", "rev-parse", "--abbrev-ref", "HEAD"],
72
112
  cwd=repo_dir,
@@ -80,6 +120,19 @@ async def get_current_branch(
80
120
  async def get_branches(
81
121
  repo_dir: str, print_method: Callable[..., Any] = print
82
122
  ) -> list[str]:
123
+ """
124
+ Get a list of all branches in the Git repository.
125
+
126
+ Args:
127
+ repo_dir (str): The path to the Git repository.
128
+ print_method (Callable[..., Any]): Method to print command output.
129
+
130
+ Returns:
131
+ list[str]: A list of branch names.
132
+
133
+ Raises:
134
+ Exception: If the git command returns a non-zero exit code.
135
+ """
83
136
  cmd_result, exit_code = await run_command(
84
137
  cmd=["git", "rev-parse", "--abbrev-ref", "HEAD"],
85
138
  cwd=repo_dir,
@@ -95,6 +148,20 @@ async def get_branches(
95
148
  async def delete_branch(
96
149
  repo_dir: str, branch_name: str, print_method: Callable[..., Any] = print
97
150
  ) -> str:
151
+ """
152
+ Delete a branch in the Git repository.
153
+
154
+ Args:
155
+ repo_dir (str): The path to the Git repository.
156
+ branch_name (str): The name of the branch to delete.
157
+ print_method (Callable[..., Any]): Method to print command output.
158
+
159
+ Returns:
160
+ str: The output of the git command.
161
+
162
+ Raises:
163
+ Exception: If the git command returns a non-zero exit code.
164
+ """
98
165
  cmd_result, exit_code = await run_command(
99
166
  cmd=["git", "branch", "-D", branch_name],
100
167
  cwd=repo_dir,
@@ -106,6 +173,16 @@ async def delete_branch(
106
173
 
107
174
 
108
175
  async def add(repo_dir: str, print_method: Callable[..., Any] = print):
176
+ """
177
+ Add all changes to the Git staging area.
178
+
179
+ Args:
180
+ repo_dir (str): The path to the Git repository.
181
+ print_method (Callable[..., Any]): Method to print command output.
182
+
183
+ Raises:
184
+ Exception: If the git command returns a non-zero exit code.
185
+ """
109
186
  _, exit_code = await run_command(
110
187
  cmd=["git", "add", ".", "-A"],
111
188
  cwd=repo_dir,
@@ -118,6 +195,18 @@ async def add(repo_dir: str, print_method: Callable[..., Any] = print):
118
195
  async def commit(
119
196
  repo_dir: str, message: str, print_method: Callable[..., Any] = print
120
197
  ) -> str:
198
+ """
199
+ Commit changes in the Git repository.
200
+
201
+ Args:
202
+ repo_dir (str): The path to the Git repository.
203
+ message (str): The commit message.
204
+ print_method (Callable[..., Any]): Method to print command output.
205
+
206
+ Raises:
207
+ Exception: If the git command returns a non-zero exit code, unless it's
208
+ the "nothing to commit" message.
209
+ """
121
210
  cmd_result, exit_code = await run_command(
122
211
  cmd=["git", "commit", "-m", message],
123
212
  cwd=repo_dir,
@@ -135,6 +224,18 @@ async def commit(
135
224
  async def pull(
136
225
  repo_dir: str, remote: str, branch: str, print_method: Callable[..., Any] = print
137
226
  ) -> str:
227
+ """
228
+ Pull changes from a remote repository and branch.
229
+
230
+ Args:
231
+ repo_dir (str): The path to the Git repository.
232
+ remote (str): The name of the remote.
233
+ branch (str): The name of the branch.
234
+ print_method (Callable[..., Any]): Method to print command output.
235
+
236
+ Raises:
237
+ Exception: If the git command returns a non-zero exit code.
238
+ """
138
239
  _, exit_code = await run_command(
139
240
  cmd=["git", "pull", remote, branch],
140
241
  cwd=repo_dir,
@@ -147,6 +248,18 @@ async def pull(
147
248
  async def push(
148
249
  repo_dir: str, remote: str, branch: str, print_method: Callable[..., Any] = print
149
250
  ) -> str:
251
+ """
252
+ Push changes to a remote repository and branch.
253
+
254
+ Args:
255
+ repo_dir (str): The path to the Git repository.
256
+ remote (str): The name of the remote.
257
+ branch (str): The name of the branch.
258
+ print_method (Callable[..., Any]): Method to print command output.
259
+
260
+ Raises:
261
+ Exception: If the git command returns a non-zero exit code.
262
+ """
150
263
  _, exit_code = await run_command(
151
264
  cmd=["git", "push", "-u", remote, branch],
152
265
  cwd=repo_dir,
zrb/util/git_subtree.py CHANGED
@@ -19,6 +19,15 @@ class SubTreeConfig(BaseModel):
19
19
 
20
20
 
21
21
  def load_config(repo_dir: str) -> SubTreeConfig:
22
+ """
23
+ Load the subtree configuration from subtrees.json.
24
+
25
+ Args:
26
+ repo_dir (str): The path to the Git repository.
27
+
28
+ Returns:
29
+ SubTreeConfig: The loaded subtree configuration.
30
+ """
22
31
  file_path = os.path.join(repo_dir, "subtrees.json")
23
32
  if not os.path.exists(file_path):
24
33
  return SubTreeConfig(data={})
@@ -26,6 +35,13 @@ def load_config(repo_dir: str) -> SubTreeConfig:
26
35
 
27
36
 
28
37
  def save_config(repo_dir: str, config: SubTreeConfig):
38
+ """
39
+ Save the subtree configuration to subtrees.json.
40
+
41
+ Args:
42
+ repo_dir (str): The path to the Git repository.
43
+ config (SubTreeConfig): The subtree configuration to save.
44
+ """
29
45
  file_path = os.path.join(repo_dir, "subtrees.json")
30
46
  write_file(file_path, config.model_dump_json(indent=2))
31
47
 
@@ -38,6 +54,22 @@ async def add_subtree(
38
54
  prefix: str,
39
55
  print_method: Callable[..., Any] = print,
40
56
  ):
57
+ """
58
+ Add a Git subtree to the repository.
59
+
60
+ Args:
61
+ repo_dir (str): The path to the Git repository.
62
+ name (str): The name for the subtree configuration.
63
+ repo_url (str): The URL of the subtree repository.
64
+ branch (str): The branch of the subtree repository.
65
+ prefix (str): The local path where the subtree will be added.
66
+ print_method (Callable[..., Any]): Method to print command output.
67
+
68
+ Raises:
69
+ ValueError: If the prefix directory already exists or subtree config
70
+ name already exists.
71
+ Exception: If the git command returns a non-zero exit code.
72
+ """
41
73
  config = load_config()
42
74
  if os.path.isdir(prefix):
43
75
  raise ValueError(f"Directory exists: {prefix}")
@@ -71,6 +103,19 @@ async def pull_subtree(
71
103
  branch: str,
72
104
  print_method: Callable[..., Any] = print,
73
105
  ):
106
+ """
107
+ Pull changes from a Git subtree.
108
+
109
+ Args:
110
+ repo_dir (str): The path to the Git repository.
111
+ prefix (str): The local path of the subtree.
112
+ repo_url (str): The URL of the subtree repository.
113
+ branch (str): The branch of the subtree repository.
114
+ print_method (Callable[..., Any]): Method to print command output.
115
+
116
+ Raises:
117
+ Exception: If the git command returns a non-zero exit code.
118
+ """
74
119
  _, exit_code = await run_command(
75
120
  cmd=[
76
121
  "git",
@@ -95,6 +140,19 @@ async def push_subtree(
95
140
  branch: str,
96
141
  print_method: Callable[..., Any] = print,
97
142
  ):
143
+ """
144
+ Push changes to a Git subtree.
145
+
146
+ Args:
147
+ repo_dir (str): The path to the Git repository.
148
+ prefix (str): The local path of the subtree.
149
+ repo_url (str): The URL of the subtree repository.
150
+ branch (str): The branch of the subtree repository.
151
+ print_method (Callable[..., Any]): Method to print command output.
152
+
153
+ Raises:
154
+ Exception: If the git command returns a non-zero exit code.
155
+ """
98
156
  _, exit_code = await run_command(
99
157
  cmd=[
100
158
  "git",
zrb/util/group.py CHANGED
@@ -9,6 +9,21 @@ class NodeNotFoundError(ValueError):
9
9
  def extract_node_from_args(
10
10
  root_group: AnyGroup, args: list[str], web_only: bool = False
11
11
  ) -> tuple[AnyGroup | AnyTask, list[str], list[str]]:
12
+ """
13
+ Extract a node (Group or Task) from a list of command-line arguments.
14
+
15
+ Args:
16
+ root_group (AnyGroup): The root group to start the search from.
17
+ args (list[str]): The list of command-line arguments.
18
+ web_only (bool): If True, only consider tasks that are not CLI-only.
19
+
20
+ Returns:
21
+ tuple[AnyGroup | AnyTask, list[str], list[str]]: A tuple containing the
22
+ extracted node, the path to the node, and any residual arguments.
23
+
24
+ Raises:
25
+ NodeNotFoundError: If no matching task or group is found for a given argument.
26
+ """
12
27
  node = root_group
13
28
  node_path = []
14
29
  residual_args = []
@@ -17,8 +32,12 @@ def extract_node_from_args(
17
32
  if web_only and task is not None and task.cli_only:
18
33
  task = None
19
34
  group = node.get_group_by_alias(name)
20
- if group is not None and len(get_all_subtasks(group, web_only)) == 0:
21
- # If group doesn't contain any task, then ignore its existence
35
+ # Only ignore empty groups if web_only is True
36
+ if (
37
+ group is not None
38
+ and web_only
39
+ and len(get_all_subtasks(group, web_only)) == 0
40
+ ):
22
41
  group = None
23
42
  if task is None and group is None:
24
43
  raise NodeNotFoundError(
@@ -40,8 +59,21 @@ def extract_node_from_args(
40
59
 
41
60
 
42
61
  def get_node_path(group: AnyGroup, node: AnyGroup | AnyTask) -> list[str] | None:
62
+ """
63
+ Get the path (aliases) to a specific node within a group hierarchy.
64
+
65
+ Args:
66
+ group (AnyGroup): The group to search within.
67
+ node (AnyGroup | AnyTask): The target node.
68
+
69
+ Returns:
70
+ list[str] | None: A list of aliases representing the path to the node,
71
+ or None if the node is not found.
72
+ """
43
73
  if group is None:
44
74
  return []
75
+ if group == node: # Handle the case where the target is the starting group
76
+ return [group.name]
45
77
  if isinstance(node, AnyTask):
46
78
  for alias, subtask in group.subtasks.items():
47
79
  if subtask == node:
@@ -60,6 +92,16 @@ def get_node_path(group: AnyGroup, node: AnyGroup | AnyTask) -> list[str] | None
60
92
  def get_non_empty_subgroups(
61
93
  group: AnyGroup, web_only: bool = False
62
94
  ) -> dict[str, AnyGroup]:
95
+ """
96
+ Get subgroups that contain at least one task.
97
+
98
+ Args:
99
+ group (AnyGroup): The group to search within.
100
+ web_only (bool): If True, only consider tasks that are not CLI-only.
101
+
102
+ Returns:
103
+ dict[str, AnyGroup]: A dictionary of subgroups that are not empty.
104
+ """
63
105
  return {
64
106
  alias: subgroup
65
107
  for alias, subgroup in group.subgroups.items()
@@ -68,6 +110,16 @@ def get_non_empty_subgroups(
68
110
 
69
111
 
70
112
  def get_subtasks(group: AnyGroup, web_only: bool = False) -> dict[str, AnyTask]:
113
+ """
114
+ Get the direct subtasks of a group.
115
+
116
+ Args:
117
+ group (AnyGroup): The group to search within.
118
+ web_only (bool): If True, only include tasks that are not CLI-only.
119
+
120
+ Returns:
121
+ dict[str, AnyTask]: A dictionary of subtasks.
122
+ """
71
123
  return {
72
124
  alias: subtask
73
125
  for alias, subtask in group.subtasks.items()
@@ -76,6 +128,16 @@ def get_subtasks(group: AnyGroup, web_only: bool = False) -> dict[str, AnyTask]:
76
128
 
77
129
 
78
130
  def get_all_subtasks(group: AnyGroup, web_only: bool = False) -> list[AnyTask]:
131
+ """
132
+ Get all subtasks (including nested ones) within a group hierarchy.
133
+
134
+ Args:
135
+ group (AnyGroup): The group to search within.
136
+ web_only (bool): If True, only include tasks that are not CLI-only.
137
+
138
+ Returns:
139
+ list[AnyTask]: A list of all subtasks.
140
+ """
79
141
  subtasks = [
80
142
  subtask
81
143
  for subtask in group.subtasks.values()
zrb/util/load.py CHANGED
@@ -11,6 +11,17 @@ pattern = re.compile("[^a-zA-Z0-9]")
11
11
 
12
12
  @lru_cache
13
13
  def load_file(script_path: str, sys_path_index: int = 0) -> Any | None:
14
+ """
15
+ Load a Python module from a file path.
16
+
17
+ Args:
18
+ script_path (str): The path to the Python script.
19
+ sys_path_index (int): The index to insert the script directory into sys.path.
20
+
21
+ Returns:
22
+ Any | None: The loaded module object, or None if the file does not
23
+ exist or cannot be loaded.
24
+ """
14
25
  if not os.path.isfile(script_path):
15
26
  return None
16
27
  module_name = pattern.sub("", script_path)
@@ -31,6 +42,15 @@ def load_file(script_path: str, sys_path_index: int = 0) -> Any | None:
31
42
 
32
43
 
33
44
  def _get_new_python_path(dir_path: str) -> str:
45
+ """
46
+ Helper function to update the PYTHONPATH environment variable.
47
+
48
+ Args:
49
+ dir_path (str): The directory path to add to PYTHONPATH.
50
+
51
+ Returns:
52
+ str: The new value for the PYTHONPATH environment variable.
53
+ """
34
54
  current_python_path = os.getenv("PYTHONPATH")
35
55
  if current_python_path is None or current_python_path == "":
36
56
  return dir_path
@@ -40,5 +60,14 @@ def _get_new_python_path(dir_path: str) -> str:
40
60
 
41
61
 
42
62
  def load_module(module_name: str) -> Any:
63
+ """
64
+ Load a Python module by its name.
65
+
66
+ Args:
67
+ module_name (str): The name of the module to load.
68
+
69
+ Returns:
70
+ Any: The loaded module object.
71
+ """
43
72
  module = importlib.import_module(module_name)
44
73
  return module
zrb/util/run.py CHANGED
@@ -4,6 +4,15 @@ from typing import Any
4
4
 
5
5
 
6
6
  async def run_async(value: Any) -> Any:
7
+ """
8
+ Run a value asynchronously, awaiting if it's awaitable or running in a thread if not.
9
+
10
+ Args:
11
+ value (Any): The value to run. Can be awaitable or not.
12
+
13
+ Returns:
14
+ Any: The result of the awaited value or the value itself if not awaitable.
15
+ """
7
16
  if inspect.isawaitable(value):
8
17
  return await value
9
18
  return await asyncio.to_thread(lambda: value)