zrb 1.0.0a5__py3-none-any.whl → 1.0.0a10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zrb/task/base_task.py CHANGED
@@ -38,6 +38,7 @@ class BaseTask(AnyTask):
38
38
  monitor_readiness: bool = False,
39
39
  upstream: list[AnyTask] | AnyTask | None = None,
40
40
  fallback: list[AnyTask] | AnyTask | None = None,
41
+ successor: list[AnyTask] | AnyTask | None = None,
41
42
  ):
42
43
  self._name = name
43
44
  self._color = color
@@ -50,6 +51,7 @@ class BaseTask(AnyTask):
50
51
  self._retry_period = retry_period
51
52
  self._upstreams = upstream
52
53
  self._fallbacks = fallback
54
+ self._successors = successor
53
55
  self._readiness_checks = readiness_check
54
56
  self._readiness_check_delay = readiness_check_delay
55
57
  self._readiness_check_period = readiness_check_period
@@ -65,17 +67,17 @@ class BaseTask(AnyTask):
65
67
  def __rshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask:
66
68
  try:
67
69
  if isinstance(other, AnyTask):
68
- other.append_upstreams(self)
70
+ other.append_upstream(self)
69
71
  elif isinstance(other, list):
70
72
  for task in other:
71
- task.append_upstreams(self)
73
+ task.append_upstream(self)
72
74
  return other
73
75
  except Exception as e:
74
76
  raise ValueError(f"Invalid operation {self} >> {other}: {e}")
75
77
 
76
78
  def __lshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask:
77
79
  try:
78
- self.append_upstreams(other)
80
+ self.append_upstream(other)
79
81
  return self
80
82
  except Exception as e:
81
83
  raise ValueError(f"Invalid operation {self} << {other}: {e}")
@@ -142,6 +144,44 @@ class BaseTask(AnyTask):
142
144
  return [self._fallbacks]
143
145
  return self._fallbacks
144
146
 
147
+ def append_fallback(self, fallbacks: AnyTask | list[AnyTask]):
148
+ fallback_list = [fallbacks] if isinstance(fallbacks, AnyTask) else fallbacks
149
+ for fallback in fallback_list:
150
+ self.__append_fallback(fallback)
151
+
152
+ def __append_fallback(self, fallback: AnyTask):
153
+ # Make sure self._fallbacks is a list
154
+ if self._fallbacks is None:
155
+ self._fallbacks = []
156
+ elif isinstance(self._fallbacks, AnyTask):
157
+ self._fallbacks = [self._fallbacks]
158
+ # Add fallback if it was not on self._fallbacks
159
+ if fallback not in self._fallbacks:
160
+ self._fallbacks.append(fallback)
161
+
162
+ @property
163
+ def successors(self) -> list[AnyTask]:
164
+ if self._successors is None:
165
+ return []
166
+ elif isinstance(self._successors, AnyTask):
167
+ return [self._successors]
168
+ return self._successors
169
+
170
+ def append_successor(self, successors: AnyTask | list[AnyTask]):
171
+ successor_list = [successors] if isinstance(successors, AnyTask) else successors
172
+ for successor in successor_list:
173
+ self.__append_successor(successor)
174
+
175
+ def __append_successor(self, successor: AnyTask):
176
+ # Make sure self._successors is a list
177
+ if self._successors is None:
178
+ self._successors = []
179
+ elif isinstance(self._successors, AnyTask):
180
+ self._successors = [self._successors]
181
+ # Add successor if it was not on self._successors
182
+ if successor not in self._successors:
183
+ self._successors.append(successor)
184
+
145
185
  @property
146
186
  def readiness_checks(self) -> list[AnyTask]:
147
187
  if self._readiness_checks is None:
@@ -150,6 +190,25 @@ class BaseTask(AnyTask):
150
190
  return [self._readiness_checks]
151
191
  return self._readiness_checks
152
192
 
193
+ def append_readiness_check(self, readiness_checks: AnyTask | list[AnyTask]):
194
+ readiness_check_list = (
195
+ [readiness_checks]
196
+ if isinstance(readiness_checks, AnyTask)
197
+ else readiness_checks
198
+ )
199
+ for readiness_check in readiness_check_list:
200
+ self.__append_readiness_check(readiness_check)
201
+
202
+ def __append_readiness_check(self, readiness_check: AnyTask):
203
+ # Make sure self._readiness_checks is a list
204
+ if self._readiness_checks is None:
205
+ self._readiness_checks = []
206
+ elif isinstance(self._readiness_checks, AnyTask):
207
+ self._readiness_checks = [self._readiness_checks]
208
+ # Add readiness_check if it was not on self._readiness_checks
209
+ if readiness_check not in self._readiness_checks:
210
+ self._readiness_checks.append(readiness_check)
211
+
153
212
  @property
154
213
  def upstreams(self) -> list[AnyTask]:
155
214
  if self._upstreams is None:
@@ -158,7 +217,7 @@ class BaseTask(AnyTask):
158
217
  return [self._upstreams]
159
218
  return self._upstreams
160
219
 
161
- def append_upstreams(self, upstreams: AnyTask | list[AnyTask]):
220
+ def append_upstream(self, upstreams: AnyTask | list[AnyTask]):
162
221
  upstream_list = [upstreams] if isinstance(upstreams, AnyTask) else upstreams
163
222
  for upstream in upstream_list:
164
223
  self.__append_upstream(upstream)
@@ -374,6 +433,7 @@ class BaseTask(AnyTask):
374
433
  # Put result on xcom
375
434
  task_xcom: Xcom = ctx.xcom.get(self.name)
376
435
  task_xcom.push(result)
436
+ await run_async(self.__exec_successors(session))
377
437
  return result
378
438
  except (asyncio.CancelledError, KeyboardInterrupt):
379
439
  ctx.log_info("Marked as failed")
@@ -390,6 +450,13 @@ class BaseTask(AnyTask):
390
450
  await run_async(self.__exec_fallbacks(session))
391
451
  raise e
392
452
 
453
+ async def __exec_successors(self, session: AnySession) -> Any:
454
+ successors: list[AnyTask] = self.successors
455
+ successor_coros = [
456
+ run_async(successor.exec_chain(session)) for successor in successors
457
+ ]
458
+ await asyncio.gather(*successor_coros)
459
+
393
460
  async def __exec_fallbacks(self, session: AnySession) -> Any:
394
461
  fallbacks: list[AnyTask] = self.fallbacks
395
462
  fallback_coros = [
zrb/task/cmd_task.py CHANGED
@@ -5,13 +5,14 @@ import sys
5
5
  from zrb.attr.type import BoolAttr, IntAttr, StrAttr
6
6
  from zrb.cmd.cmd_result import CmdResult
7
7
  from zrb.cmd.cmd_val import AnyCmdVal, CmdVal, SingleCmdVal
8
- from zrb.config import DEFAULT_SHELL
8
+ from zrb.config import DEFAULT_SHELL, WARN_UNRECOMMENDED_COMMAND
9
9
  from zrb.context.any_context import AnyContext
10
10
  from zrb.env.any_env import AnyEnv
11
11
  from zrb.input.any_input import AnyInput
12
12
  from zrb.task.any_task import AnyTask
13
13
  from zrb.task.base_task import BaseTask
14
14
  from zrb.util.attr import get_int_attr, get_str_attr
15
+ from zrb.util.cmd.command import check_unrecommended_commands
15
16
  from zrb.util.cmd.remote import get_remote_cmd_script
16
17
 
17
18
 
@@ -32,6 +33,7 @@ class CmdTask(BaseTask):
32
33
  remote_host: StrAttr | None = None,
33
34
  render_remote_host: bool = True,
34
35
  remote_port: IntAttr | None = None,
36
+ render_remote_port: bool = True,
35
37
  remote_user: StrAttr | None = None,
36
38
  render_remote_user: bool = True,
37
39
  remote_password: StrAttr | None = None,
@@ -42,6 +44,7 @@ class CmdTask(BaseTask):
42
44
  render_cmd: bool = True,
43
45
  cwd: str | None = None,
44
46
  render_cwd: bool = True,
47
+ warn_unrecommended_command: bool | None = None,
45
48
  max_output_line: int = 1000,
46
49
  max_error_line: int = 1000,
47
50
  execute_condition: BoolAttr = True,
@@ -83,6 +86,7 @@ class CmdTask(BaseTask):
83
86
  self._remote_host = remote_host
84
87
  self._render_remote_host = render_remote_host
85
88
  self._remote_port = remote_port
89
+ self._render_remote_port = render_remote_port
86
90
  self._remote_user = remote_user
87
91
  self._render_remote_user = render_remote_user
88
92
  self._remote_password = remote_password
@@ -95,6 +99,7 @@ class CmdTask(BaseTask):
95
99
  self._render_cwd = render_cwd
96
100
  self._max_output_line = max_output_line
97
101
  self._max_error_line = max_error_line
102
+ self._should_warn_unrecommended_command = warn_unrecommended_command
98
103
 
99
104
  async def _exec_action(self, ctx: AnyContext) -> CmdResult:
100
105
  """Turn _cmd attribute into subprocess.Popen and execute it as task's action.
@@ -105,7 +110,6 @@ class CmdTask(BaseTask):
105
110
  Returns:
106
111
  Any: The result of the action execution.
107
112
  """
108
- ctx.log_info("Running script")
109
113
  cmd_script = self._get_cmd_script(ctx)
110
114
  ctx.log_debug(f"Script: {self.__get_multiline_repr(cmd_script)}")
111
115
  shell = self._get_shell(ctx)
@@ -116,7 +120,10 @@ class CmdTask(BaseTask):
116
120
  env_map = self.__get_env_map(ctx)
117
121
  ctx.log_debug(f"Environment map: {env_map}")
118
122
  cmd_process = None
123
+ if self._get_should_warn_unrecommended_commands():
124
+ self._check_unrecommended_commands(ctx, shell, cmd_script)
119
125
  try:
126
+ ctx.log_info("Running script")
120
127
  cmd_process = await asyncio.create_subprocess_exec(
121
128
  shell,
122
129
  shell_flag,
@@ -150,6 +157,21 @@ class CmdTask(BaseTask):
150
157
  if cmd_process is not None and cmd_process.returncode is None:
151
158
  cmd_process.terminate()
152
159
 
160
+ def _get_should_warn_unrecommended_commands(self):
161
+ if self._should_warn_unrecommended_command is None:
162
+ return WARN_UNRECOMMENDED_COMMAND
163
+ return self._should_warn_unrecommended_command
164
+
165
+ def _check_unrecommended_commands(
166
+ self, ctx: AnyContext, shell: str, cmd_script: str
167
+ ):
168
+ if shell.endswith("bash") or shell.endswith("zsh"):
169
+ unrecommended_commands = check_unrecommended_commands(cmd_script)
170
+ if unrecommended_commands:
171
+ ctx.log_warning("The script contains unrecommended commands")
172
+ for command, reason in unrecommended_commands.items():
173
+ ctx.log_warning(f"- {command}: {reason}")
174
+
153
175
  def __get_env_map(self, ctx: AnyContext) -> dict[str, str]:
154
176
  envs = {key: val for key, val in ctx.env.items()}
155
177
  envs["_ZRB_SSH_PASSWORD"] = self._get_remote_password(ctx)
@@ -195,7 +217,9 @@ class CmdTask(BaseTask):
195
217
  )
196
218
 
197
219
  def _get_remote_port(self, ctx: AnyContext) -> int:
198
- return get_int_attr(ctx, self._remote_port, 22, auto_render=True)
220
+ return get_int_attr(
221
+ ctx, self._remote_port, 22, auto_render=self._render_remote_port
222
+ )
199
223
 
200
224
  def _get_remote_user(self, ctx: AnyContext) -> str:
201
225
  return get_str_attr(
zrb/task/llm_task.py CHANGED
@@ -108,8 +108,13 @@ class LLMTask(BaseTask):
108
108
  )
109
109
 
110
110
  async def _exec_action(self, ctx: AnyContext) -> Any:
111
- from litellm import acompletion
111
+ from litellm import acompletion, supports_function_calling
112
112
 
113
+ model = self._get_model(ctx)
114
+ try:
115
+ allow_function_call = supports_function_calling(model=model)
116
+ except Exception:
117
+ allow_function_call = False
113
118
  model_kwargs = self._get_model_kwargs(ctx)
114
119
  ctx.log_debug("MODEL KWARGS", model_kwargs)
115
120
  system_prompt = self._get_system_prompt(ctx)
@@ -121,27 +126,28 @@ class LLMTask(BaseTask):
121
126
  messages = history + [user_message]
122
127
  available_tools = self._get_tools(ctx)
123
128
  available_tools["scratchpad"] = scratchpad
124
- tool_schema = [
125
- callable_to_tool_schema(tool, name)
126
- for name, tool in available_tools.items()
127
- ]
128
- for additional_tool in self._additional_tools:
129
- fn = additional_tool.fn
130
- tool_name = additional_tool.name or fn.__name__
131
- tool_description = additional_tool.description
132
- available_tools[tool_name] = additional_tool.fn
133
- tool_schema.append(
134
- callable_to_tool_schema(
135
- fn, name=tool_name, description=tool_description
129
+ if allow_function_call:
130
+ tool_schema = [
131
+ callable_to_tool_schema(tool, name)
132
+ for name, tool in available_tools.items()
133
+ ]
134
+ for additional_tool in self._additional_tools:
135
+ fn = additional_tool.fn
136
+ tool_name = additional_tool.name or fn.__name__
137
+ tool_description = additional_tool.description
138
+ available_tools[tool_name] = additional_tool.fn
139
+ tool_schema.append(
140
+ callable_to_tool_schema(
141
+ fn, name=tool_name, description=tool_description
142
+ )
136
143
  )
137
- )
138
- ctx.log_debug("TOOL SCHEMA", tool_schema)
144
+ model_kwargs["tools"] = tool_schema
145
+ ctx.log_debug("TOOL SCHEMA", tool_schema)
139
146
  history_file = self._get_history_file(ctx)
140
147
  while True:
141
148
  response = await acompletion(
142
- model=self._get_model(ctx),
149
+ model=model,
143
150
  messages=[{"role": "system", "content": system_prompt}] + messages,
144
- tools=tool_schema,
145
151
  **model_kwargs,
146
152
  )
147
153
  response_message = response.choices[0].message
@@ -189,7 +195,7 @@ class LLMTask(BaseTask):
189
195
  def _get_model_kwargs(self, ctx: AnyContext) -> dict[str, Callable]:
190
196
  if callable(self._model_kwargs):
191
197
  return self._model_kwargs(ctx)
192
- return self._model_kwargs
198
+ return {**self._model_kwargs}
193
199
 
194
200
  def _get_tools(self, ctx: AnyContext) -> dict[str, Callable]:
195
201
  if callable(self._tools):
zrb/task/rsync_task.py CHANGED
@@ -24,13 +24,13 @@ class RsyncTask(CmdTask):
24
24
  remote_host: StrAttr | None = None,
25
25
  auto_render_remote_host: bool = True,
26
26
  remote_port: IntAttr | None = None,
27
- auto_render_remote_port: bool = True,
27
+ render_remote_port: bool = True,
28
28
  remote_user: StrAttr | None = None,
29
- auto_render_remote_user: bool = True,
29
+ render_remote_user: bool = True,
30
30
  remote_password: StrAttr | None = None,
31
- auto_render_remote_password: bool = True,
31
+ render_remote_password: bool = True,
32
32
  remote_ssh_key: StrAttr | None = None,
33
- auto_render_remote_ssh_key: bool = True,
33
+ render_remote_ssh_key: bool = True,
34
34
  remote_source_path: StrAttr | None = None,
35
35
  render_remote_source_path: bool = True,
36
36
  remote_destination_path: StrAttr | None = None,
@@ -63,13 +63,13 @@ class RsyncTask(CmdTask):
63
63
  remote_host=remote_host,
64
64
  render_remote_host=auto_render_remote_host,
65
65
  remote_port=remote_port,
66
- auto_render_remote_port=auto_render_remote_port,
66
+ auto_render_remote_port=render_remote_port,
67
67
  remote_user=remote_user,
68
- render_remote_user=auto_render_remote_user,
68
+ render_remote_user=render_remote_user,
69
69
  remote_password=remote_password,
70
- render_remote_password=auto_render_remote_password,
70
+ render_remote_password=render_remote_password,
71
71
  remote_ssh_key=remote_ssh_key,
72
- render_remote_ssh_key=auto_render_remote_ssh_key,
72
+ render_remote_ssh_key=render_remote_ssh_key,
73
73
  cwd=cwd,
74
74
  render_cwd=auto_render_cwd,
75
75
  max_output_line=max_output_line,
@@ -0,0 +1,33 @@
1
+ import re
2
+
3
+
4
+ def check_unrecommended_commands(cmd_script: str) -> dict[str, str]:
5
+ banned_commands = {
6
+ "<(": "Process substitution isn't POSIX compliant and causes trouble",
7
+ "column": "Command isn't included in Ubuntu packages and is not POSIX compliant",
8
+ "echo": "echo isn't consistent across OS; use printf instead",
9
+ "eval": "Avoid eval as it can accidentally execute arbitrary strings",
10
+ "realpath": "Not available by default on OSX",
11
+ "source": "Not POSIX compliant; use '.' instead",
12
+ " test": "Use '[' instead for consistency",
13
+ "which": "Command in not POSIX compliant, use command -v",
14
+ }
15
+ banned_commands_regex = {
16
+ r"grep.* -y": "grep -y does not work on Alpine; use grep -i",
17
+ r"grep.* -P": "grep -P is not valid on OSX",
18
+ r"grep[^|]+--\w{2,}": "grep long commands do not work on Alpine",
19
+ r'readlink.+-.*f.+["$]': "readlink -f behaves differently on OSX",
20
+ r"sort.*-V": "sort -V is not supported everywhere",
21
+ r"sort.*--sort-versions": "sort --sort-version is not supported everywhere",
22
+ r"\bls ": "Avoid using ls; use shell globs or find instead",
23
+ }
24
+ violations = {}
25
+ # Check banned commands
26
+ for cmd, reason in banned_commands.items():
27
+ if cmd in cmd_script:
28
+ violations[cmd] = reason
29
+ # Check banned regex patterns
30
+ for pattern, reason in banned_commands_regex.items():
31
+ if re.search(pattern, cmd_script):
32
+ violations[pattern] = reason
33
+ return violations
@@ -0,0 +1,38 @@
1
+ import libcst as cst
2
+
3
+
4
+ class ParentClassAdder(cst.CSTTransformer):
5
+ def __init__(self, class_name: str, parent_class_name: str):
6
+ self.class_name = class_name
7
+ self.parent_class_name = parent_class_name
8
+ self.class_found = False
9
+
10
+ def leave_ClassDef(
11
+ self, original_node: cst.ClassDef, updated_node: cst.ClassDef
12
+ ) -> cst.ClassDef:
13
+ # Check if this is the target class
14
+ if original_node.name.value == self.class_name:
15
+ self.class_found = True
16
+ # Add the parent class to the existing bases
17
+ new_bases = (
18
+ cst.Arg(value=cst.Name(self.parent_class_name)),
19
+ *updated_node.bases,
20
+ )
21
+ return updated_node.with_changes(bases=new_bases)
22
+ return updated_node
23
+
24
+
25
+ def add_parent_to_class(
26
+ original_code: str, class_name: str, parent_class_name: str
27
+ ) -> str:
28
+ # Parse the original code into a module
29
+ module = cst.parse_module(original_code)
30
+ # Initialize transformer with the class name and parent class name
31
+ transformer = ParentClassAdder(class_name, parent_class_name)
32
+ # Apply the transformation
33
+ modified_module = module.visit(transformer)
34
+ # Check if the class was found
35
+ if not transformer.class_found:
36
+ raise ValueError(f"Class {class_name} not found in the provided code.")
37
+ # Return the modified code
38
+ return modified_module.code
zrb/util/git.py CHANGED
@@ -1,3 +1,4 @@
1
+ import os
1
2
  import subprocess
2
3
 
3
4
  from pydantic import BaseModel
@@ -9,14 +10,19 @@ class DiffResult(BaseModel):
9
10
  updated: list[str]
10
11
 
11
12
 
12
- def get_diff(source_commit: str, current_commit: str) -> DiffResult:
13
- # git show b176b5a main
14
- exit_status, output = subprocess.getstatusoutput(
15
- f"git diff {source_commit} {current_commit}"
16
- )
17
- if exit_status != 0:
18
- raise Exception(output)
19
- lines = output.split("\n")
13
+ def get_diff(repo_dir: str, source_commit: str, current_commit: str) -> DiffResult:
14
+ try:
15
+ result = subprocess.run(
16
+ ["git", "diff", source_commit, current_commit],
17
+ stdout=subprocess.PIPE,
18
+ stderr=subprocess.PIPE,
19
+ cwd=repo_dir,
20
+ text=True,
21
+ check=True,
22
+ )
23
+ except subprocess.CalledProcessError as e:
24
+ raise Exception(e.stderr or e.stdout)
25
+ lines = result.stdout.strip().split("\n")
20
26
  diff: dict[str, dict[str, bool]] = {}
21
27
  for line in lines:
22
28
  if not line.startswith("---") and not line.startswith("+++"):
@@ -55,17 +61,18 @@ def get_repo_dir() -> str:
55
61
  check=True,
56
62
  )
57
63
  # Return the directory path
58
- return result.stdout.strip()
64
+ return os.path.abspath(result.stdout.strip())
59
65
  except subprocess.CalledProcessError as e:
60
66
  raise Exception(e.stderr or e.stdout)
61
67
 
62
68
 
63
- def get_current_branch() -> str:
69
+ def get_current_branch(repo_dir: str) -> str:
64
70
  try:
65
71
  result = subprocess.run(
66
72
  ["git", "rev-parse", "--abbrev-ref", "HEAD"],
67
73
  stdout=subprocess.PIPE,
68
74
  stderr=subprocess.PIPE,
75
+ cwd=repo_dir,
69
76
  text=True,
70
77
  check=True,
71
78
  )
@@ -74,12 +81,13 @@ def get_current_branch() -> str:
74
81
  raise Exception(e.stderr or e.stdout)
75
82
 
76
83
 
77
- def get_branches() -> list[str]:
84
+ def get_branches(repo_dir: str) -> list[str]:
78
85
  try:
79
86
  result = subprocess.run(
80
87
  ["git", "branch"],
81
88
  stdout=subprocess.PIPE,
82
89
  stderr=subprocess.PIPE,
90
+ cwd=repo_dir,
83
91
  text=True,
84
92
  check=True,
85
93
  )
@@ -90,12 +98,13 @@ def get_branches() -> list[str]:
90
98
  raise Exception(e.stderr or e.stdout)
91
99
 
92
100
 
93
- def delete_branch(branch_name: str) -> str:
101
+ def delete_branch(repo_dir: str, branch_name: str) -> str:
94
102
  try:
95
103
  result = subprocess.run(
96
104
  ["git", "branch", "-D", branch_name],
97
105
  stdout=subprocess.PIPE,
98
106
  stderr=subprocess.PIPE,
107
+ cwd=repo_dir,
99
108
  text=True,
100
109
  check=True,
101
110
  )
@@ -104,12 +113,13 @@ def delete_branch(branch_name: str) -> str:
104
113
  raise Exception(e.stderr or e.stdout)
105
114
 
106
115
 
107
- def add() -> str:
116
+ def add(repo_dir: str) -> str:
108
117
  try:
109
118
  subprocess.run(
110
119
  ["git", "add", ".", "-A"],
111
120
  stdout=subprocess.PIPE,
112
121
  stderr=subprocess.PIPE,
122
+ cwd=repo_dir,
113
123
  text=True,
114
124
  check=True,
115
125
  )
@@ -117,25 +127,32 @@ def add() -> str:
117
127
  raise Exception(e.stderr or e.stdout)
118
128
 
119
129
 
120
- def commit(message: str) -> str:
130
+ def commit(repo_dir: str, message: str) -> str:
121
131
  try:
122
132
  subprocess.run(
123
133
  ["git", "commit", "-m", message],
124
134
  stdout=subprocess.PIPE,
125
135
  stderr=subprocess.PIPE,
136
+ cwd=repo_dir,
126
137
  text=True,
127
138
  check=True,
128
139
  )
129
140
  except subprocess.CalledProcessError as e:
130
- raise Exception(e.stderr or e.stdout)
141
+ ignored_error_message = "nothing to commit, working tree clean"
142
+ if (
143
+ ignored_error_message not in e.stderr
144
+ and ignored_error_message not in e.stdout
145
+ ):
146
+ raise Exception(e.stderr or e.stdout)
131
147
 
132
148
 
133
- def pull(remote: str, branch: str) -> str:
149
+ def pull(repo_dir: str, remote: str, branch: str) -> str:
134
150
  try:
135
151
  subprocess.run(
136
152
  ["git", "pull", remote, branch],
137
153
  stdout=subprocess.PIPE,
138
154
  stderr=subprocess.PIPE,
155
+ cwd=repo_dir,
139
156
  text=True,
140
157
  check=True,
141
158
  )
@@ -143,12 +160,13 @@ def pull(remote: str, branch: str) -> str:
143
160
  raise Exception(e.stderr or e.stdout)
144
161
 
145
162
 
146
- def push(remote: str, branch: str) -> str:
163
+ def push(repo_dir: str, remote: str, branch: str) -> str:
147
164
  try:
148
165
  subprocess.run(
149
166
  ["git", "push", "-u", remote, branch],
150
167
  stdout=subprocess.PIPE,
151
168
  stderr=subprocess.PIPE,
169
+ cwd=repo_dir,
152
170
  text=True,
153
171
  check=True,
154
172
  )
zrb/util/git_subtree.py CHANGED
@@ -3,8 +3,6 @@ import subprocess
3
3
 
4
4
  from pydantic import BaseModel
5
5
 
6
- from zrb.util.git import get_repo_dir
7
-
8
6
 
9
7
  class SingleSubTreeConfig(BaseModel):
10
8
  repo_url: str
@@ -16,21 +14,21 @@ class SubTreeConfig(BaseModel):
16
14
  data: dict[str, SingleSubTreeConfig]
17
15
 
18
16
 
19
- def load_config() -> SubTreeConfig:
20
- file_path = os.path.join(get_repo_dir(), "subtrees.json")
17
+ def load_config(repo_dir: str) -> SubTreeConfig:
18
+ file_path = os.path.join(repo_dir, "subtrees.json")
21
19
  if not os.path.exists(file_path):
22
20
  return SubTreeConfig(data={})
23
21
  with open(file_path, "r") as f:
24
22
  return SubTreeConfig.model_validate_json(f.read())
25
23
 
26
24
 
27
- def save_config(config: SubTreeConfig):
28
- file_path = os.path.join(get_repo_dir(), "subtrees.json")
25
+ def save_config(repo_dir: str, config: SubTreeConfig):
26
+ file_path = os.path.join(repo_dir, "subtrees.json")
29
27
  with open(file_path, "w") as f:
30
28
  f.write(config.model_dump_json(indent=2))
31
29
 
32
30
 
33
- def add_subtree(name: str, repo_url: str, branch: str, prefix: str):
31
+ def add_subtree(repo_dir: str, name: str, repo_url: str, branch: str, prefix: str):
34
32
  config = load_config()
35
33
  if os.path.isdir(prefix):
36
34
  raise ValueError(f"Directory exists: {prefix}")
@@ -41,6 +39,7 @@ def add_subtree(name: str, repo_url: str, branch: str, prefix: str):
41
39
  ["git", "subtree", "add", "--prefix", prefix, repo_url, branch],
42
40
  stdout=subprocess.PIPE,
43
41
  stderr=subprocess.PIPE,
42
+ cwd=repo_dir,
44
43
  text=True,
45
44
  check=True,
46
45
  )
@@ -49,10 +48,10 @@ def add_subtree(name: str, repo_url: str, branch: str, prefix: str):
49
48
  config.data[name] = SingleSubTreeConfig(
50
49
  repo_url=repo_url, branch=branch, prefix=prefix
51
50
  )
52
- save_config(config)
51
+ save_config(repo_dir, config)
53
52
 
54
53
 
55
- def pull_subtree(prefix: str, repo_url: str, branch: str):
54
+ def pull_subtree(repo_dir: str, prefix: str, repo_url: str, branch: str):
56
55
  try:
57
56
  subprocess.run(
58
57
  [
@@ -66,6 +65,7 @@ def pull_subtree(prefix: str, repo_url: str, branch: str):
66
65
  ],
67
66
  stdout=subprocess.PIPE,
68
67
  stderr=subprocess.PIPE,
68
+ cwd=repo_dir,
69
69
  text=True,
70
70
  check=True,
71
71
  )
@@ -73,7 +73,7 @@ def pull_subtree(prefix: str, repo_url: str, branch: str):
73
73
  raise Exception(e.stderr or e.stdout)
74
74
 
75
75
 
76
- def push_subtree(prefix: str, repo_url: str, branch: str):
76
+ def push_subtree(repo_dir: str, prefix: str, repo_url: str, branch: str):
77
77
  try:
78
78
  subprocess.run(
79
79
  [
@@ -87,6 +87,7 @@ def push_subtree(prefix: str, repo_url: str, branch: str):
87
87
  ],
88
88
  stdout=subprocess.PIPE,
89
89
  stderr=subprocess.PIPE,
90
+ cwd=repo_dir,
90
91
  text=True,
91
92
  check=True,
92
93
  )
zrb/util/string/format.py CHANGED
@@ -1,9 +1,19 @@
1
+ import re
1
2
  from typing import Any
2
3
 
3
4
 
4
5
  def fstring_format(template: str, data: dict[str, Any]) -> str:
5
- # Safely evaluate the template as a Python expression
6
+ def replace_expr(match):
7
+ expr = match.group(1)
8
+ try:
9
+ result = eval(expr, {}, data)
10
+ return str(result)
11
+ except Exception as e:
12
+ raise ValueError(f"Failed to evaluate expression: {expr}: {e}")
13
+
14
+ # Use regex to find and replace all expressions in curly braces
15
+ pattern = r"\{([^}]+)\}"
6
16
  try:
7
- return eval(f'f"""{template}"""', {}, data)
17
+ return re.sub(pattern, replace_expr, template)
8
18
  except Exception as e:
9
19
  raise ValueError(f"Failed to parse template: {template}: {e}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: zrb
3
- Version: 1.0.0a5
3
+ Version: 1.0.0a10
4
4
  Summary: Your Automation Powerhouse
5
5
  Home-page: https://github.com/state-alchemists/zrb
6
6
  License: AGPL-3.0-or-later
@@ -13,6 +13,7 @@ Classifier: Programming Language :: Python :: 3
13
13
  Classifier: Programming Language :: Python :: 3.10
14
14
  Classifier: Programming Language :: Python :: 3.11
15
15
  Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
16
17
  Provides-Extra: rag
17
18
  Requires-Dist: autopep8 (>=2.0.4,<3.0.0)
18
19
  Requires-Dist: beautifulsoup4 (>=4.12.3,<5.0.0)