dais-sdk 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dais_sdk/__init__.py +320 -0
- dais_sdk/debug.py +4 -0
- dais_sdk/logger.py +22 -0
- dais_sdk/mcp_client/__init__.py +15 -0
- dais_sdk/mcp_client/base_mcp_client.py +38 -0
- dais_sdk/mcp_client/local_mcp_client.py +55 -0
- dais_sdk/mcp_client/oauth_server.py +100 -0
- dais_sdk/mcp_client/remote_mcp_client.py +157 -0
- dais_sdk/param_parser.py +55 -0
- dais_sdk/stream.py +82 -0
- dais_sdk/tool/__init__.py +0 -0
- dais_sdk/tool/execute.py +65 -0
- dais_sdk/tool/prepare.py +283 -0
- dais_sdk/tool/toolset/__init__.py +18 -0
- dais_sdk/tool/toolset/mcp_toolset.py +94 -0
- dais_sdk/tool/toolset/python_toolset.py +31 -0
- dais_sdk/tool/toolset/toolset.py +13 -0
- dais_sdk/tool/utils.py +11 -0
- dais_sdk/types/__init__.py +20 -0
- dais_sdk/types/exceptions.py +27 -0
- dais_sdk/types/message.py +211 -0
- dais_sdk/types/request_params.py +55 -0
- dais_sdk/types/tool.py +47 -0
- dais_sdk-0.6.0.dist-info/METADATA +100 -0
- dais_sdk-0.6.0.dist-info/RECORD +27 -0
- dais_sdk-0.6.0.dist-info/WHEEL +4 -0
- dais_sdk-0.6.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import httpx
|
|
2
|
+
import webbrowser
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from contextlib import AsyncExitStack
|
|
5
|
+
from typing import Any, NamedTuple
|
|
6
|
+
from mcp import ClientSession
|
|
7
|
+
from mcp.client.auth import OAuthClientProvider
|
|
8
|
+
from mcp.client.streamable_http import streamable_http_client
|
|
9
|
+
from mcp.shared.auth import OAuthClientMetadata
|
|
10
|
+
from pydantic import AnyUrl, BaseModel, Field, ConfigDict, SkipValidation
|
|
11
|
+
from .oauth_server import LocalOAuthServer, OAuthCode, TokenStorage, InMemoryTokenStorage
|
|
12
|
+
from .base_mcp_client import McpClient, Tool, ToolResult, McpSessionNotEstablishedError
|
|
13
|
+
from ..logger import logger
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class OAuthParams:
|
|
17
|
+
oauth_scopes: list[str] | None = None
|
|
18
|
+
oauth_timeout: int = 120
|
|
19
|
+
oauth_token_storage: SkipValidation[TokenStorage] = Field(
|
|
20
|
+
default_factory=InMemoryTokenStorage,
|
|
21
|
+
exclude=True,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
class RemoteServerParams(BaseModel):
|
|
25
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
26
|
+
|
|
27
|
+
url: str
|
|
28
|
+
bearer_token: str | None = None
|
|
29
|
+
oauth_params: OAuthParams | None = None
|
|
30
|
+
http_headers: dict[str, str] | None = None
|
|
31
|
+
|
|
32
|
+
# --- --- --- --- --- ---
|
|
33
|
+
|
|
34
|
+
class OAuthContext(NamedTuple):
|
|
35
|
+
client: httpx.AsyncClient
|
|
36
|
+
server: LocalOAuthServer
|
|
37
|
+
|
|
38
|
+
class RemoteMcpClient(McpClient):
|
|
39
|
+
def __init__(self,
|
|
40
|
+
name: str,
|
|
41
|
+
params: RemoteServerParams,
|
|
42
|
+
storage: TokenStorage | None = None):
|
|
43
|
+
self._name = name
|
|
44
|
+
self._params = params
|
|
45
|
+
self._session: ClientSession | None = None
|
|
46
|
+
self._exit_stack: AsyncExitStack | None = None
|
|
47
|
+
self._oauth_context: OAuthContext | None = self._init_oauth()
|
|
48
|
+
if self._params.oauth_params is not None and storage is not None:
|
|
49
|
+
self._params.oauth_params.oauth_token_storage = storage
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def name(self) -> str:
|
|
53
|
+
return self._name
|
|
54
|
+
|
|
55
|
+
def _init_http_headers(self) -> dict[str, str] | None:
|
|
56
|
+
if self._params.http_headers is None and self._params.bearer_token is None:
|
|
57
|
+
return None
|
|
58
|
+
headers = {}
|
|
59
|
+
if self._params.http_headers is not None:
|
|
60
|
+
headers.update(self._params.http_headers)
|
|
61
|
+
if self._params.bearer_token is not None:
|
|
62
|
+
headers["Authorization"] = f"Bearer {self._params.bearer_token}"
|
|
63
|
+
return headers
|
|
64
|
+
|
|
65
|
+
def _init_oauth(self) -> OAuthContext | None:
|
|
66
|
+
if self._params.oauth_params is None:
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
oauth_params = self._params.oauth_params
|
|
70
|
+
|
|
71
|
+
server = LocalOAuthServer(timeout=oauth_params.oauth_timeout)
|
|
72
|
+
scopes = None
|
|
73
|
+
if oauth_params.oauth_scopes is not None:
|
|
74
|
+
scopes = " ".join(oauth_params.oauth_scopes)
|
|
75
|
+
|
|
76
|
+
client_provider = OAuthClientProvider(
|
|
77
|
+
server_url=self._params.url,
|
|
78
|
+
client_metadata=OAuthClientMetadata(
|
|
79
|
+
client_name=self._name,
|
|
80
|
+
redirect_uris=[AnyUrl(server.callback_url)],
|
|
81
|
+
grant_types=["authorization_code", "refresh_token"],
|
|
82
|
+
response_types=["code"],
|
|
83
|
+
scope=scopes,
|
|
84
|
+
token_endpoint_auth_method="none",
|
|
85
|
+
),
|
|
86
|
+
storage=oauth_params.oauth_token_storage,
|
|
87
|
+
redirect_handler=self._handle_redirect,
|
|
88
|
+
callback_handler=self._handle_oauth_callback,
|
|
89
|
+
)
|
|
90
|
+
client = httpx.AsyncClient(auth=client_provider,
|
|
91
|
+
headers=self._init_http_headers(),
|
|
92
|
+
follow_redirects=True)
|
|
93
|
+
return OAuthContext(client, server)
|
|
94
|
+
|
|
95
|
+
async def _handle_redirect(self, url: str) -> None:
|
|
96
|
+
logger.info("[OAuth] Authentication required, opening browser...")
|
|
97
|
+
logger.info(f"[OAuth] If browser does not open automatically, copy and open the following link: \n{url}\n")
|
|
98
|
+
try:
|
|
99
|
+
webbrowser.open(url)
|
|
100
|
+
except Exception as e:
|
|
101
|
+
logger.error(f"[OAuth] Not able to open browser: {e}")
|
|
102
|
+
|
|
103
|
+
async def _handle_oauth_callback(self) -> OAuthCode:
|
|
104
|
+
if self._oauth_context is None:
|
|
105
|
+
raise ValueError("OAuth context not initialized")
|
|
106
|
+
return await self._oauth_context.server.wait_for_code()
|
|
107
|
+
|
|
108
|
+
async def connect(self):
|
|
109
|
+
self._exit_stack = AsyncExitStack()
|
|
110
|
+
if self._oauth_context:
|
|
111
|
+
http_client = self._oauth_context.client
|
|
112
|
+
await self._oauth_context.server.start()
|
|
113
|
+
else:
|
|
114
|
+
http_client = await self._exit_stack.enter_async_context(
|
|
115
|
+
httpx.AsyncClient(headers=self._init_http_headers(), follow_redirects=True))
|
|
116
|
+
|
|
117
|
+
try:
|
|
118
|
+
read_stream, write_stream, _ = await self._exit_stack.enter_async_context(
|
|
119
|
+
streamable_http_client(self._params.url, http_client=http_client)
|
|
120
|
+
)
|
|
121
|
+
self._session = await self._exit_stack.enter_async_context(
|
|
122
|
+
ClientSession(read_stream, write_stream)
|
|
123
|
+
)
|
|
124
|
+
await self._session.initialize()
|
|
125
|
+
except Exception:
|
|
126
|
+
await self.disconnect()
|
|
127
|
+
raise
|
|
128
|
+
|
|
129
|
+
async def list_tools(self) -> list[Tool]:
|
|
130
|
+
if not self._session:
|
|
131
|
+
raise McpSessionNotEstablishedError()
|
|
132
|
+
|
|
133
|
+
result = await self._session.list_tools()
|
|
134
|
+
return result.tools
|
|
135
|
+
|
|
136
|
+
async def call_tool(
|
|
137
|
+
self, tool_name: str, arguments: dict[str, Any] | None = None
|
|
138
|
+
) -> ToolResult:
|
|
139
|
+
if not self._session:
|
|
140
|
+
raise McpSessionNotEstablishedError()
|
|
141
|
+
|
|
142
|
+
response = await self._session.call_tool(tool_name, arguments)
|
|
143
|
+
return ToolResult(response.isError, response.content)
|
|
144
|
+
|
|
145
|
+
async def disconnect(self):
|
|
146
|
+
try:
|
|
147
|
+
if self._exit_stack:
|
|
148
|
+
await self._exit_stack.aclose()
|
|
149
|
+
finally:
|
|
150
|
+
self._session = None
|
|
151
|
+
self._exit_stack = None
|
|
152
|
+
|
|
153
|
+
if self._oauth_context:
|
|
154
|
+
try: await self._oauth_context.client.aclose()
|
|
155
|
+
except Exception: pass
|
|
156
|
+
try: await self._oauth_context.server.stop()
|
|
157
|
+
except Exception: pass
|
dais_sdk/param_parser.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any, TYPE_CHECKING
|
|
3
|
+
from litellm.types.utils import LlmProviders
|
|
4
|
+
from .tool.prepare import prepare_tools
|
|
5
|
+
from .types.message import ToolMessage
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from .types.request_params import LlmRequestParams
|
|
9
|
+
|
|
10
|
+
ParsedParams = dict[str, Any]
|
|
11
|
+
|
|
12
|
+
class ParamParser:
|
|
13
|
+
def __init__(self,
|
|
14
|
+
provider: LlmProviders,
|
|
15
|
+
base_url: str,
|
|
16
|
+
api_key: str):
|
|
17
|
+
self._provider = provider
|
|
18
|
+
self._base_url = base_url
|
|
19
|
+
self._api_key = api_key
|
|
20
|
+
|
|
21
|
+
def _parse(self, params: LlmRequestParams) -> ParsedParams:
|
|
22
|
+
extracted_tool_likes = params.extract_tools()
|
|
23
|
+
tools = extracted_tool_likes and prepare_tools(extracted_tool_likes)
|
|
24
|
+
|
|
25
|
+
transformed_messages = []
|
|
26
|
+
for message in params.messages:
|
|
27
|
+
if (type(message) is ToolMessage and
|
|
28
|
+
message.result is None and
|
|
29
|
+
message.error is None):
|
|
30
|
+
# skip ToolMessage that is not resolved
|
|
31
|
+
continue
|
|
32
|
+
transformed_messages.append(message.to_litellm_message())
|
|
33
|
+
|
|
34
|
+
return {
|
|
35
|
+
"model": f"{self._provider.value}/{params.model}",
|
|
36
|
+
"messages": transformed_messages,
|
|
37
|
+
"base_url": self._base_url,
|
|
38
|
+
"api_key": self._api_key,
|
|
39
|
+
"tools": tools,
|
|
40
|
+
"tool_choice": params.tool_choice,
|
|
41
|
+
"timeout": params.timeout_sec,
|
|
42
|
+
"extra_headers": params.headers,
|
|
43
|
+
**(params.extra_args or {})
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
def parse_nonstream(self, params: LlmRequestParams) -> ParsedParams:
|
|
47
|
+
parsed = self._parse(params)
|
|
48
|
+
parsed["stream"] = False
|
|
49
|
+
return parsed
|
|
50
|
+
|
|
51
|
+
def parse_stream(self, params: LlmRequestParams) -> ParsedParams:
|
|
52
|
+
parsed = self._parse(params)
|
|
53
|
+
parsed["stream"] = True
|
|
54
|
+
parsed["stream_options"] = {"include_usage": True}
|
|
55
|
+
return parsed
|
dais_sdk/stream.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from litellm import ChatCompletionAssistantToolCall
|
|
3
|
+
from litellm.types.utils import (ChatCompletionDeltaToolCall,
|
|
4
|
+
ModelResponseStream as LiteLlmModelResponseStream)
|
|
5
|
+
from .types.message import AssistantMessage
|
|
6
|
+
|
|
7
|
+
@dataclasses.dataclass
|
|
8
|
+
class ToolCallTemp:
|
|
9
|
+
id: str | None = None
|
|
10
|
+
name: str = ""
|
|
11
|
+
arguments: str = ""
|
|
12
|
+
|
|
13
|
+
class ToolCallCollector:
|
|
14
|
+
def __init__(self):
|
|
15
|
+
self.tool_call_map: dict[int, ToolCallTemp] = {}
|
|
16
|
+
|
|
17
|
+
def collect(self, tool_call_chunk: ChatCompletionDeltaToolCall):
|
|
18
|
+
if tool_call_chunk.index not in self.tool_call_map:
|
|
19
|
+
self.tool_call_map[tool_call_chunk.index] = ToolCallTemp()
|
|
20
|
+
|
|
21
|
+
temp_tool_call = self.tool_call_map[tool_call_chunk.index]
|
|
22
|
+
if tool_call_chunk.get("id"):
|
|
23
|
+
temp_tool_call.id = tool_call_chunk.id
|
|
24
|
+
if tool_call_chunk.function.get("name"):
|
|
25
|
+
assert tool_call_chunk.function.name is not None
|
|
26
|
+
temp_tool_call.name += tool_call_chunk.function.name
|
|
27
|
+
if tool_call_chunk.function.get("arguments"):
|
|
28
|
+
assert tool_call_chunk.function.arguments is not None
|
|
29
|
+
temp_tool_call.arguments += tool_call_chunk.function.arguments
|
|
30
|
+
|
|
31
|
+
def get_tool_calls(self) -> list[ChatCompletionAssistantToolCall]:
|
|
32
|
+
return [{
|
|
33
|
+
"id": tool_call.id,
|
|
34
|
+
"function": {
|
|
35
|
+
"name": tool_call.name,
|
|
36
|
+
"arguments": tool_call.arguments,
|
|
37
|
+
},
|
|
38
|
+
"type": "function"
|
|
39
|
+
} for tool_call in self.tool_call_map.values()]
|
|
40
|
+
|
|
41
|
+
class AssistantMessageCollector:
|
|
42
|
+
def __init__(self):
|
|
43
|
+
self.message_buf = AssistantMessage(content=None)
|
|
44
|
+
self.tool_call_collector = ToolCallCollector()
|
|
45
|
+
|
|
46
|
+
def collect(self, chunk: LiteLlmModelResponseStream):
|
|
47
|
+
delta = chunk.choices[0].delta
|
|
48
|
+
if delta.get("content"):
|
|
49
|
+
assert delta.content is not None
|
|
50
|
+
if self.message_buf.content is None:
|
|
51
|
+
self.message_buf.content = ""
|
|
52
|
+
self.message_buf.content += delta.content
|
|
53
|
+
|
|
54
|
+
if delta.get("reasoning_content"):
|
|
55
|
+
assert delta.reasoning_content is not None
|
|
56
|
+
if self.message_buf.reasoning_content is None:
|
|
57
|
+
self.message_buf.reasoning_content = ""
|
|
58
|
+
self.message_buf.reasoning_content += delta.reasoning_content
|
|
59
|
+
|
|
60
|
+
if delta.get("tool_calls"):
|
|
61
|
+
assert delta.tool_calls is not None
|
|
62
|
+
for tool_call_chunk in delta.tool_calls:
|
|
63
|
+
self.tool_call_collector.collect(tool_call_chunk)
|
|
64
|
+
|
|
65
|
+
if delta.get("images"):
|
|
66
|
+
assert delta.images is not None
|
|
67
|
+
if self.message_buf.images is None:
|
|
68
|
+
self.message_buf.images = []
|
|
69
|
+
self.message_buf.images = delta.images
|
|
70
|
+
|
|
71
|
+
if delta.get("audio"):
|
|
72
|
+
assert delta.audio is not None
|
|
73
|
+
self.message_buf.audio = delta.audio
|
|
74
|
+
|
|
75
|
+
def get_message(self) -> AssistantMessage:
|
|
76
|
+
self.message_buf.tool_calls = self.tool_call_collector.get_tool_calls()
|
|
77
|
+
return self.message_buf
|
|
78
|
+
|
|
79
|
+
def clear(self):
|
|
80
|
+
"""
|
|
81
|
+
This class will be created for each stream, so we don't need to clear it.
|
|
82
|
+
"""
|
|
File without changes
|
dais_sdk/tool/execute.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
from functools import singledispatch
|
|
4
|
+
from typing import Any, Awaitable, Callable, cast
|
|
5
|
+
from types import FunctionType, MethodType, CoroutineType
|
|
6
|
+
from ..types.tool import ToolDef, ToolLike
|
|
7
|
+
|
|
8
|
+
async def _coroutine_wrapper(awaitable: Awaitable[Any]) -> CoroutineType:
|
|
9
|
+
return await awaitable
|
|
10
|
+
|
|
11
|
+
def _arguments_normalizer(arguments: str | dict) -> dict:
|
|
12
|
+
if isinstance(arguments, str):
|
|
13
|
+
parsed = json.loads(arguments)
|
|
14
|
+
return cast(dict, parsed)
|
|
15
|
+
elif isinstance(arguments, dict):
|
|
16
|
+
return arguments
|
|
17
|
+
else:
|
|
18
|
+
raise ValueError(f"Invalid arguments type: {type(arguments)}")
|
|
19
|
+
|
|
20
|
+
def _result_normalizer(result: Any) -> str:
|
|
21
|
+
if isinstance(result, str):
|
|
22
|
+
return result
|
|
23
|
+
return json.dumps(result, ensure_ascii=False)
|
|
24
|
+
|
|
25
|
+
@singledispatch
|
|
26
|
+
def execute_tool_sync(tool: ToolLike, arguments: str | dict) -> str:
|
|
27
|
+
raise ValueError(f"Invalid tool type: {type(tool)}")
|
|
28
|
+
|
|
29
|
+
@execute_tool_sync.register(FunctionType)
|
|
30
|
+
@execute_tool_sync.register(MethodType)
|
|
31
|
+
def _(toolfn: Callable, arguments: str | dict) -> str:
|
|
32
|
+
arguments = _arguments_normalizer(arguments)
|
|
33
|
+
result = (asyncio.run(_coroutine_wrapper(toolfn(**arguments)))
|
|
34
|
+
if asyncio.iscoroutinefunction(toolfn)
|
|
35
|
+
else toolfn(**arguments))
|
|
36
|
+
return _result_normalizer(result)
|
|
37
|
+
|
|
38
|
+
@execute_tool_sync.register(ToolDef)
|
|
39
|
+
def _(tooldef: ToolDef, arguments: str | dict) -> str:
|
|
40
|
+
arguments = _arguments_normalizer(arguments)
|
|
41
|
+
result = (asyncio.run(_coroutine_wrapper(tooldef.execute(**arguments)))
|
|
42
|
+
if asyncio.iscoroutinefunction(tooldef.execute)
|
|
43
|
+
else tooldef.execute(**arguments))
|
|
44
|
+
return _result_normalizer(result)
|
|
45
|
+
|
|
46
|
+
@singledispatch
|
|
47
|
+
async def execute_tool(tool: ToolLike, arguments: str | dict) -> str:
|
|
48
|
+
raise ValueError(f"Invalid tool type: {type(tool)}")
|
|
49
|
+
|
|
50
|
+
@execute_tool.register(FunctionType)
|
|
51
|
+
@execute_tool.register(MethodType)
|
|
52
|
+
async def _(toolfn: Callable, arguments: str | dict) -> str:
|
|
53
|
+
arguments = _arguments_normalizer(arguments)
|
|
54
|
+
result = (await toolfn(**arguments)
|
|
55
|
+
if asyncio.iscoroutinefunction(toolfn)
|
|
56
|
+
else toolfn(**arguments))
|
|
57
|
+
return _result_normalizer(result)
|
|
58
|
+
|
|
59
|
+
@execute_tool.register(ToolDef)
|
|
60
|
+
async def _(tooldef: ToolDef, arguments: str | dict) -> str:
|
|
61
|
+
arguments = _arguments_normalizer(arguments)
|
|
62
|
+
result = (await tooldef.execute(**arguments)
|
|
63
|
+
if asyncio.iscoroutinefunction(tooldef.execute)
|
|
64
|
+
else tooldef.execute(**arguments))
|
|
65
|
+
return _result_normalizer(result)
|
dais_sdk/tool/prepare.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
"""
|
|
2
|
+
source: https://github.com/mozilla-ai/any-llm/blob/main/src/any_llm/tools.py
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import dataclasses
|
|
6
|
+
import enum
|
|
7
|
+
import inspect
|
|
8
|
+
import types as _types
|
|
9
|
+
from collections.abc import Mapping, Sequence
|
|
10
|
+
from datetime import date, datetime, time
|
|
11
|
+
from typing import (Annotated as _Annotated, Literal as _Literal,
|
|
12
|
+
is_typeddict as _is_typeddict, Any, get_args,
|
|
13
|
+
get_origin, get_type_hints)
|
|
14
|
+
from pydantic import BaseModel as PydanticBaseModel
|
|
15
|
+
from ..types.tool import ToolFn, ToolDef, RawToolDef, ToolLike
|
|
16
|
+
|
|
17
|
+
def _python_type_to_json_schema(python_type: Any) -> dict[str, Any]:
|
|
18
|
+
"""Convert Python type annotation to a JSON Schema for a parameter.
|
|
19
|
+
|
|
20
|
+
Supported mappings (subset tailored for LLM tool schemas):
|
|
21
|
+
- Primitives: str/int/float/bool -> string/integer/number/boolean
|
|
22
|
+
- bytes -> string with contentEncoding base64
|
|
23
|
+
- datetime/date/time -> string with format date-time/date/time
|
|
24
|
+
- list[T] / Sequence[T] / set[T] / frozenset[T] -> array with items=schema(T)
|
|
25
|
+
- set/frozenset include uniqueItems=true
|
|
26
|
+
- list without type args defaults items to string
|
|
27
|
+
- dict[K,V] / Mapping[K,V] -> object with additionalProperties=schema(V)
|
|
28
|
+
- dict without type args defaults additionalProperties to string
|
|
29
|
+
- tuple[T1, T2, ...] -> array with prefixItems per element and min/maxItems
|
|
30
|
+
- tuple[T, ...] -> array with items=schema(T)
|
|
31
|
+
- Union[X, Y] and X | Y -> oneOf=[schema(X), schema(Y)] (without top-level type)
|
|
32
|
+
- Optional[T] (Union[T, None]) -> schema(T) (nullability not encoded)
|
|
33
|
+
- Literal[...]/Enum -> enum with appropriate type inference when uniform
|
|
34
|
+
- TypedDict -> object with properties/required per annotations
|
|
35
|
+
- dataclass/Pydantic BaseModel -> object with nested properties inferred from fields
|
|
36
|
+
"""
|
|
37
|
+
origin = get_origin(python_type)
|
|
38
|
+
args = get_args(python_type)
|
|
39
|
+
|
|
40
|
+
if _Annotated is not None and origin is _Annotated and len(args) >= 1:
|
|
41
|
+
python_type = args[0]
|
|
42
|
+
origin = get_origin(python_type)
|
|
43
|
+
args = get_args(python_type)
|
|
44
|
+
|
|
45
|
+
if python_type is Any:
|
|
46
|
+
return {"type": "string"}
|
|
47
|
+
|
|
48
|
+
primitive_map = {str: "string", int: "integer", float: "number", bool: "boolean"}
|
|
49
|
+
if python_type in primitive_map:
|
|
50
|
+
return {"type": primitive_map[python_type]}
|
|
51
|
+
|
|
52
|
+
if python_type is bytes:
|
|
53
|
+
return {"type": "string", "contentEncoding": "base64"}
|
|
54
|
+
if python_type is datetime:
|
|
55
|
+
return {"type": "string", "format": "date-time"}
|
|
56
|
+
if python_type is date:
|
|
57
|
+
return {"type": "string", "format": "date"}
|
|
58
|
+
if python_type is time:
|
|
59
|
+
return {"type": "string", "format": "time"}
|
|
60
|
+
|
|
61
|
+
if python_type is list:
|
|
62
|
+
return {"type": "array", "items": {"type": "string"}}
|
|
63
|
+
if python_type is dict:
|
|
64
|
+
return {"type": "object", "additionalProperties": {"type": "string"}}
|
|
65
|
+
|
|
66
|
+
if origin is _Literal:
|
|
67
|
+
literal_values = list(args)
|
|
68
|
+
schema_lit: dict[str, Any] = {"enum": literal_values}
|
|
69
|
+
if all(isinstance(v, bool) for v in literal_values):
|
|
70
|
+
schema_lit["type"] = "boolean"
|
|
71
|
+
elif all(isinstance(v, str) for v in literal_values):
|
|
72
|
+
schema_lit["type"] = "string"
|
|
73
|
+
elif all(isinstance(v, int) and not isinstance(v, bool) for v in literal_values):
|
|
74
|
+
schema_lit["type"] = "integer"
|
|
75
|
+
elif all(isinstance(v, int | float) and not isinstance(v, bool) for v in literal_values):
|
|
76
|
+
schema_lit["type"] = "number"
|
|
77
|
+
return schema_lit
|
|
78
|
+
|
|
79
|
+
if inspect.isclass(python_type) and issubclass(python_type, enum.Enum):
|
|
80
|
+
enum_values = [e.value for e in python_type]
|
|
81
|
+
value_types = {type(v) for v in enum_values}
|
|
82
|
+
schema: dict[str, Any] = {"enum": enum_values}
|
|
83
|
+
if value_types == {str}:
|
|
84
|
+
schema["type"] = "string"
|
|
85
|
+
elif value_types == {int}:
|
|
86
|
+
schema["type"] = "integer"
|
|
87
|
+
elif value_types <= {int, float}:
|
|
88
|
+
schema["type"] = "number"
|
|
89
|
+
elif value_types == {bool}:
|
|
90
|
+
schema["type"] = "boolean"
|
|
91
|
+
return schema
|
|
92
|
+
|
|
93
|
+
if _is_typeddict(python_type):
|
|
94
|
+
annotations: dict[str, Any] = getattr(python_type, "__annotations__", {}) or {}
|
|
95
|
+
required_keys = set(getattr(python_type, "__required_keys__", set()))
|
|
96
|
+
td_properties: dict[str, Any] = {}
|
|
97
|
+
td_required: list[str] = []
|
|
98
|
+
for field_name, field_type in annotations.items():
|
|
99
|
+
td_properties[field_name] = _python_type_to_json_schema(field_type)
|
|
100
|
+
if field_name in required_keys:
|
|
101
|
+
td_required.append(field_name)
|
|
102
|
+
schema_td: dict[str, Any] = {
|
|
103
|
+
"type": "object",
|
|
104
|
+
"properties": td_properties,
|
|
105
|
+
}
|
|
106
|
+
if td_required:
|
|
107
|
+
schema_td["required"] = td_required
|
|
108
|
+
return schema_td
|
|
109
|
+
|
|
110
|
+
if inspect.isclass(python_type) and dataclasses.is_dataclass(python_type):
|
|
111
|
+
type_hints = get_type_hints(python_type)
|
|
112
|
+
dc_properties: dict[str, Any] = {}
|
|
113
|
+
dc_required: list[str] = []
|
|
114
|
+
for field in dataclasses.fields(python_type):
|
|
115
|
+
field_type = type_hints.get(field.name, Any)
|
|
116
|
+
dc_properties[field.name] = _python_type_to_json_schema(field_type)
|
|
117
|
+
if (
|
|
118
|
+
field.default is dataclasses.MISSING
|
|
119
|
+
and getattr(field, "default_factory", dataclasses.MISSING) is dataclasses.MISSING
|
|
120
|
+
):
|
|
121
|
+
dc_required.append(field.name)
|
|
122
|
+
schema_dc: dict[str, Any] = {"type": "object", "properties": dc_properties}
|
|
123
|
+
if dc_required:
|
|
124
|
+
schema_dc["required"] = dc_required
|
|
125
|
+
return schema_dc
|
|
126
|
+
|
|
127
|
+
if inspect.isclass(python_type) and issubclass(python_type, PydanticBaseModel):
|
|
128
|
+
model_type_hints = get_type_hints(python_type)
|
|
129
|
+
pd_properties: dict[str, Any] = {}
|
|
130
|
+
pd_required: list[str] = []
|
|
131
|
+
model_fields = getattr(python_type, "model_fields", {})
|
|
132
|
+
for name, field_info in model_fields.items():
|
|
133
|
+
pd_properties[name] = _python_type_to_json_schema(model_type_hints.get(name, Any))
|
|
134
|
+
is_required = getattr(field_info, "is_required", None)
|
|
135
|
+
if callable(is_required) and is_required():
|
|
136
|
+
pd_required.append(name)
|
|
137
|
+
schema_pd: dict[str, Any] = {"type": "object", "properties": pd_properties}
|
|
138
|
+
if pd_required:
|
|
139
|
+
schema_pd["required"] = pd_required
|
|
140
|
+
return schema_pd
|
|
141
|
+
|
|
142
|
+
if origin in (list, Sequence, set, frozenset):
|
|
143
|
+
item_type = args[0] if args else Any
|
|
144
|
+
item_schema = _python_type_to_json_schema(item_type)
|
|
145
|
+
schema_arr: dict[str, Any] = {"type": "array", "items": item_schema or {"type": "string"}}
|
|
146
|
+
if origin in (set, frozenset):
|
|
147
|
+
schema_arr["uniqueItems"] = True
|
|
148
|
+
return schema_arr
|
|
149
|
+
if origin is tuple:
|
|
150
|
+
if not args:
|
|
151
|
+
return {"type": "array", "items": {"type": "string"}}
|
|
152
|
+
if len(args) == 2 and args[1] is Ellipsis:
|
|
153
|
+
return {"type": "array", "items": _python_type_to_json_schema(args[0])}
|
|
154
|
+
prefix_items = [_python_type_to_json_schema(a) for a in args]
|
|
155
|
+
return {
|
|
156
|
+
"type": "array",
|
|
157
|
+
"prefixItems": prefix_items,
|
|
158
|
+
"minItems": len(prefix_items),
|
|
159
|
+
"maxItems": len(prefix_items),
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
if origin in (dict, Mapping):
|
|
163
|
+
value_type = args[1] if len(args) >= 2 else Any
|
|
164
|
+
value_schema = _python_type_to_json_schema(value_type)
|
|
165
|
+
return {"type": "object", "additionalProperties": value_schema or {"type": "string"}}
|
|
166
|
+
|
|
167
|
+
typing_union = getattr(__import__("typing"), "Union", None)
|
|
168
|
+
if origin in (typing_union, _types.UnionType):
|
|
169
|
+
non_none_args = [a for a in args if a is not type(None)]
|
|
170
|
+
if len(non_none_args) > 1:
|
|
171
|
+
schemas = [_python_type_to_json_schema(arg) for arg in non_none_args]
|
|
172
|
+
return {"oneOf": schemas}
|
|
173
|
+
if non_none_args:
|
|
174
|
+
return _python_type_to_json_schema(non_none_args[0])
|
|
175
|
+
return {"type": "string"}
|
|
176
|
+
|
|
177
|
+
return {"type": "string"}
|
|
178
|
+
|
|
179
|
+
def _parse_callable_properties(func: ToolFn) -> tuple[dict[str, dict[str, Any]], list[str]]:
|
|
180
|
+
sig = inspect.signature(func)
|
|
181
|
+
type_hints = get_type_hints(func)
|
|
182
|
+
|
|
183
|
+
properties: dict[str, dict[str, Any]] = {}
|
|
184
|
+
required: list[str] = []
|
|
185
|
+
|
|
186
|
+
for param_name, param in sig.parameters.items():
|
|
187
|
+
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
annotated_type = type_hints.get(param_name, str)
|
|
191
|
+
param_schema = _python_type_to_json_schema(annotated_type)
|
|
192
|
+
|
|
193
|
+
type_name = getattr(annotated_type, "__name__", str(annotated_type))
|
|
194
|
+
properties[param_name] = {
|
|
195
|
+
**param_schema,
|
|
196
|
+
"description": f"Parameter {param_name} of type {type_name}",
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
if param.default == inspect.Parameter.empty:
|
|
200
|
+
required.append(param_name)
|
|
201
|
+
|
|
202
|
+
return properties, required
|
|
203
|
+
|
|
204
|
+
def generate_tool_definition_from_callable(func: ToolFn) -> dict[str, Any]:
|
|
205
|
+
"""Convert a Python callable to OpenAI tools format.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
func: A Python callable (function) to convert to a tool
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Dictionary in OpenAI tools format
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
ValueError: If the function doesn't have proper docstring or type annotations
|
|
215
|
+
|
|
216
|
+
Example:
|
|
217
|
+
>>> def get_weather(location: str, unit: str = "celsius") -> str:
|
|
218
|
+
... '''Get weather information for a location.'''
|
|
219
|
+
... return f"Weather in {location} is sunny, 25°{unit[0].upper()}"
|
|
220
|
+
>>>
|
|
221
|
+
>>> tool = generate_tool_definition_from_callable(get_weather)
|
|
222
|
+
>>> # Returns OpenAI tools format dict
|
|
223
|
+
|
|
224
|
+
"""
|
|
225
|
+
if not func.__doc__:
|
|
226
|
+
msg = f"Function {func.__name__} must have a docstring"
|
|
227
|
+
raise ValueError(msg)
|
|
228
|
+
|
|
229
|
+
properties, required = _parse_callable_properties(func)
|
|
230
|
+
return {
|
|
231
|
+
"type": "function",
|
|
232
|
+
"function": {
|
|
233
|
+
"name": func.__name__,
|
|
234
|
+
"description": inspect.cleandoc(func.__doc__),
|
|
235
|
+
"parameters": {"type": "object", "properties": properties, "required": required},
|
|
236
|
+
},
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
def generate_tool_definition_from_tool_def(tool_def: ToolDef) -> dict[str, Any]:
|
|
240
|
+
"""Convert a ToolDef to OpenAI tools format.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
tool_def: A ToolDef to convert to a tool
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Dictionary in OpenAI tools format
|
|
247
|
+
|
|
248
|
+
Example:
|
|
249
|
+
>>> tool_def = ToolDef(
|
|
250
|
+
... name="get_weather",
|
|
251
|
+
... description="Get weather information for a location.",
|
|
252
|
+
... execute=SomeFunction(),
|
|
253
|
+
... )
|
|
254
|
+
>>> tool = generate_tool_definition_from_tool_def(tool_def)
|
|
255
|
+
>>> # Returns OpenAI tools format dict
|
|
256
|
+
"""
|
|
257
|
+
properties, required = _parse_callable_properties(tool_def.execute)
|
|
258
|
+
return {
|
|
259
|
+
"type": "function",
|
|
260
|
+
"function": {
|
|
261
|
+
"name": tool_def.name,
|
|
262
|
+
"description": tool_def.description,
|
|
263
|
+
"parameters": (tool_def.parameters or
|
|
264
|
+
{"type": "object", "properties": properties, "required": required}),
|
|
265
|
+
},
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
def generate_tool_definition_from_raw_tool_def(raw_tool_def: RawToolDef) -> dict[str, Any]:
|
|
269
|
+
return {
|
|
270
|
+
"type": "function",
|
|
271
|
+
"function": raw_tool_def,
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
def prepare_tools(tools: Sequence[ToolLike]) -> list[dict]:
|
|
275
|
+
tool_defs = []
|
|
276
|
+
for tool in tools:
|
|
277
|
+
if callable(tool):
|
|
278
|
+
tool_defs.append(generate_tool_definition_from_callable(tool))
|
|
279
|
+
elif isinstance(tool, ToolDef):
|
|
280
|
+
tool_defs.append(generate_tool_definition_from_tool_def(tool))
|
|
281
|
+
else:
|
|
282
|
+
tool_defs.append(generate_tool_definition_from_raw_tool_def(tool))
|
|
283
|
+
return tool_defs
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from .toolset import Toolset
|
|
2
|
+
from .python_toolset import PythonToolset, python_tool
|
|
3
|
+
from .mcp_toolset import (
|
|
4
|
+
McpToolset,
|
|
5
|
+
LocalMcpToolset,
|
|
6
|
+
RemoteMcpToolset,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Toolset",
|
|
11
|
+
|
|
12
|
+
"PythonToolset",
|
|
13
|
+
"python_tool",
|
|
14
|
+
|
|
15
|
+
"McpToolset",
|
|
16
|
+
"LocalMcpToolset",
|
|
17
|
+
"RemoteMcpToolset",
|
|
18
|
+
]
|