stores 0.0.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
stores/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ from stores.format import ProviderFormat
2
+ from stores.indexes import Index
3
+ from stores.parse import llm_parse_json
4
+
5
+ __all__ = [
6
+ "Index",
7
+ "ProviderFormat",
8
+ "llm_parse_json",
9
+ ]
stores/constants.py ADDED
@@ -0,0 +1,2 @@
1
+ VENV_NAME = ".venv"
2
+ TOOLS_CONFIG_FILENAME = "tools.toml"
stores/format.py ADDED
@@ -0,0 +1,214 @@
1
+ import inspect
2
+ import logging
3
+ import types as T
4
+ from enum import Enum
5
+ from itertools import chain
6
+ from typing import (
7
+ Callable,
8
+ Dict,
9
+ GenericAlias,
10
+ List,
11
+ Literal,
12
+ Tuple,
13
+ Type,
14
+ Union,
15
+ get_args,
16
+ get_origin,
17
+ get_type_hints,
18
+ )
19
+
20
+ from stores.utils import check_duplicates
21
+
22
+ logging.basicConfig()
23
+ logger = logging.getLogger("stores.format")
24
+ logger.setLevel(logging.INFO)
25
+
26
+
27
+ class ProviderFormat(str, Enum):
28
+ ANTHROPIC = "anthropic"
29
+ GOOGLE_GEMINI = "google-gemini"
30
+ OPENAI_CHAT = "openai-chat-completions"
31
+ OPENAI_RESPONSES = "openai-responses"
32
+
33
+
34
+ def get_type_repr(typ: Type | GenericAlias) -> list[str]:
35
+ origin = get_origin(typ)
36
+ args = get_args(typ)
37
+
38
+ if origin is Literal:
39
+ return list(dict.fromkeys(chain(*[get_type_repr(type(arg)) for arg in args])))
40
+ if inspect.isclass(typ) and issubclass(typ, Enum):
41
+ return list(dict.fromkeys(chain(*[get_type_repr(type(v.value)) for v in typ])))
42
+ if isinstance(typ, type) and typ.__class__.__name__ == "_TypedDictMeta":
43
+ return ["object"]
44
+ if origin in (list, List) or typ is list:
45
+ return ["array"]
46
+ if origin in (dict, Dict) or typ is dict:
47
+ return ["object"]
48
+ if origin in (tuple, Tuple) or typ is tuple:
49
+ return ["array"]
50
+ if origin is Union or origin is T.UnionType:
51
+ return list(dict.fromkeys(chain(*[get_type_repr(arg) for arg in args])))
52
+
53
+ type_mappings = {
54
+ "str": "string",
55
+ "int": "integer",
56
+ "bool": "boolean",
57
+ "float": "number",
58
+ "NoneType": "null",
59
+ }
60
+ if typ.__name__ in type_mappings:
61
+ return [type_mappings[typ.__name__]]
62
+
63
+
64
+ def get_type_schema(typ: Type | GenericAlias):
65
+ origin = get_origin(typ)
66
+ args = get_args(typ)
67
+
68
+ schema = {
69
+ "type": get_type_repr(typ),
70
+ # TODO: Retrieve description from Annotation if available
71
+ "description": "",
72
+ }
73
+
74
+ if origin is Literal:
75
+ schema["enum"] = list(args)
76
+ elif inspect.isclass(typ) and issubclass(typ, Enum):
77
+ schema["enum"] = [v.value for v in typ]
78
+ elif isinstance(typ, type) and typ.__class__.__name__ == "_TypedDictMeta":
79
+ hints = get_type_hints(typ)
80
+ schema["properties"] = {k: get_type_schema(v) for k, v in hints.items()}
81
+ schema["additionalProperties"] = False
82
+ schema["required"] = list(hints.keys())
83
+ elif origin in (list, List) or typ is dict:
84
+ if args:
85
+ schema["items"] = get_type_schema(args[0])
86
+ else:
87
+ raise TypeError("Insufficient argument type information")
88
+ elif origin in (dict, Dict) or typ is dict:
89
+ raise TypeError("Insufficient argument type information")
90
+ elif origin in (tuple, Tuple) or typ is tuple:
91
+ if args:
92
+ schema["items"] = get_type_schema(args[0])
93
+ else:
94
+ raise TypeError("Insufficient argument type information")
95
+ elif origin is Union or origin is T.UnionType:
96
+ for arg in args:
97
+ subschema = get_type_schema(arg)
98
+ del subschema["type"]
99
+ schema = {
100
+ **schema,
101
+ **subschema,
102
+ }
103
+
104
+ # Un-nest single member type lists since Gemini does not accept list of types
105
+ # Optional for OpenAI or Anthropic
106
+ if schema["type"] and len(schema["type"]) == 1:
107
+ schema["type"] = schema["type"][0]
108
+
109
+ return schema
110
+
111
+
112
+ def get_param_schema(param: inspect.Parameter, provider: ProviderFormat):
113
+ param_schema = get_type_schema(param.annotation)
114
+
115
+ if param_schema["type"] is None:
116
+ raise TypeError(f"Unsupported type: {param.annotation.__name__}")
117
+
118
+ if (
119
+ param.default is not inspect.Parameter.empty
120
+ and "null" not in param_schema["type"]
121
+ ):
122
+ if type(param_schema["type"]) is list:
123
+ param_schema["type"].append("null")
124
+ else:
125
+ param_schema["type"] = [param_schema["type"], "null"]
126
+
127
+ if provider == ProviderFormat.GOOGLE_GEMINI:
128
+ # Filter out "null" type
129
+ if type(param_schema["type"]) is list:
130
+ param_schema["type"] = [t for t in param_schema["type"] if t != "null"]
131
+ if len(param_schema["type"]) == 1:
132
+ param_schema["type"] = param_schema["type"][0]
133
+ # Check if there are still multiple types are provided for a single argument
134
+ if type(param_schema["type"]) is list:
135
+ logger.warning(
136
+ f"Gemini does not support a function argument with multiple types e.g. Union[str, int]; defaulting to first found non-null type: {param_schema['type'][0]}"
137
+ )
138
+ param_schema["type"] = param_schema["type"][0]
139
+ # Add nullable property for Gemini
140
+ param_schema["nullable"] = param.default is not inspect.Parameter.empty
141
+ if param_schema["type"] == "object":
142
+ logger.warning(
143
+ f'Type of argument {param.name} is {param.annotation}, which is being formatted as an "object" type. However, Gemini does not seem to officially support an "object" parameter type yet and success rate might be spotty. Proceed with caution, or refactor {param.name} into one of the basic supported types: [string, integer, boolean, array].'
144
+ )
145
+ return param_schema
146
+
147
+
148
+ def format_tools(
149
+ tools: list[Callable],
150
+ provider: ProviderFormat,
151
+ ):
152
+ """Format tools based on the provider's requirements."""
153
+
154
+ # Check for duplicate tool names
155
+ check_duplicates([t.__name__ for t in tools])
156
+
157
+ formatted_tools = []
158
+ for tool in tools:
159
+ # Extract parameters and their types from the tool's function signature
160
+ signature = inspect.signature(tool)
161
+ parameters = {}
162
+ required_params = []
163
+ for param_name, param in signature.parameters.items():
164
+ parameters[param_name] = get_param_schema(param, provider)
165
+ required_params.append(param_name)
166
+
167
+ # Create formatted tool structure based on provider
168
+ description = inspect.getdoc(tool) or "No description available."
169
+ input_schema = {
170
+ "type": "object",
171
+ "properties": parameters,
172
+ "required": required_params,
173
+ }
174
+
175
+ # Format tool based on provider
176
+ if provider == ProviderFormat.OPENAI_CHAT:
177
+ formatted_tool = {
178
+ "type": "function",
179
+ "function": {
180
+ # OpenAI only supports ^[a-zA-Z0-9_-]{1,64}$
181
+ "name": tool.__name__.replace(".", "-"),
182
+ "description": description,
183
+ "parameters": {**input_schema, "additionalProperties": False},
184
+ "strict": True,
185
+ },
186
+ }
187
+ elif provider == ProviderFormat.OPENAI_RESPONSES:
188
+ formatted_tool = {
189
+ "type": "function",
190
+ # OpenAI only supports ^[a-zA-Z0-9_-]{1,64}$
191
+ "name": tool.__name__.replace(".", "-"),
192
+ "description": description,
193
+ "parameters": {**input_schema, "additionalProperties": False},
194
+ }
195
+ elif provider == ProviderFormat.ANTHROPIC:
196
+ formatted_tool = {
197
+ # Claude only supports ^[a-zA-Z0-9_-]{1,64}$
198
+ "name": tool.__name__.replace(".", "-"),
199
+ "description": description,
200
+ "input_schema": input_schema,
201
+ }
202
+ elif provider == ProviderFormat.GOOGLE_GEMINI:
203
+ formatted_tool = {
204
+ "name": tool.__name__,
205
+ "parameters": {
206
+ "type": "object",
207
+ "description": description,
208
+ "properties": parameters,
209
+ "required": required_params,
210
+ },
211
+ }
212
+
213
+ formatted_tools.append(formatted_tool)
214
+ return formatted_tools
@@ -0,0 +1,11 @@
1
+ from .base_index import BaseIndex
2
+ from .index import Index
3
+ from .local_index import LocalIndex
4
+ from .remote_index import RemoteIndex
5
+
6
+ __all__ = [
7
+ "BaseIndex",
8
+ "Index",
9
+ "LocalIndex",
10
+ "RemoteIndex",
11
+ ]
@@ -0,0 +1,283 @@
1
+ import asyncio
2
+ import inspect
3
+ import logging
4
+ import re
5
+ from inspect import Parameter
6
+ from types import NoneType, UnionType
7
+ from typing import (
8
+ Any,
9
+ Callable,
10
+ List,
11
+ Literal,
12
+ Optional,
13
+ Tuple,
14
+ Union,
15
+ get_args,
16
+ get_origin,
17
+ get_type_hints,
18
+ )
19
+
20
+ from makefun import create_function
21
+
22
+ from stores.format import ProviderFormat, format_tools
23
+ from stores.parse import llm_parse_json
24
+ from stores.utils import check_duplicates
25
+
26
+ logging.basicConfig()
27
+ logger = logging.getLogger("stores.indexes.base_index")
28
+ logger.setLevel(logging.INFO)
29
+
30
+
31
+ def _cast_arg(value: Any, typ: type | tuple[type]):
32
+ try:
33
+ if isinstance(typ, tuple) and len(typ) == 1:
34
+ typ = typ[0]
35
+ typ_origin = get_origin(typ)
36
+ if typ in [float, int, str]:
37
+ return typ(value)
38
+ if typ is bool:
39
+ if isinstance(value, str) and value.lower() == "false":
40
+ return False
41
+ else:
42
+ return typ(value)
43
+ if typ_origin in (list, List) and isinstance(value, (list, tuple)):
44
+ return [_cast_arg(v, get_args(typ)) for v in value]
45
+ if typ_origin in (tuple, Tuple) and isinstance(value, (list, tuple)):
46
+ return tuple(_cast_arg(v, get_args(typ)) for v in value)
47
+ if isinstance(typ, type) and typ.__class__.__name__ == "_TypedDictMeta":
48
+ hints = get_type_hints(typ)
49
+ for k, v in value.items():
50
+ value[k] = _cast_arg(v, hints[k])
51
+ return value
52
+ if typ_origin in [Union, UnionType]:
53
+ if NoneType in get_args(typ) and value is None:
54
+ return value
55
+ valid_types = [a for a in get_args(typ) if a is not NoneType]
56
+ if len(valid_types) == 1:
57
+ return _cast_arg(value, valid_types[0])
58
+ except Exception:
59
+ pass
60
+ # If not in one of the cases above, we return value unchanged
61
+ return value
62
+
63
+
64
+ def _cast_bound_args(bound_args: inspect.BoundArguments):
65
+ """
66
+ In some packages, passed argument types are incorrect
67
+ e.g. LangChain returns float even when argtype is int
68
+ This only casts basic argtypes
69
+ """
70
+ for arg, argparam in bound_args.signature.parameters.items():
71
+ argtype = argparam.annotation
72
+ value = bound_args.arguments[arg]
73
+ new_value = _cast_arg(value, argtype)
74
+ if new_value != value:
75
+ # Warn that we are modifying value since this might not be expected
76
+ logger.warning(
77
+ f'Argument "{arg}" is type {argtype} but passed value is {value} of type {type(value)} - modifying value to {value} instead.'
78
+ )
79
+ bound_args.arguments[arg] = new_value
80
+
81
+ return bound_args
82
+
83
+
84
+ # TODO: Support more nested types
85
+ def _handle_non_string_literal(annotation: type):
86
+ origin = get_origin(annotation)
87
+ if origin is Literal:
88
+ if any([not isinstance(a, str) for a in get_args(annotation)]):
89
+ # TODO: Handle duplicates
90
+ literal_map = {str(a): a for a in get_args(annotation)}
91
+ new_annotation = Literal.__getitem__(tuple(literal_map.keys()))
92
+ return new_annotation, literal_map
93
+ else:
94
+ return annotation, {}
95
+ if origin in (list, List):
96
+ args = get_args(annotation)
97
+ new_annotation, literal_map = _handle_non_string_literal(args[0])
98
+ return list[new_annotation], {"item": literal_map}
99
+ if origin is Union or origin is UnionType:
100
+ union_literal_maps = {}
101
+ argtype_args = [a for a in get_args(annotation) if a != NoneType]
102
+ new_union, literal_map = _handle_non_string_literal(argtype_args[0])
103
+ union_literal_maps[new_union.__name__] = literal_map
104
+ for child_argtype in argtype_args[1:]:
105
+ new_annotation, literal_map = _handle_non_string_literal(child_argtype)
106
+ new_union = new_union | new_annotation
107
+ union_literal_maps[new_annotation.__name__] = literal_map
108
+ return new_union, union_literal_maps
109
+ return annotation, {}
110
+
111
+
112
+ # TODO: Support more nested types
113
+ def _undo_non_string_literal(annotation: type, value: Any, literal_map: dict):
114
+ origin = get_origin(annotation)
115
+ if origin is Literal:
116
+ return literal_map.get(value, value)
117
+ if origin in (list, List) and isinstance(value, (list, tuple)):
118
+ args = get_args(annotation)
119
+ return [
120
+ _undo_non_string_literal(args[0], v, literal_map["item"]) for v in value
121
+ ]
122
+ if origin is Union or origin is UnionType:
123
+ for arg in get_args(annotation):
124
+ try:
125
+ return _undo_non_string_literal(arg, value, literal_map[arg.__name__])
126
+ except Exception:
127
+ pass
128
+ return value
129
+
130
+
131
+ def wrap_tool(tool: Callable):
132
+ """
133
+ Wrap tool to make it compatible with LLM libraries
134
+ - Gemini does not accept non-None default values
135
+ If there are any default args, we set default value to None
136
+ and inject the correct default value at runtime.
137
+ - Gemini does not accept non-string Literals
138
+ We convert non-string Literals to strings and reset this at runtime
139
+ """
140
+ if hasattr(tool, "_wrapped") and tool._wrapped:
141
+ return tool
142
+
143
+ # Retrieve default arguments
144
+ original_signature = inspect.signature(tool)
145
+ new_args = []
146
+ literal_maps = {}
147
+ for arg in original_signature.parameters.values():
148
+ new_arg = arg
149
+
150
+ # Handle non-string Literals
151
+ argtype = new_arg.annotation
152
+ new_annotation, literal_map = _handle_non_string_literal(argtype)
153
+ literal_maps[arg.name] = literal_map
154
+ new_arg = new_arg.replace(
155
+ kind=Parameter.POSITIONAL_OR_KEYWORD,
156
+ annotation=new_annotation,
157
+ )
158
+
159
+ # Handle defaults
160
+ argtype = new_arg.annotation
161
+ if new_arg.default is Parameter.empty:
162
+ # If it's annotated with Optional or Union[None, X]
163
+ # remove the Optional tag since no default value is supplied
164
+ origin = get_origin(argtype)
165
+ if (origin in [Union, UnionType]) and NoneType in get_args(argtype):
166
+ argtype_args = [a for a in get_args(argtype) if a != NoneType]
167
+ new_annotation = argtype_args[0]
168
+ for child_argtype in argtype_args[1:]:
169
+ new_annotation = new_annotation | child_argtype
170
+ new_arg = new_arg.replace(
171
+ kind=Parameter.POSITIONAL_OR_KEYWORD,
172
+ annotation=new_annotation,
173
+ )
174
+ else:
175
+ # Process args with default values: make sure type includes None
176
+ new_annotation = argtype
177
+ if new_annotation is Parameter.empty:
178
+ new_annotation = Optional[type(new_arg.default)]
179
+ origin = get_origin(new_annotation)
180
+ if origin not in [Union, UnionType] or NoneType not in get_args(
181
+ new_annotation
182
+ ):
183
+ new_annotation = Optional[new_annotation]
184
+ new_arg = new_arg.replace(
185
+ default=None,
186
+ kind=Parameter.POSITIONAL_OR_KEYWORD,
187
+ annotation=new_annotation,
188
+ )
189
+ new_args.append(new_arg)
190
+ new_sig = original_signature.replace(parameters=new_args)
191
+
192
+ if inspect.iscoroutinefunction(tool):
193
+
194
+ async def wrapper(*args, **kwargs):
195
+ # Inject default values within wrapper
196
+ bound_args = original_signature.bind(*args, **kwargs)
197
+ bound_args.apply_defaults()
198
+ _cast_bound_args(bound_args)
199
+ # Inject correct Literals
200
+ for k, v in bound_args.arguments.items():
201
+ if k in literal_maps:
202
+ param = original_signature.parameters[k]
203
+ bound_args.arguments[k] = _undo_non_string_literal(
204
+ param.annotation, v, literal_maps[k]
205
+ )
206
+ return await tool(*bound_args.args, **bound_args.kwargs)
207
+ else:
208
+
209
+ def wrapper(*args, **kwargs):
210
+ # Inject default values within wrapper
211
+ bound_args = original_signature.bind(*args, **kwargs)
212
+ bound_args.apply_defaults()
213
+ # Inject correct Literals
214
+ for k, v in bound_args.arguments.items():
215
+ if (
216
+ v is None
217
+ and original_signature.parameters[k].default is not Parameter.empty
218
+ ):
219
+ bound_args.arguments[k] = original_signature.parameters[k].default
220
+
221
+ _cast_bound_args(bound_args)
222
+ for k, v in bound_args.arguments.items():
223
+ if k in literal_maps:
224
+ param = original_signature.parameters[k]
225
+ bound_args.arguments[k] = _undo_non_string_literal(
226
+ param.annotation, v, literal_maps[k]
227
+ )
228
+ return tool(*bound_args.args, **bound_args.kwargs)
229
+
230
+ wrapped = create_function(
231
+ new_sig,
232
+ wrapper,
233
+ qualname=tool.__name__,
234
+ doc=inspect.getdoc(tool),
235
+ )
236
+
237
+ wrapped.__name__ = tool.__name__
238
+ wrapped._wrapped = True
239
+
240
+ return wrapped
241
+
242
+
243
+ class BaseIndex:
244
+ def __init__(self, tools: list[Callable]):
245
+ check_duplicates([t.__name__ for t in tools])
246
+ self.tools = [wrap_tool(t) for t in tools]
247
+
248
+ @property
249
+ def tools_dict(self):
250
+ return {tool.__name__: tool for tool in self.tools}
251
+
252
+ def execute(self, toolname: str, kwargs: dict | None = None):
253
+ kwargs = kwargs or {}
254
+
255
+ # Use regex since we need to match cases where we perform
256
+ # substitutions such as replace(".", "-")
257
+ pattern = re.compile(":?" + re.sub("-|\\.", "(-|\\.)", toolname) + "$")
258
+
259
+ matching_tools = []
260
+ for key in self.tools_dict.keys():
261
+ if pattern.match(key):
262
+ matching_tools.append(key)
263
+ if len(matching_tools) == 0:
264
+ raise ValueError(f"No tool matching '{toolname}'")
265
+ elif len(matching_tools) > 1:
266
+ raise ValueError(f"'{toolname}' matches multiple tools - {matching_tools}")
267
+ else:
268
+ toolname = matching_tools[0]
269
+
270
+ tool = self.tools_dict[toolname]
271
+ if inspect.iscoroutinefunction(tool):
272
+ loop = asyncio.new_event_loop()
273
+ asyncio.set_event_loop(loop)
274
+ return loop.run_until_complete(tool(**kwargs))
275
+ else:
276
+ return tool(**kwargs)
277
+
278
+ def parse_and_execute(self, msg: str):
279
+ toolcall = llm_parse_json(msg, keys=["toolname", "kwargs"])
280
+ return self.execute(toolcall.get("toolname"), toolcall.get("kwargs"))
281
+
282
+ def format_tools(self, provider: ProviderFormat):
283
+ return format_tools(self.tools, provider)
@@ -0,0 +1,56 @@
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Callable
5
+
6
+ from stores.indexes.base_index import BaseIndex
7
+ from stores.indexes.local_index import LocalIndex
8
+ from stores.indexes.remote_index import RemoteIndex
9
+
10
+ logging.basicConfig()
11
+ logger = logging.getLogger("stores.index")
12
+ logger.setLevel(logging.INFO)
13
+
14
+
15
+ class Index(BaseIndex):
16
+ def __init__(
17
+ self,
18
+ tools: list[Callable, os.PathLike] | None = None,
19
+ env_var: dict[str, dict] | None = None,
20
+ ):
21
+ self.env_var = env_var or {}
22
+ tools = tools or []
23
+
24
+ _tools = []
25
+ for tool in tools:
26
+ if isinstance(tool, (str, Path)):
27
+ index_name = tool
28
+ loaded_index = None
29
+ if Path(index_name).exists():
30
+ # Load LocalIndex
31
+ try:
32
+ loaded_index = LocalIndex(index_name)
33
+ except Exception:
34
+ logger.warning(
35
+ f'Unable to load index "{index_name}"', exc_info=True
36
+ )
37
+ if loaded_index is None and isinstance(index_name, str):
38
+ # Load RemoteIndex
39
+ try:
40
+ loaded_index = RemoteIndex(
41
+ index_name, env_var=self.env_var.get(index_name)
42
+ )
43
+ except Exception:
44
+ logger.warning(
45
+ f'Unable to load index "{index_name}"\nIf this is a local index, make sure it can be found as a directory and contains a tools.toml file.',
46
+ exc_info=True,
47
+ )
48
+ if loaded_index is None:
49
+ raise ValueError(
50
+ f'Unable to load index "{index_name}"\nIf this is a local index, make sure it can be found as a directory and contains a tools.toml file.'
51
+ )
52
+ _tools += loaded_index.tools
53
+ elif isinstance(tool, Callable):
54
+ _tools.append(tool)
55
+
56
+ super().__init__(_tools)
@@ -0,0 +1,84 @@
1
+ import importlib
2
+ import logging
3
+ import os
4
+ import sys
5
+ import venv
6
+ from pathlib import Path
7
+
8
+ from stores.constants import TOOLS_CONFIG_FILENAME, VENV_NAME
9
+ from stores.indexes.base_index import BaseIndex
10
+ from stores.indexes.venv_utils import init_venv_tools, install_venv_deps
11
+
12
+ if sys.version_info >= (3, 11):
13
+ import tomllib
14
+ else:
15
+ import tomli as tomllib
16
+
17
+ logging.basicConfig()
18
+ logger = logging.getLogger("stores.indexes.local_index")
19
+ logger.setLevel(logging.INFO)
20
+
21
+
22
+ class LocalIndex(BaseIndex):
23
+ def __init__(
24
+ self,
25
+ index_folder: os.PathLike,
26
+ create_venv: bool = False,
27
+ env_var: dict | None = None,
28
+ ):
29
+ self.index_folder = Path(index_folder)
30
+ self.env_var = env_var or {}
31
+
32
+ if not self.index_folder.exists():
33
+ raise ValueError(
34
+ f"Unable to load index - {self.index_folder} does not exist"
35
+ )
36
+
37
+ if create_venv:
38
+ # Create venv and install deps
39
+ self.venv = self.index_folder / VENV_NAME
40
+ if not self.venv.exists():
41
+ venv.create(self.venv, symlinks=True, with_pip=True, upgrade_deps=True)
42
+ install_venv_deps(self.index_folder)
43
+ # Initialize tools
44
+ tools = init_venv_tools(self.index_folder, self.env_var)
45
+ else:
46
+ if self.env_var:
47
+ raise ValueError(
48
+ "Environment variables will only be restricted if create_venv=True when initializing LocalIndex"
49
+ )
50
+ tools = self._init_tools()
51
+ super().__init__(tools)
52
+
53
+ def _init_tools(self):
54
+ """
55
+ Load local tools.toml file and import tool functions
56
+
57
+ NOTE: Can we just add index_folder to sys.path and import the functions?
58
+ """
59
+ index_manifest = self.index_folder / TOOLS_CONFIG_FILENAME
60
+ if not index_manifest.exists():
61
+ raise ValueError(f"Unable to load index - {index_manifest} does not exist")
62
+
63
+ with open(index_manifest, "rb") as file:
64
+ manifest = tomllib.load(file)["index"]
65
+
66
+ tools = []
67
+ for tool_id in manifest.get("tools", []):
68
+ module_name = ".".join(tool_id.split(".")[:-1])
69
+ tool_name = tool_id.split(".")[-1]
70
+
71
+ module_file = self.index_folder / module_name.replace(".", "/")
72
+ if (module_file / "__init__.py").exists():
73
+ module_file = module_file / "__init__.py"
74
+ else:
75
+ module_file = Path(str(module_file) + ".py")
76
+
77
+ spec = importlib.util.spec_from_file_location(module_name, module_file)
78
+ module = importlib.util.module_from_spec(spec)
79
+ sys.modules[spec.name] = module
80
+ spec.loader.exec_module(module)
81
+ tool = getattr(module, tool_name)
82
+ tool.__name__ = tool_id
83
+ tools.append(tool)
84
+ return tools