dais-sdk 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dais_sdk/__init__.py +320 -0
- dais_sdk/debug.py +4 -0
- dais_sdk/logger.py +22 -0
- dais_sdk/mcp_client/__init__.py +15 -0
- dais_sdk/mcp_client/base_mcp_client.py +38 -0
- dais_sdk/mcp_client/local_mcp_client.py +55 -0
- dais_sdk/mcp_client/oauth_server.py +100 -0
- dais_sdk/mcp_client/remote_mcp_client.py +157 -0
- dais_sdk/param_parser.py +55 -0
- dais_sdk/stream.py +82 -0
- dais_sdk/tool/__init__.py +0 -0
- dais_sdk/tool/execute.py +65 -0
- dais_sdk/tool/prepare.py +283 -0
- dais_sdk/tool/toolset/__init__.py +18 -0
- dais_sdk/tool/toolset/mcp_toolset.py +94 -0
- dais_sdk/tool/toolset/python_toolset.py +31 -0
- dais_sdk/tool/toolset/toolset.py +13 -0
- dais_sdk/tool/utils.py +11 -0
- dais_sdk/types/__init__.py +20 -0
- dais_sdk/types/exceptions.py +27 -0
- dais_sdk/types/message.py +211 -0
- dais_sdk/types/request_params.py +55 -0
- dais_sdk/types/tool.py +47 -0
- dais_sdk-0.6.0.dist-info/METADATA +100 -0
- dais_sdk-0.6.0.dist-info/RECORD +27 -0
- dais_sdk-0.6.0.dist-info/WHEEL +4 -0
- dais_sdk-0.6.0.dist-info/licenses/LICENSE +21 -0
dais_sdk/__init__.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import queue
|
|
3
|
+
from typing import cast
|
|
4
|
+
from collections.abc import AsyncGenerator, Generator
|
|
5
|
+
from litellm import CustomStreamWrapper, completion, acompletion
|
|
6
|
+
from litellm.utils import ProviderConfigManager
|
|
7
|
+
from litellm.types.utils import (
|
|
8
|
+
LlmProviders,
|
|
9
|
+
ModelResponse as LiteLlmModelResponse,
|
|
10
|
+
ModelResponseStream as LiteLlmModelResponseStream
|
|
11
|
+
)
|
|
12
|
+
from .debug import enable_debugging
|
|
13
|
+
from .param_parser import ParamParser
|
|
14
|
+
from .stream import AssistantMessageCollector
|
|
15
|
+
from .tool.execute import execute_tool_sync, execute_tool
|
|
16
|
+
from .tool.toolset import (
|
|
17
|
+
Toolset,
|
|
18
|
+
python_tool,
|
|
19
|
+
PythonToolset,
|
|
20
|
+
McpToolset,
|
|
21
|
+
LocalMcpToolset,
|
|
22
|
+
RemoteMcpToolset,
|
|
23
|
+
)
|
|
24
|
+
from .tool.utils import find_tool_by_name
|
|
25
|
+
from .mcp_client import (
|
|
26
|
+
McpClient,
|
|
27
|
+
McpTool,
|
|
28
|
+
McpToolResult,
|
|
29
|
+
LocalMcpClient,
|
|
30
|
+
RemoteMcpClient,
|
|
31
|
+
LocalServerParams,
|
|
32
|
+
RemoteServerParams,
|
|
33
|
+
OAuthParams,
|
|
34
|
+
)
|
|
35
|
+
from .types import (
|
|
36
|
+
GenerateTextResponse,
|
|
37
|
+
StreamTextResponseSync, StreamTextResponseAsync,
|
|
38
|
+
FullMessageQueueSync, FullMessageQueueAsync,
|
|
39
|
+
)
|
|
40
|
+
from .types.request_params import LlmRequestParams
|
|
41
|
+
from .types.tool import ToolFn, ToolDef, RawToolDef, ToolLike
|
|
42
|
+
from .types.exceptions import (
|
|
43
|
+
AuthenticationError,
|
|
44
|
+
PermissionDeniedError,
|
|
45
|
+
RateLimitError,
|
|
46
|
+
ContextWindowExceededError,
|
|
47
|
+
BadRequestError,
|
|
48
|
+
InvalidRequestError,
|
|
49
|
+
InternalServerError,
|
|
50
|
+
ServiceUnavailableError,
|
|
51
|
+
ContentPolicyViolationError,
|
|
52
|
+
APIError,
|
|
53
|
+
Timeout,
|
|
54
|
+
)
|
|
55
|
+
from .types.message import (
|
|
56
|
+
ChatMessage, UserMessage, SystemMessage, AssistantMessage, ToolMessage,
|
|
57
|
+
MessageChunk, TextChunk, UsageChunk, ReasoningChunk, AudioChunk, ImageChunk, ToolCallChunk,
|
|
58
|
+
openai_chunk_normalizer
|
|
59
|
+
)
|
|
60
|
+
from .logger import logger, enable_logging
|
|
61
|
+
|
|
62
|
+
class LLM:
|
|
63
|
+
"""
|
|
64
|
+
The `generate_text` and `stream_text` API will return ToolMessage in the returned sequence
|
|
65
|
+
only if `params.execute_tools` is True.
|
|
66
|
+
|
|
67
|
+
- - -
|
|
68
|
+
|
|
69
|
+
Possible exceptions raises for `generate_text` and `stream_text`:
|
|
70
|
+
- AuthenticationError
|
|
71
|
+
- PermissionDeniedError
|
|
72
|
+
- RateLimitError
|
|
73
|
+
- ContextWindowExceededError
|
|
74
|
+
- BadRequestError
|
|
75
|
+
- InvalidRequestError
|
|
76
|
+
- InternalServerError
|
|
77
|
+
- ServiceUnavailableError
|
|
78
|
+
- ContentPolicyViolationError
|
|
79
|
+
- APIError
|
|
80
|
+
- Timeout
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self,
|
|
84
|
+
provider: LlmProviders,
|
|
85
|
+
base_url: str,
|
|
86
|
+
api_key: str):
|
|
87
|
+
self.provider = provider
|
|
88
|
+
self.base_url = base_url
|
|
89
|
+
self.api_key = api_key
|
|
90
|
+
self._param_parser = ParamParser(self.provider, self.base_url, self.api_key)
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
async def execute_tool_call(
|
|
94
|
+
params: LlmRequestParams,
|
|
95
|
+
incomplete_tool_message: ToolMessage
|
|
96
|
+
) -> tuple[str | None, str | None]:
|
|
97
|
+
"""
|
|
98
|
+
Receive incomplete tool messages, execute the tool calls and
|
|
99
|
+
return the result and error tuple.
|
|
100
|
+
"""
|
|
101
|
+
name, arguments = incomplete_tool_message.name, incomplete_tool_message.arguments
|
|
102
|
+
tool_def = params.find_tool(incomplete_tool_message.name)
|
|
103
|
+
if tool_def is None:
|
|
104
|
+
raise LlmRequestParams.ToolDoesNotExistError(name)
|
|
105
|
+
|
|
106
|
+
result, error = None, None
|
|
107
|
+
try:
|
|
108
|
+
result = await execute_tool(tool_def, arguments)
|
|
109
|
+
except Exception as e:
|
|
110
|
+
error = f"{type(e).__name__}: {str(e)}"
|
|
111
|
+
return result, error
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def execute_tool_call_sync(
|
|
115
|
+
params: LlmRequestParams,
|
|
116
|
+
incomplete_tool_message: ToolMessage
|
|
117
|
+
) -> tuple[str | None, str | None]:
|
|
118
|
+
"""
|
|
119
|
+
Synchronous version of `execute_tool_call`.
|
|
120
|
+
"""
|
|
121
|
+
name, arguments = incomplete_tool_message.name, incomplete_tool_message.arguments
|
|
122
|
+
tool_def = params.find_tool(incomplete_tool_message.name)
|
|
123
|
+
if tool_def is None:
|
|
124
|
+
raise LlmRequestParams.ToolDoesNotExistError(name)
|
|
125
|
+
|
|
126
|
+
result, error = None, None
|
|
127
|
+
try:
|
|
128
|
+
result = execute_tool_sync(tool_def, arguments)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
error = f"{type(e).__name__}: {str(e)}"
|
|
131
|
+
return result, error
|
|
132
|
+
|
|
133
|
+
def _resolve_tool_calls_sync(self, params: LlmRequestParams, assistant_message: AssistantMessage) -> Generator[ToolMessage]:
|
|
134
|
+
if not params.execute_tools: return
|
|
135
|
+
if (incomplete_tool_messages
|
|
136
|
+
:= assistant_message.get_incomplete_tool_messages()) is None:
|
|
137
|
+
return
|
|
138
|
+
for incomplete_tool_message in incomplete_tool_messages:
|
|
139
|
+
try:
|
|
140
|
+
result, error = LLM.execute_tool_call_sync(params, incomplete_tool_message)
|
|
141
|
+
except LlmRequestParams.ToolDoesNotExistError as e:
|
|
142
|
+
logger.warning(f"{e.message} Skipping this tool call.")
|
|
143
|
+
continue
|
|
144
|
+
yield ToolMessage(
|
|
145
|
+
tool_call_id=incomplete_tool_message.tool_call_id,
|
|
146
|
+
name=incomplete_tool_message.name,
|
|
147
|
+
arguments=incomplete_tool_message.arguments,
|
|
148
|
+
result=result,
|
|
149
|
+
error=error)
|
|
150
|
+
|
|
151
|
+
async def _resolve_tool_calls(self, params: LlmRequestParams, assistant_message: AssistantMessage) -> AsyncGenerator[ToolMessage]:
|
|
152
|
+
if not params.execute_tools: return
|
|
153
|
+
if (incomplete_tool_messages :=
|
|
154
|
+
assistant_message.get_incomplete_tool_messages()) is None:
|
|
155
|
+
return
|
|
156
|
+
for incomplete_tool_message in incomplete_tool_messages:
|
|
157
|
+
try:
|
|
158
|
+
result, error = await LLM.execute_tool_call(params, incomplete_tool_message)
|
|
159
|
+
except LlmRequestParams.ToolDoesNotExistError as e:
|
|
160
|
+
logger.warning(f"{e.message} Skipping this tool call.")
|
|
161
|
+
continue
|
|
162
|
+
yield ToolMessage(
|
|
163
|
+
tool_call_id=incomplete_tool_message.tool_call_id,
|
|
164
|
+
name=incomplete_tool_message.name,
|
|
165
|
+
arguments=incomplete_tool_message.arguments,
|
|
166
|
+
result=result,
|
|
167
|
+
error=error)
|
|
168
|
+
|
|
169
|
+
def list_models(self) -> list[str]:
|
|
170
|
+
provider_config = ProviderConfigManager.get_provider_model_info(
|
|
171
|
+
model=None,
|
|
172
|
+
provider=self.provider,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if provider_config is None:
|
|
176
|
+
raise ValueError(f"The '{self.provider}' provider is not supported to list models.")
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
models = provider_config.get_models(
|
|
180
|
+
api_key=self.api_key,
|
|
181
|
+
api_base=self.base_url
|
|
182
|
+
)
|
|
183
|
+
except Exception as e:
|
|
184
|
+
raise e
|
|
185
|
+
return models
|
|
186
|
+
|
|
187
|
+
def generate_text_sync(self, params: LlmRequestParams) -> GenerateTextResponse:
|
|
188
|
+
response = completion(**self._param_parser.parse_nonstream(params))
|
|
189
|
+
response = cast(LiteLlmModelResponse, response)
|
|
190
|
+
assistant_message = AssistantMessage.from_litellm_message(response)
|
|
191
|
+
result: GenerateTextResponse = [assistant_message]
|
|
192
|
+
for tool_message in self._resolve_tool_calls_sync(params, assistant_message):
|
|
193
|
+
result.append(tool_message)
|
|
194
|
+
return result
|
|
195
|
+
|
|
196
|
+
async def generate_text(self, params: LlmRequestParams) -> GenerateTextResponse:
|
|
197
|
+
response = await acompletion(**self._param_parser.parse_nonstream(params))
|
|
198
|
+
response = cast(LiteLlmModelResponse, response)
|
|
199
|
+
assistant_message = AssistantMessage.from_litellm_message(response)
|
|
200
|
+
result: GenerateTextResponse = [assistant_message]
|
|
201
|
+
async for tool_message in self._resolve_tool_calls(params, assistant_message):
|
|
202
|
+
result.append(tool_message)
|
|
203
|
+
return result
|
|
204
|
+
|
|
205
|
+
def stream_text_sync(self, params: LlmRequestParams) -> StreamTextResponseSync:
|
|
206
|
+
"""
|
|
207
|
+
Returns:
|
|
208
|
+
- stream: Generator yielding `MessageChunk` objects
|
|
209
|
+
- full_message_queue: Queue containing complete `AssistantMessage`, `ToolMessage` (or `None` when done)
|
|
210
|
+
"""
|
|
211
|
+
def stream(response: CustomStreamWrapper) -> Generator[MessageChunk]:
|
|
212
|
+
nonlocal message_collector
|
|
213
|
+
for chunk in response:
|
|
214
|
+
chunk = cast(LiteLlmModelResponseStream, chunk)
|
|
215
|
+
yield from openai_chunk_normalizer(chunk)
|
|
216
|
+
message_collector.collect(chunk)
|
|
217
|
+
|
|
218
|
+
message = message_collector.get_message()
|
|
219
|
+
full_message_queue.put(message)
|
|
220
|
+
|
|
221
|
+
for tool_message in self._resolve_tool_calls_sync(params, message):
|
|
222
|
+
full_message_queue.put(tool_message)
|
|
223
|
+
full_message_queue.put(None)
|
|
224
|
+
|
|
225
|
+
response = completion(**self._param_parser.parse_stream(params))
|
|
226
|
+
message_collector = AssistantMessageCollector()
|
|
227
|
+
returned_stream = stream(cast(CustomStreamWrapper, response))
|
|
228
|
+
full_message_queue = FullMessageQueueSync()
|
|
229
|
+
return returned_stream, full_message_queue
|
|
230
|
+
|
|
231
|
+
async def stream_text(self, params: LlmRequestParams) -> StreamTextResponseAsync:
|
|
232
|
+
"""
|
|
233
|
+
Returns:
|
|
234
|
+
- stream: Generator yielding `MessageChunk` objects
|
|
235
|
+
- full_message_queue: Queue containing complete `AssistantMessage`, `ToolMessage` (or `None` when done)
|
|
236
|
+
"""
|
|
237
|
+
async def stream(response: CustomStreamWrapper) -> AsyncGenerator[MessageChunk]:
|
|
238
|
+
nonlocal message_collector
|
|
239
|
+
async for chunk in response:
|
|
240
|
+
chunk = cast(LiteLlmModelResponseStream, chunk)
|
|
241
|
+
for normalized_chunk in openai_chunk_normalizer(chunk):
|
|
242
|
+
yield normalized_chunk
|
|
243
|
+
message_collector.collect(chunk)
|
|
244
|
+
|
|
245
|
+
message = message_collector.get_message()
|
|
246
|
+
await full_message_queue.put(message)
|
|
247
|
+
async for tool_message in self._resolve_tool_calls(params, message):
|
|
248
|
+
await full_message_queue.put(tool_message)
|
|
249
|
+
await full_message_queue.put(None)
|
|
250
|
+
|
|
251
|
+
response = await acompletion(**self._param_parser.parse_stream(params))
|
|
252
|
+
message_collector = AssistantMessageCollector()
|
|
253
|
+
returned_stream = stream(cast(CustomStreamWrapper, response))
|
|
254
|
+
full_message_queue = FullMessageQueueAsync()
|
|
255
|
+
return returned_stream, full_message_queue
|
|
256
|
+
|
|
257
|
+
__all__ = [
|
|
258
|
+
"LLM",
|
|
259
|
+
"LlmProviders",
|
|
260
|
+
"LlmRequestParams",
|
|
261
|
+
|
|
262
|
+
"Toolset",
|
|
263
|
+
"python_tool",
|
|
264
|
+
"PythonToolset",
|
|
265
|
+
"McpToolset",
|
|
266
|
+
"LocalMcpToolset",
|
|
267
|
+
"RemoteMcpToolset",
|
|
268
|
+
|
|
269
|
+
"McpClient",
|
|
270
|
+
"McpTool",
|
|
271
|
+
"McpToolResult",
|
|
272
|
+
"LocalMcpClient",
|
|
273
|
+
"RemoteMcpClient",
|
|
274
|
+
"LocalServerParams",
|
|
275
|
+
"RemoteServerParams",
|
|
276
|
+
"OAuthParams",
|
|
277
|
+
|
|
278
|
+
"ToolFn",
|
|
279
|
+
"ToolDef",
|
|
280
|
+
"RawToolDef",
|
|
281
|
+
"ToolLike",
|
|
282
|
+
"execute_tool",
|
|
283
|
+
"execute_tool_sync",
|
|
284
|
+
|
|
285
|
+
"ChatMessage",
|
|
286
|
+
"UserMessage",
|
|
287
|
+
"SystemMessage",
|
|
288
|
+
"AssistantMessage",
|
|
289
|
+
"ToolMessage",
|
|
290
|
+
|
|
291
|
+
"MessageChunk",
|
|
292
|
+
"TextChunk",
|
|
293
|
+
"UsageChunk",
|
|
294
|
+
"ReasoningChunk",
|
|
295
|
+
"AudioChunk",
|
|
296
|
+
"ImageChunk",
|
|
297
|
+
"ToolCallChunk",
|
|
298
|
+
|
|
299
|
+
"GenerateTextResponse",
|
|
300
|
+
"StreamTextResponseSync",
|
|
301
|
+
"StreamTextResponseAsync",
|
|
302
|
+
"FullMessageQueueSync",
|
|
303
|
+
"FullMessageQueueAsync",
|
|
304
|
+
|
|
305
|
+
"enable_debugging",
|
|
306
|
+
"enable_logging",
|
|
307
|
+
|
|
308
|
+
# Exceptions
|
|
309
|
+
"AuthenticationError",
|
|
310
|
+
"PermissionDeniedError",
|
|
311
|
+
"RateLimitError",
|
|
312
|
+
"ContextWindowExceededError",
|
|
313
|
+
"BadRequestError",
|
|
314
|
+
"InvalidRequestError",
|
|
315
|
+
"InternalServerError",
|
|
316
|
+
"ServiceUnavailableError",
|
|
317
|
+
"ContentPolicyViolationError",
|
|
318
|
+
"APIError",
|
|
319
|
+
"Timeout",
|
|
320
|
+
]
|
dais_sdk/debug.py
ADDED
dais_sdk/logger.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
logger = logging.getLogger("LiteAI-SDK")
|
|
5
|
+
logger.addHandler(logging.NullHandler())
|
|
6
|
+
|
|
7
|
+
def enable_logging(level=logging.INFO):
|
|
8
|
+
"""
|
|
9
|
+
Enable logging for the LiteAI SDK.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
level: The logging level (default: logging.INFO).
|
|
13
|
+
|
|
14
|
+
Common values: logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR
|
|
15
|
+
"""
|
|
16
|
+
logger.setLevel(level)
|
|
17
|
+
|
|
18
|
+
if not any(isinstance(h, logging.StreamHandler) for h in logger.handlers):
|
|
19
|
+
handler = logging.StreamHandler(sys.stderr)
|
|
20
|
+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(message)s")
|
|
21
|
+
handler.setFormatter(formatter)
|
|
22
|
+
logger.addHandler(handler)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .base_mcp_client import McpClient, McpTool, ToolResult as McpToolResult
|
|
2
|
+
from .local_mcp_client import LocalMcpClient, LocalServerParams
|
|
3
|
+
from .remote_mcp_client import RemoteMcpClient, RemoteServerParams, OAuthParams
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"McpClient",
|
|
7
|
+
"McpTool",
|
|
8
|
+
"McpToolResult",
|
|
9
|
+
|
|
10
|
+
"LocalMcpClient",
|
|
11
|
+
"LocalServerParams",
|
|
12
|
+
"RemoteMcpClient",
|
|
13
|
+
"RemoteServerParams",
|
|
14
|
+
"OAuthParams",
|
|
15
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, NamedTuple
|
|
3
|
+
from mcp import Tool as McpTool
|
|
4
|
+
from mcp.types import ContentBlock as ToolResultBlock
|
|
5
|
+
|
|
6
|
+
Tool = McpTool
|
|
7
|
+
|
|
8
|
+
class ToolResult(NamedTuple):
|
|
9
|
+
is_error: bool
|
|
10
|
+
content: list[ToolResultBlock]
|
|
11
|
+
|
|
12
|
+
class McpClient(ABC):
|
|
13
|
+
@property
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def name(self) -> str: ...
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
async def connect(self): ...
|
|
19
|
+
@abstractmethod
|
|
20
|
+
async def disconnect(self): ...
|
|
21
|
+
@abstractmethod
|
|
22
|
+
async def list_tools(self) -> list[Tool]:
|
|
23
|
+
"""
|
|
24
|
+
Raises:
|
|
25
|
+
McpSessionNotEstablishedError: If the session is not established.
|
|
26
|
+
"""
|
|
27
|
+
@abstractmethod
|
|
28
|
+
async def call_tool(
|
|
29
|
+
self, tool_name: str, arguments: dict[str, Any] | None = None
|
|
30
|
+
) -> ToolResult:
|
|
31
|
+
"""
|
|
32
|
+
Raises:
|
|
33
|
+
McpSessionNotEstablishedError: If the session is not established.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
class McpSessionNotEstablishedError(RuntimeError):
|
|
37
|
+
def __init__(self):
|
|
38
|
+
super().__init__("MCP Session not established, please call connect() first")
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from contextlib import AsyncExitStack
|
|
2
|
+
from typing import Any
|
|
3
|
+
from mcp import ClientSession, StdioServerParameters as StdioServerParams
|
|
4
|
+
from mcp.client.stdio import stdio_client
|
|
5
|
+
from .base_mcp_client import McpClient, Tool, ToolResult, McpSessionNotEstablishedError
|
|
6
|
+
|
|
7
|
+
LocalServerParams = StdioServerParams
|
|
8
|
+
|
|
9
|
+
class LocalMcpClient(McpClient):
|
|
10
|
+
def __init__(self, name: str, params: LocalServerParams):
|
|
11
|
+
self._name: str = name
|
|
12
|
+
self._params: LocalServerParams = params
|
|
13
|
+
self._session: ClientSession | None = None
|
|
14
|
+
self._exit_stack: AsyncExitStack | None = None
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def name(self) -> str:
|
|
18
|
+
return self._name
|
|
19
|
+
|
|
20
|
+
async def connect(self):
|
|
21
|
+
self._exit_stack = AsyncExitStack()
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
read_stream, write_stream = await self._exit_stack.enter_async_context(
|
|
25
|
+
stdio_client(self._params)
|
|
26
|
+
)
|
|
27
|
+
self._session = await self._exit_stack.enter_async_context(
|
|
28
|
+
ClientSession(read_stream, write_stream)
|
|
29
|
+
)
|
|
30
|
+
await self._session.initialize()
|
|
31
|
+
except Exception:
|
|
32
|
+
await self.disconnect()
|
|
33
|
+
raise
|
|
34
|
+
|
|
35
|
+
async def list_tools(self) -> list[Tool]:
|
|
36
|
+
if not self._session:
|
|
37
|
+
raise McpSessionNotEstablishedError()
|
|
38
|
+
|
|
39
|
+
result = await self._session.list_tools()
|
|
40
|
+
return result.tools
|
|
41
|
+
|
|
42
|
+
async def call_tool(
|
|
43
|
+
self, tool_name: str, arguments: dict[str, Any] | None = None
|
|
44
|
+
) -> ToolResult:
|
|
45
|
+
if not self._session:
|
|
46
|
+
raise McpSessionNotEstablishedError()
|
|
47
|
+
|
|
48
|
+
response = await self._session.call_tool(tool_name, arguments)
|
|
49
|
+
return ToolResult(response.isError, response.content)
|
|
50
|
+
|
|
51
|
+
async def disconnect(self) -> None:
|
|
52
|
+
if self._exit_stack:
|
|
53
|
+
await self._exit_stack.aclose()
|
|
54
|
+
self._session = None
|
|
55
|
+
self._exit_stack = None
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import socket
|
|
3
|
+
import uvicorn
|
|
4
|
+
from starlette.applications import Starlette
|
|
5
|
+
from starlette.responses import HTMLResponse
|
|
6
|
+
from starlette.routing import Route
|
|
7
|
+
from starlette.requests import Request
|
|
8
|
+
from mcp.client.auth import TokenStorage
|
|
9
|
+
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
|
10
|
+
|
|
11
|
+
CALLBACK_PAGE = """
|
|
12
|
+
<html>
|
|
13
|
+
<body style="font-family: sans-serif; text-align: center; padding-top: 50px;">
|
|
14
|
+
<h2 style="color: green;">Login Successful!</h2>
|
|
15
|
+
<p>You can close this window now.</p>
|
|
16
|
+
<script>window.close();</script>
|
|
17
|
+
</body>
|
|
18
|
+
</html>
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
OAuthCode = tuple[str, str | None]
|
|
22
|
+
OAuthCallbackFuture = asyncio.Future[OAuthCode]
|
|
23
|
+
|
|
24
|
+
class InMemoryTokenStorage(TokenStorage):
|
|
25
|
+
def __init__(self):
|
|
26
|
+
self.tokens: OAuthToken | None = None
|
|
27
|
+
self.client_info: OAuthClientInformationFull | None = None
|
|
28
|
+
|
|
29
|
+
async def get_tokens(self) -> OAuthToken | None:
|
|
30
|
+
return self.tokens
|
|
31
|
+
|
|
32
|
+
async def set_tokens(self, tokens: OAuthToken) -> None:
|
|
33
|
+
self.tokens = tokens
|
|
34
|
+
|
|
35
|
+
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
|
36
|
+
return self.client_info
|
|
37
|
+
|
|
38
|
+
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
|
39
|
+
self.client_info = client_info
|
|
40
|
+
|
|
41
|
+
def _find_free_port() -> int:
|
|
42
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
43
|
+
s.bind(("0.0.0.0", 0))
|
|
44
|
+
port = s.getsockname()[1]
|
|
45
|
+
return port
|
|
46
|
+
|
|
47
|
+
class LocalOAuthServer:
|
|
48
|
+
def __init__(self, timeout: int):
|
|
49
|
+
self._port = _find_free_port()
|
|
50
|
+
self._timeout = timeout
|
|
51
|
+
self._future = OAuthCallbackFuture()
|
|
52
|
+
self._server: uvicorn.Server | None = None
|
|
53
|
+
self._server_task: asyncio.Task | None = None
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def callback_url(self) -> str:
|
|
57
|
+
return f"http://localhost:{self._port}/callback"
|
|
58
|
+
|
|
59
|
+
async def _handle_callback(self, request: Request):
|
|
60
|
+
params = request.query_params
|
|
61
|
+
|
|
62
|
+
# check if error
|
|
63
|
+
if "error" in params:
|
|
64
|
+
error = params.get("error")
|
|
65
|
+
desc = params.get("error_description", "Unknown error")
|
|
66
|
+
if not self._future.done():
|
|
67
|
+
self._future.set_exception(RuntimeError(f"OAuth Error: {error} - {desc}"))
|
|
68
|
+
return HTMLResponse(f"<h3>Auth failed: {error}</h3>", status_code=400)
|
|
69
|
+
|
|
70
|
+
code = params.get("code")
|
|
71
|
+
state = params.get("state")
|
|
72
|
+
|
|
73
|
+
if not code:
|
|
74
|
+
return HTMLResponse("<h3>Missing 'code' parameter</h3>", status_code=400)
|
|
75
|
+
|
|
76
|
+
if not self._future.done():
|
|
77
|
+
self._future.set_result((code, state))
|
|
78
|
+
|
|
79
|
+
return HTMLResponse(CALLBACK_PAGE)
|
|
80
|
+
|
|
81
|
+
async def wait_for_code(self) -> OAuthCode:
|
|
82
|
+
auth_code = await asyncio.wait_for(self._future, timeout=self._timeout)
|
|
83
|
+
self._future = OAuthCallbackFuture()
|
|
84
|
+
return auth_code
|
|
85
|
+
|
|
86
|
+
async def start(self):
|
|
87
|
+
routes = [
|
|
88
|
+
Route("/callback", self._handle_callback, methods=["GET"])
|
|
89
|
+
]
|
|
90
|
+
app = Starlette(routes=routes)
|
|
91
|
+
|
|
92
|
+
config = uvicorn.Config(app=app, host="127.0.0.1", port=self._port, log_level="error")
|
|
93
|
+
self._server = uvicorn.Server(config)
|
|
94
|
+
self._server_task = asyncio.create_task(self._server.serve())
|
|
95
|
+
|
|
96
|
+
async def stop(self):
|
|
97
|
+
if self._server:
|
|
98
|
+
self._server.should_exit = True
|
|
99
|
+
if self._server_task:
|
|
100
|
+
await self._server_task
|