stores 0.0.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stores/__init__.py +9 -0
- stores/constants.py +2 -0
- stores/format.py +214 -0
- stores/indexes/__init__.py +11 -0
- stores/indexes/base_index.py +283 -0
- stores/indexes/index.py +56 -0
- stores/indexes/local_index.py +84 -0
- stores/indexes/remote_index.py +76 -0
- stores/indexes/venv_utils.py +376 -0
- stores/parse.py +144 -0
- stores/utils.py +8 -0
- stores-0.1.1.dist-info/METADATA +85 -0
- stores-0.1.1.dist-info/RECORD +15 -0
- {stores-0.0.0.dist-info → stores-0.1.1.dist-info}/WHEEL +1 -2
- stores-0.1.1.dist-info/licenses/LICENSE +21 -0
- stores-0.0.0.dist-info/METADATA +0 -19
- stores-0.0.0.dist-info/RECORD +0 -4
- stores-0.0.0.dist-info/top_level.txt +0 -1
stores/__init__.py
ADDED
stores/constants.py
ADDED
stores/format.py
ADDED
@@ -0,0 +1,214 @@
|
|
1
|
+
import inspect
|
2
|
+
import logging
|
3
|
+
import types as T
|
4
|
+
from enum import Enum
|
5
|
+
from itertools import chain
|
6
|
+
from typing import (
|
7
|
+
Callable,
|
8
|
+
Dict,
|
9
|
+
GenericAlias,
|
10
|
+
List,
|
11
|
+
Literal,
|
12
|
+
Tuple,
|
13
|
+
Type,
|
14
|
+
Union,
|
15
|
+
get_args,
|
16
|
+
get_origin,
|
17
|
+
get_type_hints,
|
18
|
+
)
|
19
|
+
|
20
|
+
from stores.utils import check_duplicates
|
21
|
+
|
22
|
+
logging.basicConfig()
|
23
|
+
logger = logging.getLogger("stores.format")
|
24
|
+
logger.setLevel(logging.INFO)
|
25
|
+
|
26
|
+
|
27
|
+
class ProviderFormat(str, Enum):
|
28
|
+
ANTHROPIC = "anthropic"
|
29
|
+
GOOGLE_GEMINI = "google-gemini"
|
30
|
+
OPENAI_CHAT = "openai-chat-completions"
|
31
|
+
OPENAI_RESPONSES = "openai-responses"
|
32
|
+
|
33
|
+
|
34
|
+
def get_type_repr(typ: Type | GenericAlias) -> list[str]:
|
35
|
+
origin = get_origin(typ)
|
36
|
+
args = get_args(typ)
|
37
|
+
|
38
|
+
if origin is Literal:
|
39
|
+
return list(dict.fromkeys(chain(*[get_type_repr(type(arg)) for arg in args])))
|
40
|
+
if inspect.isclass(typ) and issubclass(typ, Enum):
|
41
|
+
return list(dict.fromkeys(chain(*[get_type_repr(type(v.value)) for v in typ])))
|
42
|
+
if isinstance(typ, type) and typ.__class__.__name__ == "_TypedDictMeta":
|
43
|
+
return ["object"]
|
44
|
+
if origin in (list, List) or typ is list:
|
45
|
+
return ["array"]
|
46
|
+
if origin in (dict, Dict) or typ is dict:
|
47
|
+
return ["object"]
|
48
|
+
if origin in (tuple, Tuple) or typ is tuple:
|
49
|
+
return ["array"]
|
50
|
+
if origin is Union or origin is T.UnionType:
|
51
|
+
return list(dict.fromkeys(chain(*[get_type_repr(arg) for arg in args])))
|
52
|
+
|
53
|
+
type_mappings = {
|
54
|
+
"str": "string",
|
55
|
+
"int": "integer",
|
56
|
+
"bool": "boolean",
|
57
|
+
"float": "number",
|
58
|
+
"NoneType": "null",
|
59
|
+
}
|
60
|
+
if typ.__name__ in type_mappings:
|
61
|
+
return [type_mappings[typ.__name__]]
|
62
|
+
|
63
|
+
|
64
|
+
def get_type_schema(typ: Type | GenericAlias):
|
65
|
+
origin = get_origin(typ)
|
66
|
+
args = get_args(typ)
|
67
|
+
|
68
|
+
schema = {
|
69
|
+
"type": get_type_repr(typ),
|
70
|
+
# TODO: Retrieve description from Annotation if available
|
71
|
+
"description": "",
|
72
|
+
}
|
73
|
+
|
74
|
+
if origin is Literal:
|
75
|
+
schema["enum"] = list(args)
|
76
|
+
elif inspect.isclass(typ) and issubclass(typ, Enum):
|
77
|
+
schema["enum"] = [v.value for v in typ]
|
78
|
+
elif isinstance(typ, type) and typ.__class__.__name__ == "_TypedDictMeta":
|
79
|
+
hints = get_type_hints(typ)
|
80
|
+
schema["properties"] = {k: get_type_schema(v) for k, v in hints.items()}
|
81
|
+
schema["additionalProperties"] = False
|
82
|
+
schema["required"] = list(hints.keys())
|
83
|
+
elif origin in (list, List) or typ is dict:
|
84
|
+
if args:
|
85
|
+
schema["items"] = get_type_schema(args[0])
|
86
|
+
else:
|
87
|
+
raise TypeError("Insufficient argument type information")
|
88
|
+
elif origin in (dict, Dict) or typ is dict:
|
89
|
+
raise TypeError("Insufficient argument type information")
|
90
|
+
elif origin in (tuple, Tuple) or typ is tuple:
|
91
|
+
if args:
|
92
|
+
schema["items"] = get_type_schema(args[0])
|
93
|
+
else:
|
94
|
+
raise TypeError("Insufficient argument type information")
|
95
|
+
elif origin is Union or origin is T.UnionType:
|
96
|
+
for arg in args:
|
97
|
+
subschema = get_type_schema(arg)
|
98
|
+
del subschema["type"]
|
99
|
+
schema = {
|
100
|
+
**schema,
|
101
|
+
**subschema,
|
102
|
+
}
|
103
|
+
|
104
|
+
# Un-nest single member type lists since Gemini does not accept list of types
|
105
|
+
# Optional for OpenAI or Anthropic
|
106
|
+
if schema["type"] and len(schema["type"]) == 1:
|
107
|
+
schema["type"] = schema["type"][0]
|
108
|
+
|
109
|
+
return schema
|
110
|
+
|
111
|
+
|
112
|
+
def get_param_schema(param: inspect.Parameter, provider: ProviderFormat):
|
113
|
+
param_schema = get_type_schema(param.annotation)
|
114
|
+
|
115
|
+
if param_schema["type"] is None:
|
116
|
+
raise TypeError(f"Unsupported type: {param.annotation.__name__}")
|
117
|
+
|
118
|
+
if (
|
119
|
+
param.default is not inspect.Parameter.empty
|
120
|
+
and "null" not in param_schema["type"]
|
121
|
+
):
|
122
|
+
if type(param_schema["type"]) is list:
|
123
|
+
param_schema["type"].append("null")
|
124
|
+
else:
|
125
|
+
param_schema["type"] = [param_schema["type"], "null"]
|
126
|
+
|
127
|
+
if provider == ProviderFormat.GOOGLE_GEMINI:
|
128
|
+
# Filter out "null" type
|
129
|
+
if type(param_schema["type"]) is list:
|
130
|
+
param_schema["type"] = [t for t in param_schema["type"] if t != "null"]
|
131
|
+
if len(param_schema["type"]) == 1:
|
132
|
+
param_schema["type"] = param_schema["type"][0]
|
133
|
+
# Check if there are still multiple types are provided for a single argument
|
134
|
+
if type(param_schema["type"]) is list:
|
135
|
+
logger.warning(
|
136
|
+
f"Gemini does not support a function argument with multiple types e.g. Union[str, int]; defaulting to first found non-null type: {param_schema['type'][0]}"
|
137
|
+
)
|
138
|
+
param_schema["type"] = param_schema["type"][0]
|
139
|
+
# Add nullable property for Gemini
|
140
|
+
param_schema["nullable"] = param.default is not inspect.Parameter.empty
|
141
|
+
if param_schema["type"] == "object":
|
142
|
+
logger.warning(
|
143
|
+
f'Type of argument {param.name} is {param.annotation}, which is being formatted as an "object" type. However, Gemini does not seem to officially support an "object" parameter type yet and success rate might be spotty. Proceed with caution, or refactor {param.name} into one of the basic supported types: [string, integer, boolean, array].'
|
144
|
+
)
|
145
|
+
return param_schema
|
146
|
+
|
147
|
+
|
148
|
+
def format_tools(
|
149
|
+
tools: list[Callable],
|
150
|
+
provider: ProviderFormat,
|
151
|
+
):
|
152
|
+
"""Format tools based on the provider's requirements."""
|
153
|
+
|
154
|
+
# Check for duplicate tool names
|
155
|
+
check_duplicates([t.__name__ for t in tools])
|
156
|
+
|
157
|
+
formatted_tools = []
|
158
|
+
for tool in tools:
|
159
|
+
# Extract parameters and their types from the tool's function signature
|
160
|
+
signature = inspect.signature(tool)
|
161
|
+
parameters = {}
|
162
|
+
required_params = []
|
163
|
+
for param_name, param in signature.parameters.items():
|
164
|
+
parameters[param_name] = get_param_schema(param, provider)
|
165
|
+
required_params.append(param_name)
|
166
|
+
|
167
|
+
# Create formatted tool structure based on provider
|
168
|
+
description = inspect.getdoc(tool) or "No description available."
|
169
|
+
input_schema = {
|
170
|
+
"type": "object",
|
171
|
+
"properties": parameters,
|
172
|
+
"required": required_params,
|
173
|
+
}
|
174
|
+
|
175
|
+
# Format tool based on provider
|
176
|
+
if provider == ProviderFormat.OPENAI_CHAT:
|
177
|
+
formatted_tool = {
|
178
|
+
"type": "function",
|
179
|
+
"function": {
|
180
|
+
# OpenAI only supports ^[a-zA-Z0-9_-]{1,64}$
|
181
|
+
"name": tool.__name__.replace(".", "-"),
|
182
|
+
"description": description,
|
183
|
+
"parameters": {**input_schema, "additionalProperties": False},
|
184
|
+
"strict": True,
|
185
|
+
},
|
186
|
+
}
|
187
|
+
elif provider == ProviderFormat.OPENAI_RESPONSES:
|
188
|
+
formatted_tool = {
|
189
|
+
"type": "function",
|
190
|
+
# OpenAI only supports ^[a-zA-Z0-9_-]{1,64}$
|
191
|
+
"name": tool.__name__.replace(".", "-"),
|
192
|
+
"description": description,
|
193
|
+
"parameters": {**input_schema, "additionalProperties": False},
|
194
|
+
}
|
195
|
+
elif provider == ProviderFormat.ANTHROPIC:
|
196
|
+
formatted_tool = {
|
197
|
+
# Claude only supports ^[a-zA-Z0-9_-]{1,64}$
|
198
|
+
"name": tool.__name__.replace(".", "-"),
|
199
|
+
"description": description,
|
200
|
+
"input_schema": input_schema,
|
201
|
+
}
|
202
|
+
elif provider == ProviderFormat.GOOGLE_GEMINI:
|
203
|
+
formatted_tool = {
|
204
|
+
"name": tool.__name__,
|
205
|
+
"parameters": {
|
206
|
+
"type": "object",
|
207
|
+
"description": description,
|
208
|
+
"properties": parameters,
|
209
|
+
"required": required_params,
|
210
|
+
},
|
211
|
+
}
|
212
|
+
|
213
|
+
formatted_tools.append(formatted_tool)
|
214
|
+
return formatted_tools
|
@@ -0,0 +1,283 @@
|
|
1
|
+
import asyncio
|
2
|
+
import inspect
|
3
|
+
import logging
|
4
|
+
import re
|
5
|
+
from inspect import Parameter
|
6
|
+
from types import NoneType, UnionType
|
7
|
+
from typing import (
|
8
|
+
Any,
|
9
|
+
Callable,
|
10
|
+
List,
|
11
|
+
Literal,
|
12
|
+
Optional,
|
13
|
+
Tuple,
|
14
|
+
Union,
|
15
|
+
get_args,
|
16
|
+
get_origin,
|
17
|
+
get_type_hints,
|
18
|
+
)
|
19
|
+
|
20
|
+
from makefun import create_function
|
21
|
+
|
22
|
+
from stores.format import ProviderFormat, format_tools
|
23
|
+
from stores.parse import llm_parse_json
|
24
|
+
from stores.utils import check_duplicates
|
25
|
+
|
26
|
+
logging.basicConfig()
|
27
|
+
logger = logging.getLogger("stores.indexes.base_index")
|
28
|
+
logger.setLevel(logging.INFO)
|
29
|
+
|
30
|
+
|
31
|
+
def _cast_arg(value: Any, typ: type | tuple[type]):
|
32
|
+
try:
|
33
|
+
if isinstance(typ, tuple) and len(typ) == 1:
|
34
|
+
typ = typ[0]
|
35
|
+
typ_origin = get_origin(typ)
|
36
|
+
if typ in [float, int, str]:
|
37
|
+
return typ(value)
|
38
|
+
if typ is bool:
|
39
|
+
if isinstance(value, str) and value.lower() == "false":
|
40
|
+
return False
|
41
|
+
else:
|
42
|
+
return typ(value)
|
43
|
+
if typ_origin in (list, List) and isinstance(value, (list, tuple)):
|
44
|
+
return [_cast_arg(v, get_args(typ)) for v in value]
|
45
|
+
if typ_origin in (tuple, Tuple) and isinstance(value, (list, tuple)):
|
46
|
+
return tuple(_cast_arg(v, get_args(typ)) for v in value)
|
47
|
+
if isinstance(typ, type) and typ.__class__.__name__ == "_TypedDictMeta":
|
48
|
+
hints = get_type_hints(typ)
|
49
|
+
for k, v in value.items():
|
50
|
+
value[k] = _cast_arg(v, hints[k])
|
51
|
+
return value
|
52
|
+
if typ_origin in [Union, UnionType]:
|
53
|
+
if NoneType in get_args(typ) and value is None:
|
54
|
+
return value
|
55
|
+
valid_types = [a for a in get_args(typ) if a is not NoneType]
|
56
|
+
if len(valid_types) == 1:
|
57
|
+
return _cast_arg(value, valid_types[0])
|
58
|
+
except Exception:
|
59
|
+
pass
|
60
|
+
# If not in one of the cases above, we return value unchanged
|
61
|
+
return value
|
62
|
+
|
63
|
+
|
64
|
+
def _cast_bound_args(bound_args: inspect.BoundArguments):
|
65
|
+
"""
|
66
|
+
In some packages, passed argument types are incorrect
|
67
|
+
e.g. LangChain returns float even when argtype is int
|
68
|
+
This only casts basic argtypes
|
69
|
+
"""
|
70
|
+
for arg, argparam in bound_args.signature.parameters.items():
|
71
|
+
argtype = argparam.annotation
|
72
|
+
value = bound_args.arguments[arg]
|
73
|
+
new_value = _cast_arg(value, argtype)
|
74
|
+
if new_value != value:
|
75
|
+
# Warn that we are modifying value since this might not be expected
|
76
|
+
logger.warning(
|
77
|
+
f'Argument "{arg}" is type {argtype} but passed value is {value} of type {type(value)} - modifying value to {value} instead.'
|
78
|
+
)
|
79
|
+
bound_args.arguments[arg] = new_value
|
80
|
+
|
81
|
+
return bound_args
|
82
|
+
|
83
|
+
|
84
|
+
# TODO: Support more nested types
|
85
|
+
def _handle_non_string_literal(annotation: type):
|
86
|
+
origin = get_origin(annotation)
|
87
|
+
if origin is Literal:
|
88
|
+
if any([not isinstance(a, str) for a in get_args(annotation)]):
|
89
|
+
# TODO: Handle duplicates
|
90
|
+
literal_map = {str(a): a for a in get_args(annotation)}
|
91
|
+
new_annotation = Literal.__getitem__(tuple(literal_map.keys()))
|
92
|
+
return new_annotation, literal_map
|
93
|
+
else:
|
94
|
+
return annotation, {}
|
95
|
+
if origin in (list, List):
|
96
|
+
args = get_args(annotation)
|
97
|
+
new_annotation, literal_map = _handle_non_string_literal(args[0])
|
98
|
+
return list[new_annotation], {"item": literal_map}
|
99
|
+
if origin is Union or origin is UnionType:
|
100
|
+
union_literal_maps = {}
|
101
|
+
argtype_args = [a for a in get_args(annotation) if a != NoneType]
|
102
|
+
new_union, literal_map = _handle_non_string_literal(argtype_args[0])
|
103
|
+
union_literal_maps[new_union.__name__] = literal_map
|
104
|
+
for child_argtype in argtype_args[1:]:
|
105
|
+
new_annotation, literal_map = _handle_non_string_literal(child_argtype)
|
106
|
+
new_union = new_union | new_annotation
|
107
|
+
union_literal_maps[new_annotation.__name__] = literal_map
|
108
|
+
return new_union, union_literal_maps
|
109
|
+
return annotation, {}
|
110
|
+
|
111
|
+
|
112
|
+
# TODO: Support more nested types
|
113
|
+
def _undo_non_string_literal(annotation: type, value: Any, literal_map: dict):
|
114
|
+
origin = get_origin(annotation)
|
115
|
+
if origin is Literal:
|
116
|
+
return literal_map.get(value, value)
|
117
|
+
if origin in (list, List) and isinstance(value, (list, tuple)):
|
118
|
+
args = get_args(annotation)
|
119
|
+
return [
|
120
|
+
_undo_non_string_literal(args[0], v, literal_map["item"]) for v in value
|
121
|
+
]
|
122
|
+
if origin is Union or origin is UnionType:
|
123
|
+
for arg in get_args(annotation):
|
124
|
+
try:
|
125
|
+
return _undo_non_string_literal(arg, value, literal_map[arg.__name__])
|
126
|
+
except Exception:
|
127
|
+
pass
|
128
|
+
return value
|
129
|
+
|
130
|
+
|
131
|
+
def wrap_tool(tool: Callable):
|
132
|
+
"""
|
133
|
+
Wrap tool to make it compatible with LLM libraries
|
134
|
+
- Gemini does not accept non-None default values
|
135
|
+
If there are any default args, we set default value to None
|
136
|
+
and inject the correct default value at runtime.
|
137
|
+
- Gemini does not accept non-string Literals
|
138
|
+
We convert non-string Literals to strings and reset this at runtime
|
139
|
+
"""
|
140
|
+
if hasattr(tool, "_wrapped") and tool._wrapped:
|
141
|
+
return tool
|
142
|
+
|
143
|
+
# Retrieve default arguments
|
144
|
+
original_signature = inspect.signature(tool)
|
145
|
+
new_args = []
|
146
|
+
literal_maps = {}
|
147
|
+
for arg in original_signature.parameters.values():
|
148
|
+
new_arg = arg
|
149
|
+
|
150
|
+
# Handle non-string Literals
|
151
|
+
argtype = new_arg.annotation
|
152
|
+
new_annotation, literal_map = _handle_non_string_literal(argtype)
|
153
|
+
literal_maps[arg.name] = literal_map
|
154
|
+
new_arg = new_arg.replace(
|
155
|
+
kind=Parameter.POSITIONAL_OR_KEYWORD,
|
156
|
+
annotation=new_annotation,
|
157
|
+
)
|
158
|
+
|
159
|
+
# Handle defaults
|
160
|
+
argtype = new_arg.annotation
|
161
|
+
if new_arg.default is Parameter.empty:
|
162
|
+
# If it's annotated with Optional or Union[None, X]
|
163
|
+
# remove the Optional tag since no default value is supplied
|
164
|
+
origin = get_origin(argtype)
|
165
|
+
if (origin in [Union, UnionType]) and NoneType in get_args(argtype):
|
166
|
+
argtype_args = [a for a in get_args(argtype) if a != NoneType]
|
167
|
+
new_annotation = argtype_args[0]
|
168
|
+
for child_argtype in argtype_args[1:]:
|
169
|
+
new_annotation = new_annotation | child_argtype
|
170
|
+
new_arg = new_arg.replace(
|
171
|
+
kind=Parameter.POSITIONAL_OR_KEYWORD,
|
172
|
+
annotation=new_annotation,
|
173
|
+
)
|
174
|
+
else:
|
175
|
+
# Process args with default values: make sure type includes None
|
176
|
+
new_annotation = argtype
|
177
|
+
if new_annotation is Parameter.empty:
|
178
|
+
new_annotation = Optional[type(new_arg.default)]
|
179
|
+
origin = get_origin(new_annotation)
|
180
|
+
if origin not in [Union, UnionType] or NoneType not in get_args(
|
181
|
+
new_annotation
|
182
|
+
):
|
183
|
+
new_annotation = Optional[new_annotation]
|
184
|
+
new_arg = new_arg.replace(
|
185
|
+
default=None,
|
186
|
+
kind=Parameter.POSITIONAL_OR_KEYWORD,
|
187
|
+
annotation=new_annotation,
|
188
|
+
)
|
189
|
+
new_args.append(new_arg)
|
190
|
+
new_sig = original_signature.replace(parameters=new_args)
|
191
|
+
|
192
|
+
if inspect.iscoroutinefunction(tool):
|
193
|
+
|
194
|
+
async def wrapper(*args, **kwargs):
|
195
|
+
# Inject default values within wrapper
|
196
|
+
bound_args = original_signature.bind(*args, **kwargs)
|
197
|
+
bound_args.apply_defaults()
|
198
|
+
_cast_bound_args(bound_args)
|
199
|
+
# Inject correct Literals
|
200
|
+
for k, v in bound_args.arguments.items():
|
201
|
+
if k in literal_maps:
|
202
|
+
param = original_signature.parameters[k]
|
203
|
+
bound_args.arguments[k] = _undo_non_string_literal(
|
204
|
+
param.annotation, v, literal_maps[k]
|
205
|
+
)
|
206
|
+
return await tool(*bound_args.args, **bound_args.kwargs)
|
207
|
+
else:
|
208
|
+
|
209
|
+
def wrapper(*args, **kwargs):
|
210
|
+
# Inject default values within wrapper
|
211
|
+
bound_args = original_signature.bind(*args, **kwargs)
|
212
|
+
bound_args.apply_defaults()
|
213
|
+
# Inject correct Literals
|
214
|
+
for k, v in bound_args.arguments.items():
|
215
|
+
if (
|
216
|
+
v is None
|
217
|
+
and original_signature.parameters[k].default is not Parameter.empty
|
218
|
+
):
|
219
|
+
bound_args.arguments[k] = original_signature.parameters[k].default
|
220
|
+
|
221
|
+
_cast_bound_args(bound_args)
|
222
|
+
for k, v in bound_args.arguments.items():
|
223
|
+
if k in literal_maps:
|
224
|
+
param = original_signature.parameters[k]
|
225
|
+
bound_args.arguments[k] = _undo_non_string_literal(
|
226
|
+
param.annotation, v, literal_maps[k]
|
227
|
+
)
|
228
|
+
return tool(*bound_args.args, **bound_args.kwargs)
|
229
|
+
|
230
|
+
wrapped = create_function(
|
231
|
+
new_sig,
|
232
|
+
wrapper,
|
233
|
+
qualname=tool.__name__,
|
234
|
+
doc=inspect.getdoc(tool),
|
235
|
+
)
|
236
|
+
|
237
|
+
wrapped.__name__ = tool.__name__
|
238
|
+
wrapped._wrapped = True
|
239
|
+
|
240
|
+
return wrapped
|
241
|
+
|
242
|
+
|
243
|
+
class BaseIndex:
|
244
|
+
def __init__(self, tools: list[Callable]):
|
245
|
+
check_duplicates([t.__name__ for t in tools])
|
246
|
+
self.tools = [wrap_tool(t) for t in tools]
|
247
|
+
|
248
|
+
@property
|
249
|
+
def tools_dict(self):
|
250
|
+
return {tool.__name__: tool for tool in self.tools}
|
251
|
+
|
252
|
+
def execute(self, toolname: str, kwargs: dict | None = None):
|
253
|
+
kwargs = kwargs or {}
|
254
|
+
|
255
|
+
# Use regex since we need to match cases where we perform
|
256
|
+
# substitutions such as replace(".", "-")
|
257
|
+
pattern = re.compile(":?" + re.sub("-|\\.", "(-|\\.)", toolname) + "$")
|
258
|
+
|
259
|
+
matching_tools = []
|
260
|
+
for key in self.tools_dict.keys():
|
261
|
+
if pattern.match(key):
|
262
|
+
matching_tools.append(key)
|
263
|
+
if len(matching_tools) == 0:
|
264
|
+
raise ValueError(f"No tool matching '{toolname}'")
|
265
|
+
elif len(matching_tools) > 1:
|
266
|
+
raise ValueError(f"'{toolname}' matches multiple tools - {matching_tools}")
|
267
|
+
else:
|
268
|
+
toolname = matching_tools[0]
|
269
|
+
|
270
|
+
tool = self.tools_dict[toolname]
|
271
|
+
if inspect.iscoroutinefunction(tool):
|
272
|
+
loop = asyncio.new_event_loop()
|
273
|
+
asyncio.set_event_loop(loop)
|
274
|
+
return loop.run_until_complete(tool(**kwargs))
|
275
|
+
else:
|
276
|
+
return tool(**kwargs)
|
277
|
+
|
278
|
+
def parse_and_execute(self, msg: str):
|
279
|
+
toolcall = llm_parse_json(msg, keys=["toolname", "kwargs"])
|
280
|
+
return self.execute(toolcall.get("toolname"), toolcall.get("kwargs"))
|
281
|
+
|
282
|
+
def format_tools(self, provider: ProviderFormat):
|
283
|
+
return format_tools(self.tools, provider)
|
stores/indexes/index.py
ADDED
@@ -0,0 +1,56 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Callable
|
5
|
+
|
6
|
+
from stores.indexes.base_index import BaseIndex
|
7
|
+
from stores.indexes.local_index import LocalIndex
|
8
|
+
from stores.indexes.remote_index import RemoteIndex
|
9
|
+
|
10
|
+
logging.basicConfig()
|
11
|
+
logger = logging.getLogger("stores.index")
|
12
|
+
logger.setLevel(logging.INFO)
|
13
|
+
|
14
|
+
|
15
|
+
class Index(BaseIndex):
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
tools: list[Callable, os.PathLike] | None = None,
|
19
|
+
env_var: dict[str, dict] | None = None,
|
20
|
+
):
|
21
|
+
self.env_var = env_var or {}
|
22
|
+
tools = tools or []
|
23
|
+
|
24
|
+
_tools = []
|
25
|
+
for tool in tools:
|
26
|
+
if isinstance(tool, (str, Path)):
|
27
|
+
index_name = tool
|
28
|
+
loaded_index = None
|
29
|
+
if Path(index_name).exists():
|
30
|
+
# Load LocalIndex
|
31
|
+
try:
|
32
|
+
loaded_index = LocalIndex(index_name)
|
33
|
+
except Exception:
|
34
|
+
logger.warning(
|
35
|
+
f'Unable to load index "{index_name}"', exc_info=True
|
36
|
+
)
|
37
|
+
if loaded_index is None and isinstance(index_name, str):
|
38
|
+
# Load RemoteIndex
|
39
|
+
try:
|
40
|
+
loaded_index = RemoteIndex(
|
41
|
+
index_name, env_var=self.env_var.get(index_name)
|
42
|
+
)
|
43
|
+
except Exception:
|
44
|
+
logger.warning(
|
45
|
+
f'Unable to load index "{index_name}"\nIf this is a local index, make sure it can be found as a directory and contains a tools.toml file.',
|
46
|
+
exc_info=True,
|
47
|
+
)
|
48
|
+
if loaded_index is None:
|
49
|
+
raise ValueError(
|
50
|
+
f'Unable to load index "{index_name}"\nIf this is a local index, make sure it can be found as a directory and contains a tools.toml file.'
|
51
|
+
)
|
52
|
+
_tools += loaded_index.tools
|
53
|
+
elif isinstance(tool, Callable):
|
54
|
+
_tools.append(tool)
|
55
|
+
|
56
|
+
super().__init__(_tools)
|
@@ -0,0 +1,84 @@
|
|
1
|
+
import importlib
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import sys
|
5
|
+
import venv
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
from stores.constants import TOOLS_CONFIG_FILENAME, VENV_NAME
|
9
|
+
from stores.indexes.base_index import BaseIndex
|
10
|
+
from stores.indexes.venv_utils import init_venv_tools, install_venv_deps
|
11
|
+
|
12
|
+
if sys.version_info >= (3, 11):
|
13
|
+
import tomllib
|
14
|
+
else:
|
15
|
+
import tomli as tomllib
|
16
|
+
|
17
|
+
logging.basicConfig()
|
18
|
+
logger = logging.getLogger("stores.indexes.local_index")
|
19
|
+
logger.setLevel(logging.INFO)
|
20
|
+
|
21
|
+
|
22
|
+
class LocalIndex(BaseIndex):
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
index_folder: os.PathLike,
|
26
|
+
create_venv: bool = False,
|
27
|
+
env_var: dict | None = None,
|
28
|
+
):
|
29
|
+
self.index_folder = Path(index_folder)
|
30
|
+
self.env_var = env_var or {}
|
31
|
+
|
32
|
+
if not self.index_folder.exists():
|
33
|
+
raise ValueError(
|
34
|
+
f"Unable to load index - {self.index_folder} does not exist"
|
35
|
+
)
|
36
|
+
|
37
|
+
if create_venv:
|
38
|
+
# Create venv and install deps
|
39
|
+
self.venv = self.index_folder / VENV_NAME
|
40
|
+
if not self.venv.exists():
|
41
|
+
venv.create(self.venv, symlinks=True, with_pip=True, upgrade_deps=True)
|
42
|
+
install_venv_deps(self.index_folder)
|
43
|
+
# Initialize tools
|
44
|
+
tools = init_venv_tools(self.index_folder, self.env_var)
|
45
|
+
else:
|
46
|
+
if self.env_var:
|
47
|
+
raise ValueError(
|
48
|
+
"Environment variables will only be restricted if create_venv=True when initializing LocalIndex"
|
49
|
+
)
|
50
|
+
tools = self._init_tools()
|
51
|
+
super().__init__(tools)
|
52
|
+
|
53
|
+
def _init_tools(self):
|
54
|
+
"""
|
55
|
+
Load local tools.toml file and import tool functions
|
56
|
+
|
57
|
+
NOTE: Can we just add index_folder to sys.path and import the functions?
|
58
|
+
"""
|
59
|
+
index_manifest = self.index_folder / TOOLS_CONFIG_FILENAME
|
60
|
+
if not index_manifest.exists():
|
61
|
+
raise ValueError(f"Unable to load index - {index_manifest} does not exist")
|
62
|
+
|
63
|
+
with open(index_manifest, "rb") as file:
|
64
|
+
manifest = tomllib.load(file)["index"]
|
65
|
+
|
66
|
+
tools = []
|
67
|
+
for tool_id in manifest.get("tools", []):
|
68
|
+
module_name = ".".join(tool_id.split(".")[:-1])
|
69
|
+
tool_name = tool_id.split(".")[-1]
|
70
|
+
|
71
|
+
module_file = self.index_folder / module_name.replace(".", "/")
|
72
|
+
if (module_file / "__init__.py").exists():
|
73
|
+
module_file = module_file / "__init__.py"
|
74
|
+
else:
|
75
|
+
module_file = Path(str(module_file) + ".py")
|
76
|
+
|
77
|
+
spec = importlib.util.spec_from_file_location(module_name, module_file)
|
78
|
+
module = importlib.util.module_from_spec(spec)
|
79
|
+
sys.modules[spec.name] = module
|
80
|
+
spec.loader.exec_module(module)
|
81
|
+
tool = getattr(module, tool_name)
|
82
|
+
tool.__name__ = tool_id
|
83
|
+
tools.append(tool)
|
84
|
+
return tools
|