arcade-core 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arcade_core/catalog.py ADDED
@@ -0,0 +1,894 @@
1
+ import asyncio
2
+ import inspect
3
+ import logging
4
+ import os
5
+ import re
6
+ import typing
7
+ from collections.abc import Iterator
8
+ from dataclasses import dataclass
9
+ from datetime import datetime
10
+ from enum import Enum
11
+ from importlib import import_module
12
+ from types import ModuleType
13
+ from typing import (
14
+ Annotated,
15
+ Any,
16
+ Callable,
17
+ Literal,
18
+ Optional,
19
+ Union,
20
+ cast,
21
+ get_args,
22
+ get_origin,
23
+ )
24
+
25
+ from pydantic import BaseModel, Field, create_model
26
+ from pydantic.fields import FieldInfo
27
+ from pydantic_core import PydanticUndefined
28
+
29
+ from arcade_core.annotations import Inferrable
30
+ from arcade_core.auth import OAuth2, ToolAuthorization
31
+ from arcade_core.errors import ToolDefinitionError
32
+ from arcade_core.schema import (
33
+ TOOL_NAME_SEPARATOR,
34
+ FullyQualifiedName,
35
+ InputParameter,
36
+ OAuth2Requirement,
37
+ ToolAuthRequirement,
38
+ ToolContext,
39
+ ToolDefinition,
40
+ ToolInput,
41
+ ToolkitDefinition,
42
+ ToolMetadataKey,
43
+ ToolMetadataRequirement,
44
+ ToolOutput,
45
+ ToolRequirements,
46
+ ToolSecretRequirement,
47
+ ValueSchema,
48
+ )
49
+ from arcade_core.toolkit import Toolkit
50
+ from arcade_core.utils import (
51
+ does_function_return_value,
52
+ first_or_none,
53
+ is_strict_optional,
54
+ is_string_literal,
55
+ is_union,
56
+ snake_to_pascal_case,
57
+ )
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ InnerWireType = Literal["string", "integer", "number", "boolean", "json"]
62
+ WireType = Union[InnerWireType, Literal["array"]]
63
+
64
+
65
+ @dataclass
66
+ class WireTypeInfo:
67
+ """
68
+ Represents the wire type information for a value, including its inner type if it's a list.
69
+ """
70
+
71
+ wire_type: WireType
72
+ inner_wire_type: InnerWireType | None = None
73
+ enum_values: list[str] | None = None
74
+
75
+
76
+ class ToolMeta(BaseModel):
77
+ """
78
+ Metadata for a tool once it's been materialized.
79
+ """
80
+
81
+ module: str
82
+ toolkit: Optional[str] = None
83
+ package: Optional[str] = None
84
+ path: Optional[str] = None
85
+ date_added: datetime = Field(default_factory=datetime.now)
86
+ date_updated: datetime = Field(default_factory=datetime.now)
87
+
88
+
89
+ class MaterializedTool(BaseModel):
90
+ """
91
+ Data structure that holds tool information while stored in the Catalog
92
+ """
93
+
94
+ tool: Callable
95
+ definition: ToolDefinition
96
+ meta: ToolMeta
97
+
98
+ # Thought (Sam): Should generate create these from ToolDefinition?
99
+ input_model: type[BaseModel]
100
+ output_model: type[BaseModel]
101
+
102
+ @property
103
+ def name(self) -> str:
104
+ return self.definition.name
105
+
106
+ @property
107
+ def version(self) -> str | None:
108
+ return self.definition.toolkit.version
109
+
110
+ @property
111
+ def description(self) -> str:
112
+ return self.definition.description
113
+
114
+ @property
115
+ def requires_auth(self) -> bool:
116
+ return self.definition.requirements.authorization is not None
117
+
118
+
119
+ class ToolCatalog(BaseModel):
120
+ """Singleton class that holds all tools for a given worker"""
121
+
122
+ _tools: dict[FullyQualifiedName, MaterializedTool] = {}
123
+
124
+ _disabled_tools: set[str] = set()
125
+ _disabled_toolkits: set[str] = set()
126
+
127
+ def __init__(self, **data) -> None: # type: ignore[no-untyped-def]
128
+ super().__init__(**data)
129
+ self._load_disabled_tools()
130
+ self._load_disabled_toolkits()
131
+
132
+ def _load_disabled_tools(self) -> None:
133
+ """Load disabled tools from the environment variable.
134
+
135
+ The ARCADE_DISABLED_TOOLS environment variable should contain a
136
+ comma-separated list of tools that are to be excluded from the
137
+ catalog.
138
+
139
+ The expected format for each disabled tool is:
140
+ - [CamelCaseToolkitName][TOOL_NAME_SEPARATOR][CamelCaseToolName]
141
+ """
142
+ disabled_tools = os.getenv("ARCADE_DISABLED_TOOLS", "").strip().split(",")
143
+ if not disabled_tools:
144
+ return
145
+
146
+ pattern = re.compile(rf"^[a-zA-Z]+{re.escape(TOOL_NAME_SEPARATOR)}[a-zA-Z]+$")
147
+
148
+ for tool in disabled_tools:
149
+ if not pattern.match(tool):
150
+ continue
151
+
152
+ self._disabled_tools.add(tool.lower())
153
+
154
+ def _load_disabled_toolkits(self) -> None:
155
+ """Load disabled toolkits from the environment variable.
156
+
157
+ The ARCADE_DISABLED_TOOLKITS environment variable should contain a
158
+ comma-separated list of toolkits that are to be excluded from the
159
+ catalog.
160
+
161
+ The expected format for each disabled toolkit is:
162
+ - [CamelCaseToolkitName]
163
+ """
164
+ disabled_toolkits = os.getenv("ARCADE_DISABLED_TOOLKITS", "").strip().split(",")
165
+ if not disabled_toolkits:
166
+ return
167
+
168
+ for toolkit in disabled_toolkits:
169
+ self._disabled_toolkits.add(toolkit.lower())
170
+
171
+ def add_tool(
172
+ self,
173
+ tool_func: Callable,
174
+ toolkit_or_name: Union[str, Toolkit],
175
+ module: ModuleType | None = None,
176
+ ) -> None:
177
+ """
178
+ Add a function to the catalog as a tool.
179
+ """
180
+
181
+ input_model, output_model = create_func_models(tool_func)
182
+
183
+ if isinstance(toolkit_or_name, Toolkit):
184
+ toolkit = toolkit_or_name
185
+ toolkit_name = toolkit.name
186
+ elif isinstance(toolkit_or_name, str):
187
+ toolkit = None
188
+ toolkit_name = toolkit_or_name
189
+
190
+ if not toolkit_name:
191
+ raise ValueError("A toolkit name or toolkit must be provided.")
192
+
193
+ definition = ToolCatalog.create_tool_definition(
194
+ tool_func,
195
+ toolkit_name,
196
+ toolkit.version if toolkit else None,
197
+ toolkit.description if toolkit else None,
198
+ )
199
+
200
+ fully_qualified_name = definition.get_fully_qualified_name()
201
+
202
+ if fully_qualified_name in self._tools:
203
+ raise KeyError(f"Tool '{definition.name}' already exists in the catalog.")
204
+
205
+ if str(fully_qualified_name).lower() in self._disabled_tools:
206
+ logger.info(f"Tool '{fully_qualified_name!s}' is disabled and will not be cataloged.")
207
+ return
208
+
209
+ if str(toolkit_name).lower() in self._disabled_toolkits:
210
+ logger.info(f"Toolkit '{toolkit_name!s}' is disabled and will not be cataloged.")
211
+ return
212
+
213
+ self._tools[fully_qualified_name] = MaterializedTool(
214
+ definition=definition,
215
+ tool=tool_func,
216
+ meta=ToolMeta(
217
+ module=module.__name__ if module else tool_func.__module__,
218
+ toolkit=toolkit_name,
219
+ package=toolkit.package_name if toolkit else None,
220
+ path=module.__file__ if module else None,
221
+ ),
222
+ input_model=input_model,
223
+ output_model=output_model,
224
+ )
225
+
226
+ def add_module(self, module: ModuleType) -> None:
227
+ """
228
+ Add all the tools in a module to the catalog.
229
+ """
230
+ toolkit = Toolkit.from_module(module)
231
+ self.add_toolkit(toolkit)
232
+
233
+ def add_toolkit(self, toolkit: Toolkit) -> None:
234
+ """
235
+ Add the tools from a loaded toolkit to the catalog.
236
+ """
237
+
238
+ if str(toolkit).lower() in self._disabled_toolkits:
239
+ logger.info(f"Toolkit '{toolkit.name!s}' is disabled and will not be cataloged.")
240
+ return
241
+
242
+ for module_name, tool_names in toolkit.tools.items():
243
+ for tool_name in tool_names:
244
+ try:
245
+ module = import_module(module_name)
246
+ tool_func = getattr(module, tool_name)
247
+ self.add_tool(tool_func, toolkit, module)
248
+
249
+ except AttributeError as e:
250
+ raise ToolDefinitionError(
251
+ f"Could not import tool {tool_name} in module {module_name}. Reason: {e}"
252
+ )
253
+ except ImportError as e:
254
+ raise ToolDefinitionError(f"Could not import module {module_name}. Reason: {e}")
255
+ except TypeError as e:
256
+ raise ToolDefinitionError(
257
+ f"Type error encountered while adding tool {tool_name} from {module_name}. Reason: {e}"
258
+ )
259
+ except Exception as e:
260
+ raise ToolDefinitionError(
261
+ f"Error encountered while adding tool {tool_name} from {module_name}. Reason: {e}"
262
+ )
263
+
264
+ def __getitem__(self, name: FullyQualifiedName) -> MaterializedTool:
265
+ return self.get_tool(name)
266
+
267
+ def __contains__(self, name: FullyQualifiedName) -> bool:
268
+ return name in self._tools
269
+
270
+ def __iter__(self) -> Iterator[MaterializedTool]: # type: ignore[override]
271
+ yield from self._tools.values()
272
+
273
+ def __len__(self) -> int:
274
+ return len(self._tools)
275
+
276
+ def is_empty(self) -> bool:
277
+ return len(self._tools) == 0
278
+
279
+ def get_tool_names(self) -> list[FullyQualifiedName]:
280
+ return [tool.definition.get_fully_qualified_name() for tool in self._tools.values()]
281
+
282
+ def find_tool_by_func(self, func: Callable) -> ToolDefinition:
283
+ """
284
+ Find a tool by its function.
285
+ """
286
+ for _, tool in self._tools.items():
287
+ if tool.tool == func:
288
+ return tool.definition
289
+ raise ValueError(f"Tool {func} not found in the catalog.")
290
+
291
+ def get_tool_by_name(
292
+ self, name: str, version: Optional[str] = None, separator: str = TOOL_NAME_SEPARATOR
293
+ ) -> MaterializedTool:
294
+ """Get a tool from the catalog by name.
295
+
296
+ Args:
297
+ name: The name of the tool, potentially including the toolkit name separated by the `separator`.
298
+ version: The version of the toolkit. Defaults to None.
299
+ separator: The separator between toolkit and tool names. Defaults to `TOOL_NAME_SEPARATOR`.
300
+
301
+ Returns:
302
+ MaterializedTool: The matching tool from the catalog.
303
+
304
+ Raises:
305
+ ValueError: If the tool is not found in the catalog.
306
+ """
307
+ if separator in name:
308
+ toolkit_name, tool_name = name.split(separator, 1)
309
+ fq_name = FullyQualifiedName(
310
+ name=tool_name, toolkit_name=toolkit_name, toolkit_version=version
311
+ )
312
+ return self.get_tool(fq_name)
313
+ else:
314
+ # No toolkit name provided, search tools with matching tool name
315
+ matching_tools = [
316
+ tool
317
+ for fq_name, tool in self._tools.items()
318
+ if fq_name.name.lower() == name.lower()
319
+ and (
320
+ version is None
321
+ or (fq_name.toolkit_version or "").lower() == (version or "").lower()
322
+ )
323
+ ]
324
+ if matching_tools:
325
+ return matching_tools[0]
326
+
327
+ raise ValueError(f"Tool {name} not found in the catalog.")
328
+
329
+ def get_tool(self, name: FullyQualifiedName) -> MaterializedTool:
330
+ """
331
+ Get a tool from the catalog by fully-qualified name and version.
332
+ If the version is not specified, the any version is returned.
333
+ """
334
+ if name.toolkit_version:
335
+ try:
336
+ return self._tools[name]
337
+ except KeyError:
338
+ raise ValueError(f"Tool {name}@{name.toolkit_version} not found in the catalog.")
339
+
340
+ for key, tool in self._tools.items():
341
+ if key.equals_ignoring_version(name):
342
+ return tool
343
+
344
+ raise ValueError(f"Tool {name} not found.")
345
+
346
+ def get_tool_count(self) -> int:
347
+ """
348
+ Get the number of tools in the catalog.
349
+ """
350
+ return len(self._tools)
351
+
352
+ @staticmethod
353
+ def create_tool_definition(
354
+ tool: Callable,
355
+ toolkit_name: str,
356
+ toolkit_version: Optional[str] = None,
357
+ toolkit_desc: Optional[str] = None,
358
+ ) -> ToolDefinition:
359
+ """
360
+ Given a tool function, create a ToolDefinition
361
+ """
362
+
363
+ raw_tool_name = getattr(tool, "__tool_name__", tool.__name__)
364
+
365
+ # Hard requirement: tools must have descriptions
366
+ tool_description = getattr(tool, "__tool_description__", None)
367
+ if not tool_description:
368
+ raise ToolDefinitionError(f"Tool {raw_tool_name} is missing a description")
369
+
370
+ # If the function returns a value, it must have a type annotation
371
+ if does_function_return_value(tool) and tool.__annotations__.get("return") is None:
372
+ raise ToolDefinitionError(f"Tool {raw_tool_name} must have a return type annotation")
373
+
374
+ auth_requirement = create_auth_requirement(tool)
375
+ secrets_requirement = create_secrets_requirement(tool)
376
+ metadata_requirement = create_metadata_requirement(tool, auth_requirement)
377
+
378
+ toolkit_definition = ToolkitDefinition(
379
+ name=snake_to_pascal_case(toolkit_name),
380
+ description=toolkit_desc,
381
+ version=toolkit_version,
382
+ )
383
+
384
+ tool_name = snake_to_pascal_case(raw_tool_name)
385
+ fully_qualified_name = FullyQualifiedName.from_toolkit(tool_name, toolkit_definition)
386
+ deprecation_message = getattr(tool, "__tool_deprecation_message__", None)
387
+
388
+ return ToolDefinition(
389
+ name=tool_name,
390
+ fully_qualified_name=str(fully_qualified_name),
391
+ description=tool_description,
392
+ toolkit=toolkit_definition,
393
+ input=create_input_definition(tool),
394
+ output=create_output_definition(tool),
395
+ requirements=ToolRequirements(
396
+ authorization=auth_requirement,
397
+ secrets=secrets_requirement,
398
+ metadata=metadata_requirement,
399
+ ),
400
+ deprecation_message=deprecation_message,
401
+ )
402
+
403
+
404
+ def create_input_definition(func: Callable) -> ToolInput:
405
+ """
406
+ Create an input model for a function based on its parameters.
407
+ """
408
+ input_parameters = []
409
+ tool_context_param_name: str | None = None
410
+
411
+ for _, param in inspect.signature(func, follow_wrapped=True).parameters.items():
412
+ if param.annotation is ToolContext:
413
+ if tool_context_param_name is not None:
414
+ raise ToolDefinitionError(
415
+ f"Only one ToolContext parameter is supported, but tool {func.__name__} has multiple."
416
+ )
417
+
418
+ tool_context_param_name = param.name
419
+ continue # No further processing of this param (don't add it to the list of inputs)
420
+
421
+ tool_field_info = extract_field_info(param)
422
+
423
+ # If the field has a default value, it is not required
424
+ # If the field is optional, it is not required
425
+ has_default_value = tool_field_info.default is not None
426
+ is_required = not tool_field_info.is_optional and not has_default_value
427
+
428
+ input_parameters.append(
429
+ InputParameter(
430
+ name=tool_field_info.name,
431
+ description=tool_field_info.description,
432
+ required=is_required,
433
+ inferrable=tool_field_info.is_inferrable,
434
+ value_schema=ValueSchema(
435
+ val_type=tool_field_info.wire_type_info.wire_type,
436
+ inner_val_type=tool_field_info.wire_type_info.inner_wire_type,
437
+ enum=tool_field_info.wire_type_info.enum_values,
438
+ ),
439
+ )
440
+ )
441
+
442
+ return ToolInput(
443
+ parameters=input_parameters, tool_context_parameter_name=tool_context_param_name
444
+ )
445
+
446
+
447
+ def create_output_definition(func: Callable) -> ToolOutput:
448
+ """
449
+ Create an output model for a function based on its return annotation.
450
+ """
451
+ return_type = inspect.signature(func, follow_wrapped=True).return_annotation
452
+ description = "No description provided."
453
+
454
+ if return_type is inspect.Signature.empty:
455
+ return ToolOutput(
456
+ value_schema=None,
457
+ description="No description provided.",
458
+ available_modes=["null"],
459
+ )
460
+
461
+ if hasattr(return_type, "__metadata__"):
462
+ description = return_type.__metadata__[0] if return_type.__metadata__ else None # type: ignore[assignment]
463
+ return_type = return_type.__origin__
464
+
465
+ # Unwrap Optional types
466
+ # Both Optional[T] and T | None are supported
467
+ is_optional = is_strict_optional(return_type)
468
+ if is_optional:
469
+ return_type = next(arg for arg in get_args(return_type) if arg is not type(None))
470
+
471
+ wire_type_info = get_wire_type_info(return_type)
472
+
473
+ available_modes = ["value", "error"]
474
+
475
+ if is_optional:
476
+ available_modes.append("null")
477
+
478
+ return ToolOutput(
479
+ description=description,
480
+ available_modes=available_modes,
481
+ value_schema=ValueSchema(
482
+ val_type=wire_type_info.wire_type,
483
+ inner_val_type=wire_type_info.inner_wire_type,
484
+ enum=wire_type_info.enum_values,
485
+ ),
486
+ )
487
+
488
+
489
+ def create_auth_requirement(tool: Callable) -> ToolAuthRequirement | None:
490
+ """
491
+ Create an auth requirement for a tool.
492
+ """
493
+ auth_requirement = getattr(tool, "__tool_requires_auth__", None)
494
+ if isinstance(auth_requirement, ToolAuthorization):
495
+ new_auth_requirement = ToolAuthRequirement(
496
+ provider_id=auth_requirement.provider_id,
497
+ provider_type=auth_requirement.provider_type,
498
+ id=auth_requirement.id,
499
+ )
500
+ if isinstance(auth_requirement, OAuth2):
501
+ new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
502
+ auth_requirement = new_auth_requirement
503
+
504
+ return auth_requirement
505
+
506
+
507
+ def create_secrets_requirement(tool: Callable) -> list[ToolSecretRequirement] | None:
508
+ """
509
+ Create a secrets requirement for a tool.
510
+ """
511
+ raw_tool_name = getattr(tool, "__tool_name__", tool.__name__)
512
+ secrets_requirement = getattr(tool, "__tool_requires_secrets__", None)
513
+ if isinstance(secrets_requirement, list):
514
+ if any(not isinstance(secret, str) for secret in secrets_requirement):
515
+ raise ToolDefinitionError(
516
+ f"Secret keys must be strings (error in tool {raw_tool_name})."
517
+ )
518
+
519
+ secrets_requirement = to_tool_secret_requirements(secrets_requirement)
520
+ if any(secret.key is None or secret.key.strip() == "" for secret in secrets_requirement):
521
+ raise ToolDefinitionError(
522
+ f"Secrets must have a non-empty key (error in tool {raw_tool_name})."
523
+ )
524
+
525
+ return secrets_requirement
526
+
527
+
528
+ def create_metadata_requirement(
529
+ tool: Callable, auth_requirement: ToolAuthRequirement | None
530
+ ) -> list[ToolMetadataRequirement] | None:
531
+ """
532
+ Create a metadata requirement for a tool.
533
+ """
534
+ raw_tool_name = getattr(tool, "__tool_name__", tool.__name__)
535
+ metadata_requirement = getattr(tool, "__tool_requires_metadata__", None)
536
+ if isinstance(metadata_requirement, list):
537
+ for metadata in metadata_requirement:
538
+ if not isinstance(metadata, str):
539
+ raise ToolDefinitionError(
540
+ f"Metadata must be strings (error in tool {raw_tool_name})."
541
+ )
542
+ if ToolMetadataKey.requires_auth(metadata) and auth_requirement is None:
543
+ raise ToolDefinitionError(
544
+ f"Tool {raw_tool_name} declares metadata key '{metadata}', "
545
+ "which requires that the tool has an auth requirement, "
546
+ "but no auth requirement was provided. Please specify an auth requirement."
547
+ )
548
+
549
+ metadata_requirement = to_tool_metadata_requirements(metadata_requirement)
550
+ if any(
551
+ metadata.key is None or metadata.key.strip() == "" for metadata in metadata_requirement
552
+ ):
553
+ raise ToolDefinitionError(
554
+ f"Metadata must have a non-empty key (error in tool {raw_tool_name})."
555
+ )
556
+
557
+ return metadata_requirement
558
+
559
+
560
+ @dataclass
561
+ class ParamInfo:
562
+ """
563
+ Information about a function parameter found through inspection.
564
+ """
565
+
566
+ name: str
567
+ default: Any
568
+ original_type: type
569
+ field_type: type
570
+ description: str | None = None
571
+ is_optional: bool = True
572
+
573
+
574
+ @dataclass
575
+ class ToolParamInfo:
576
+ """
577
+ Information about a tool parameter, including computed values.
578
+ """
579
+
580
+ name: str
581
+ default: Any
582
+ original_type: type
583
+ field_type: type
584
+ wire_type_info: WireTypeInfo
585
+ description: str | None = None
586
+ is_optional: bool = True
587
+ is_inferrable: bool = True
588
+
589
+ @classmethod
590
+ def from_param_info(
591
+ cls,
592
+ param_info: ParamInfo,
593
+ wire_type_info: WireTypeInfo,
594
+ is_inferrable: bool = True,
595
+ ) -> "ToolParamInfo":
596
+ return cls(
597
+ name=param_info.name,
598
+ default=param_info.default,
599
+ original_type=param_info.original_type,
600
+ field_type=param_info.field_type,
601
+ description=param_info.description,
602
+ is_optional=param_info.is_optional,
603
+ wire_type_info=wire_type_info,
604
+ is_inferrable=is_inferrable,
605
+ )
606
+
607
+
608
+ def extract_field_info(param: inspect.Parameter) -> ToolParamInfo:
609
+ """
610
+ Extract type and field parameters from a function parameter.
611
+ """
612
+ annotation = param.annotation
613
+ if annotation == inspect.Parameter.empty:
614
+ raise ToolDefinitionError(f"Parameter {param} has no type annotation.")
615
+
616
+ # Get the majority of the param info from either the Pydantic Field() or regular inspection
617
+ if isinstance(param.default, FieldInfo):
618
+ param_info = extract_pydantic_param_info(param)
619
+ else:
620
+ param_info = extract_python_param_info(param)
621
+
622
+ metadata = getattr(annotation, "__metadata__", [])
623
+ str_annotations = [m for m in metadata if isinstance(m, str)]
624
+
625
+ # Get the description from annotations, if present
626
+ if len(str_annotations) == 0:
627
+ pass
628
+ elif len(str_annotations) == 1:
629
+ param_info.description = str_annotations[0]
630
+ elif len(str_annotations) == 2:
631
+ new_name = str_annotations[0]
632
+ if not new_name.isidentifier():
633
+ raise ToolDefinitionError(
634
+ f"Invalid parameter name: '{new_name}' is not a valid identifier. "
635
+ "Identifiers must start with a letter or underscore, "
636
+ "and can only contain letters, digits, or underscores."
637
+ )
638
+ param_info.name = new_name
639
+ param_info.description = str_annotations[1]
640
+ else:
641
+ raise ToolDefinitionError(
642
+ f"Parameter {param} has too many string annotations. Expected 0, 1, or 2, got {len(str_annotations)}."
643
+ )
644
+
645
+ # Get the Inferrable annotation, if it exists
646
+ inferrable_annotation = first_or_none(Inferrable, get_args(annotation))
647
+
648
+ # Params are inferrable by default
649
+ is_inferrable = inferrable_annotation.value if inferrable_annotation else True
650
+
651
+ # Get the wire (serialization) type information for the type
652
+ wire_type_info = get_wire_type_info(param_info.field_type)
653
+
654
+ # Final reality check
655
+ if param_info.description is None:
656
+ raise ToolDefinitionError(f"Parameter {param_info.name} is missing a description")
657
+
658
+ if wire_type_info.wire_type is None:
659
+ raise ToolDefinitionError(f"Unknown parameter type: {param_info.field_type}")
660
+
661
+ return ToolParamInfo.from_param_info(param_info, wire_type_info, is_inferrable)
662
+
663
+
664
+ def get_wire_type_info(_type: type) -> WireTypeInfo:
665
+ """
666
+ Get the wire type information for a given type.
667
+ """
668
+
669
+ # Is this a list type?
670
+ # If so, get the inner (enclosed) type
671
+ is_list = get_origin(_type) is list
672
+ if is_list:
673
+ inner_type = get_args(_type)[0]
674
+ inner_wire_type = cast(
675
+ InnerWireType,
676
+ get_wire_type(str) if is_string_literal(inner_type) else get_wire_type(inner_type),
677
+ )
678
+ else:
679
+ inner_wire_type = None
680
+
681
+ # Get the outer wire type
682
+ wire_type = get_wire_type(str) if is_string_literal(_type) else get_wire_type(_type)
683
+
684
+ # Handle enums (known/fixed lists of values)
685
+ is_enum = False
686
+ enum_values: list[str] = []
687
+
688
+ type_to_check = inner_type if is_list else _type
689
+
690
+ # Strip generic parameters if type_to_check is a parameterized generic
691
+ actual_type = get_origin(type_to_check) or type_to_check
692
+
693
+ # Special case: Literal["string1", "string2"] can be enumerated on the wire
694
+ if is_string_literal(type_to_check):
695
+ is_enum = True
696
+ enum_values = [str(e) for e in get_args(type_to_check)]
697
+
698
+ # Special case: Enum can be enumerated on the wire
699
+ elif issubclass(actual_type, Enum):
700
+ is_enum = True
701
+ enum_values = [e.value for e in actual_type] # type: ignore[union-attr]
702
+
703
+ return WireTypeInfo(wire_type, inner_wire_type, enum_values if is_enum else None)
704
+
705
+
706
+ def extract_python_param_info(param: inspect.Parameter) -> ParamInfo:
707
+ # If the param is Annotated[], unwrap the annotation to get the "real" type
708
+ # Otherwise, use the literal type
709
+ annotation = param.annotation
710
+ original_type = annotation.__args__[0] if get_origin(annotation) is Annotated else annotation
711
+ field_type = original_type
712
+
713
+ # Handle optional types
714
+ # Both Optional[T] and T | None are supported
715
+ is_optional = is_strict_optional(field_type)
716
+ if is_optional:
717
+ field_type = next(arg for arg in get_args(field_type) if arg is not type(None))
718
+
719
+ # Union types are not currently supported
720
+ # (other than optional, which is handled above)
721
+ if is_union(field_type):
722
+ raise ToolDefinitionError(
723
+ f"Parameter {param.name} is a union type. Only optional types are supported."
724
+ )
725
+
726
+ return ParamInfo(
727
+ name=param.name,
728
+ default=param.default if param.default is not inspect.Parameter.empty else None,
729
+ is_optional=is_optional,
730
+ original_type=original_type,
731
+ field_type=field_type,
732
+ )
733
+
734
+
735
+ def extract_pydantic_param_info(param: inspect.Parameter) -> ParamInfo:
736
+ default_value = None if param.default.default is PydanticUndefined else param.default.default
737
+
738
+ if param.default.default_factory is not None:
739
+ if callable(param.default.default_factory):
740
+ default_value = param.default.default_factory()
741
+ else:
742
+ raise ToolDefinitionError(f"Default factory for parameter {param} is not callable.")
743
+
744
+ # If the param is Annotated[], unwrap the annotation to get the "real" type
745
+ # Otherwise, use the literal type
746
+ original_type = (
747
+ param.annotation.__args__[0]
748
+ if get_origin(param.annotation) is Annotated
749
+ else param.annotation
750
+ )
751
+ field_type = original_type
752
+
753
+ # Unwrap Optional types
754
+ # Both Optional[T] and T | None are supported
755
+ is_optional = is_strict_optional(field_type)
756
+ if is_optional:
757
+ field_type = next(arg for arg in get_args(field_type) if arg is not type(None))
758
+
759
+ return ParamInfo(
760
+ name=param.name,
761
+ description=param.default.description,
762
+ default=default_value,
763
+ is_optional=is_optional,
764
+ original_type=original_type,
765
+ field_type=field_type,
766
+ )
767
+
768
+
769
+ def get_wire_type(
770
+ _type: type,
771
+ ) -> WireType:
772
+ """
773
+ Mapping between Python types and HTTP/JSON types
774
+ """
775
+ # TODO ensure Any is not allowed
776
+ type_mapping: dict[type, WireType] = {
777
+ str: "string",
778
+ bool: "boolean",
779
+ int: "integer",
780
+ float: "number",
781
+ dict: "json",
782
+ }
783
+ outer_type_mapping: dict[type, WireType] = {
784
+ list: "array",
785
+ dict: "json",
786
+ }
787
+ wire_type = type_mapping.get(_type)
788
+ if wire_type:
789
+ return wire_type
790
+
791
+ if hasattr(_type, "__origin__"):
792
+ wire_type = outer_type_mapping.get(cast(type, get_origin(_type)))
793
+ if wire_type:
794
+ return wire_type
795
+
796
+ if isinstance(_type, type) and issubclass(_type, Enum):
797
+ return "string"
798
+
799
+ if isinstance(_type, type) and issubclass(_type, BaseModel):
800
+ return "json"
801
+
802
+ raise ToolDefinitionError(f"Unsupported parameter type: {_type}")
803
+
804
+
805
+ def create_func_models(func: Callable) -> tuple[type[BaseModel], type[BaseModel]]:
806
+ """
807
+ Analyze a function to create corresponding Pydantic models for its input and output.
808
+ """
809
+ input_fields = {}
810
+ # TODO figure this out (Sam)
811
+ if asyncio.iscoroutinefunction(func) and hasattr(func, "__wrapped__"):
812
+ func = func.__wrapped__
813
+ for name, param in inspect.signature(func, follow_wrapped=True).parameters.items():
814
+ # Skip ToolContext parameters
815
+ if param.annotation is ToolContext:
816
+ continue
817
+
818
+ # TODO make this cleaner
819
+ tool_field_info = extract_field_info(param)
820
+ param_fields = {
821
+ "default": tool_field_info.default,
822
+ "description": tool_field_info.description,
823
+ # TODO more here?
824
+ }
825
+ input_fields[name] = (tool_field_info.field_type, Field(**param_fields))
826
+
827
+ input_model = create_model(f"{snake_to_pascal_case(func.__name__)}Input", **input_fields) # type: ignore[call-overload]
828
+
829
+ output_model = determine_output_model(func)
830
+
831
+ return input_model, output_model
832
+
833
+
834
+ def determine_output_model(func: Callable) -> type[BaseModel]:
835
+ """
836
+ Determine the output model for a function based on its return annotation.
837
+ """
838
+ return_annotation = inspect.signature(func).return_annotation
839
+ output_model_name = f"{snake_to_pascal_case(func.__name__)}Output"
840
+ if return_annotation is inspect.Signature.empty:
841
+ return create_model(output_model_name)
842
+ elif hasattr(return_annotation, "__origin__"):
843
+ if hasattr(return_annotation, "__metadata__"):
844
+ field_type = return_annotation.__args__[0]
845
+ description = (
846
+ return_annotation.__metadata__[0] if return_annotation.__metadata__ else ""
847
+ )
848
+ if description:
849
+ return create_model(
850
+ output_model_name,
851
+ result=(field_type, Field(description=str(description))),
852
+ )
853
+ # Handle Union types
854
+ origin = return_annotation.__origin__
855
+ if origin is typing.Union:
856
+ # For union types, create a model with the first non-None argument
857
+ # TODO handle multiple non-None arguments. Raise error?
858
+ for arg in get_args(return_annotation):
859
+ if arg is not type(None):
860
+ return create_model(
861
+ output_model_name,
862
+ result=(arg, Field(description="No description provided.")),
863
+ )
864
+ # when the return_annotation has an __origin__ attribute
865
+ # and does not have a __metadata__ attribute.
866
+ return create_model(
867
+ output_model_name,
868
+ result=(
869
+ return_annotation,
870
+ Field(description="No description provided."),
871
+ ),
872
+ )
873
+ else:
874
+ # Handle simple return types (like str)
875
+ return create_model(
876
+ output_model_name,
877
+ result=(return_annotation, Field(description="No description provided.")),
878
+ )
879
+
880
+
881
+ def to_tool_secret_requirements(
882
+ secrets_requirement: list[str],
883
+ ) -> list[ToolSecretRequirement]:
884
+ # Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolSecretRequirement
885
+ unique_secrets = {name.lower(): name.lower() for name in secrets_requirement}.values()
886
+ return [ToolSecretRequirement(key=name) for name in unique_secrets]
887
+
888
+
889
+ def to_tool_metadata_requirements(
890
+ metadata_requirement: list[str],
891
+ ) -> list[ToolMetadataRequirement]:
892
+ # Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolMetadataRequirement
893
+ unique_metadata = {name.lower(): name.lower() for name in metadata_requirement}.values()
894
+ return [ToolMetadataRequirement(key=name) for name in unique_metadata]