starbash 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of starbash might be problematic. Click here for more details.
- starbash/__init__.py +6 -0
- starbash/app.py +228 -0
- starbash/appdefaults.sb.toml +89 -0
- starbash/commands/__init__.py +0 -0
- starbash/commands/repo.py +51 -0
- starbash/database.py +75 -0
- starbash/main.py +27 -0
- starbash/repo/__init__.py +7 -0
- starbash/repo/manager.py +248 -0
- starbash/tool.py +260 -0
- starbash-0.1.0.dist-info/LICENSE +674 -0
- starbash-0.1.0.dist-info/METADATA +82 -0
- starbash-0.1.0.dist-info/RECORD +15 -0
- starbash-0.1.0.dist-info/WHEEL +4 -0
- starbash-0.1.0.dist-info/entry_points.txt +4 -0
starbash/repo/manager.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Manages the repository of processing recipes and configurations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
import logging
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import tomlkit
|
|
10
|
+
from tomlkit.items import AoT
|
|
11
|
+
from multidict import MultiDict
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
repo_suffix = "starbash.toml"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Repo:
|
|
18
|
+
"""
|
|
19
|
+
Represents a single starbash repository."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, manager: RepoManager, url: str, config: str | None = None):
|
|
22
|
+
"""
|
|
23
|
+
Initializes a Repo instance.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
url: The URL to the repository (file or general http/https urls are acceptable).
|
|
27
|
+
"""
|
|
28
|
+
self.manager = manager
|
|
29
|
+
self.url = url
|
|
30
|
+
self.config = tomlkit.parse(config) if config else self._load_config()
|
|
31
|
+
self.manager.add_all_repos(self.config, self.get_path())
|
|
32
|
+
|
|
33
|
+
def __str__(self) -> str:
|
|
34
|
+
"""Return a concise one-line description of this repo.
|
|
35
|
+
|
|
36
|
+
Example: "Repo(kind=recipe, local=True, url=file:///path/to/repo)"
|
|
37
|
+
"""
|
|
38
|
+
return f"Repo(kind={self.kind}, local={self.is_local}, url={self.url})"
|
|
39
|
+
|
|
40
|
+
__repr__ = __str__
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def kind(self) -> str:
|
|
44
|
+
"""
|
|
45
|
+
Read-only attribute for the repository kind (e.g., "recipe", "data", etc.).
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
The kind of the repository as a string.
|
|
49
|
+
"""
|
|
50
|
+
return str(self.get("repo.kind", "unknown"))
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def is_local(self) -> bool:
|
|
54
|
+
"""
|
|
55
|
+
Read-only attribute indicating whether the repository URL points to a
|
|
56
|
+
local file system path (file:// scheme).
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
bool: True if the URL is a local file path, False otherwise.
|
|
60
|
+
"""
|
|
61
|
+
return self.url.startswith("file://")
|
|
62
|
+
|
|
63
|
+
def get_path(self) -> Path | None:
|
|
64
|
+
"""
|
|
65
|
+
Resolves the URL to a local file system path if it's a file URI.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
url: The repository URL.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
A Path object if the URL is a local file, otherwise None.
|
|
72
|
+
"""
|
|
73
|
+
if self.is_local:
|
|
74
|
+
return Path(self.url[len("file://") :])
|
|
75
|
+
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
def read(self, filepath: str) -> str:
|
|
79
|
+
"""
|
|
80
|
+
Read a filepath relative to the base of this repo. Return the contents in a string.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
filepath: The path to the file, relative to the repository root.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
The content of the file as a string.
|
|
87
|
+
"""
|
|
88
|
+
base_path = self.get_path()
|
|
89
|
+
if base_path is None:
|
|
90
|
+
raise ValueError("Cannot read files from non-local repositories")
|
|
91
|
+
target_path = (base_path / filepath).resolve()
|
|
92
|
+
|
|
93
|
+
# Security check to prevent reading files outside the repo directory
|
|
94
|
+
if base_path not in target_path.parents and target_path != base_path:
|
|
95
|
+
raise PermissionError("Attempted to read file outside of repository")
|
|
96
|
+
|
|
97
|
+
return target_path.read_text()
|
|
98
|
+
|
|
99
|
+
def _load_config(self) -> dict:
|
|
100
|
+
"""
|
|
101
|
+
Loads the repository's configuration file (e.g., repo.sb.toml).
|
|
102
|
+
|
|
103
|
+
If the config file does not exist, it logs a warning and returns an empty dict.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
A dictionary containing the parsed configuration.
|
|
107
|
+
"""
|
|
108
|
+
try:
|
|
109
|
+
config_content = self.read(repo_suffix)
|
|
110
|
+
logging.debug(f"Loading repo config from {repo_suffix}")
|
|
111
|
+
return tomlkit.parse(config_content)
|
|
112
|
+
except FileNotFoundError:
|
|
113
|
+
logging.warning(f"No {repo_suffix} found")
|
|
114
|
+
return {}
|
|
115
|
+
|
|
116
|
+
def get(self, key: str, default=None):
|
|
117
|
+
"""
|
|
118
|
+
Gets a value from this repo's config for a given key.
|
|
119
|
+
The key can be a dot-separated string for nested values.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
key: The dot-separated key to search for (e.g., "repo.kind").
|
|
123
|
+
default: The value to return if the key is not found.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
The found value or the default.
|
|
127
|
+
"""
|
|
128
|
+
value = self.config
|
|
129
|
+
for k in key.split("."):
|
|
130
|
+
if not isinstance(value, dict):
|
|
131
|
+
return default
|
|
132
|
+
value = value.get(k)
|
|
133
|
+
return value if value is not None else default
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class RepoManager:
|
|
137
|
+
"""
|
|
138
|
+
Manages the collection of starbash repositories.
|
|
139
|
+
|
|
140
|
+
This class is responsible for finding, loading, and providing an API
|
|
141
|
+
for searching through known repositories defined in TOML configuration
|
|
142
|
+
files (like appdefaults.sb.toml).
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __init__(self, app_defaults: str):
|
|
146
|
+
"""
|
|
147
|
+
Initializes the RepoManager by loading the application default repos.
|
|
148
|
+
"""
|
|
149
|
+
self.repos = []
|
|
150
|
+
|
|
151
|
+
# We expose the app default preferences as a special root repo with a private URL
|
|
152
|
+
root_repo = Repo(self, "pkg://starbash-defaults", config=app_defaults)
|
|
153
|
+
self.repos.append(root_repo)
|
|
154
|
+
|
|
155
|
+
# Most users will just want to read from merged
|
|
156
|
+
self.merged = self._union()
|
|
157
|
+
|
|
158
|
+
def add_all_repos(self, toml: dict, base_path: Path | None = None) -> None:
|
|
159
|
+
# From appdefaults.sb.toml, repo.ref is a list of tables
|
|
160
|
+
repo_refs = toml.get("repo", {}).get("ref", [])
|
|
161
|
+
|
|
162
|
+
for ref in repo_refs:
|
|
163
|
+
if "url" in ref:
|
|
164
|
+
url = ref["url"]
|
|
165
|
+
elif "dir" in ref:
|
|
166
|
+
path = Path(ref["dir"])
|
|
167
|
+
if base_path and not path.is_absolute():
|
|
168
|
+
# Resolve relative to the current TOML file's directory
|
|
169
|
+
path = (base_path / path).resolve()
|
|
170
|
+
else:
|
|
171
|
+
# Expand ~ and resolve from CWD
|
|
172
|
+
path = path.expanduser().resolve()
|
|
173
|
+
url = f"file://{path}"
|
|
174
|
+
else:
|
|
175
|
+
raise ValueError(f"Invalid repo reference: {ref}")
|
|
176
|
+
self.add_repo(url)
|
|
177
|
+
|
|
178
|
+
def add_repo(self, url: str) -> None:
|
|
179
|
+
logging.debug(f"Adding repo: {url}")
|
|
180
|
+
self.repos.append(Repo(self, url))
|
|
181
|
+
|
|
182
|
+
def get(self, key: str, default=None):
|
|
183
|
+
"""
|
|
184
|
+
Searches for a key across all repositories and returns the first value found.
|
|
185
|
+
The search is performed in reverse order of repository loading, so the
|
|
186
|
+
most recently added repositories have precedence.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
key: The dot-separated key to search for (e.g., "repo.kind").
|
|
190
|
+
default: The value to return if the key is not found in any repo.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
The found value or the default.
|
|
194
|
+
"""
|
|
195
|
+
# Iterate in reverse to give precedence to later-loaded repos
|
|
196
|
+
for repo in reversed(self.repos):
|
|
197
|
+
value = repo.get(key)
|
|
198
|
+
if value is not None:
|
|
199
|
+
return value
|
|
200
|
+
|
|
201
|
+
return default
|
|
202
|
+
|
|
203
|
+
def dump(self):
|
|
204
|
+
"""
|
|
205
|
+
Prints a detailed, multi-line description of the combined top-level keys
|
|
206
|
+
and values from all repositories, using a MultiDict for aggregation.
|
|
207
|
+
This is useful for debugging and inspecting the consolidated configuration.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
combined_config = self.merged
|
|
211
|
+
logging.info("RepoManager Dump")
|
|
212
|
+
for key, value in combined_config.items():
|
|
213
|
+
# tomlkit.items() can return complex types (e.g., ArrayOfTables, Table)
|
|
214
|
+
# For a debug dump, a simple string representation is usually sufficient.
|
|
215
|
+
logging.info(f" %s: %s", key, value)
|
|
216
|
+
|
|
217
|
+
def _union(self) -> MultiDict:
|
|
218
|
+
"""
|
|
219
|
+
Merges the top-level keys from all repository configurations into a MultiDict.
|
|
220
|
+
|
|
221
|
+
This method iterates through all loaded repositories in their original order
|
|
222
|
+
and combines their top-level configuration keys. If a key exists in multiple
|
|
223
|
+
repositories, all of its values will be present in the returned MultiDict.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
A MultiDict containing the union of all top-level keys.
|
|
227
|
+
"""
|
|
228
|
+
merged_dict = MultiDict()
|
|
229
|
+
for repo in self.repos:
|
|
230
|
+
for key, value in repo.config.items():
|
|
231
|
+
# if the toml object is an AoT type, monkey patch each element in the array instead
|
|
232
|
+
if isinstance(value, AoT):
|
|
233
|
+
for v in value:
|
|
234
|
+
setattr(v, "source", repo)
|
|
235
|
+
else:
|
|
236
|
+
# We monkey patch source into any object that came from a repo, so that users can
|
|
237
|
+
# find the source repo (for attribution, URL relative resolution, whatever...)
|
|
238
|
+
setattr(value, "source", repo)
|
|
239
|
+
|
|
240
|
+
merged_dict.add(key, value)
|
|
241
|
+
|
|
242
|
+
return merged_dict
|
|
243
|
+
|
|
244
|
+
def __str__(self):
|
|
245
|
+
lines = [f"RepoManager with {len(self.repos)} repositories:"]
|
|
246
|
+
for i, repo in enumerate(self.repos):
|
|
247
|
+
lines.append(f" [{i}] {repo.url}")
|
|
248
|
+
return "\n".join(lines)
|
starbash/tool.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
import textwrap
|
|
4
|
+
import tempfile
|
|
5
|
+
import subprocess
|
|
6
|
+
import re
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
import RestrictedPython
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class _SafeFormatter(dict):
|
|
16
|
+
"""A dictionary for safe string formatting that ignores missing keys during expansion."""
|
|
17
|
+
|
|
18
|
+
def __missing__(self, key):
|
|
19
|
+
return "{" + key + "}"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def expand_context(s: str, context: dict) -> str:
|
|
23
|
+
"""Expand any named variables in the provided string
|
|
24
|
+
|
|
25
|
+
Will expand strings of the form MyStr{somevar}a{someothervar} using vars listed in context.
|
|
26
|
+
Guaranteed safe, doesn't run any python scripts.
|
|
27
|
+
"""
|
|
28
|
+
# Iteratively expand the command string to handle nested placeholders.
|
|
29
|
+
# The loop continues until the string no longer changes.
|
|
30
|
+
expanded = s
|
|
31
|
+
previous = None
|
|
32
|
+
max_iterations = 10 # Safety break for infinite recursion
|
|
33
|
+
for i in range(max_iterations):
|
|
34
|
+
if expanded == previous:
|
|
35
|
+
break # Expansion is complete
|
|
36
|
+
previous = expanded
|
|
37
|
+
expanded = expanded.format_map(_SafeFormatter(context))
|
|
38
|
+
else:
|
|
39
|
+
logger.warning(
|
|
40
|
+
f"Template expansion reached max iterations ({max_iterations}). Possible recursive definition in '{s}'."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
logger.debug(f"Expanded '{s}' into '{expanded}'")
|
|
44
|
+
|
|
45
|
+
# throw an error if any remaining unexpanded variables remain unexpanded
|
|
46
|
+
unexpanded_vars = re.findall(r"\{([^{}]+)\}", expanded)
|
|
47
|
+
if unexpanded_vars:
|
|
48
|
+
raise KeyError("Missing context variable(s): " + ", ".join(unexpanded_vars))
|
|
49
|
+
|
|
50
|
+
return expanded
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def make_safe_globals(context: dict = {}) -> dict:
|
|
54
|
+
"""Generate a set of RestrictedPython globals for AstoGlue exec/eval usage"""
|
|
55
|
+
# Define the global and local namespaces for the restricted execution.
|
|
56
|
+
# FIXME - this is still unsafe, policies need to be added to limit import/getattr etc...
|
|
57
|
+
# see https://restrictedpython.readthedocs.io/en/latest/usage/policy.html#implementing-a-policy
|
|
58
|
+
|
|
59
|
+
builtins = RestrictedPython.safe_builtins.copy()
|
|
60
|
+
|
|
61
|
+
def write_test(obj):
|
|
62
|
+
"""``_write_`` is a guard function taking a single argument. If the
|
|
63
|
+
object passed to it may be written to, it should be returned,
|
|
64
|
+
otherwise the guard function should raise an exception. ``_write_``
|
|
65
|
+
is typically called on an object before a ``setattr`` operation."""
|
|
66
|
+
return obj
|
|
67
|
+
|
|
68
|
+
def getitem_glue(baseobj, index):
|
|
69
|
+
return baseobj[index]
|
|
70
|
+
|
|
71
|
+
extras = {
|
|
72
|
+
"__import__": __import__, # FIXME very unsafe
|
|
73
|
+
"_getitem_": getitem_glue, # why isn't the default guarded getitem found?
|
|
74
|
+
"_getiter_": iter, # Allows for loops and other iterations.
|
|
75
|
+
"_write_": write_test,
|
|
76
|
+
# Add common built-in types
|
|
77
|
+
"list": list,
|
|
78
|
+
"dict": dict,
|
|
79
|
+
"str": str,
|
|
80
|
+
"int": int,
|
|
81
|
+
"all": all,
|
|
82
|
+
}
|
|
83
|
+
builtins.update(extras)
|
|
84
|
+
|
|
85
|
+
execution_globals = {
|
|
86
|
+
# Required for RestrictedPython
|
|
87
|
+
"__builtins__": builtins,
|
|
88
|
+
"__name__": "__starbash_script__",
|
|
89
|
+
"__metaclass__": type,
|
|
90
|
+
# Extra globals auto imported into the scripts context
|
|
91
|
+
"context": context,
|
|
92
|
+
"logger": logging.getLogger("script"), # Allow logging within the script
|
|
93
|
+
}
|
|
94
|
+
return execution_globals
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def strip_comments(text: str) -> str:
|
|
98
|
+
"""Removes comments from a string.
|
|
99
|
+
|
|
100
|
+
This function removes both full-line comments (lines starting with '#')
|
|
101
|
+
and inline comments (text after '#' on a line).
|
|
102
|
+
"""
|
|
103
|
+
lines = []
|
|
104
|
+
for line in text.splitlines():
|
|
105
|
+
lines.append(line.split("#", 1)[0].rstrip())
|
|
106
|
+
return "\n".join(lines)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def tool_run(cmd: str, cwd: str, commands: str | None = None) -> None:
|
|
110
|
+
"""Executes an external tool with an optional script of commands in a given working directory."""
|
|
111
|
+
|
|
112
|
+
logger.debug(f"Running {cmd} in {cwd}: stdin={commands}")
|
|
113
|
+
result = subprocess.run(
|
|
114
|
+
cmd, input=commands, shell=True, capture_output=True, text=True, cwd=cwd
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if result.stderr:
|
|
118
|
+
logger.warning(f"Tool error message:\n{result.stderr.strip()}")
|
|
119
|
+
|
|
120
|
+
if result.returncode != 0:
|
|
121
|
+
# If we got an error, print the entire tool stdout as a warning
|
|
122
|
+
logger.warning(f"Tool output:\n{result.stdout.strip()}")
|
|
123
|
+
raise RuntimeError(f"Tool failed with exit code {result.returncode}")
|
|
124
|
+
else:
|
|
125
|
+
logger.debug("Tool command successful.")
|
|
126
|
+
|
|
127
|
+
if result.stdout:
|
|
128
|
+
logger.debug(f"Tool output:\n{result.stdout.strip()}")
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# siril_path = "/home/kevinh/packages/Siril-1.4.0~beta3-x86_64.AppImage"
|
|
132
|
+
siril_path = "org.siril.Siril" # flatpak
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def siril_run(temp_dir: str, commands: str, input_files: list[str] = []) -> None:
|
|
136
|
+
"""Executes Siril with a script of commands in a given working directory."""
|
|
137
|
+
|
|
138
|
+
# Create symbolic links for all input files in the temp directory
|
|
139
|
+
for f in input_files:
|
|
140
|
+
os.symlink(
|
|
141
|
+
os.path.abspath(str(f)), os.path.join(temp_dir, os.path.basename(str(f)))
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# We dedent here because the commands are often indented multiline strings
|
|
145
|
+
script_content = textwrap.dedent(
|
|
146
|
+
f"""
|
|
147
|
+
requires 1.4.0-beta3
|
|
148
|
+
{textwrap.dedent(strip_comments(commands))}
|
|
149
|
+
"""
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
logger.info(
|
|
153
|
+
f"Running Siril in {temp_dir}, ({len(input_files)} input files) cmds:\n{script_content}"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# The `-s -` arguments tell Siril to run in script mode and read commands from stdin.
|
|
157
|
+
# It seems like the -d command may also be required when siril is in a flatpak
|
|
158
|
+
cmd = f"{siril_path} -d {temp_dir} -s -"
|
|
159
|
+
|
|
160
|
+
tool_run(cmd, temp_dir, script_content)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def graxpert_run(cwd: str, arguments: str) -> None:
|
|
164
|
+
"""Executes Graxpert with the specified command line arguments"""
|
|
165
|
+
|
|
166
|
+
# Arguments look similar to: graxpert -cmd background-extraction -output /tmp/testout tests/test_images/real_crummy.fits
|
|
167
|
+
cmd = f"graxpert {arguments}"
|
|
168
|
+
|
|
169
|
+
tool_run(cmd, cwd)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class Tool:
|
|
173
|
+
"""A tool for stage execution"""
|
|
174
|
+
|
|
175
|
+
def __init__(self, name: str) -> None:
|
|
176
|
+
self.name = name
|
|
177
|
+
|
|
178
|
+
# default script file name
|
|
179
|
+
self.default_script_file = None
|
|
180
|
+
|
|
181
|
+
def run_in_temp_dir(self, commands: str, context: dict = {}) -> None:
|
|
182
|
+
"""Run commands inside this tool (with cwd pointing to a temp directory)"""
|
|
183
|
+
# Create a temporary directory for processing
|
|
184
|
+
temp_dir = tempfile.mkdtemp(prefix=self.name)
|
|
185
|
+
|
|
186
|
+
context["temp_dir"] = (
|
|
187
|
+
temp_dir # pass our directory path in for the tool's usage
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
self.run(temp_dir, commands, context=context)
|
|
192
|
+
finally:
|
|
193
|
+
shutil.rmtree(temp_dir)
|
|
194
|
+
|
|
195
|
+
def run(self, cwd: str, commands: str, context: dict = {}) -> None:
|
|
196
|
+
"""Run commands inside this tool (with cwd pointing to the specified directory)"""
|
|
197
|
+
raise NotImplementedError()
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class SirilTool(Tool):
|
|
201
|
+
"""Expose Siril as a tool"""
|
|
202
|
+
|
|
203
|
+
def __init__(self) -> None:
|
|
204
|
+
super().__init__("siril")
|
|
205
|
+
|
|
206
|
+
def run(self, cwd: str, commands: str, context: dict = {}) -> None:
|
|
207
|
+
|
|
208
|
+
# Iteratively expand the command string to handle nested placeholders.
|
|
209
|
+
# The loop continues until the string no longer changes.
|
|
210
|
+
expanded = expand_context(commands, context)
|
|
211
|
+
|
|
212
|
+
input_files = context.get("input_files", [])
|
|
213
|
+
|
|
214
|
+
siril_run(cwd, expanded, input_files)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class GraxpertTool(Tool):
|
|
218
|
+
"""Expose Graxpert as a tool"""
|
|
219
|
+
|
|
220
|
+
def __init__(self) -> None:
|
|
221
|
+
super().__init__("graxpert")
|
|
222
|
+
|
|
223
|
+
def run(self, cwd: str, commands: str, context: dict = {}) -> None:
|
|
224
|
+
graxpert_run(cwd, commands)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class PythonTool(Tool):
|
|
228
|
+
"""Expose Python as a tool"""
|
|
229
|
+
|
|
230
|
+
def __init__(self) -> None:
|
|
231
|
+
super().__init__("python")
|
|
232
|
+
|
|
233
|
+
# default script file override
|
|
234
|
+
self.default_script_file = "starbash.py"
|
|
235
|
+
|
|
236
|
+
def run(self, cwd: str, commands: str, context: dict = {}) -> None:
|
|
237
|
+
original_cwd = os.getcwd()
|
|
238
|
+
try:
|
|
239
|
+
os.chdir(cwd) # cd to where this script expects to run
|
|
240
|
+
|
|
241
|
+
logger.info(f"Executing python script in {cwd} using RestrictedPython")
|
|
242
|
+
try:
|
|
243
|
+
byte_code = RestrictedPython.compile_restricted(
|
|
244
|
+
commands, filename="<python script>", mode="exec"
|
|
245
|
+
)
|
|
246
|
+
# No locals yet
|
|
247
|
+
execution_locals = None
|
|
248
|
+
exec(byte_code, make_safe_globals(context), execution_locals)
|
|
249
|
+
except SyntaxError as e:
|
|
250
|
+
logger.error(f"Syntax error in python script: {e}")
|
|
251
|
+
raise
|
|
252
|
+
except Exception as e:
|
|
253
|
+
logger.error(f"Error during python script execution: {e}")
|
|
254
|
+
raise
|
|
255
|
+
finally:
|
|
256
|
+
os.chdir(original_cwd)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
# A dictionary mapping tool names to their respective tool instances.
|
|
260
|
+
tools = {tool.name: tool for tool in [SirilTool(), GraxpertTool(), PythonTool()]}
|