zrb 1.5.10__py3-none-any.whl → 1.5.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/builtin/llm/llm_chat.py +1 -1
- zrb/builtin/llm/tool/__init__.py +0 -0
- zrb/builtin/llm/tool/file.py +22 -88
- zrb/builtin/llm/tool/sub_agent.py +125 -0
- zrb/builtin/llm/tool/web.py +0 -2
- zrb/config.py +0 -3
- zrb/llm_config.py +16 -2
- zrb/task/base_task.py +20 -0
- zrb/task/cmd_task.py +2 -2
- zrb/task/llm/agent.py +5 -8
- zrb/task/llm/context.py +17 -8
- zrb/task/llm/context_enrichment.py +52 -13
- zrb/task/llm/history_summarization.py +3 -5
- zrb/task/llm/prompt.py +7 -4
- zrb/task/llm/tool_wrapper.py +115 -53
- zrb/task/llm_task.py +16 -1
- zrb/util/attr.py +84 -1
- zrb/util/cli/style.py +147 -0
- zrb/util/cli/subcommand.py +22 -1
- zrb/util/cmd/command.py +18 -0
- zrb/util/cmd/remote.py +15 -0
- zrb/util/codemod/modification_mode.py +4 -0
- zrb/util/codemod/modify_class.py +72 -0
- zrb/util/codemod/modify_class_parent.py +68 -0
- zrb/util/codemod/modify_class_property.py +67 -0
- zrb/util/codemod/modify_dict.py +62 -0
- zrb/util/codemod/modify_function.py +75 -3
- zrb/util/codemod/modify_function_call.py +72 -0
- zrb/util/codemod/modify_method.py +77 -0
- zrb/util/codemod/modify_module.py +10 -0
- zrb/util/cron.py +37 -3
- zrb/util/file.py +32 -0
- zrb/util/git.py +113 -0
- zrb/util/git_subtree.py +58 -0
- zrb/util/group.py +64 -2
- zrb/util/load.py +29 -0
- zrb/util/run.py +9 -0
- zrb/util/string/conversion.py +86 -0
- zrb/util/string/format.py +20 -0
- zrb/util/string/name.py +12 -0
- zrb/util/todo.py +165 -4
- {zrb-1.5.10.dist-info → zrb-1.5.12.dist-info}/METADATA +3 -3
- {zrb-1.5.10.dist-info → zrb-1.5.12.dist-info}/RECORD +45 -44
- zrb/task/base/dependencies.py +0 -57
- {zrb-1.5.10.dist-info → zrb-1.5.12.dist-info}/WHEEL +0 -0
- {zrb-1.5.10.dist-info → zrb-1.5.12.dist-info}/entry_points.txt +0 -0
@@ -6,24 +6,85 @@ from zrb.util.codemod.modification_mode import APPEND, PREPEND, REPLACE
|
|
6
6
|
def replace_method_code(
|
7
7
|
original_code: str, class_name: str, method_name: str, new_code: str
|
8
8
|
) -> str:
|
9
|
+
"""
|
10
|
+
Replace the entire code body of a specified method within a class.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
original_code (str): The original Python code as a string.
|
14
|
+
class_name (str): The name of the class containing the method.
|
15
|
+
method_name (str): The name of the method to modify.
|
16
|
+
new_code (str): The new code body for the method as a string.
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
str: The modified Python code as a string.
|
20
|
+
|
21
|
+
Raises:
|
22
|
+
ValueError: If the specified class or method is not found in the code.
|
23
|
+
"""
|
9
24
|
return _modify_method(original_code, class_name, method_name, new_code, REPLACE)
|
10
25
|
|
11
26
|
|
12
27
|
def prepend_code_to_method(
|
13
28
|
original_code: str, class_name: str, method_name: str, new_code: str
|
14
29
|
) -> str:
|
30
|
+
"""
|
31
|
+
Prepend code to the body of a specified method within a class.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
original_code (str): The original Python code as a string.
|
35
|
+
class_name (str): The name of the class containing the method.
|
36
|
+
method_name (str): The name of the method to modify.
|
37
|
+
new_code (str): The code to prepend as a string.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
str: The modified Python code as a string.
|
41
|
+
|
42
|
+
Raises:
|
43
|
+
ValueError: If the specified class or method is not found in the code.
|
44
|
+
"""
|
15
45
|
return _modify_method(original_code, class_name, method_name, new_code, PREPEND)
|
16
46
|
|
17
47
|
|
18
48
|
def append_code_to_method(
|
19
49
|
original_code: str, class_name: str, method_name: str, new_code: str
|
20
50
|
) -> str:
|
51
|
+
"""
|
52
|
+
Append code to the body of a specified method within a class.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
original_code (str): The original Python code as a string.
|
56
|
+
class_name (str): The name of the class containing the method.
|
57
|
+
method_name (str): The name of the method to modify.
|
58
|
+
new_code (str): The code to append as a string.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
str: The modified Python code as a string.
|
62
|
+
|
63
|
+
Raises:
|
64
|
+
ValueError: If the specified class or method is not found in the code.
|
65
|
+
"""
|
21
66
|
return _modify_method(original_code, class_name, method_name, new_code, APPEND)
|
22
67
|
|
23
68
|
|
24
69
|
def _modify_method(
|
25
70
|
original_code: str, class_name: str, method_name: str, new_code: str, mode: int
|
26
71
|
) -> str:
|
72
|
+
"""
|
73
|
+
Modify the code body of a specified method within a class.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
original_code (str): The original Python code as a string.
|
77
|
+
class_name (str): The name of the class containing the method.
|
78
|
+
method_name (str): The name of the method to modify.
|
79
|
+
new_code (str): The code to add/replace as a string.
|
80
|
+
mode (int): The modification mode (PREPEND, APPEND, or REPLACE).
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
str: The modified Python code as a string.
|
84
|
+
|
85
|
+
Raises:
|
86
|
+
ValueError: If the specified class or method is not found in the code.
|
87
|
+
"""
|
27
88
|
# Parse the original code into a module
|
28
89
|
module = cst.parse_module(original_code)
|
29
90
|
# Initialize the transformer with the necessary information
|
@@ -40,7 +101,20 @@ def _modify_method(
|
|
40
101
|
|
41
102
|
|
42
103
|
class _MethodModifier(cst.CSTTransformer):
|
104
|
+
"""
|
105
|
+
A LibCST transformer to modify the code body of a method within a ClassDef node.
|
106
|
+
"""
|
107
|
+
|
43
108
|
def __init__(self, class_name: str, method_name: str, new_code: str, mode: int):
|
109
|
+
"""
|
110
|
+
Initialize the transformer.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
class_name (str): The name of the target class.
|
114
|
+
method_name (str): The name of the target method.
|
115
|
+
new_code (str): The new code body as a string.
|
116
|
+
mode (int): The modification mode (PREPEND, APPEND, or REPLACE).
|
117
|
+
"""
|
44
118
|
self.class_name = class_name
|
45
119
|
self.method_name = method_name
|
46
120
|
# Use parse_module to handle multiple statements
|
@@ -52,6 +126,9 @@ class _MethodModifier(cst.CSTTransformer):
|
|
52
126
|
def leave_ClassDef(
|
53
127
|
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
54
128
|
) -> cst.ClassDef:
|
129
|
+
"""
|
130
|
+
Called when leaving a ClassDef node. Modifies the body of the target method.
|
131
|
+
"""
|
55
132
|
# Check if the class matches the target class
|
56
133
|
if original_node.name.value == self.class_name:
|
57
134
|
self.class_found = True
|
@@ -1,4 +1,14 @@
|
|
1
1
|
def prepend_code_to_module(original_code: str, new_code: str) -> str:
|
2
|
+
"""
|
3
|
+
Prepend code to a module after the last import statement.
|
4
|
+
|
5
|
+
Args:
|
6
|
+
original_code (str): The original Python code as a string.
|
7
|
+
new_code (str): The code to prepend as a string.
|
8
|
+
|
9
|
+
Returns:
|
10
|
+
str: The modified Python code as a string.
|
11
|
+
"""
|
2
12
|
lines = original_code.splitlines()
|
3
13
|
last_import_index = -1
|
4
14
|
for i, line in enumerate(lines):
|
zrb/util/cron.py
CHANGED
@@ -2,7 +2,20 @@ import datetime
|
|
2
2
|
|
3
3
|
|
4
4
|
def parse_cron_field(field: str, min_value: int, max_value: int):
|
5
|
-
"""
|
5
|
+
"""
|
6
|
+
Parse a cron field string into a set of integer values.
|
7
|
+
|
8
|
+
Supports wildcards (*), ranges (e.g., 1-5), lists (e.g., 1,3,5),
|
9
|
+
and step values (e.g., */5, 1-10/2).
|
10
|
+
|
11
|
+
Args:
|
12
|
+
field (str): The cron field string (e.g., "*", "1-5", "*/10").
|
13
|
+
min_value (int): The minimum allowed value for the field.
|
14
|
+
max_value (int): The maximum allowed value for the field.
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
set[int]: A set of integer values represented by the cron field.
|
18
|
+
"""
|
6
19
|
values = set()
|
7
20
|
if field == "*":
|
8
21
|
return set(range(min_value, max_value + 1))
|
@@ -34,7 +47,16 @@ def parse_cron_field(field: str, min_value: int, max_value: int):
|
|
34
47
|
|
35
48
|
|
36
49
|
def handle_special_cron_patterns(pattern: str, dt: datetime.datetime):
|
37
|
-
"""
|
50
|
+
"""
|
51
|
+
Check if a datetime object matches a special cron pattern.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
pattern (str): The special cron pattern (e.g., "@yearly", "@monthly").
|
55
|
+
dt (datetime.datetime): The datetime object to check.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
bool: True if the datetime matches the pattern, False otherwise.
|
59
|
+
"""
|
38
60
|
if pattern == "@yearly" or pattern == "@annually":
|
39
61
|
return dt.month == 1 and dt.day == 1 and dt.hour == 0 and dt.minute == 0
|
40
62
|
elif pattern == "@monthly":
|
@@ -53,7 +75,19 @@ def handle_special_cron_patterns(pattern: str, dt: datetime.datetime):
|
|
53
75
|
|
54
76
|
|
55
77
|
def match_cron(cron_pattern: str, dt: datetime.datetime):
|
56
|
-
"""
|
78
|
+
"""
|
79
|
+
Check if a datetime object matches a cron pattern.
|
80
|
+
|
81
|
+
Supports standard cron format (minute hour day month day_of_week)
|
82
|
+
and special patterns (e.g., "@yearly", "@monthly").
|
83
|
+
|
84
|
+
Args:
|
85
|
+
cron_pattern (str): The cron pattern string.
|
86
|
+
dt (datetime.datetime): The datetime object to check.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
bool: True if the datetime matches the cron pattern, False otherwise.
|
90
|
+
"""
|
57
91
|
# Handle special cron patterns
|
58
92
|
if cron_pattern.startswith("@"):
|
59
93
|
return handle_special_cron_patterns(cron_pattern, dt)
|
zrb/util/file.py
CHANGED
@@ -3,6 +3,15 @@ import re
|
|
3
3
|
|
4
4
|
|
5
5
|
def read_file(file_path: str, replace_map: dict[str, str] = {}) -> str:
|
6
|
+
"""Reads a file and optionally replaces content based on a map.
|
7
|
+
|
8
|
+
Args:
|
9
|
+
file_path: The path to the file.
|
10
|
+
replace_map: A dictionary of strings to replace.
|
11
|
+
|
12
|
+
Returns:
|
13
|
+
The content of the file with replacements applied.
|
14
|
+
"""
|
6
15
|
with open(
|
7
16
|
os.path.abspath(os.path.expanduser(file_path)), "r", encoding="utf-8"
|
8
17
|
) as f:
|
@@ -15,6 +24,15 @@ def read_file(file_path: str, replace_map: dict[str, str] = {}) -> str:
|
|
15
24
|
def read_file_with_line_numbers(
|
16
25
|
file_path: str, replace_map: dict[str, str] = {}
|
17
26
|
) -> str:
|
27
|
+
"""Reads a file and returns content with line numbers.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
file_path: The path to the file.
|
31
|
+
replace_map: A dictionary of strings to replace.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
The content of the file with line numbers and replacements applied.
|
35
|
+
"""
|
18
36
|
content = read_file(file_path, replace_map)
|
19
37
|
lines = content.splitlines()
|
20
38
|
numbered_lines = [f"{i + 1} | {line}" for i, line in enumerate(lines)]
|
@@ -22,10 +40,24 @@ def read_file_with_line_numbers(
|
|
22
40
|
|
23
41
|
|
24
42
|
def read_dir(dir_path: str) -> list[str]:
|
43
|
+
"""Reads a directory and returns a list of file names.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
dir_path: The path to the directory.
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
A list of file names in the directory.
|
50
|
+
"""
|
25
51
|
return [f for f in os.listdir(os.path.abspath(os.path.expanduser(dir_path)))]
|
26
52
|
|
27
53
|
|
28
54
|
def write_file(file_path: str, content: str | list[str]):
|
55
|
+
"""Writes content to a file.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
file_path: The path to the file.
|
59
|
+
content: The content to write, either a string or a list of strings.
|
60
|
+
"""
|
29
61
|
if isinstance(content, list):
|
30
62
|
content = "\n".join([line for line in content if line is not None])
|
31
63
|
dir_path = os.path.dirname(file_path)
|
zrb/util/git.py
CHANGED
@@ -19,6 +19,21 @@ async def get_diff(
|
|
19
19
|
current_commit: str,
|
20
20
|
print_method: Callable[..., Any] = print,
|
21
21
|
) -> DiffResult:
|
22
|
+
"""
|
23
|
+
Get the difference between two commits in a Git repository.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
repo_dir (str): The path to the Git repository.
|
27
|
+
source_commit (str): The source commit hash or reference.
|
28
|
+
current_commit (str): The current commit hash or reference.
|
29
|
+
print_method (Callable[..., Any]): Method to print command output.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
DiffResult: An object containing lists of created, removed, and updated files.
|
33
|
+
|
34
|
+
Raises:
|
35
|
+
Exception: If the git command returns a non-zero exit code.
|
36
|
+
"""
|
22
37
|
cmd_result, exit_code = await run_command(
|
23
38
|
cmd=["git", "diff", source_commit, current_commit],
|
24
39
|
cwd=repo_dir,
|
@@ -55,6 +70,18 @@ async def get_diff(
|
|
55
70
|
|
56
71
|
|
57
72
|
async def get_repo_dir(print_method: Callable[..., Any] = print) -> str:
|
73
|
+
"""
|
74
|
+
Get the top-level directory of the Git repository.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
print_method (Callable[..., Any]): Method to print command output.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
str: The absolute path to the repository's top-level directory.
|
81
|
+
|
82
|
+
Raises:
|
83
|
+
Exception: If the git command returns a non-zero exit code.
|
84
|
+
"""
|
58
85
|
cmd_result, exit_code = await run_command(
|
59
86
|
cmd=["git", "rev-parse", "--show-toplevel"],
|
60
87
|
print_method=print_method,
|
@@ -67,6 +94,19 @@ async def get_repo_dir(print_method: Callable[..., Any] = print) -> str:
|
|
67
94
|
async def get_current_branch(
|
68
95
|
repo_dir: str, print_method: Callable[..., Any] = print
|
69
96
|
) -> str:
|
97
|
+
"""
|
98
|
+
Get the current branch name of the Git repository.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
repo_dir (str): The path to the Git repository.
|
102
|
+
print_method (Callable[..., Any]): Method to print command output.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
str: The name of the current branch.
|
106
|
+
|
107
|
+
Raises:
|
108
|
+
Exception: If the git command returns a non-zero exit code.
|
109
|
+
"""
|
70
110
|
cmd_result, exit_code = await run_command(
|
71
111
|
cmd=["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
72
112
|
cwd=repo_dir,
|
@@ -80,6 +120,19 @@ async def get_current_branch(
|
|
80
120
|
async def get_branches(
|
81
121
|
repo_dir: str, print_method: Callable[..., Any] = print
|
82
122
|
) -> list[str]:
|
123
|
+
"""
|
124
|
+
Get a list of all branches in the Git repository.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
repo_dir (str): The path to the Git repository.
|
128
|
+
print_method (Callable[..., Any]): Method to print command output.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
list[str]: A list of branch names.
|
132
|
+
|
133
|
+
Raises:
|
134
|
+
Exception: If the git command returns a non-zero exit code.
|
135
|
+
"""
|
83
136
|
cmd_result, exit_code = await run_command(
|
84
137
|
cmd=["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
85
138
|
cwd=repo_dir,
|
@@ -95,6 +148,20 @@ async def get_branches(
|
|
95
148
|
async def delete_branch(
|
96
149
|
repo_dir: str, branch_name: str, print_method: Callable[..., Any] = print
|
97
150
|
) -> str:
|
151
|
+
"""
|
152
|
+
Delete a branch in the Git repository.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
repo_dir (str): The path to the Git repository.
|
156
|
+
branch_name (str): The name of the branch to delete.
|
157
|
+
print_method (Callable[..., Any]): Method to print command output.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
str: The output of the git command.
|
161
|
+
|
162
|
+
Raises:
|
163
|
+
Exception: If the git command returns a non-zero exit code.
|
164
|
+
"""
|
98
165
|
cmd_result, exit_code = await run_command(
|
99
166
|
cmd=["git", "branch", "-D", branch_name],
|
100
167
|
cwd=repo_dir,
|
@@ -106,6 +173,16 @@ async def delete_branch(
|
|
106
173
|
|
107
174
|
|
108
175
|
async def add(repo_dir: str, print_method: Callable[..., Any] = print):
|
176
|
+
"""
|
177
|
+
Add all changes to the Git staging area.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
repo_dir (str): The path to the Git repository.
|
181
|
+
print_method (Callable[..., Any]): Method to print command output.
|
182
|
+
|
183
|
+
Raises:
|
184
|
+
Exception: If the git command returns a non-zero exit code.
|
185
|
+
"""
|
109
186
|
_, exit_code = await run_command(
|
110
187
|
cmd=["git", "add", ".", "-A"],
|
111
188
|
cwd=repo_dir,
|
@@ -118,6 +195,18 @@ async def add(repo_dir: str, print_method: Callable[..., Any] = print):
|
|
118
195
|
async def commit(
|
119
196
|
repo_dir: str, message: str, print_method: Callable[..., Any] = print
|
120
197
|
) -> str:
|
198
|
+
"""
|
199
|
+
Commit changes in the Git repository.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
repo_dir (str): The path to the Git repository.
|
203
|
+
message (str): The commit message.
|
204
|
+
print_method (Callable[..., Any]): Method to print command output.
|
205
|
+
|
206
|
+
Raises:
|
207
|
+
Exception: If the git command returns a non-zero exit code, unless it's
|
208
|
+
the "nothing to commit" message.
|
209
|
+
"""
|
121
210
|
cmd_result, exit_code = await run_command(
|
122
211
|
cmd=["git", "commit", "-m", message],
|
123
212
|
cwd=repo_dir,
|
@@ -135,6 +224,18 @@ async def commit(
|
|
135
224
|
async def pull(
|
136
225
|
repo_dir: str, remote: str, branch: str, print_method: Callable[..., Any] = print
|
137
226
|
) -> str:
|
227
|
+
"""
|
228
|
+
Pull changes from a remote repository and branch.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
repo_dir (str): The path to the Git repository.
|
232
|
+
remote (str): The name of the remote.
|
233
|
+
branch (str): The name of the branch.
|
234
|
+
print_method (Callable[..., Any]): Method to print command output.
|
235
|
+
|
236
|
+
Raises:
|
237
|
+
Exception: If the git command returns a non-zero exit code.
|
238
|
+
"""
|
138
239
|
_, exit_code = await run_command(
|
139
240
|
cmd=["git", "pull", remote, branch],
|
140
241
|
cwd=repo_dir,
|
@@ -147,6 +248,18 @@ async def pull(
|
|
147
248
|
async def push(
|
148
249
|
repo_dir: str, remote: str, branch: str, print_method: Callable[..., Any] = print
|
149
250
|
) -> str:
|
251
|
+
"""
|
252
|
+
Push changes to a remote repository and branch.
|
253
|
+
|
254
|
+
Args:
|
255
|
+
repo_dir (str): The path to the Git repository.
|
256
|
+
remote (str): The name of the remote.
|
257
|
+
branch (str): The name of the branch.
|
258
|
+
print_method (Callable[..., Any]): Method to print command output.
|
259
|
+
|
260
|
+
Raises:
|
261
|
+
Exception: If the git command returns a non-zero exit code.
|
262
|
+
"""
|
150
263
|
_, exit_code = await run_command(
|
151
264
|
cmd=["git", "push", "-u", remote, branch],
|
152
265
|
cwd=repo_dir,
|
zrb/util/git_subtree.py
CHANGED
@@ -19,6 +19,15 @@ class SubTreeConfig(BaseModel):
|
|
19
19
|
|
20
20
|
|
21
21
|
def load_config(repo_dir: str) -> SubTreeConfig:
|
22
|
+
"""
|
23
|
+
Load the subtree configuration from subtrees.json.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
repo_dir (str): The path to the Git repository.
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
SubTreeConfig: The loaded subtree configuration.
|
30
|
+
"""
|
22
31
|
file_path = os.path.join(repo_dir, "subtrees.json")
|
23
32
|
if not os.path.exists(file_path):
|
24
33
|
return SubTreeConfig(data={})
|
@@ -26,6 +35,13 @@ def load_config(repo_dir: str) -> SubTreeConfig:
|
|
26
35
|
|
27
36
|
|
28
37
|
def save_config(repo_dir: str, config: SubTreeConfig):
|
38
|
+
"""
|
39
|
+
Save the subtree configuration to subtrees.json.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
repo_dir (str): The path to the Git repository.
|
43
|
+
config (SubTreeConfig): The subtree configuration to save.
|
44
|
+
"""
|
29
45
|
file_path = os.path.join(repo_dir, "subtrees.json")
|
30
46
|
write_file(file_path, config.model_dump_json(indent=2))
|
31
47
|
|
@@ -38,6 +54,22 @@ async def add_subtree(
|
|
38
54
|
prefix: str,
|
39
55
|
print_method: Callable[..., Any] = print,
|
40
56
|
):
|
57
|
+
"""
|
58
|
+
Add a Git subtree to the repository.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
repo_dir (str): The path to the Git repository.
|
62
|
+
name (str): The name for the subtree configuration.
|
63
|
+
repo_url (str): The URL of the subtree repository.
|
64
|
+
branch (str): The branch of the subtree repository.
|
65
|
+
prefix (str): The local path where the subtree will be added.
|
66
|
+
print_method (Callable[..., Any]): Method to print command output.
|
67
|
+
|
68
|
+
Raises:
|
69
|
+
ValueError: If the prefix directory already exists or subtree config
|
70
|
+
name already exists.
|
71
|
+
Exception: If the git command returns a non-zero exit code.
|
72
|
+
"""
|
41
73
|
config = load_config()
|
42
74
|
if os.path.isdir(prefix):
|
43
75
|
raise ValueError(f"Directory exists: {prefix}")
|
@@ -71,6 +103,19 @@ async def pull_subtree(
|
|
71
103
|
branch: str,
|
72
104
|
print_method: Callable[..., Any] = print,
|
73
105
|
):
|
106
|
+
"""
|
107
|
+
Pull changes from a Git subtree.
|
108
|
+
|
109
|
+
Args:
|
110
|
+
repo_dir (str): The path to the Git repository.
|
111
|
+
prefix (str): The local path of the subtree.
|
112
|
+
repo_url (str): The URL of the subtree repository.
|
113
|
+
branch (str): The branch of the subtree repository.
|
114
|
+
print_method (Callable[..., Any]): Method to print command output.
|
115
|
+
|
116
|
+
Raises:
|
117
|
+
Exception: If the git command returns a non-zero exit code.
|
118
|
+
"""
|
74
119
|
_, exit_code = await run_command(
|
75
120
|
cmd=[
|
76
121
|
"git",
|
@@ -95,6 +140,19 @@ async def push_subtree(
|
|
95
140
|
branch: str,
|
96
141
|
print_method: Callable[..., Any] = print,
|
97
142
|
):
|
143
|
+
"""
|
144
|
+
Push changes to a Git subtree.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
repo_dir (str): The path to the Git repository.
|
148
|
+
prefix (str): The local path of the subtree.
|
149
|
+
repo_url (str): The URL of the subtree repository.
|
150
|
+
branch (str): The branch of the subtree repository.
|
151
|
+
print_method (Callable[..., Any]): Method to print command output.
|
152
|
+
|
153
|
+
Raises:
|
154
|
+
Exception: If the git command returns a non-zero exit code.
|
155
|
+
"""
|
98
156
|
_, exit_code = await run_command(
|
99
157
|
cmd=[
|
100
158
|
"git",
|
zrb/util/group.py
CHANGED
@@ -9,6 +9,21 @@ class NodeNotFoundError(ValueError):
|
|
9
9
|
def extract_node_from_args(
|
10
10
|
root_group: AnyGroup, args: list[str], web_only: bool = False
|
11
11
|
) -> tuple[AnyGroup | AnyTask, list[str], list[str]]:
|
12
|
+
"""
|
13
|
+
Extract a node (Group or Task) from a list of command-line arguments.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
root_group (AnyGroup): The root group to start the search from.
|
17
|
+
args (list[str]): The list of command-line arguments.
|
18
|
+
web_only (bool): If True, only consider tasks that are not CLI-only.
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
tuple[AnyGroup | AnyTask, list[str], list[str]]: A tuple containing the
|
22
|
+
extracted node, the path to the node, and any residual arguments.
|
23
|
+
|
24
|
+
Raises:
|
25
|
+
NodeNotFoundError: If no matching task or group is found for a given argument.
|
26
|
+
"""
|
12
27
|
node = root_group
|
13
28
|
node_path = []
|
14
29
|
residual_args = []
|
@@ -17,8 +32,12 @@ def extract_node_from_args(
|
|
17
32
|
if web_only and task is not None and task.cli_only:
|
18
33
|
task = None
|
19
34
|
group = node.get_group_by_alias(name)
|
20
|
-
|
21
|
-
|
35
|
+
# Only ignore empty groups if web_only is True
|
36
|
+
if (
|
37
|
+
group is not None
|
38
|
+
and web_only
|
39
|
+
and len(get_all_subtasks(group, web_only)) == 0
|
40
|
+
):
|
22
41
|
group = None
|
23
42
|
if task is None and group is None:
|
24
43
|
raise NodeNotFoundError(
|
@@ -40,8 +59,21 @@ def extract_node_from_args(
|
|
40
59
|
|
41
60
|
|
42
61
|
def get_node_path(group: AnyGroup, node: AnyGroup | AnyTask) -> list[str] | None:
|
62
|
+
"""
|
63
|
+
Get the path (aliases) to a specific node within a group hierarchy.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
group (AnyGroup): The group to search within.
|
67
|
+
node (AnyGroup | AnyTask): The target node.
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
list[str] | None: A list of aliases representing the path to the node,
|
71
|
+
or None if the node is not found.
|
72
|
+
"""
|
43
73
|
if group is None:
|
44
74
|
return []
|
75
|
+
if group == node: # Handle the case where the target is the starting group
|
76
|
+
return [group.name]
|
45
77
|
if isinstance(node, AnyTask):
|
46
78
|
for alias, subtask in group.subtasks.items():
|
47
79
|
if subtask == node:
|
@@ -60,6 +92,16 @@ def get_node_path(group: AnyGroup, node: AnyGroup | AnyTask) -> list[str] | None
|
|
60
92
|
def get_non_empty_subgroups(
|
61
93
|
group: AnyGroup, web_only: bool = False
|
62
94
|
) -> dict[str, AnyGroup]:
|
95
|
+
"""
|
96
|
+
Get subgroups that contain at least one task.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
group (AnyGroup): The group to search within.
|
100
|
+
web_only (bool): If True, only consider tasks that are not CLI-only.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
dict[str, AnyGroup]: A dictionary of subgroups that are not empty.
|
104
|
+
"""
|
63
105
|
return {
|
64
106
|
alias: subgroup
|
65
107
|
for alias, subgroup in group.subgroups.items()
|
@@ -68,6 +110,16 @@ def get_non_empty_subgroups(
|
|
68
110
|
|
69
111
|
|
70
112
|
def get_subtasks(group: AnyGroup, web_only: bool = False) -> dict[str, AnyTask]:
|
113
|
+
"""
|
114
|
+
Get the direct subtasks of a group.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
group (AnyGroup): The group to search within.
|
118
|
+
web_only (bool): If True, only include tasks that are not CLI-only.
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
dict[str, AnyTask]: A dictionary of subtasks.
|
122
|
+
"""
|
71
123
|
return {
|
72
124
|
alias: subtask
|
73
125
|
for alias, subtask in group.subtasks.items()
|
@@ -76,6 +128,16 @@ def get_subtasks(group: AnyGroup, web_only: bool = False) -> dict[str, AnyTask]:
|
|
76
128
|
|
77
129
|
|
78
130
|
def get_all_subtasks(group: AnyGroup, web_only: bool = False) -> list[AnyTask]:
|
131
|
+
"""
|
132
|
+
Get all subtasks (including nested ones) within a group hierarchy.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
group (AnyGroup): The group to search within.
|
136
|
+
web_only (bool): If True, only include tasks that are not CLI-only.
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
list[AnyTask]: A list of all subtasks.
|
140
|
+
"""
|
79
141
|
subtasks = [
|
80
142
|
subtask
|
81
143
|
for subtask in group.subtasks.values()
|
zrb/util/load.py
CHANGED
@@ -11,6 +11,17 @@ pattern = re.compile("[^a-zA-Z0-9]")
|
|
11
11
|
|
12
12
|
@lru_cache
|
13
13
|
def load_file(script_path: str, sys_path_index: int = 0) -> Any | None:
|
14
|
+
"""
|
15
|
+
Load a Python module from a file path.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
script_path (str): The path to the Python script.
|
19
|
+
sys_path_index (int): The index to insert the script directory into sys.path.
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
Any | None: The loaded module object, or None if the file does not
|
23
|
+
exist or cannot be loaded.
|
24
|
+
"""
|
14
25
|
if not os.path.isfile(script_path):
|
15
26
|
return None
|
16
27
|
module_name = pattern.sub("", script_path)
|
@@ -31,6 +42,15 @@ def load_file(script_path: str, sys_path_index: int = 0) -> Any | None:
|
|
31
42
|
|
32
43
|
|
33
44
|
def _get_new_python_path(dir_path: str) -> str:
|
45
|
+
"""
|
46
|
+
Helper function to update the PYTHONPATH environment variable.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
dir_path (str): The directory path to add to PYTHONPATH.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
str: The new value for the PYTHONPATH environment variable.
|
53
|
+
"""
|
34
54
|
current_python_path = os.getenv("PYTHONPATH")
|
35
55
|
if current_python_path is None or current_python_path == "":
|
36
56
|
return dir_path
|
@@ -40,5 +60,14 @@ def _get_new_python_path(dir_path: str) -> str:
|
|
40
60
|
|
41
61
|
|
42
62
|
def load_module(module_name: str) -> Any:
|
63
|
+
"""
|
64
|
+
Load a Python module by its name.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
module_name (str): The name of the module to load.
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
Any: The loaded module object.
|
71
|
+
"""
|
43
72
|
module = importlib.import_module(module_name)
|
44
73
|
return module
|
zrb/util/run.py
CHANGED
@@ -4,6 +4,15 @@ from typing import Any
|
|
4
4
|
|
5
5
|
|
6
6
|
async def run_async(value: Any) -> Any:
|
7
|
+
"""
|
8
|
+
Run a value asynchronously, awaiting if it's awaitable or running in a thread if not.
|
9
|
+
|
10
|
+
Args:
|
11
|
+
value (Any): The value to run. Can be awaitable or not.
|
12
|
+
|
13
|
+
Returns:
|
14
|
+
Any: The result of the awaited value or the value itself if not awaitable.
|
15
|
+
"""
|
7
16
|
if inspect.isawaitable(value):
|
8
17
|
return await value
|
9
18
|
return await asyncio.to_thread(lambda: value)
|