dais-sdk 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,157 @@
1
+ import httpx
2
+ import webbrowser
3
+ from dataclasses import dataclass
4
+ from contextlib import AsyncExitStack
5
+ from typing import Any, NamedTuple
6
+ from mcp import ClientSession
7
+ from mcp.client.auth import OAuthClientProvider
8
+ from mcp.client.streamable_http import streamable_http_client
9
+ from mcp.shared.auth import OAuthClientMetadata
10
+ from pydantic import AnyUrl, BaseModel, Field, ConfigDict, SkipValidation
11
+ from .oauth_server import LocalOAuthServer, OAuthCode, TokenStorage, InMemoryTokenStorage
12
+ from .base_mcp_client import McpClient, Tool, ToolResult, McpSessionNotEstablishedError
13
+ from ..logger import logger
14
+
15
+ @dataclass
16
+ class OAuthParams:
17
+ oauth_scopes: list[str] | None = None
18
+ oauth_timeout: int = 120
19
+ oauth_token_storage: SkipValidation[TokenStorage] = Field(
20
+ default_factory=InMemoryTokenStorage,
21
+ exclude=True,
22
+ )
23
+
24
+ class RemoteServerParams(BaseModel):
25
+ model_config = ConfigDict(arbitrary_types_allowed=True)
26
+
27
+ url: str
28
+ bearer_token: str | None = None
29
+ oauth_params: OAuthParams | None = None
30
+ http_headers: dict[str, str] | None = None
31
+
32
+ # --- --- --- --- --- ---
33
+
34
+ class OAuthContext(NamedTuple):
35
+ client: httpx.AsyncClient
36
+ server: LocalOAuthServer
37
+
38
+ class RemoteMcpClient(McpClient):
39
+ def __init__(self,
40
+ name: str,
41
+ params: RemoteServerParams,
42
+ storage: TokenStorage | None = None):
43
+ self._name = name
44
+ self._params = params
45
+ self._session: ClientSession | None = None
46
+ self._exit_stack: AsyncExitStack | None = None
47
+ self._oauth_context: OAuthContext | None = self._init_oauth()
48
+ if self._params.oauth_params is not None and storage is not None:
49
+ self._params.oauth_params.oauth_token_storage = storage
50
+
51
+ @property
52
+ def name(self) -> str:
53
+ return self._name
54
+
55
+ def _init_http_headers(self) -> dict[str, str] | None:
56
+ if self._params.http_headers is None and self._params.bearer_token is None:
57
+ return None
58
+ headers = {}
59
+ if self._params.http_headers is not None:
60
+ headers.update(self._params.http_headers)
61
+ if self._params.bearer_token is not None:
62
+ headers["Authorization"] = f"Bearer {self._params.bearer_token}"
63
+ return headers
64
+
65
+ def _init_oauth(self) -> OAuthContext | None:
66
+ if self._params.oauth_params is None:
67
+ return None
68
+
69
+ oauth_params = self._params.oauth_params
70
+
71
+ server = LocalOAuthServer(timeout=oauth_params.oauth_timeout)
72
+ scopes = None
73
+ if oauth_params.oauth_scopes is not None:
74
+ scopes = " ".join(oauth_params.oauth_scopes)
75
+
76
+ client_provider = OAuthClientProvider(
77
+ server_url=self._params.url,
78
+ client_metadata=OAuthClientMetadata(
79
+ client_name=self._name,
80
+ redirect_uris=[AnyUrl(server.callback_url)],
81
+ grant_types=["authorization_code", "refresh_token"],
82
+ response_types=["code"],
83
+ scope=scopes,
84
+ token_endpoint_auth_method="none",
85
+ ),
86
+ storage=oauth_params.oauth_token_storage,
87
+ redirect_handler=self._handle_redirect,
88
+ callback_handler=self._handle_oauth_callback,
89
+ )
90
+ client = httpx.AsyncClient(auth=client_provider,
91
+ headers=self._init_http_headers(),
92
+ follow_redirects=True)
93
+ return OAuthContext(client, server)
94
+
95
+ async def _handle_redirect(self, url: str) -> None:
96
+ logger.info("[OAuth] Authentication required, opening browser...")
97
+ logger.info(f"[OAuth] If browser does not open automatically, copy and open the following link: \n{url}\n")
98
+ try:
99
+ webbrowser.open(url)
100
+ except Exception as e:
101
+ logger.error(f"[OAuth] Not able to open browser: {e}")
102
+
103
+ async def _handle_oauth_callback(self) -> OAuthCode:
104
+ if self._oauth_context is None:
105
+ raise ValueError("OAuth context not initialized")
106
+ return await self._oauth_context.server.wait_for_code()
107
+
108
+ async def connect(self):
109
+ self._exit_stack = AsyncExitStack()
110
+ if self._oauth_context:
111
+ http_client = self._oauth_context.client
112
+ await self._oauth_context.server.start()
113
+ else:
114
+ http_client = await self._exit_stack.enter_async_context(
115
+ httpx.AsyncClient(headers=self._init_http_headers(), follow_redirects=True))
116
+
117
+ try:
118
+ read_stream, write_stream, _ = await self._exit_stack.enter_async_context(
119
+ streamable_http_client(self._params.url, http_client=http_client)
120
+ )
121
+ self._session = await self._exit_stack.enter_async_context(
122
+ ClientSession(read_stream, write_stream)
123
+ )
124
+ await self._session.initialize()
125
+ except Exception:
126
+ await self.disconnect()
127
+ raise
128
+
129
+ async def list_tools(self) -> list[Tool]:
130
+ if not self._session:
131
+ raise McpSessionNotEstablishedError()
132
+
133
+ result = await self._session.list_tools()
134
+ return result.tools
135
+
136
+ async def call_tool(
137
+ self, tool_name: str, arguments: dict[str, Any] | None = None
138
+ ) -> ToolResult:
139
+ if not self._session:
140
+ raise McpSessionNotEstablishedError()
141
+
142
+ response = await self._session.call_tool(tool_name, arguments)
143
+ return ToolResult(response.isError, response.content)
144
+
145
+ async def disconnect(self):
146
+ try:
147
+ if self._exit_stack:
148
+ await self._exit_stack.aclose()
149
+ finally:
150
+ self._session = None
151
+ self._exit_stack = None
152
+
153
+ if self._oauth_context:
154
+ try: await self._oauth_context.client.aclose()
155
+ except Exception: pass
156
+ try: await self._oauth_context.server.stop()
157
+ except Exception: pass
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+ from typing import Any, TYPE_CHECKING
3
+ from litellm.types.utils import LlmProviders
4
+ from .tool.prepare import prepare_tools
5
+ from .types.message import ToolMessage
6
+
7
+ if TYPE_CHECKING:
8
+ from .types.request_params import LlmRequestParams
9
+
10
+ ParsedParams = dict[str, Any]
11
+
12
+ class ParamParser:
13
+ def __init__(self,
14
+ provider: LlmProviders,
15
+ base_url: str,
16
+ api_key: str):
17
+ self._provider = provider
18
+ self._base_url = base_url
19
+ self._api_key = api_key
20
+
21
+ def _parse(self, params: LlmRequestParams) -> ParsedParams:
22
+ extracted_tool_likes = params.extract_tools()
23
+ tools = extracted_tool_likes and prepare_tools(extracted_tool_likes)
24
+
25
+ transformed_messages = []
26
+ for message in params.messages:
27
+ if (type(message) is ToolMessage and
28
+ message.result is None and
29
+ message.error is None):
30
+ # skip ToolMessage that is not resolved
31
+ continue
32
+ transformed_messages.append(message.to_litellm_message())
33
+
34
+ return {
35
+ "model": f"{self._provider.value}/{params.model}",
36
+ "messages": transformed_messages,
37
+ "base_url": self._base_url,
38
+ "api_key": self._api_key,
39
+ "tools": tools,
40
+ "tool_choice": params.tool_choice,
41
+ "timeout": params.timeout_sec,
42
+ "extra_headers": params.headers,
43
+ **(params.extra_args or {})
44
+ }
45
+
46
+ def parse_nonstream(self, params: LlmRequestParams) -> ParsedParams:
47
+ parsed = self._parse(params)
48
+ parsed["stream"] = False
49
+ return parsed
50
+
51
+ def parse_stream(self, params: LlmRequestParams) -> ParsedParams:
52
+ parsed = self._parse(params)
53
+ parsed["stream"] = True
54
+ parsed["stream_options"] = {"include_usage": True}
55
+ return parsed
dais_sdk/stream.py ADDED
@@ -0,0 +1,82 @@
1
+ import dataclasses
2
+ from litellm import ChatCompletionAssistantToolCall
3
+ from litellm.types.utils import (ChatCompletionDeltaToolCall,
4
+ ModelResponseStream as LiteLlmModelResponseStream)
5
+ from .types.message import AssistantMessage
6
+
7
+ @dataclasses.dataclass
8
+ class ToolCallTemp:
9
+ id: str | None = None
10
+ name: str = ""
11
+ arguments: str = ""
12
+
13
+ class ToolCallCollector:
14
+ def __init__(self):
15
+ self.tool_call_map: dict[int, ToolCallTemp] = {}
16
+
17
+ def collect(self, tool_call_chunk: ChatCompletionDeltaToolCall):
18
+ if tool_call_chunk.index not in self.tool_call_map:
19
+ self.tool_call_map[tool_call_chunk.index] = ToolCallTemp()
20
+
21
+ temp_tool_call = self.tool_call_map[tool_call_chunk.index]
22
+ if tool_call_chunk.get("id"):
23
+ temp_tool_call.id = tool_call_chunk.id
24
+ if tool_call_chunk.function.get("name"):
25
+ assert tool_call_chunk.function.name is not None
26
+ temp_tool_call.name += tool_call_chunk.function.name
27
+ if tool_call_chunk.function.get("arguments"):
28
+ assert tool_call_chunk.function.arguments is not None
29
+ temp_tool_call.arguments += tool_call_chunk.function.arguments
30
+
31
+ def get_tool_calls(self) -> list[ChatCompletionAssistantToolCall]:
32
+ return [{
33
+ "id": tool_call.id,
34
+ "function": {
35
+ "name": tool_call.name,
36
+ "arguments": tool_call.arguments,
37
+ },
38
+ "type": "function"
39
+ } for tool_call in self.tool_call_map.values()]
40
+
41
+ class AssistantMessageCollector:
42
+ def __init__(self):
43
+ self.message_buf = AssistantMessage(content=None)
44
+ self.tool_call_collector = ToolCallCollector()
45
+
46
+ def collect(self, chunk: LiteLlmModelResponseStream):
47
+ delta = chunk.choices[0].delta
48
+ if delta.get("content"):
49
+ assert delta.content is not None
50
+ if self.message_buf.content is None:
51
+ self.message_buf.content = ""
52
+ self.message_buf.content += delta.content
53
+
54
+ if delta.get("reasoning_content"):
55
+ assert delta.reasoning_content is not None
56
+ if self.message_buf.reasoning_content is None:
57
+ self.message_buf.reasoning_content = ""
58
+ self.message_buf.reasoning_content += delta.reasoning_content
59
+
60
+ if delta.get("tool_calls"):
61
+ assert delta.tool_calls is not None
62
+ for tool_call_chunk in delta.tool_calls:
63
+ self.tool_call_collector.collect(tool_call_chunk)
64
+
65
+ if delta.get("images"):
66
+ assert delta.images is not None
67
+ if self.message_buf.images is None:
68
+ self.message_buf.images = []
69
+ self.message_buf.images = delta.images
70
+
71
+ if delta.get("audio"):
72
+ assert delta.audio is not None
73
+ self.message_buf.audio = delta.audio
74
+
75
+ def get_message(self) -> AssistantMessage:
76
+ self.message_buf.tool_calls = self.tool_call_collector.get_tool_calls()
77
+ return self.message_buf
78
+
79
+ def clear(self):
80
+ """
81
+ This class will be created for each stream, so we don't need to clear it.
82
+ """
File without changes
@@ -0,0 +1,65 @@
1
+ import asyncio
2
+ import json
3
+ from functools import singledispatch
4
+ from typing import Any, Awaitable, Callable, cast
5
+ from types import FunctionType, MethodType, CoroutineType
6
+ from ..types.tool import ToolDef, ToolLike
7
+
8
+ async def _coroutine_wrapper(awaitable: Awaitable[Any]) -> CoroutineType:
9
+ return await awaitable
10
+
11
+ def _arguments_normalizer(arguments: str | dict) -> dict:
12
+ if isinstance(arguments, str):
13
+ parsed = json.loads(arguments)
14
+ return cast(dict, parsed)
15
+ elif isinstance(arguments, dict):
16
+ return arguments
17
+ else:
18
+ raise ValueError(f"Invalid arguments type: {type(arguments)}")
19
+
20
+ def _result_normalizer(result: Any) -> str:
21
+ if isinstance(result, str):
22
+ return result
23
+ return json.dumps(result, ensure_ascii=False)
24
+
25
+ @singledispatch
26
+ def execute_tool_sync(tool: ToolLike, arguments: str | dict) -> str:
27
+ raise ValueError(f"Invalid tool type: {type(tool)}")
28
+
29
+ @execute_tool_sync.register(FunctionType)
30
+ @execute_tool_sync.register(MethodType)
31
+ def _(toolfn: Callable, arguments: str | dict) -> str:
32
+ arguments = _arguments_normalizer(arguments)
33
+ result = (asyncio.run(_coroutine_wrapper(toolfn(**arguments)))
34
+ if asyncio.iscoroutinefunction(toolfn)
35
+ else toolfn(**arguments))
36
+ return _result_normalizer(result)
37
+
38
+ @execute_tool_sync.register(ToolDef)
39
+ def _(tooldef: ToolDef, arguments: str | dict) -> str:
40
+ arguments = _arguments_normalizer(arguments)
41
+ result = (asyncio.run(_coroutine_wrapper(tooldef.execute(**arguments)))
42
+ if asyncio.iscoroutinefunction(tooldef.execute)
43
+ else tooldef.execute(**arguments))
44
+ return _result_normalizer(result)
45
+
46
+ @singledispatch
47
+ async def execute_tool(tool: ToolLike, arguments: str | dict) -> str:
48
+ raise ValueError(f"Invalid tool type: {type(tool)}")
49
+
50
+ @execute_tool.register(FunctionType)
51
+ @execute_tool.register(MethodType)
52
+ async def _(toolfn: Callable, arguments: str | dict) -> str:
53
+ arguments = _arguments_normalizer(arguments)
54
+ result = (await toolfn(**arguments)
55
+ if asyncio.iscoroutinefunction(toolfn)
56
+ else toolfn(**arguments))
57
+ return _result_normalizer(result)
58
+
59
+ @execute_tool.register(ToolDef)
60
+ async def _(tooldef: ToolDef, arguments: str | dict) -> str:
61
+ arguments = _arguments_normalizer(arguments)
62
+ result = (await tooldef.execute(**arguments)
63
+ if asyncio.iscoroutinefunction(tooldef.execute)
64
+ else tooldef.execute(**arguments))
65
+ return _result_normalizer(result)
@@ -0,0 +1,283 @@
1
+ """
2
+ source: https://github.com/mozilla-ai/any-llm/blob/main/src/any_llm/tools.py
3
+ """
4
+
5
+ import dataclasses
6
+ import enum
7
+ import inspect
8
+ import types as _types
9
+ from collections.abc import Mapping, Sequence
10
+ from datetime import date, datetime, time
11
+ from typing import (Annotated as _Annotated, Literal as _Literal,
12
+ is_typeddict as _is_typeddict, Any, get_args,
13
+ get_origin, get_type_hints)
14
+ from pydantic import BaseModel as PydanticBaseModel
15
+ from ..types.tool import ToolFn, ToolDef, RawToolDef, ToolLike
16
+
17
+ def _python_type_to_json_schema(python_type: Any) -> dict[str, Any]:
18
+ """Convert Python type annotation to a JSON Schema for a parameter.
19
+
20
+ Supported mappings (subset tailored for LLM tool schemas):
21
+ - Primitives: str/int/float/bool -> string/integer/number/boolean
22
+ - bytes -> string with contentEncoding base64
23
+ - datetime/date/time -> string with format date-time/date/time
24
+ - list[T] / Sequence[T] / set[T] / frozenset[T] -> array with items=schema(T)
25
+ - set/frozenset include uniqueItems=true
26
+ - list without type args defaults items to string
27
+ - dict[K,V] / Mapping[K,V] -> object with additionalProperties=schema(V)
28
+ - dict without type args defaults additionalProperties to string
29
+ - tuple[T1, T2, ...] -> array with prefixItems per element and min/maxItems
30
+ - tuple[T, ...] -> array with items=schema(T)
31
+ - Union[X, Y] and X | Y -> oneOf=[schema(X), schema(Y)] (without top-level type)
32
+ - Optional[T] (Union[T, None]) -> schema(T) (nullability not encoded)
33
+ - Literal[...]/Enum -> enum with appropriate type inference when uniform
34
+ - TypedDict -> object with properties/required per annotations
35
+ - dataclass/Pydantic BaseModel -> object with nested properties inferred from fields
36
+ """
37
+ origin = get_origin(python_type)
38
+ args = get_args(python_type)
39
+
40
+ if _Annotated is not None and origin is _Annotated and len(args) >= 1:
41
+ python_type = args[0]
42
+ origin = get_origin(python_type)
43
+ args = get_args(python_type)
44
+
45
+ if python_type is Any:
46
+ return {"type": "string"}
47
+
48
+ primitive_map = {str: "string", int: "integer", float: "number", bool: "boolean"}
49
+ if python_type in primitive_map:
50
+ return {"type": primitive_map[python_type]}
51
+
52
+ if python_type is bytes:
53
+ return {"type": "string", "contentEncoding": "base64"}
54
+ if python_type is datetime:
55
+ return {"type": "string", "format": "date-time"}
56
+ if python_type is date:
57
+ return {"type": "string", "format": "date"}
58
+ if python_type is time:
59
+ return {"type": "string", "format": "time"}
60
+
61
+ if python_type is list:
62
+ return {"type": "array", "items": {"type": "string"}}
63
+ if python_type is dict:
64
+ return {"type": "object", "additionalProperties": {"type": "string"}}
65
+
66
+ if origin is _Literal:
67
+ literal_values = list(args)
68
+ schema_lit: dict[str, Any] = {"enum": literal_values}
69
+ if all(isinstance(v, bool) for v in literal_values):
70
+ schema_lit["type"] = "boolean"
71
+ elif all(isinstance(v, str) for v in literal_values):
72
+ schema_lit["type"] = "string"
73
+ elif all(isinstance(v, int) and not isinstance(v, bool) for v in literal_values):
74
+ schema_lit["type"] = "integer"
75
+ elif all(isinstance(v, int | float) and not isinstance(v, bool) for v in literal_values):
76
+ schema_lit["type"] = "number"
77
+ return schema_lit
78
+
79
+ if inspect.isclass(python_type) and issubclass(python_type, enum.Enum):
80
+ enum_values = [e.value for e in python_type]
81
+ value_types = {type(v) for v in enum_values}
82
+ schema: dict[str, Any] = {"enum": enum_values}
83
+ if value_types == {str}:
84
+ schema["type"] = "string"
85
+ elif value_types == {int}:
86
+ schema["type"] = "integer"
87
+ elif value_types <= {int, float}:
88
+ schema["type"] = "number"
89
+ elif value_types == {bool}:
90
+ schema["type"] = "boolean"
91
+ return schema
92
+
93
+ if _is_typeddict(python_type):
94
+ annotations: dict[str, Any] = getattr(python_type, "__annotations__", {}) or {}
95
+ required_keys = set(getattr(python_type, "__required_keys__", set()))
96
+ td_properties: dict[str, Any] = {}
97
+ td_required: list[str] = []
98
+ for field_name, field_type in annotations.items():
99
+ td_properties[field_name] = _python_type_to_json_schema(field_type)
100
+ if field_name in required_keys:
101
+ td_required.append(field_name)
102
+ schema_td: dict[str, Any] = {
103
+ "type": "object",
104
+ "properties": td_properties,
105
+ }
106
+ if td_required:
107
+ schema_td["required"] = td_required
108
+ return schema_td
109
+
110
+ if inspect.isclass(python_type) and dataclasses.is_dataclass(python_type):
111
+ type_hints = get_type_hints(python_type)
112
+ dc_properties: dict[str, Any] = {}
113
+ dc_required: list[str] = []
114
+ for field in dataclasses.fields(python_type):
115
+ field_type = type_hints.get(field.name, Any)
116
+ dc_properties[field.name] = _python_type_to_json_schema(field_type)
117
+ if (
118
+ field.default is dataclasses.MISSING
119
+ and getattr(field, "default_factory", dataclasses.MISSING) is dataclasses.MISSING
120
+ ):
121
+ dc_required.append(field.name)
122
+ schema_dc: dict[str, Any] = {"type": "object", "properties": dc_properties}
123
+ if dc_required:
124
+ schema_dc["required"] = dc_required
125
+ return schema_dc
126
+
127
+ if inspect.isclass(python_type) and issubclass(python_type, PydanticBaseModel):
128
+ model_type_hints = get_type_hints(python_type)
129
+ pd_properties: dict[str, Any] = {}
130
+ pd_required: list[str] = []
131
+ model_fields = getattr(python_type, "model_fields", {})
132
+ for name, field_info in model_fields.items():
133
+ pd_properties[name] = _python_type_to_json_schema(model_type_hints.get(name, Any))
134
+ is_required = getattr(field_info, "is_required", None)
135
+ if callable(is_required) and is_required():
136
+ pd_required.append(name)
137
+ schema_pd: dict[str, Any] = {"type": "object", "properties": pd_properties}
138
+ if pd_required:
139
+ schema_pd["required"] = pd_required
140
+ return schema_pd
141
+
142
+ if origin in (list, Sequence, set, frozenset):
143
+ item_type = args[0] if args else Any
144
+ item_schema = _python_type_to_json_schema(item_type)
145
+ schema_arr: dict[str, Any] = {"type": "array", "items": item_schema or {"type": "string"}}
146
+ if origin in (set, frozenset):
147
+ schema_arr["uniqueItems"] = True
148
+ return schema_arr
149
+ if origin is tuple:
150
+ if not args:
151
+ return {"type": "array", "items": {"type": "string"}}
152
+ if len(args) == 2 and args[1] is Ellipsis:
153
+ return {"type": "array", "items": _python_type_to_json_schema(args[0])}
154
+ prefix_items = [_python_type_to_json_schema(a) for a in args]
155
+ return {
156
+ "type": "array",
157
+ "prefixItems": prefix_items,
158
+ "minItems": len(prefix_items),
159
+ "maxItems": len(prefix_items),
160
+ }
161
+
162
+ if origin in (dict, Mapping):
163
+ value_type = args[1] if len(args) >= 2 else Any
164
+ value_schema = _python_type_to_json_schema(value_type)
165
+ return {"type": "object", "additionalProperties": value_schema or {"type": "string"}}
166
+
167
+ typing_union = getattr(__import__("typing"), "Union", None)
168
+ if origin in (typing_union, _types.UnionType):
169
+ non_none_args = [a for a in args if a is not type(None)]
170
+ if len(non_none_args) > 1:
171
+ schemas = [_python_type_to_json_schema(arg) for arg in non_none_args]
172
+ return {"oneOf": schemas}
173
+ if non_none_args:
174
+ return _python_type_to_json_schema(non_none_args[0])
175
+ return {"type": "string"}
176
+
177
+ return {"type": "string"}
178
+
179
+ def _parse_callable_properties(func: ToolFn) -> tuple[dict[str, dict[str, Any]], list[str]]:
180
+ sig = inspect.signature(func)
181
+ type_hints = get_type_hints(func)
182
+
183
+ properties: dict[str, dict[str, Any]] = {}
184
+ required: list[str] = []
185
+
186
+ for param_name, param in sig.parameters.items():
187
+ if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
188
+ continue
189
+
190
+ annotated_type = type_hints.get(param_name, str)
191
+ param_schema = _python_type_to_json_schema(annotated_type)
192
+
193
+ type_name = getattr(annotated_type, "__name__", str(annotated_type))
194
+ properties[param_name] = {
195
+ **param_schema,
196
+ "description": f"Parameter {param_name} of type {type_name}",
197
+ }
198
+
199
+ if param.default == inspect.Parameter.empty:
200
+ required.append(param_name)
201
+
202
+ return properties, required
203
+
204
+ def generate_tool_definition_from_callable(func: ToolFn) -> dict[str, Any]:
205
+ """Convert a Python callable to OpenAI tools format.
206
+
207
+ Args:
208
+ func: A Python callable (function) to convert to a tool
209
+
210
+ Returns:
211
+ Dictionary in OpenAI tools format
212
+
213
+ Raises:
214
+ ValueError: If the function doesn't have proper docstring or type annotations
215
+
216
+ Example:
217
+ >>> def get_weather(location: str, unit: str = "celsius") -> str:
218
+ ... '''Get weather information for a location.'''
219
+ ... return f"Weather in {location} is sunny, 25°{unit[0].upper()}"
220
+ >>>
221
+ >>> tool = generate_tool_definition_from_callable(get_weather)
222
+ >>> # Returns OpenAI tools format dict
223
+
224
+ """
225
+ if not func.__doc__:
226
+ msg = f"Function {func.__name__} must have a docstring"
227
+ raise ValueError(msg)
228
+
229
+ properties, required = _parse_callable_properties(func)
230
+ return {
231
+ "type": "function",
232
+ "function": {
233
+ "name": func.__name__,
234
+ "description": inspect.cleandoc(func.__doc__),
235
+ "parameters": {"type": "object", "properties": properties, "required": required},
236
+ },
237
+ }
238
+
239
+ def generate_tool_definition_from_tool_def(tool_def: ToolDef) -> dict[str, Any]:
240
+ """Convert a ToolDef to OpenAI tools format.
241
+
242
+ Args:
243
+ tool_def: A ToolDef to convert to a tool
244
+
245
+ Returns:
246
+ Dictionary in OpenAI tools format
247
+
248
+ Example:
249
+ >>> tool_def = ToolDef(
250
+ ... name="get_weather",
251
+ ... description="Get weather information for a location.",
252
+ ... execute=SomeFunction(),
253
+ ... )
254
+ >>> tool = generate_tool_definition_from_tool_def(tool_def)
255
+ >>> # Returns OpenAI tools format dict
256
+ """
257
+ properties, required = _parse_callable_properties(tool_def.execute)
258
+ return {
259
+ "type": "function",
260
+ "function": {
261
+ "name": tool_def.name,
262
+ "description": tool_def.description,
263
+ "parameters": (tool_def.parameters or
264
+ {"type": "object", "properties": properties, "required": required}),
265
+ },
266
+ }
267
+
268
+ def generate_tool_definition_from_raw_tool_def(raw_tool_def: RawToolDef) -> dict[str, Any]:
269
+ return {
270
+ "type": "function",
271
+ "function": raw_tool_def,
272
+ }
273
+
274
+ def prepare_tools(tools: Sequence[ToolLike]) -> list[dict]:
275
+ tool_defs = []
276
+ for tool in tools:
277
+ if callable(tool):
278
+ tool_defs.append(generate_tool_definition_from_callable(tool))
279
+ elif isinstance(tool, ToolDef):
280
+ tool_defs.append(generate_tool_definition_from_tool_def(tool))
281
+ else:
282
+ tool_defs.append(generate_tool_definition_from_raw_tool_def(tool))
283
+ return tool_defs
@@ -0,0 +1,18 @@
1
+ from .toolset import Toolset
2
+ from .python_toolset import PythonToolset, python_tool
3
+ from .mcp_toolset import (
4
+ McpToolset,
5
+ LocalMcpToolset,
6
+ RemoteMcpToolset,
7
+ )
8
+
9
+ __all__ = [
10
+ "Toolset",
11
+
12
+ "PythonToolset",
13
+ "python_tool",
14
+
15
+ "McpToolset",
16
+ "LocalMcpToolset",
17
+ "RemoteMcpToolset",
18
+ ]