dais-sdk 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dais_sdk/__init__.py ADDED
@@ -0,0 +1,320 @@
1
+ import asyncio
2
+ import queue
3
+ from typing import cast
4
+ from collections.abc import AsyncGenerator, Generator
5
+ from litellm import CustomStreamWrapper, completion, acompletion
6
+ from litellm.utils import ProviderConfigManager
7
+ from litellm.types.utils import (
8
+ LlmProviders,
9
+ ModelResponse as LiteLlmModelResponse,
10
+ ModelResponseStream as LiteLlmModelResponseStream
11
+ )
12
+ from .debug import enable_debugging
13
+ from .param_parser import ParamParser
14
+ from .stream import AssistantMessageCollector
15
+ from .tool.execute import execute_tool_sync, execute_tool
16
+ from .tool.toolset import (
17
+ Toolset,
18
+ python_tool,
19
+ PythonToolset,
20
+ McpToolset,
21
+ LocalMcpToolset,
22
+ RemoteMcpToolset,
23
+ )
24
+ from .tool.utils import find_tool_by_name
25
+ from .mcp_client import (
26
+ McpClient,
27
+ McpTool,
28
+ McpToolResult,
29
+ LocalMcpClient,
30
+ RemoteMcpClient,
31
+ LocalServerParams,
32
+ RemoteServerParams,
33
+ OAuthParams,
34
+ )
35
+ from .types import (
36
+ GenerateTextResponse,
37
+ StreamTextResponseSync, StreamTextResponseAsync,
38
+ FullMessageQueueSync, FullMessageQueueAsync,
39
+ )
40
+ from .types.request_params import LlmRequestParams
41
+ from .types.tool import ToolFn, ToolDef, RawToolDef, ToolLike
42
+ from .types.exceptions import (
43
+ AuthenticationError,
44
+ PermissionDeniedError,
45
+ RateLimitError,
46
+ ContextWindowExceededError,
47
+ BadRequestError,
48
+ InvalidRequestError,
49
+ InternalServerError,
50
+ ServiceUnavailableError,
51
+ ContentPolicyViolationError,
52
+ APIError,
53
+ Timeout,
54
+ )
55
+ from .types.message import (
56
+ ChatMessage, UserMessage, SystemMessage, AssistantMessage, ToolMessage,
57
+ MessageChunk, TextChunk, UsageChunk, ReasoningChunk, AudioChunk, ImageChunk, ToolCallChunk,
58
+ openai_chunk_normalizer
59
+ )
60
+ from .logger import logger, enable_logging
61
+
62
+ class LLM:
63
+ """
64
+ The `generate_text` and `stream_text` API will return ToolMessage in the returned sequence
65
+ only if `params.execute_tools` is True.
66
+
67
+ - - -
68
+
69
+ Possible exceptions raises for `generate_text` and `stream_text`:
70
+ - AuthenticationError
71
+ - PermissionDeniedError
72
+ - RateLimitError
73
+ - ContextWindowExceededError
74
+ - BadRequestError
75
+ - InvalidRequestError
76
+ - InternalServerError
77
+ - ServiceUnavailableError
78
+ - ContentPolicyViolationError
79
+ - APIError
80
+ - Timeout
81
+ """
82
+
83
+ def __init__(self,
84
+ provider: LlmProviders,
85
+ base_url: str,
86
+ api_key: str):
87
+ self.provider = provider
88
+ self.base_url = base_url
89
+ self.api_key = api_key
90
+ self._param_parser = ParamParser(self.provider, self.base_url, self.api_key)
91
+
92
+ @staticmethod
93
+ async def execute_tool_call(
94
+ params: LlmRequestParams,
95
+ incomplete_tool_message: ToolMessage
96
+ ) -> tuple[str | None, str | None]:
97
+ """
98
+ Receive incomplete tool messages, execute the tool calls and
99
+ return the result and error tuple.
100
+ """
101
+ name, arguments = incomplete_tool_message.name, incomplete_tool_message.arguments
102
+ tool_def = params.find_tool(incomplete_tool_message.name)
103
+ if tool_def is None:
104
+ raise LlmRequestParams.ToolDoesNotExistError(name)
105
+
106
+ result, error = None, None
107
+ try:
108
+ result = await execute_tool(tool_def, arguments)
109
+ except Exception as e:
110
+ error = f"{type(e).__name__}: {str(e)}"
111
+ return result, error
112
+
113
+ @staticmethod
114
+ def execute_tool_call_sync(
115
+ params: LlmRequestParams,
116
+ incomplete_tool_message: ToolMessage
117
+ ) -> tuple[str | None, str | None]:
118
+ """
119
+ Synchronous version of `execute_tool_call`.
120
+ """
121
+ name, arguments = incomplete_tool_message.name, incomplete_tool_message.arguments
122
+ tool_def = params.find_tool(incomplete_tool_message.name)
123
+ if tool_def is None:
124
+ raise LlmRequestParams.ToolDoesNotExistError(name)
125
+
126
+ result, error = None, None
127
+ try:
128
+ result = execute_tool_sync(tool_def, arguments)
129
+ except Exception as e:
130
+ error = f"{type(e).__name__}: {str(e)}"
131
+ return result, error
132
+
133
+ def _resolve_tool_calls_sync(self, params: LlmRequestParams, assistant_message: AssistantMessage) -> Generator[ToolMessage]:
134
+ if not params.execute_tools: return
135
+ if (incomplete_tool_messages
136
+ := assistant_message.get_incomplete_tool_messages()) is None:
137
+ return
138
+ for incomplete_tool_message in incomplete_tool_messages:
139
+ try:
140
+ result, error = LLM.execute_tool_call_sync(params, incomplete_tool_message)
141
+ except LlmRequestParams.ToolDoesNotExistError as e:
142
+ logger.warning(f"{e.message} Skipping this tool call.")
143
+ continue
144
+ yield ToolMessage(
145
+ tool_call_id=incomplete_tool_message.tool_call_id,
146
+ name=incomplete_tool_message.name,
147
+ arguments=incomplete_tool_message.arguments,
148
+ result=result,
149
+ error=error)
150
+
151
+ async def _resolve_tool_calls(self, params: LlmRequestParams, assistant_message: AssistantMessage) -> AsyncGenerator[ToolMessage]:
152
+ if not params.execute_tools: return
153
+ if (incomplete_tool_messages :=
154
+ assistant_message.get_incomplete_tool_messages()) is None:
155
+ return
156
+ for incomplete_tool_message in incomplete_tool_messages:
157
+ try:
158
+ result, error = await LLM.execute_tool_call(params, incomplete_tool_message)
159
+ except LlmRequestParams.ToolDoesNotExistError as e:
160
+ logger.warning(f"{e.message} Skipping this tool call.")
161
+ continue
162
+ yield ToolMessage(
163
+ tool_call_id=incomplete_tool_message.tool_call_id,
164
+ name=incomplete_tool_message.name,
165
+ arguments=incomplete_tool_message.arguments,
166
+ result=result,
167
+ error=error)
168
+
169
+ def list_models(self) -> list[str]:
170
+ provider_config = ProviderConfigManager.get_provider_model_info(
171
+ model=None,
172
+ provider=self.provider,
173
+ )
174
+
175
+ if provider_config is None:
176
+ raise ValueError(f"The '{self.provider}' provider is not supported to list models.")
177
+
178
+ try:
179
+ models = provider_config.get_models(
180
+ api_key=self.api_key,
181
+ api_base=self.base_url
182
+ )
183
+ except Exception as e:
184
+ raise e
185
+ return models
186
+
187
+ def generate_text_sync(self, params: LlmRequestParams) -> GenerateTextResponse:
188
+ response = completion(**self._param_parser.parse_nonstream(params))
189
+ response = cast(LiteLlmModelResponse, response)
190
+ assistant_message = AssistantMessage.from_litellm_message(response)
191
+ result: GenerateTextResponse = [assistant_message]
192
+ for tool_message in self._resolve_tool_calls_sync(params, assistant_message):
193
+ result.append(tool_message)
194
+ return result
195
+
196
+ async def generate_text(self, params: LlmRequestParams) -> GenerateTextResponse:
197
+ response = await acompletion(**self._param_parser.parse_nonstream(params))
198
+ response = cast(LiteLlmModelResponse, response)
199
+ assistant_message = AssistantMessage.from_litellm_message(response)
200
+ result: GenerateTextResponse = [assistant_message]
201
+ async for tool_message in self._resolve_tool_calls(params, assistant_message):
202
+ result.append(tool_message)
203
+ return result
204
+
205
+ def stream_text_sync(self, params: LlmRequestParams) -> StreamTextResponseSync:
206
+ """
207
+ Returns:
208
+ - stream: Generator yielding `MessageChunk` objects
209
+ - full_message_queue: Queue containing complete `AssistantMessage`, `ToolMessage` (or `None` when done)
210
+ """
211
+ def stream(response: CustomStreamWrapper) -> Generator[MessageChunk]:
212
+ nonlocal message_collector
213
+ for chunk in response:
214
+ chunk = cast(LiteLlmModelResponseStream, chunk)
215
+ yield from openai_chunk_normalizer(chunk)
216
+ message_collector.collect(chunk)
217
+
218
+ message = message_collector.get_message()
219
+ full_message_queue.put(message)
220
+
221
+ for tool_message in self._resolve_tool_calls_sync(params, message):
222
+ full_message_queue.put(tool_message)
223
+ full_message_queue.put(None)
224
+
225
+ response = completion(**self._param_parser.parse_stream(params))
226
+ message_collector = AssistantMessageCollector()
227
+ returned_stream = stream(cast(CustomStreamWrapper, response))
228
+ full_message_queue = FullMessageQueueSync()
229
+ return returned_stream, full_message_queue
230
+
231
+ async def stream_text(self, params: LlmRequestParams) -> StreamTextResponseAsync:
232
+ """
233
+ Returns:
234
+ - stream: Generator yielding `MessageChunk` objects
235
+ - full_message_queue: Queue containing complete `AssistantMessage`, `ToolMessage` (or `None` when done)
236
+ """
237
+ async def stream(response: CustomStreamWrapper) -> AsyncGenerator[MessageChunk]:
238
+ nonlocal message_collector
239
+ async for chunk in response:
240
+ chunk = cast(LiteLlmModelResponseStream, chunk)
241
+ for normalized_chunk in openai_chunk_normalizer(chunk):
242
+ yield normalized_chunk
243
+ message_collector.collect(chunk)
244
+
245
+ message = message_collector.get_message()
246
+ await full_message_queue.put(message)
247
+ async for tool_message in self._resolve_tool_calls(params, message):
248
+ await full_message_queue.put(tool_message)
249
+ await full_message_queue.put(None)
250
+
251
+ response = await acompletion(**self._param_parser.parse_stream(params))
252
+ message_collector = AssistantMessageCollector()
253
+ returned_stream = stream(cast(CustomStreamWrapper, response))
254
+ full_message_queue = FullMessageQueueAsync()
255
+ return returned_stream, full_message_queue
256
+
257
+ __all__ = [
258
+ "LLM",
259
+ "LlmProviders",
260
+ "LlmRequestParams",
261
+
262
+ "Toolset",
263
+ "python_tool",
264
+ "PythonToolset",
265
+ "McpToolset",
266
+ "LocalMcpToolset",
267
+ "RemoteMcpToolset",
268
+
269
+ "McpClient",
270
+ "McpTool",
271
+ "McpToolResult",
272
+ "LocalMcpClient",
273
+ "RemoteMcpClient",
274
+ "LocalServerParams",
275
+ "RemoteServerParams",
276
+ "OAuthParams",
277
+
278
+ "ToolFn",
279
+ "ToolDef",
280
+ "RawToolDef",
281
+ "ToolLike",
282
+ "execute_tool",
283
+ "execute_tool_sync",
284
+
285
+ "ChatMessage",
286
+ "UserMessage",
287
+ "SystemMessage",
288
+ "AssistantMessage",
289
+ "ToolMessage",
290
+
291
+ "MessageChunk",
292
+ "TextChunk",
293
+ "UsageChunk",
294
+ "ReasoningChunk",
295
+ "AudioChunk",
296
+ "ImageChunk",
297
+ "ToolCallChunk",
298
+
299
+ "GenerateTextResponse",
300
+ "StreamTextResponseSync",
301
+ "StreamTextResponseAsync",
302
+ "FullMessageQueueSync",
303
+ "FullMessageQueueAsync",
304
+
305
+ "enable_debugging",
306
+ "enable_logging",
307
+
308
+ # Exceptions
309
+ "AuthenticationError",
310
+ "PermissionDeniedError",
311
+ "RateLimitError",
312
+ "ContextWindowExceededError",
313
+ "BadRequestError",
314
+ "InvalidRequestError",
315
+ "InternalServerError",
316
+ "ServiceUnavailableError",
317
+ "ContentPolicyViolationError",
318
+ "APIError",
319
+ "Timeout",
320
+ ]
dais_sdk/debug.py ADDED
@@ -0,0 +1,4 @@
1
+ import litellm
2
+
3
+ def enable_debugging():
4
+ litellm._turn_on_debug()
dais_sdk/logger.py ADDED
@@ -0,0 +1,22 @@
1
+ import sys
2
+ import logging
3
+
4
+ logger = logging.getLogger("LiteAI-SDK")
5
+ logger.addHandler(logging.NullHandler())
6
+
7
+ def enable_logging(level=logging.INFO):
8
+ """
9
+ Enable logging for the LiteAI SDK.
10
+
11
+ Args:
12
+ level: The logging level (default: logging.INFO).
13
+
14
+ Common values: logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR
15
+ """
16
+ logger.setLevel(level)
17
+
18
+ if not any(isinstance(h, logging.StreamHandler) for h in logger.handlers):
19
+ handler = logging.StreamHandler(sys.stderr)
20
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(message)s")
21
+ handler.setFormatter(formatter)
22
+ logger.addHandler(handler)
@@ -0,0 +1,15 @@
1
+ from .base_mcp_client import McpClient, McpTool, ToolResult as McpToolResult
2
+ from .local_mcp_client import LocalMcpClient, LocalServerParams
3
+ from .remote_mcp_client import RemoteMcpClient, RemoteServerParams, OAuthParams
4
+
5
+ __all__ = [
6
+ "McpClient",
7
+ "McpTool",
8
+ "McpToolResult",
9
+
10
+ "LocalMcpClient",
11
+ "LocalServerParams",
12
+ "RemoteMcpClient",
13
+ "RemoteServerParams",
14
+ "OAuthParams",
15
+ ]
@@ -0,0 +1,38 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, NamedTuple
3
+ from mcp import Tool as McpTool
4
+ from mcp.types import ContentBlock as ToolResultBlock
5
+
6
+ Tool = McpTool
7
+
8
+ class ToolResult(NamedTuple):
9
+ is_error: bool
10
+ content: list[ToolResultBlock]
11
+
12
+ class McpClient(ABC):
13
+ @property
14
+ @abstractmethod
15
+ def name(self) -> str: ...
16
+
17
+ @abstractmethod
18
+ async def connect(self): ...
19
+ @abstractmethod
20
+ async def disconnect(self): ...
21
+ @abstractmethod
22
+ async def list_tools(self) -> list[Tool]:
23
+ """
24
+ Raises:
25
+ McpSessionNotEstablishedError: If the session is not established.
26
+ """
27
+ @abstractmethod
28
+ async def call_tool(
29
+ self, tool_name: str, arguments: dict[str, Any] | None = None
30
+ ) -> ToolResult:
31
+ """
32
+ Raises:
33
+ McpSessionNotEstablishedError: If the session is not established.
34
+ """
35
+
36
+ class McpSessionNotEstablishedError(RuntimeError):
37
+ def __init__(self):
38
+ super().__init__("MCP Session not established, please call connect() first")
@@ -0,0 +1,55 @@
1
+ from contextlib import AsyncExitStack
2
+ from typing import Any
3
+ from mcp import ClientSession, StdioServerParameters as StdioServerParams
4
+ from mcp.client.stdio import stdio_client
5
+ from .base_mcp_client import McpClient, Tool, ToolResult, McpSessionNotEstablishedError
6
+
7
+ LocalServerParams = StdioServerParams
8
+
9
+ class LocalMcpClient(McpClient):
10
+ def __init__(self, name: str, params: LocalServerParams):
11
+ self._name: str = name
12
+ self._params: LocalServerParams = params
13
+ self._session: ClientSession | None = None
14
+ self._exit_stack: AsyncExitStack | None = None
15
+
16
+ @property
17
+ def name(self) -> str:
18
+ return self._name
19
+
20
+ async def connect(self):
21
+ self._exit_stack = AsyncExitStack()
22
+
23
+ try:
24
+ read_stream, write_stream = await self._exit_stack.enter_async_context(
25
+ stdio_client(self._params)
26
+ )
27
+ self._session = await self._exit_stack.enter_async_context(
28
+ ClientSession(read_stream, write_stream)
29
+ )
30
+ await self._session.initialize()
31
+ except Exception:
32
+ await self.disconnect()
33
+ raise
34
+
35
+ async def list_tools(self) -> list[Tool]:
36
+ if not self._session:
37
+ raise McpSessionNotEstablishedError()
38
+
39
+ result = await self._session.list_tools()
40
+ return result.tools
41
+
42
+ async def call_tool(
43
+ self, tool_name: str, arguments: dict[str, Any] | None = None
44
+ ) -> ToolResult:
45
+ if not self._session:
46
+ raise McpSessionNotEstablishedError()
47
+
48
+ response = await self._session.call_tool(tool_name, arguments)
49
+ return ToolResult(response.isError, response.content)
50
+
51
+ async def disconnect(self) -> None:
52
+ if self._exit_stack:
53
+ await self._exit_stack.aclose()
54
+ self._session = None
55
+ self._exit_stack = None
@@ -0,0 +1,100 @@
1
+ import asyncio
2
+ import socket
3
+ import uvicorn
4
+ from starlette.applications import Starlette
5
+ from starlette.responses import HTMLResponse
6
+ from starlette.routing import Route
7
+ from starlette.requests import Request
8
+ from mcp.client.auth import TokenStorage
9
+ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
10
+
11
+ CALLBACK_PAGE = """
12
+ <html>
13
+ <body style="font-family: sans-serif; text-align: center; padding-top: 50px;">
14
+ <h2 style="color: green;">Login Successful!</h2>
15
+ <p>You can close this window now.</p>
16
+ <script>window.close();</script>
17
+ </body>
18
+ </html>
19
+ """
20
+
21
+ OAuthCode = tuple[str, str | None]
22
+ OAuthCallbackFuture = asyncio.Future[OAuthCode]
23
+
24
+ class InMemoryTokenStorage(TokenStorage):
25
+ def __init__(self):
26
+ self.tokens: OAuthToken | None = None
27
+ self.client_info: OAuthClientInformationFull | None = None
28
+
29
+ async def get_tokens(self) -> OAuthToken | None:
30
+ return self.tokens
31
+
32
+ async def set_tokens(self, tokens: OAuthToken) -> None:
33
+ self.tokens = tokens
34
+
35
+ async def get_client_info(self) -> OAuthClientInformationFull | None:
36
+ return self.client_info
37
+
38
+ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
39
+ self.client_info = client_info
40
+
41
+ def _find_free_port() -> int:
42
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
43
+ s.bind(("0.0.0.0", 0))
44
+ port = s.getsockname()[1]
45
+ return port
46
+
47
+ class LocalOAuthServer:
48
+ def __init__(self, timeout: int):
49
+ self._port = _find_free_port()
50
+ self._timeout = timeout
51
+ self._future = OAuthCallbackFuture()
52
+ self._server: uvicorn.Server | None = None
53
+ self._server_task: asyncio.Task | None = None
54
+
55
+ @property
56
+ def callback_url(self) -> str:
57
+ return f"http://localhost:{self._port}/callback"
58
+
59
+ async def _handle_callback(self, request: Request):
60
+ params = request.query_params
61
+
62
+ # check if error
63
+ if "error" in params:
64
+ error = params.get("error")
65
+ desc = params.get("error_description", "Unknown error")
66
+ if not self._future.done():
67
+ self._future.set_exception(RuntimeError(f"OAuth Error: {error} - {desc}"))
68
+ return HTMLResponse(f"<h3>Auth failed: {error}</h3>", status_code=400)
69
+
70
+ code = params.get("code")
71
+ state = params.get("state")
72
+
73
+ if not code:
74
+ return HTMLResponse("<h3>Missing 'code' parameter</h3>", status_code=400)
75
+
76
+ if not self._future.done():
77
+ self._future.set_result((code, state))
78
+
79
+ return HTMLResponse(CALLBACK_PAGE)
80
+
81
+ async def wait_for_code(self) -> OAuthCode:
82
+ auth_code = await asyncio.wait_for(self._future, timeout=self._timeout)
83
+ self._future = OAuthCallbackFuture()
84
+ return auth_code
85
+
86
+ async def start(self):
87
+ routes = [
88
+ Route("/callback", self._handle_callback, methods=["GET"])
89
+ ]
90
+ app = Starlette(routes=routes)
91
+
92
+ config = uvicorn.Config(app=app, host="127.0.0.1", port=self._port, log_level="error")
93
+ self._server = uvicorn.Server(config)
94
+ self._server_task = asyncio.create_task(self._server.serve())
95
+
96
+ async def stop(self):
97
+ if self._server:
98
+ self._server.should_exit = True
99
+ if self._server_task:
100
+ await self._server_task