arcade-core 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arcade_core/__init__.py +2 -0
- arcade_core/annotations.py +8 -0
- arcade_core/auth.py +177 -0
- arcade_core/catalog.py +894 -0
- arcade_core/config.py +23 -0
- arcade_core/config_model.py +146 -0
- arcade_core/errors.py +103 -0
- arcade_core/executor.py +129 -0
- arcade_core/output.py +64 -0
- arcade_core/parse.py +63 -0
- arcade_core/py.typed +0 -0
- arcade_core/schema.py +441 -0
- arcade_core/telemetry.py +130 -0
- arcade_core/toolkit.py +155 -0
- arcade_core/utils.py +99 -0
- arcade_core/version.py +1 -0
- arcade_core-2.0.0.dist-info/METADATA +77 -0
- arcade_core-2.0.0.dist-info/RECORD +19 -0
- arcade_core-2.0.0.dist-info/WHEEL +4 -0
arcade_core/catalog.py
ADDED
|
@@ -0,0 +1,894 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import inspect
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import typing
|
|
7
|
+
from collections.abc import Iterator
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from enum import Enum
|
|
11
|
+
from importlib import import_module
|
|
12
|
+
from types import ModuleType
|
|
13
|
+
from typing import (
|
|
14
|
+
Annotated,
|
|
15
|
+
Any,
|
|
16
|
+
Callable,
|
|
17
|
+
Literal,
|
|
18
|
+
Optional,
|
|
19
|
+
Union,
|
|
20
|
+
cast,
|
|
21
|
+
get_args,
|
|
22
|
+
get_origin,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from pydantic import BaseModel, Field, create_model
|
|
26
|
+
from pydantic.fields import FieldInfo
|
|
27
|
+
from pydantic_core import PydanticUndefined
|
|
28
|
+
|
|
29
|
+
from arcade_core.annotations import Inferrable
|
|
30
|
+
from arcade_core.auth import OAuth2, ToolAuthorization
|
|
31
|
+
from arcade_core.errors import ToolDefinitionError
|
|
32
|
+
from arcade_core.schema import (
|
|
33
|
+
TOOL_NAME_SEPARATOR,
|
|
34
|
+
FullyQualifiedName,
|
|
35
|
+
InputParameter,
|
|
36
|
+
OAuth2Requirement,
|
|
37
|
+
ToolAuthRequirement,
|
|
38
|
+
ToolContext,
|
|
39
|
+
ToolDefinition,
|
|
40
|
+
ToolInput,
|
|
41
|
+
ToolkitDefinition,
|
|
42
|
+
ToolMetadataKey,
|
|
43
|
+
ToolMetadataRequirement,
|
|
44
|
+
ToolOutput,
|
|
45
|
+
ToolRequirements,
|
|
46
|
+
ToolSecretRequirement,
|
|
47
|
+
ValueSchema,
|
|
48
|
+
)
|
|
49
|
+
from arcade_core.toolkit import Toolkit
|
|
50
|
+
from arcade_core.utils import (
|
|
51
|
+
does_function_return_value,
|
|
52
|
+
first_or_none,
|
|
53
|
+
is_strict_optional,
|
|
54
|
+
is_string_literal,
|
|
55
|
+
is_union,
|
|
56
|
+
snake_to_pascal_case,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
logger = logging.getLogger(__name__)
|
|
60
|
+
|
|
61
|
+
InnerWireType = Literal["string", "integer", "number", "boolean", "json"]
|
|
62
|
+
WireType = Union[InnerWireType, Literal["array"]]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class WireTypeInfo:
|
|
67
|
+
"""
|
|
68
|
+
Represents the wire type information for a value, including its inner type if it's a list.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
wire_type: WireType
|
|
72
|
+
inner_wire_type: InnerWireType | None = None
|
|
73
|
+
enum_values: list[str] | None = None
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class ToolMeta(BaseModel):
|
|
77
|
+
"""
|
|
78
|
+
Metadata for a tool once it's been materialized.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
module: str
|
|
82
|
+
toolkit: Optional[str] = None
|
|
83
|
+
package: Optional[str] = None
|
|
84
|
+
path: Optional[str] = None
|
|
85
|
+
date_added: datetime = Field(default_factory=datetime.now)
|
|
86
|
+
date_updated: datetime = Field(default_factory=datetime.now)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class MaterializedTool(BaseModel):
|
|
90
|
+
"""
|
|
91
|
+
Data structure that holds tool information while stored in the Catalog
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
tool: Callable
|
|
95
|
+
definition: ToolDefinition
|
|
96
|
+
meta: ToolMeta
|
|
97
|
+
|
|
98
|
+
# Thought (Sam): Should generate create these from ToolDefinition?
|
|
99
|
+
input_model: type[BaseModel]
|
|
100
|
+
output_model: type[BaseModel]
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def name(self) -> str:
|
|
104
|
+
return self.definition.name
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def version(self) -> str | None:
|
|
108
|
+
return self.definition.toolkit.version
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def description(self) -> str:
|
|
112
|
+
return self.definition.description
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def requires_auth(self) -> bool:
|
|
116
|
+
return self.definition.requirements.authorization is not None
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class ToolCatalog(BaseModel):
|
|
120
|
+
"""Singleton class that holds all tools for a given worker"""
|
|
121
|
+
|
|
122
|
+
_tools: dict[FullyQualifiedName, MaterializedTool] = {}
|
|
123
|
+
|
|
124
|
+
_disabled_tools: set[str] = set()
|
|
125
|
+
_disabled_toolkits: set[str] = set()
|
|
126
|
+
|
|
127
|
+
def __init__(self, **data) -> None: # type: ignore[no-untyped-def]
|
|
128
|
+
super().__init__(**data)
|
|
129
|
+
self._load_disabled_tools()
|
|
130
|
+
self._load_disabled_toolkits()
|
|
131
|
+
|
|
132
|
+
def _load_disabled_tools(self) -> None:
|
|
133
|
+
"""Load disabled tools from the environment variable.
|
|
134
|
+
|
|
135
|
+
The ARCADE_DISABLED_TOOLS environment variable should contain a
|
|
136
|
+
comma-separated list of tools that are to be excluded from the
|
|
137
|
+
catalog.
|
|
138
|
+
|
|
139
|
+
The expected format for each disabled tool is:
|
|
140
|
+
- [CamelCaseToolkitName][TOOL_NAME_SEPARATOR][CamelCaseToolName]
|
|
141
|
+
"""
|
|
142
|
+
disabled_tools = os.getenv("ARCADE_DISABLED_TOOLS", "").strip().split(",")
|
|
143
|
+
if not disabled_tools:
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
pattern = re.compile(rf"^[a-zA-Z]+{re.escape(TOOL_NAME_SEPARATOR)}[a-zA-Z]+$")
|
|
147
|
+
|
|
148
|
+
for tool in disabled_tools:
|
|
149
|
+
if not pattern.match(tool):
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
self._disabled_tools.add(tool.lower())
|
|
153
|
+
|
|
154
|
+
def _load_disabled_toolkits(self) -> None:
|
|
155
|
+
"""Load disabled toolkits from the environment variable.
|
|
156
|
+
|
|
157
|
+
The ARCADE_DISABLED_TOOLKITS environment variable should contain a
|
|
158
|
+
comma-separated list of toolkits that are to be excluded from the
|
|
159
|
+
catalog.
|
|
160
|
+
|
|
161
|
+
The expected format for each disabled toolkit is:
|
|
162
|
+
- [CamelCaseToolkitName]
|
|
163
|
+
"""
|
|
164
|
+
disabled_toolkits = os.getenv("ARCADE_DISABLED_TOOLKITS", "").strip().split(",")
|
|
165
|
+
if not disabled_toolkits:
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
for toolkit in disabled_toolkits:
|
|
169
|
+
self._disabled_toolkits.add(toolkit.lower())
|
|
170
|
+
|
|
171
|
+
def add_tool(
|
|
172
|
+
self,
|
|
173
|
+
tool_func: Callable,
|
|
174
|
+
toolkit_or_name: Union[str, Toolkit],
|
|
175
|
+
module: ModuleType | None = None,
|
|
176
|
+
) -> None:
|
|
177
|
+
"""
|
|
178
|
+
Add a function to the catalog as a tool.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
input_model, output_model = create_func_models(tool_func)
|
|
182
|
+
|
|
183
|
+
if isinstance(toolkit_or_name, Toolkit):
|
|
184
|
+
toolkit = toolkit_or_name
|
|
185
|
+
toolkit_name = toolkit.name
|
|
186
|
+
elif isinstance(toolkit_or_name, str):
|
|
187
|
+
toolkit = None
|
|
188
|
+
toolkit_name = toolkit_or_name
|
|
189
|
+
|
|
190
|
+
if not toolkit_name:
|
|
191
|
+
raise ValueError("A toolkit name or toolkit must be provided.")
|
|
192
|
+
|
|
193
|
+
definition = ToolCatalog.create_tool_definition(
|
|
194
|
+
tool_func,
|
|
195
|
+
toolkit_name,
|
|
196
|
+
toolkit.version if toolkit else None,
|
|
197
|
+
toolkit.description if toolkit else None,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
fully_qualified_name = definition.get_fully_qualified_name()
|
|
201
|
+
|
|
202
|
+
if fully_qualified_name in self._tools:
|
|
203
|
+
raise KeyError(f"Tool '{definition.name}' already exists in the catalog.")
|
|
204
|
+
|
|
205
|
+
if str(fully_qualified_name).lower() in self._disabled_tools:
|
|
206
|
+
logger.info(f"Tool '{fully_qualified_name!s}' is disabled and will not be cataloged.")
|
|
207
|
+
return
|
|
208
|
+
|
|
209
|
+
if str(toolkit_name).lower() in self._disabled_toolkits:
|
|
210
|
+
logger.info(f"Toolkit '{toolkit_name!s}' is disabled and will not be cataloged.")
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
self._tools[fully_qualified_name] = MaterializedTool(
|
|
214
|
+
definition=definition,
|
|
215
|
+
tool=tool_func,
|
|
216
|
+
meta=ToolMeta(
|
|
217
|
+
module=module.__name__ if module else tool_func.__module__,
|
|
218
|
+
toolkit=toolkit_name,
|
|
219
|
+
package=toolkit.package_name if toolkit else None,
|
|
220
|
+
path=module.__file__ if module else None,
|
|
221
|
+
),
|
|
222
|
+
input_model=input_model,
|
|
223
|
+
output_model=output_model,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def add_module(self, module: ModuleType) -> None:
|
|
227
|
+
"""
|
|
228
|
+
Add all the tools in a module to the catalog.
|
|
229
|
+
"""
|
|
230
|
+
toolkit = Toolkit.from_module(module)
|
|
231
|
+
self.add_toolkit(toolkit)
|
|
232
|
+
|
|
233
|
+
def add_toolkit(self, toolkit: Toolkit) -> None:
|
|
234
|
+
"""
|
|
235
|
+
Add the tools from a loaded toolkit to the catalog.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
if str(toolkit).lower() in self._disabled_toolkits:
|
|
239
|
+
logger.info(f"Toolkit '{toolkit.name!s}' is disabled and will not be cataloged.")
|
|
240
|
+
return
|
|
241
|
+
|
|
242
|
+
for module_name, tool_names in toolkit.tools.items():
|
|
243
|
+
for tool_name in tool_names:
|
|
244
|
+
try:
|
|
245
|
+
module = import_module(module_name)
|
|
246
|
+
tool_func = getattr(module, tool_name)
|
|
247
|
+
self.add_tool(tool_func, toolkit, module)
|
|
248
|
+
|
|
249
|
+
except AttributeError as e:
|
|
250
|
+
raise ToolDefinitionError(
|
|
251
|
+
f"Could not import tool {tool_name} in module {module_name}. Reason: {e}"
|
|
252
|
+
)
|
|
253
|
+
except ImportError as e:
|
|
254
|
+
raise ToolDefinitionError(f"Could not import module {module_name}. Reason: {e}")
|
|
255
|
+
except TypeError as e:
|
|
256
|
+
raise ToolDefinitionError(
|
|
257
|
+
f"Type error encountered while adding tool {tool_name} from {module_name}. Reason: {e}"
|
|
258
|
+
)
|
|
259
|
+
except Exception as e:
|
|
260
|
+
raise ToolDefinitionError(
|
|
261
|
+
f"Error encountered while adding tool {tool_name} from {module_name}. Reason: {e}"
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
def __getitem__(self, name: FullyQualifiedName) -> MaterializedTool:
|
|
265
|
+
return self.get_tool(name)
|
|
266
|
+
|
|
267
|
+
def __contains__(self, name: FullyQualifiedName) -> bool:
|
|
268
|
+
return name in self._tools
|
|
269
|
+
|
|
270
|
+
def __iter__(self) -> Iterator[MaterializedTool]: # type: ignore[override]
|
|
271
|
+
yield from self._tools.values()
|
|
272
|
+
|
|
273
|
+
def __len__(self) -> int:
|
|
274
|
+
return len(self._tools)
|
|
275
|
+
|
|
276
|
+
def is_empty(self) -> bool:
|
|
277
|
+
return len(self._tools) == 0
|
|
278
|
+
|
|
279
|
+
def get_tool_names(self) -> list[FullyQualifiedName]:
|
|
280
|
+
return [tool.definition.get_fully_qualified_name() for tool in self._tools.values()]
|
|
281
|
+
|
|
282
|
+
def find_tool_by_func(self, func: Callable) -> ToolDefinition:
|
|
283
|
+
"""
|
|
284
|
+
Find a tool by its function.
|
|
285
|
+
"""
|
|
286
|
+
for _, tool in self._tools.items():
|
|
287
|
+
if tool.tool == func:
|
|
288
|
+
return tool.definition
|
|
289
|
+
raise ValueError(f"Tool {func} not found in the catalog.")
|
|
290
|
+
|
|
291
|
+
def get_tool_by_name(
|
|
292
|
+
self, name: str, version: Optional[str] = None, separator: str = TOOL_NAME_SEPARATOR
|
|
293
|
+
) -> MaterializedTool:
|
|
294
|
+
"""Get a tool from the catalog by name.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
name: The name of the tool, potentially including the toolkit name separated by the `separator`.
|
|
298
|
+
version: The version of the toolkit. Defaults to None.
|
|
299
|
+
separator: The separator between toolkit and tool names. Defaults to `TOOL_NAME_SEPARATOR`.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
MaterializedTool: The matching tool from the catalog.
|
|
303
|
+
|
|
304
|
+
Raises:
|
|
305
|
+
ValueError: If the tool is not found in the catalog.
|
|
306
|
+
"""
|
|
307
|
+
if separator in name:
|
|
308
|
+
toolkit_name, tool_name = name.split(separator, 1)
|
|
309
|
+
fq_name = FullyQualifiedName(
|
|
310
|
+
name=tool_name, toolkit_name=toolkit_name, toolkit_version=version
|
|
311
|
+
)
|
|
312
|
+
return self.get_tool(fq_name)
|
|
313
|
+
else:
|
|
314
|
+
# No toolkit name provided, search tools with matching tool name
|
|
315
|
+
matching_tools = [
|
|
316
|
+
tool
|
|
317
|
+
for fq_name, tool in self._tools.items()
|
|
318
|
+
if fq_name.name.lower() == name.lower()
|
|
319
|
+
and (
|
|
320
|
+
version is None
|
|
321
|
+
or (fq_name.toolkit_version or "").lower() == (version or "").lower()
|
|
322
|
+
)
|
|
323
|
+
]
|
|
324
|
+
if matching_tools:
|
|
325
|
+
return matching_tools[0]
|
|
326
|
+
|
|
327
|
+
raise ValueError(f"Tool {name} not found in the catalog.")
|
|
328
|
+
|
|
329
|
+
def get_tool(self, name: FullyQualifiedName) -> MaterializedTool:
|
|
330
|
+
"""
|
|
331
|
+
Get a tool from the catalog by fully-qualified name and version.
|
|
332
|
+
If the version is not specified, the any version is returned.
|
|
333
|
+
"""
|
|
334
|
+
if name.toolkit_version:
|
|
335
|
+
try:
|
|
336
|
+
return self._tools[name]
|
|
337
|
+
except KeyError:
|
|
338
|
+
raise ValueError(f"Tool {name}@{name.toolkit_version} not found in the catalog.")
|
|
339
|
+
|
|
340
|
+
for key, tool in self._tools.items():
|
|
341
|
+
if key.equals_ignoring_version(name):
|
|
342
|
+
return tool
|
|
343
|
+
|
|
344
|
+
raise ValueError(f"Tool {name} not found.")
|
|
345
|
+
|
|
346
|
+
def get_tool_count(self) -> int:
|
|
347
|
+
"""
|
|
348
|
+
Get the number of tools in the catalog.
|
|
349
|
+
"""
|
|
350
|
+
return len(self._tools)
|
|
351
|
+
|
|
352
|
+
@staticmethod
|
|
353
|
+
def create_tool_definition(
|
|
354
|
+
tool: Callable,
|
|
355
|
+
toolkit_name: str,
|
|
356
|
+
toolkit_version: Optional[str] = None,
|
|
357
|
+
toolkit_desc: Optional[str] = None,
|
|
358
|
+
) -> ToolDefinition:
|
|
359
|
+
"""
|
|
360
|
+
Given a tool function, create a ToolDefinition
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
raw_tool_name = getattr(tool, "__tool_name__", tool.__name__)
|
|
364
|
+
|
|
365
|
+
# Hard requirement: tools must have descriptions
|
|
366
|
+
tool_description = getattr(tool, "__tool_description__", None)
|
|
367
|
+
if not tool_description:
|
|
368
|
+
raise ToolDefinitionError(f"Tool {raw_tool_name} is missing a description")
|
|
369
|
+
|
|
370
|
+
# If the function returns a value, it must have a type annotation
|
|
371
|
+
if does_function_return_value(tool) and tool.__annotations__.get("return") is None:
|
|
372
|
+
raise ToolDefinitionError(f"Tool {raw_tool_name} must have a return type annotation")
|
|
373
|
+
|
|
374
|
+
auth_requirement = create_auth_requirement(tool)
|
|
375
|
+
secrets_requirement = create_secrets_requirement(tool)
|
|
376
|
+
metadata_requirement = create_metadata_requirement(tool, auth_requirement)
|
|
377
|
+
|
|
378
|
+
toolkit_definition = ToolkitDefinition(
|
|
379
|
+
name=snake_to_pascal_case(toolkit_name),
|
|
380
|
+
description=toolkit_desc,
|
|
381
|
+
version=toolkit_version,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
tool_name = snake_to_pascal_case(raw_tool_name)
|
|
385
|
+
fully_qualified_name = FullyQualifiedName.from_toolkit(tool_name, toolkit_definition)
|
|
386
|
+
deprecation_message = getattr(tool, "__tool_deprecation_message__", None)
|
|
387
|
+
|
|
388
|
+
return ToolDefinition(
|
|
389
|
+
name=tool_name,
|
|
390
|
+
fully_qualified_name=str(fully_qualified_name),
|
|
391
|
+
description=tool_description,
|
|
392
|
+
toolkit=toolkit_definition,
|
|
393
|
+
input=create_input_definition(tool),
|
|
394
|
+
output=create_output_definition(tool),
|
|
395
|
+
requirements=ToolRequirements(
|
|
396
|
+
authorization=auth_requirement,
|
|
397
|
+
secrets=secrets_requirement,
|
|
398
|
+
metadata=metadata_requirement,
|
|
399
|
+
),
|
|
400
|
+
deprecation_message=deprecation_message,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def create_input_definition(func: Callable) -> ToolInput:
|
|
405
|
+
"""
|
|
406
|
+
Create an input model for a function based on its parameters.
|
|
407
|
+
"""
|
|
408
|
+
input_parameters = []
|
|
409
|
+
tool_context_param_name: str | None = None
|
|
410
|
+
|
|
411
|
+
for _, param in inspect.signature(func, follow_wrapped=True).parameters.items():
|
|
412
|
+
if param.annotation is ToolContext:
|
|
413
|
+
if tool_context_param_name is not None:
|
|
414
|
+
raise ToolDefinitionError(
|
|
415
|
+
f"Only one ToolContext parameter is supported, but tool {func.__name__} has multiple."
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
tool_context_param_name = param.name
|
|
419
|
+
continue # No further processing of this param (don't add it to the list of inputs)
|
|
420
|
+
|
|
421
|
+
tool_field_info = extract_field_info(param)
|
|
422
|
+
|
|
423
|
+
# If the field has a default value, it is not required
|
|
424
|
+
# If the field is optional, it is not required
|
|
425
|
+
has_default_value = tool_field_info.default is not None
|
|
426
|
+
is_required = not tool_field_info.is_optional and not has_default_value
|
|
427
|
+
|
|
428
|
+
input_parameters.append(
|
|
429
|
+
InputParameter(
|
|
430
|
+
name=tool_field_info.name,
|
|
431
|
+
description=tool_field_info.description,
|
|
432
|
+
required=is_required,
|
|
433
|
+
inferrable=tool_field_info.is_inferrable,
|
|
434
|
+
value_schema=ValueSchema(
|
|
435
|
+
val_type=tool_field_info.wire_type_info.wire_type,
|
|
436
|
+
inner_val_type=tool_field_info.wire_type_info.inner_wire_type,
|
|
437
|
+
enum=tool_field_info.wire_type_info.enum_values,
|
|
438
|
+
),
|
|
439
|
+
)
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
return ToolInput(
|
|
443
|
+
parameters=input_parameters, tool_context_parameter_name=tool_context_param_name
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def create_output_definition(func: Callable) -> ToolOutput:
|
|
448
|
+
"""
|
|
449
|
+
Create an output model for a function based on its return annotation.
|
|
450
|
+
"""
|
|
451
|
+
return_type = inspect.signature(func, follow_wrapped=True).return_annotation
|
|
452
|
+
description = "No description provided."
|
|
453
|
+
|
|
454
|
+
if return_type is inspect.Signature.empty:
|
|
455
|
+
return ToolOutput(
|
|
456
|
+
value_schema=None,
|
|
457
|
+
description="No description provided.",
|
|
458
|
+
available_modes=["null"],
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
if hasattr(return_type, "__metadata__"):
|
|
462
|
+
description = return_type.__metadata__[0] if return_type.__metadata__ else None # type: ignore[assignment]
|
|
463
|
+
return_type = return_type.__origin__
|
|
464
|
+
|
|
465
|
+
# Unwrap Optional types
|
|
466
|
+
# Both Optional[T] and T | None are supported
|
|
467
|
+
is_optional = is_strict_optional(return_type)
|
|
468
|
+
if is_optional:
|
|
469
|
+
return_type = next(arg for arg in get_args(return_type) if arg is not type(None))
|
|
470
|
+
|
|
471
|
+
wire_type_info = get_wire_type_info(return_type)
|
|
472
|
+
|
|
473
|
+
available_modes = ["value", "error"]
|
|
474
|
+
|
|
475
|
+
if is_optional:
|
|
476
|
+
available_modes.append("null")
|
|
477
|
+
|
|
478
|
+
return ToolOutput(
|
|
479
|
+
description=description,
|
|
480
|
+
available_modes=available_modes,
|
|
481
|
+
value_schema=ValueSchema(
|
|
482
|
+
val_type=wire_type_info.wire_type,
|
|
483
|
+
inner_val_type=wire_type_info.inner_wire_type,
|
|
484
|
+
enum=wire_type_info.enum_values,
|
|
485
|
+
),
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def create_auth_requirement(tool: Callable) -> ToolAuthRequirement | None:
|
|
490
|
+
"""
|
|
491
|
+
Create an auth requirement for a tool.
|
|
492
|
+
"""
|
|
493
|
+
auth_requirement = getattr(tool, "__tool_requires_auth__", None)
|
|
494
|
+
if isinstance(auth_requirement, ToolAuthorization):
|
|
495
|
+
new_auth_requirement = ToolAuthRequirement(
|
|
496
|
+
provider_id=auth_requirement.provider_id,
|
|
497
|
+
provider_type=auth_requirement.provider_type,
|
|
498
|
+
id=auth_requirement.id,
|
|
499
|
+
)
|
|
500
|
+
if isinstance(auth_requirement, OAuth2):
|
|
501
|
+
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
|
|
502
|
+
auth_requirement = new_auth_requirement
|
|
503
|
+
|
|
504
|
+
return auth_requirement
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def create_secrets_requirement(tool: Callable) -> list[ToolSecretRequirement] | None:
|
|
508
|
+
"""
|
|
509
|
+
Create a secrets requirement for a tool.
|
|
510
|
+
"""
|
|
511
|
+
raw_tool_name = getattr(tool, "__tool_name__", tool.__name__)
|
|
512
|
+
secrets_requirement = getattr(tool, "__tool_requires_secrets__", None)
|
|
513
|
+
if isinstance(secrets_requirement, list):
|
|
514
|
+
if any(not isinstance(secret, str) for secret in secrets_requirement):
|
|
515
|
+
raise ToolDefinitionError(
|
|
516
|
+
f"Secret keys must be strings (error in tool {raw_tool_name})."
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
secrets_requirement = to_tool_secret_requirements(secrets_requirement)
|
|
520
|
+
if any(secret.key is None or secret.key.strip() == "" for secret in secrets_requirement):
|
|
521
|
+
raise ToolDefinitionError(
|
|
522
|
+
f"Secrets must have a non-empty key (error in tool {raw_tool_name})."
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
return secrets_requirement
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def create_metadata_requirement(
|
|
529
|
+
tool: Callable, auth_requirement: ToolAuthRequirement | None
|
|
530
|
+
) -> list[ToolMetadataRequirement] | None:
|
|
531
|
+
"""
|
|
532
|
+
Create a metadata requirement for a tool.
|
|
533
|
+
"""
|
|
534
|
+
raw_tool_name = getattr(tool, "__tool_name__", tool.__name__)
|
|
535
|
+
metadata_requirement = getattr(tool, "__tool_requires_metadata__", None)
|
|
536
|
+
if isinstance(metadata_requirement, list):
|
|
537
|
+
for metadata in metadata_requirement:
|
|
538
|
+
if not isinstance(metadata, str):
|
|
539
|
+
raise ToolDefinitionError(
|
|
540
|
+
f"Metadata must be strings (error in tool {raw_tool_name})."
|
|
541
|
+
)
|
|
542
|
+
if ToolMetadataKey.requires_auth(metadata) and auth_requirement is None:
|
|
543
|
+
raise ToolDefinitionError(
|
|
544
|
+
f"Tool {raw_tool_name} declares metadata key '{metadata}', "
|
|
545
|
+
"which requires that the tool has an auth requirement, "
|
|
546
|
+
"but no auth requirement was provided. Please specify an auth requirement."
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
metadata_requirement = to_tool_metadata_requirements(metadata_requirement)
|
|
550
|
+
if any(
|
|
551
|
+
metadata.key is None or metadata.key.strip() == "" for metadata in metadata_requirement
|
|
552
|
+
):
|
|
553
|
+
raise ToolDefinitionError(
|
|
554
|
+
f"Metadata must have a non-empty key (error in tool {raw_tool_name})."
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
return metadata_requirement
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
@dataclass
|
|
561
|
+
class ParamInfo:
|
|
562
|
+
"""
|
|
563
|
+
Information about a function parameter found through inspection.
|
|
564
|
+
"""
|
|
565
|
+
|
|
566
|
+
name: str
|
|
567
|
+
default: Any
|
|
568
|
+
original_type: type
|
|
569
|
+
field_type: type
|
|
570
|
+
description: str | None = None
|
|
571
|
+
is_optional: bool = True
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
@dataclass
|
|
575
|
+
class ToolParamInfo:
|
|
576
|
+
"""
|
|
577
|
+
Information about a tool parameter, including computed values.
|
|
578
|
+
"""
|
|
579
|
+
|
|
580
|
+
name: str
|
|
581
|
+
default: Any
|
|
582
|
+
original_type: type
|
|
583
|
+
field_type: type
|
|
584
|
+
wire_type_info: WireTypeInfo
|
|
585
|
+
description: str | None = None
|
|
586
|
+
is_optional: bool = True
|
|
587
|
+
is_inferrable: bool = True
|
|
588
|
+
|
|
589
|
+
@classmethod
|
|
590
|
+
def from_param_info(
|
|
591
|
+
cls,
|
|
592
|
+
param_info: ParamInfo,
|
|
593
|
+
wire_type_info: WireTypeInfo,
|
|
594
|
+
is_inferrable: bool = True,
|
|
595
|
+
) -> "ToolParamInfo":
|
|
596
|
+
return cls(
|
|
597
|
+
name=param_info.name,
|
|
598
|
+
default=param_info.default,
|
|
599
|
+
original_type=param_info.original_type,
|
|
600
|
+
field_type=param_info.field_type,
|
|
601
|
+
description=param_info.description,
|
|
602
|
+
is_optional=param_info.is_optional,
|
|
603
|
+
wire_type_info=wire_type_info,
|
|
604
|
+
is_inferrable=is_inferrable,
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
def extract_field_info(param: inspect.Parameter) -> ToolParamInfo:
|
|
609
|
+
"""
|
|
610
|
+
Extract type and field parameters from a function parameter.
|
|
611
|
+
"""
|
|
612
|
+
annotation = param.annotation
|
|
613
|
+
if annotation == inspect.Parameter.empty:
|
|
614
|
+
raise ToolDefinitionError(f"Parameter {param} has no type annotation.")
|
|
615
|
+
|
|
616
|
+
# Get the majority of the param info from either the Pydantic Field() or regular inspection
|
|
617
|
+
if isinstance(param.default, FieldInfo):
|
|
618
|
+
param_info = extract_pydantic_param_info(param)
|
|
619
|
+
else:
|
|
620
|
+
param_info = extract_python_param_info(param)
|
|
621
|
+
|
|
622
|
+
metadata = getattr(annotation, "__metadata__", [])
|
|
623
|
+
str_annotations = [m for m in metadata if isinstance(m, str)]
|
|
624
|
+
|
|
625
|
+
# Get the description from annotations, if present
|
|
626
|
+
if len(str_annotations) == 0:
|
|
627
|
+
pass
|
|
628
|
+
elif len(str_annotations) == 1:
|
|
629
|
+
param_info.description = str_annotations[0]
|
|
630
|
+
elif len(str_annotations) == 2:
|
|
631
|
+
new_name = str_annotations[0]
|
|
632
|
+
if not new_name.isidentifier():
|
|
633
|
+
raise ToolDefinitionError(
|
|
634
|
+
f"Invalid parameter name: '{new_name}' is not a valid identifier. "
|
|
635
|
+
"Identifiers must start with a letter or underscore, "
|
|
636
|
+
"and can only contain letters, digits, or underscores."
|
|
637
|
+
)
|
|
638
|
+
param_info.name = new_name
|
|
639
|
+
param_info.description = str_annotations[1]
|
|
640
|
+
else:
|
|
641
|
+
raise ToolDefinitionError(
|
|
642
|
+
f"Parameter {param} has too many string annotations. Expected 0, 1, or 2, got {len(str_annotations)}."
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
# Get the Inferrable annotation, if it exists
|
|
646
|
+
inferrable_annotation = first_or_none(Inferrable, get_args(annotation))
|
|
647
|
+
|
|
648
|
+
# Params are inferrable by default
|
|
649
|
+
is_inferrable = inferrable_annotation.value if inferrable_annotation else True
|
|
650
|
+
|
|
651
|
+
# Get the wire (serialization) type information for the type
|
|
652
|
+
wire_type_info = get_wire_type_info(param_info.field_type)
|
|
653
|
+
|
|
654
|
+
# Final reality check
|
|
655
|
+
if param_info.description is None:
|
|
656
|
+
raise ToolDefinitionError(f"Parameter {param_info.name} is missing a description")
|
|
657
|
+
|
|
658
|
+
if wire_type_info.wire_type is None:
|
|
659
|
+
raise ToolDefinitionError(f"Unknown parameter type: {param_info.field_type}")
|
|
660
|
+
|
|
661
|
+
return ToolParamInfo.from_param_info(param_info, wire_type_info, is_inferrable)
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
def get_wire_type_info(_type: type) -> WireTypeInfo:
|
|
665
|
+
"""
|
|
666
|
+
Get the wire type information for a given type.
|
|
667
|
+
"""
|
|
668
|
+
|
|
669
|
+
# Is this a list type?
|
|
670
|
+
# If so, get the inner (enclosed) type
|
|
671
|
+
is_list = get_origin(_type) is list
|
|
672
|
+
if is_list:
|
|
673
|
+
inner_type = get_args(_type)[0]
|
|
674
|
+
inner_wire_type = cast(
|
|
675
|
+
InnerWireType,
|
|
676
|
+
get_wire_type(str) if is_string_literal(inner_type) else get_wire_type(inner_type),
|
|
677
|
+
)
|
|
678
|
+
else:
|
|
679
|
+
inner_wire_type = None
|
|
680
|
+
|
|
681
|
+
# Get the outer wire type
|
|
682
|
+
wire_type = get_wire_type(str) if is_string_literal(_type) else get_wire_type(_type)
|
|
683
|
+
|
|
684
|
+
# Handle enums (known/fixed lists of values)
|
|
685
|
+
is_enum = False
|
|
686
|
+
enum_values: list[str] = []
|
|
687
|
+
|
|
688
|
+
type_to_check = inner_type if is_list else _type
|
|
689
|
+
|
|
690
|
+
# Strip generic parameters if type_to_check is a parameterized generic
|
|
691
|
+
actual_type = get_origin(type_to_check) or type_to_check
|
|
692
|
+
|
|
693
|
+
# Special case: Literal["string1", "string2"] can be enumerated on the wire
|
|
694
|
+
if is_string_literal(type_to_check):
|
|
695
|
+
is_enum = True
|
|
696
|
+
enum_values = [str(e) for e in get_args(type_to_check)]
|
|
697
|
+
|
|
698
|
+
# Special case: Enum can be enumerated on the wire
|
|
699
|
+
elif issubclass(actual_type, Enum):
|
|
700
|
+
is_enum = True
|
|
701
|
+
enum_values = [e.value for e in actual_type] # type: ignore[union-attr]
|
|
702
|
+
|
|
703
|
+
return WireTypeInfo(wire_type, inner_wire_type, enum_values if is_enum else None)
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
def extract_python_param_info(param: inspect.Parameter) -> ParamInfo:
|
|
707
|
+
# If the param is Annotated[], unwrap the annotation to get the "real" type
|
|
708
|
+
# Otherwise, use the literal type
|
|
709
|
+
annotation = param.annotation
|
|
710
|
+
original_type = annotation.__args__[0] if get_origin(annotation) is Annotated else annotation
|
|
711
|
+
field_type = original_type
|
|
712
|
+
|
|
713
|
+
# Handle optional types
|
|
714
|
+
# Both Optional[T] and T | None are supported
|
|
715
|
+
is_optional = is_strict_optional(field_type)
|
|
716
|
+
if is_optional:
|
|
717
|
+
field_type = next(arg for arg in get_args(field_type) if arg is not type(None))
|
|
718
|
+
|
|
719
|
+
# Union types are not currently supported
|
|
720
|
+
# (other than optional, which is handled above)
|
|
721
|
+
if is_union(field_type):
|
|
722
|
+
raise ToolDefinitionError(
|
|
723
|
+
f"Parameter {param.name} is a union type. Only optional types are supported."
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
return ParamInfo(
|
|
727
|
+
name=param.name,
|
|
728
|
+
default=param.default if param.default is not inspect.Parameter.empty else None,
|
|
729
|
+
is_optional=is_optional,
|
|
730
|
+
original_type=original_type,
|
|
731
|
+
field_type=field_type,
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
|
|
735
|
+
def extract_pydantic_param_info(param: inspect.Parameter) -> ParamInfo:
|
|
736
|
+
default_value = None if param.default.default is PydanticUndefined else param.default.default
|
|
737
|
+
|
|
738
|
+
if param.default.default_factory is not None:
|
|
739
|
+
if callable(param.default.default_factory):
|
|
740
|
+
default_value = param.default.default_factory()
|
|
741
|
+
else:
|
|
742
|
+
raise ToolDefinitionError(f"Default factory for parameter {param} is not callable.")
|
|
743
|
+
|
|
744
|
+
# If the param is Annotated[], unwrap the annotation to get the "real" type
|
|
745
|
+
# Otherwise, use the literal type
|
|
746
|
+
original_type = (
|
|
747
|
+
param.annotation.__args__[0]
|
|
748
|
+
if get_origin(param.annotation) is Annotated
|
|
749
|
+
else param.annotation
|
|
750
|
+
)
|
|
751
|
+
field_type = original_type
|
|
752
|
+
|
|
753
|
+
# Unwrap Optional types
|
|
754
|
+
# Both Optional[T] and T | None are supported
|
|
755
|
+
is_optional = is_strict_optional(field_type)
|
|
756
|
+
if is_optional:
|
|
757
|
+
field_type = next(arg for arg in get_args(field_type) if arg is not type(None))
|
|
758
|
+
|
|
759
|
+
return ParamInfo(
|
|
760
|
+
name=param.name,
|
|
761
|
+
description=param.default.description,
|
|
762
|
+
default=default_value,
|
|
763
|
+
is_optional=is_optional,
|
|
764
|
+
original_type=original_type,
|
|
765
|
+
field_type=field_type,
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
def get_wire_type(
|
|
770
|
+
_type: type,
|
|
771
|
+
) -> WireType:
|
|
772
|
+
"""
|
|
773
|
+
Mapping between Python types and HTTP/JSON types
|
|
774
|
+
"""
|
|
775
|
+
# TODO ensure Any is not allowed
|
|
776
|
+
type_mapping: dict[type, WireType] = {
|
|
777
|
+
str: "string",
|
|
778
|
+
bool: "boolean",
|
|
779
|
+
int: "integer",
|
|
780
|
+
float: "number",
|
|
781
|
+
dict: "json",
|
|
782
|
+
}
|
|
783
|
+
outer_type_mapping: dict[type, WireType] = {
|
|
784
|
+
list: "array",
|
|
785
|
+
dict: "json",
|
|
786
|
+
}
|
|
787
|
+
wire_type = type_mapping.get(_type)
|
|
788
|
+
if wire_type:
|
|
789
|
+
return wire_type
|
|
790
|
+
|
|
791
|
+
if hasattr(_type, "__origin__"):
|
|
792
|
+
wire_type = outer_type_mapping.get(cast(type, get_origin(_type)))
|
|
793
|
+
if wire_type:
|
|
794
|
+
return wire_type
|
|
795
|
+
|
|
796
|
+
if isinstance(_type, type) and issubclass(_type, Enum):
|
|
797
|
+
return "string"
|
|
798
|
+
|
|
799
|
+
if isinstance(_type, type) and issubclass(_type, BaseModel):
|
|
800
|
+
return "json"
|
|
801
|
+
|
|
802
|
+
raise ToolDefinitionError(f"Unsupported parameter type: {_type}")
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
def create_func_models(func: Callable) -> tuple[type[BaseModel], type[BaseModel]]:
|
|
806
|
+
"""
|
|
807
|
+
Analyze a function to create corresponding Pydantic models for its input and output.
|
|
808
|
+
"""
|
|
809
|
+
input_fields = {}
|
|
810
|
+
# TODO figure this out (Sam)
|
|
811
|
+
if asyncio.iscoroutinefunction(func) and hasattr(func, "__wrapped__"):
|
|
812
|
+
func = func.__wrapped__
|
|
813
|
+
for name, param in inspect.signature(func, follow_wrapped=True).parameters.items():
|
|
814
|
+
# Skip ToolContext parameters
|
|
815
|
+
if param.annotation is ToolContext:
|
|
816
|
+
continue
|
|
817
|
+
|
|
818
|
+
# TODO make this cleaner
|
|
819
|
+
tool_field_info = extract_field_info(param)
|
|
820
|
+
param_fields = {
|
|
821
|
+
"default": tool_field_info.default,
|
|
822
|
+
"description": tool_field_info.description,
|
|
823
|
+
# TODO more here?
|
|
824
|
+
}
|
|
825
|
+
input_fields[name] = (tool_field_info.field_type, Field(**param_fields))
|
|
826
|
+
|
|
827
|
+
input_model = create_model(f"{snake_to_pascal_case(func.__name__)}Input", **input_fields) # type: ignore[call-overload]
|
|
828
|
+
|
|
829
|
+
output_model = determine_output_model(func)
|
|
830
|
+
|
|
831
|
+
return input_model, output_model
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
def determine_output_model(func: Callable) -> type[BaseModel]:
|
|
835
|
+
"""
|
|
836
|
+
Determine the output model for a function based on its return annotation.
|
|
837
|
+
"""
|
|
838
|
+
return_annotation = inspect.signature(func).return_annotation
|
|
839
|
+
output_model_name = f"{snake_to_pascal_case(func.__name__)}Output"
|
|
840
|
+
if return_annotation is inspect.Signature.empty:
|
|
841
|
+
return create_model(output_model_name)
|
|
842
|
+
elif hasattr(return_annotation, "__origin__"):
|
|
843
|
+
if hasattr(return_annotation, "__metadata__"):
|
|
844
|
+
field_type = return_annotation.__args__[0]
|
|
845
|
+
description = (
|
|
846
|
+
return_annotation.__metadata__[0] if return_annotation.__metadata__ else ""
|
|
847
|
+
)
|
|
848
|
+
if description:
|
|
849
|
+
return create_model(
|
|
850
|
+
output_model_name,
|
|
851
|
+
result=(field_type, Field(description=str(description))),
|
|
852
|
+
)
|
|
853
|
+
# Handle Union types
|
|
854
|
+
origin = return_annotation.__origin__
|
|
855
|
+
if origin is typing.Union:
|
|
856
|
+
# For union types, create a model with the first non-None argument
|
|
857
|
+
# TODO handle multiple non-None arguments. Raise error?
|
|
858
|
+
for arg in get_args(return_annotation):
|
|
859
|
+
if arg is not type(None):
|
|
860
|
+
return create_model(
|
|
861
|
+
output_model_name,
|
|
862
|
+
result=(arg, Field(description="No description provided.")),
|
|
863
|
+
)
|
|
864
|
+
# when the return_annotation has an __origin__ attribute
|
|
865
|
+
# and does not have a __metadata__ attribute.
|
|
866
|
+
return create_model(
|
|
867
|
+
output_model_name,
|
|
868
|
+
result=(
|
|
869
|
+
return_annotation,
|
|
870
|
+
Field(description="No description provided."),
|
|
871
|
+
),
|
|
872
|
+
)
|
|
873
|
+
else:
|
|
874
|
+
# Handle simple return types (like str)
|
|
875
|
+
return create_model(
|
|
876
|
+
output_model_name,
|
|
877
|
+
result=(return_annotation, Field(description="No description provided.")),
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
def to_tool_secret_requirements(
|
|
882
|
+
secrets_requirement: list[str],
|
|
883
|
+
) -> list[ToolSecretRequirement]:
|
|
884
|
+
# Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolSecretRequirement
|
|
885
|
+
unique_secrets = {name.lower(): name.lower() for name in secrets_requirement}.values()
|
|
886
|
+
return [ToolSecretRequirement(key=name) for name in unique_secrets]
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
def to_tool_metadata_requirements(
|
|
890
|
+
metadata_requirement: list[str],
|
|
891
|
+
) -> list[ToolMetadataRequirement]:
|
|
892
|
+
# Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolMetadataRequirement
|
|
893
|
+
unique_metadata = {name.lower(): name.lower() for name in metadata_requirement}.values()
|
|
894
|
+
return [ToolMetadataRequirement(key=name) for name in unique_metadata]
|