dais-sdk 0.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dais_sdk-0.6.0/LICENSE +21 -0
- dais_sdk-0.6.0/PKG-INFO +100 -0
- dais_sdk-0.6.0/README.md +67 -0
- dais_sdk-0.6.0/pyproject.toml +102 -0
- dais_sdk-0.6.0/src/dais_sdk/__init__.py +320 -0
- dais_sdk-0.6.0/src/dais_sdk/debug.py +4 -0
- dais_sdk-0.6.0/src/dais_sdk/logger.py +22 -0
- dais_sdk-0.6.0/src/dais_sdk/mcp_client/__init__.py +15 -0
- dais_sdk-0.6.0/src/dais_sdk/mcp_client/base_mcp_client.py +38 -0
- dais_sdk-0.6.0/src/dais_sdk/mcp_client/local_mcp_client.py +55 -0
- dais_sdk-0.6.0/src/dais_sdk/mcp_client/oauth_server.py +100 -0
- dais_sdk-0.6.0/src/dais_sdk/mcp_client/remote_mcp_client.py +157 -0
- dais_sdk-0.6.0/src/dais_sdk/param_parser.py +55 -0
- dais_sdk-0.6.0/src/dais_sdk/stream.py +82 -0
- dais_sdk-0.6.0/src/dais_sdk/tool/__init__.py +0 -0
- dais_sdk-0.6.0/src/dais_sdk/tool/execute.py +65 -0
- dais_sdk-0.6.0/src/dais_sdk/tool/prepare.py +283 -0
- dais_sdk-0.6.0/src/dais_sdk/tool/toolset/__init__.py +18 -0
- dais_sdk-0.6.0/src/dais_sdk/tool/toolset/mcp_toolset.py +94 -0
- dais_sdk-0.6.0/src/dais_sdk/tool/toolset/python_toolset.py +31 -0
- dais_sdk-0.6.0/src/dais_sdk/tool/toolset/toolset.py +13 -0
- dais_sdk-0.6.0/src/dais_sdk/tool/utils.py +11 -0
- dais_sdk-0.6.0/src/dais_sdk/types/__init__.py +20 -0
- dais_sdk-0.6.0/src/dais_sdk/types/exceptions.py +27 -0
- dais_sdk-0.6.0/src/dais_sdk/types/message.py +211 -0
- dais_sdk-0.6.0/src/dais_sdk/types/request_params.py +55 -0
- dais_sdk-0.6.0/src/dais_sdk/types/tool.py +47 -0
dais_sdk-0.6.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 BHznJNs
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
dais_sdk-0.6.0/PKG-INFO
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: dais-sdk
|
|
3
|
+
Version: 0.6.0
|
|
4
|
+
Summary: A wrapper of LiteLLM
|
|
5
|
+
Author-email: BHznJNs <bhznjns@outlook.com>
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Classifier: Development Status :: 3 - Alpha
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
Requires-Dist: litellm>=1.80.0
|
|
17
|
+
Requires-Dist: pydantic>=2.0.0
|
|
18
|
+
Requires-Dist: httpx==0.28.1
|
|
19
|
+
Requires-Dist: mcp==1.25.0
|
|
20
|
+
Requires-Dist: starlette==0.50.0
|
|
21
|
+
Requires-Dist: uvicorn==0.40.0
|
|
22
|
+
Requires-Dist: python-dotenv>=1.2.1 ; extra == "dev"
|
|
23
|
+
Requires-Dist: pytest-cov ; extra == "test"
|
|
24
|
+
Requires-Dist: pytest-mock ; extra == "test"
|
|
25
|
+
Requires-Dist: pytest-runner ; extra == "test"
|
|
26
|
+
Requires-Dist: pytest ; extra == "test"
|
|
27
|
+
Requires-Dist: pytest-github-actions-annotate-failures ; extra == "test"
|
|
28
|
+
Project-URL: Source, https://github.com/Dais-Project/Dais-SDK
|
|
29
|
+
Project-URL: Tracker, https://github.com/Dais-Project/Dais-SDK/issues
|
|
30
|
+
Provides-Extra: dev
|
|
31
|
+
Provides-Extra: test
|
|
32
|
+
|
|
33
|
+
# Dais-SDK
|
|
34
|
+
|
|
35
|
+
Dais-SDK is a wrapper of LiteLLM which provides a more intuitive API and [AI SDK](https://github.com/vercel/ai) like DX.
|
|
36
|
+
|
|
37
|
+
## Installation
|
|
38
|
+
|
|
39
|
+
```
|
|
40
|
+
pip install dais_sdk
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
## Examples
|
|
44
|
+
|
|
45
|
+
Below is a simple example of just a API call:
|
|
46
|
+
|
|
47
|
+
```python
|
|
48
|
+
import os
|
|
49
|
+
from dotenv import load_dotenv
|
|
50
|
+
from dais_sdk import LLM, LlmProviders, LlmRequestParams, UserMessage
|
|
51
|
+
|
|
52
|
+
load_dotenv()
|
|
53
|
+
|
|
54
|
+
llm = LLM(provider=LlmProviders.OPENAI,
|
|
55
|
+
api_key=os.getenv("API_KEY", ""),
|
|
56
|
+
base_url=os.getenv("BASE_URL", ""))
|
|
57
|
+
|
|
58
|
+
response = llm.generate_text_sync( # sync API of generate_text
|
|
59
|
+
LlmRequestParams(
|
|
60
|
+
model="deepseek-v3.1",
|
|
61
|
+
messages=[UserMessage(content="Hello.")]))
|
|
62
|
+
print(response)
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
Below is an example that shows the automatically tool call:
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
import os
|
|
69
|
+
from dotenv import load_dotenv
|
|
70
|
+
from dais_sdk import LLM, LlmProviders, LlmRequestParams, UserMessage
|
|
71
|
+
|
|
72
|
+
load_dotenv()
|
|
73
|
+
|
|
74
|
+
def example_tool():
|
|
75
|
+
"""
|
|
76
|
+
This is a test tool that is used to test the tool calling functionality.
|
|
77
|
+
"""
|
|
78
|
+
print("The example tool is called.")
|
|
79
|
+
return "Hello World"
|
|
80
|
+
|
|
81
|
+
llm = LLM(provider=LlmProviders.OPENAI,
|
|
82
|
+
api_key=os.getenv("API_KEY", ""),
|
|
83
|
+
base_url=os.getenv("BASE_URL", ""))
|
|
84
|
+
|
|
85
|
+
params = LlmRequestParams(
|
|
86
|
+
model="deepseek-v3.1",
|
|
87
|
+
tools=[example_tool],
|
|
88
|
+
execute_tools=True,
|
|
89
|
+
messages=[UserMessage(content="Please call the tool example_tool.")])
|
|
90
|
+
|
|
91
|
+
print("User: ", "Please call the tool example_tool.")
|
|
92
|
+
messages = llm.generate_text_sync(params)
|
|
93
|
+
for message in messages:
|
|
94
|
+
match message.role:
|
|
95
|
+
case "assistant":
|
|
96
|
+
print("Assistant: ", message.content)
|
|
97
|
+
case "tool":
|
|
98
|
+
print("Tool: ", message.result)
|
|
99
|
+
```
|
|
100
|
+
|
dais_sdk-0.6.0/README.md
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Dais-SDK
|
|
2
|
+
|
|
3
|
+
Dais-SDK is a wrapper of LiteLLM which provides a more intuitive API and [AI SDK](https://github.com/vercel/ai) like DX.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```
|
|
8
|
+
pip install dais_sdk
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Examples
|
|
12
|
+
|
|
13
|
+
Below is a simple example of just a API call:
|
|
14
|
+
|
|
15
|
+
```python
|
|
16
|
+
import os
|
|
17
|
+
from dotenv import load_dotenv
|
|
18
|
+
from dais_sdk import LLM, LlmProviders, LlmRequestParams, UserMessage
|
|
19
|
+
|
|
20
|
+
load_dotenv()
|
|
21
|
+
|
|
22
|
+
llm = LLM(provider=LlmProviders.OPENAI,
|
|
23
|
+
api_key=os.getenv("API_KEY", ""),
|
|
24
|
+
base_url=os.getenv("BASE_URL", ""))
|
|
25
|
+
|
|
26
|
+
response = llm.generate_text_sync( # sync API of generate_text
|
|
27
|
+
LlmRequestParams(
|
|
28
|
+
model="deepseek-v3.1",
|
|
29
|
+
messages=[UserMessage(content="Hello.")]))
|
|
30
|
+
print(response)
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
Below is an example that shows the automatically tool call:
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
import os
|
|
37
|
+
from dotenv import load_dotenv
|
|
38
|
+
from dais_sdk import LLM, LlmProviders, LlmRequestParams, UserMessage
|
|
39
|
+
|
|
40
|
+
load_dotenv()
|
|
41
|
+
|
|
42
|
+
def example_tool():
|
|
43
|
+
"""
|
|
44
|
+
This is a test tool that is used to test the tool calling functionality.
|
|
45
|
+
"""
|
|
46
|
+
print("The example tool is called.")
|
|
47
|
+
return "Hello World"
|
|
48
|
+
|
|
49
|
+
llm = LLM(provider=LlmProviders.OPENAI,
|
|
50
|
+
api_key=os.getenv("API_KEY", ""),
|
|
51
|
+
base_url=os.getenv("BASE_URL", ""))
|
|
52
|
+
|
|
53
|
+
params = LlmRequestParams(
|
|
54
|
+
model="deepseek-v3.1",
|
|
55
|
+
tools=[example_tool],
|
|
56
|
+
execute_tools=True,
|
|
57
|
+
messages=[UserMessage(content="Please call the tool example_tool.")])
|
|
58
|
+
|
|
59
|
+
print("User: ", "Please call the tool example_tool.")
|
|
60
|
+
messages = llm.generate_text_sync(params)
|
|
61
|
+
for message in messages:
|
|
62
|
+
match message.role:
|
|
63
|
+
case "assistant":
|
|
64
|
+
print("Assistant: ", message.content)
|
|
65
|
+
case "tool":
|
|
66
|
+
print("Tool: ", message.result)
|
|
67
|
+
```
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["flit_core >=2,<4"]
|
|
3
|
+
build-backend = "flit_core.buildapi"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "dais-sdk"
|
|
7
|
+
authors = [{ name = "BHznJNs", email = "bhznjns@outlook.com" }]
|
|
8
|
+
description = "A wrapper of LiteLLM"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
classifiers = [
|
|
11
|
+
"Development Status :: 3 - Alpha",
|
|
12
|
+
"Intended Audience :: Developers",
|
|
13
|
+
"License :: OSI Approved :: MIT License",
|
|
14
|
+
"Programming Language :: Python :: 3 :: Only",
|
|
15
|
+
"Programming Language :: Python :: 3.10",
|
|
16
|
+
"Programming Language :: Python :: 3.11",
|
|
17
|
+
"Programming Language :: Python :: 3.12",
|
|
18
|
+
]
|
|
19
|
+
requires-python = ">=3.10"
|
|
20
|
+
version = "0.6.0"
|
|
21
|
+
|
|
22
|
+
dependencies = [
|
|
23
|
+
"litellm>=1.80.0",
|
|
24
|
+
"pydantic>=2.0.0",
|
|
25
|
+
|
|
26
|
+
# MCP client dependencies
|
|
27
|
+
"httpx==0.28.1",
|
|
28
|
+
"mcp==1.25.0",
|
|
29
|
+
"starlette==0.50.0",
|
|
30
|
+
"uvicorn==0.40.0",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
[project.optional-dependencies]
|
|
34
|
+
dev = [
|
|
35
|
+
"python-dotenv>=1.2.1"
|
|
36
|
+
]
|
|
37
|
+
test = [
|
|
38
|
+
"pytest-cov",
|
|
39
|
+
"pytest-mock",
|
|
40
|
+
"pytest-runner",
|
|
41
|
+
"pytest",
|
|
42
|
+
"pytest-github-actions-annotate-failures",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
[project.urls]
|
|
46
|
+
Source = "https://github.com/Dais-Project/Dais-SDK"
|
|
47
|
+
Tracker = "https://github.com/Dais-Project/Dais-SDK/issues"
|
|
48
|
+
|
|
49
|
+
[tool.flit.module]
|
|
50
|
+
name = "dais_sdk"
|
|
51
|
+
|
|
52
|
+
[tool.bandit]
|
|
53
|
+
exclude_dirs = ["build", "dist", "tests", "scripts"]
|
|
54
|
+
number = 4
|
|
55
|
+
recursive = true
|
|
56
|
+
targets = "src"
|
|
57
|
+
|
|
58
|
+
[tool.black]
|
|
59
|
+
line-length = 120
|
|
60
|
+
fast = true
|
|
61
|
+
|
|
62
|
+
[tool.coverage.run]
|
|
63
|
+
branch = true
|
|
64
|
+
|
|
65
|
+
[tool.coverage.report]
|
|
66
|
+
fail_under = 100
|
|
67
|
+
|
|
68
|
+
[tool.flake8]
|
|
69
|
+
max-line-length = 120
|
|
70
|
+
select = "F,E,W,B,B901,B902,B903"
|
|
71
|
+
exclude = [
|
|
72
|
+
".eggs",
|
|
73
|
+
".git",
|
|
74
|
+
".tox",
|
|
75
|
+
"nssm",
|
|
76
|
+
"obj",
|
|
77
|
+
"out",
|
|
78
|
+
"packages",
|
|
79
|
+
"pywin32",
|
|
80
|
+
"tests",
|
|
81
|
+
"swagger_client",
|
|
82
|
+
]
|
|
83
|
+
ignore = ["E722", "B001", "W503", "E203"]
|
|
84
|
+
|
|
85
|
+
[tool.pyright]
|
|
86
|
+
include = ["src"]
|
|
87
|
+
pythonVersion = "3.11"
|
|
88
|
+
typeCheckingMode = "standard"
|
|
89
|
+
|
|
90
|
+
[tool.pytest.ini_options]
|
|
91
|
+
addopts = "--cov-report xml:coverage.xml --cov src --cov-fail-under 0 --cov-append -m 'not integration'"
|
|
92
|
+
pythonpath = ["src"]
|
|
93
|
+
testpaths = "tests"
|
|
94
|
+
junit_family = "xunit2"
|
|
95
|
+
markers = [
|
|
96
|
+
"integration: marks as integration test",
|
|
97
|
+
"notebooks: marks as notebook test",
|
|
98
|
+
"gpu: marks as gpu test",
|
|
99
|
+
"spark: marks tests which need Spark",
|
|
100
|
+
"slow: marks tests as slow",
|
|
101
|
+
"unit: fast offline tests",
|
|
102
|
+
]
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import queue
|
|
3
|
+
from typing import cast
|
|
4
|
+
from collections.abc import AsyncGenerator, Generator
|
|
5
|
+
from litellm import CustomStreamWrapper, completion, acompletion
|
|
6
|
+
from litellm.utils import ProviderConfigManager
|
|
7
|
+
from litellm.types.utils import (
|
|
8
|
+
LlmProviders,
|
|
9
|
+
ModelResponse as LiteLlmModelResponse,
|
|
10
|
+
ModelResponseStream as LiteLlmModelResponseStream
|
|
11
|
+
)
|
|
12
|
+
from .debug import enable_debugging
|
|
13
|
+
from .param_parser import ParamParser
|
|
14
|
+
from .stream import AssistantMessageCollector
|
|
15
|
+
from .tool.execute import execute_tool_sync, execute_tool
|
|
16
|
+
from .tool.toolset import (
|
|
17
|
+
Toolset,
|
|
18
|
+
python_tool,
|
|
19
|
+
PythonToolset,
|
|
20
|
+
McpToolset,
|
|
21
|
+
LocalMcpToolset,
|
|
22
|
+
RemoteMcpToolset,
|
|
23
|
+
)
|
|
24
|
+
from .tool.utils import find_tool_by_name
|
|
25
|
+
from .mcp_client import (
|
|
26
|
+
McpClient,
|
|
27
|
+
McpTool,
|
|
28
|
+
McpToolResult,
|
|
29
|
+
LocalMcpClient,
|
|
30
|
+
RemoteMcpClient,
|
|
31
|
+
LocalServerParams,
|
|
32
|
+
RemoteServerParams,
|
|
33
|
+
OAuthParams,
|
|
34
|
+
)
|
|
35
|
+
from .types import (
|
|
36
|
+
GenerateTextResponse,
|
|
37
|
+
StreamTextResponseSync, StreamTextResponseAsync,
|
|
38
|
+
FullMessageQueueSync, FullMessageQueueAsync,
|
|
39
|
+
)
|
|
40
|
+
from .types.request_params import LlmRequestParams
|
|
41
|
+
from .types.tool import ToolFn, ToolDef, RawToolDef, ToolLike
|
|
42
|
+
from .types.exceptions import (
|
|
43
|
+
AuthenticationError,
|
|
44
|
+
PermissionDeniedError,
|
|
45
|
+
RateLimitError,
|
|
46
|
+
ContextWindowExceededError,
|
|
47
|
+
BadRequestError,
|
|
48
|
+
InvalidRequestError,
|
|
49
|
+
InternalServerError,
|
|
50
|
+
ServiceUnavailableError,
|
|
51
|
+
ContentPolicyViolationError,
|
|
52
|
+
APIError,
|
|
53
|
+
Timeout,
|
|
54
|
+
)
|
|
55
|
+
from .types.message import (
|
|
56
|
+
ChatMessage, UserMessage, SystemMessage, AssistantMessage, ToolMessage,
|
|
57
|
+
MessageChunk, TextChunk, UsageChunk, ReasoningChunk, AudioChunk, ImageChunk, ToolCallChunk,
|
|
58
|
+
openai_chunk_normalizer
|
|
59
|
+
)
|
|
60
|
+
from .logger import logger, enable_logging
|
|
61
|
+
|
|
62
|
+
class LLM:
|
|
63
|
+
"""
|
|
64
|
+
The `generate_text` and `stream_text` API will return ToolMessage in the returned sequence
|
|
65
|
+
only if `params.execute_tools` is True.
|
|
66
|
+
|
|
67
|
+
- - -
|
|
68
|
+
|
|
69
|
+
Possible exceptions raises for `generate_text` and `stream_text`:
|
|
70
|
+
- AuthenticationError
|
|
71
|
+
- PermissionDeniedError
|
|
72
|
+
- RateLimitError
|
|
73
|
+
- ContextWindowExceededError
|
|
74
|
+
- BadRequestError
|
|
75
|
+
- InvalidRequestError
|
|
76
|
+
- InternalServerError
|
|
77
|
+
- ServiceUnavailableError
|
|
78
|
+
- ContentPolicyViolationError
|
|
79
|
+
- APIError
|
|
80
|
+
- Timeout
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self,
|
|
84
|
+
provider: LlmProviders,
|
|
85
|
+
base_url: str,
|
|
86
|
+
api_key: str):
|
|
87
|
+
self.provider = provider
|
|
88
|
+
self.base_url = base_url
|
|
89
|
+
self.api_key = api_key
|
|
90
|
+
self._param_parser = ParamParser(self.provider, self.base_url, self.api_key)
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
async def execute_tool_call(
|
|
94
|
+
params: LlmRequestParams,
|
|
95
|
+
incomplete_tool_message: ToolMessage
|
|
96
|
+
) -> tuple[str | None, str | None]:
|
|
97
|
+
"""
|
|
98
|
+
Receive incomplete tool messages, execute the tool calls and
|
|
99
|
+
return the result and error tuple.
|
|
100
|
+
"""
|
|
101
|
+
name, arguments = incomplete_tool_message.name, incomplete_tool_message.arguments
|
|
102
|
+
tool_def = params.find_tool(incomplete_tool_message.name)
|
|
103
|
+
if tool_def is None:
|
|
104
|
+
raise LlmRequestParams.ToolDoesNotExistError(name)
|
|
105
|
+
|
|
106
|
+
result, error = None, None
|
|
107
|
+
try:
|
|
108
|
+
result = await execute_tool(tool_def, arguments)
|
|
109
|
+
except Exception as e:
|
|
110
|
+
error = f"{type(e).__name__}: {str(e)}"
|
|
111
|
+
return result, error
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def execute_tool_call_sync(
|
|
115
|
+
params: LlmRequestParams,
|
|
116
|
+
incomplete_tool_message: ToolMessage
|
|
117
|
+
) -> tuple[str | None, str | None]:
|
|
118
|
+
"""
|
|
119
|
+
Synchronous version of `execute_tool_call`.
|
|
120
|
+
"""
|
|
121
|
+
name, arguments = incomplete_tool_message.name, incomplete_tool_message.arguments
|
|
122
|
+
tool_def = params.find_tool(incomplete_tool_message.name)
|
|
123
|
+
if tool_def is None:
|
|
124
|
+
raise LlmRequestParams.ToolDoesNotExistError(name)
|
|
125
|
+
|
|
126
|
+
result, error = None, None
|
|
127
|
+
try:
|
|
128
|
+
result = execute_tool_sync(tool_def, arguments)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
error = f"{type(e).__name__}: {str(e)}"
|
|
131
|
+
return result, error
|
|
132
|
+
|
|
133
|
+
def _resolve_tool_calls_sync(self, params: LlmRequestParams, assistant_message: AssistantMessage) -> Generator[ToolMessage]:
|
|
134
|
+
if not params.execute_tools: return
|
|
135
|
+
if (incomplete_tool_messages
|
|
136
|
+
:= assistant_message.get_incomplete_tool_messages()) is None:
|
|
137
|
+
return
|
|
138
|
+
for incomplete_tool_message in incomplete_tool_messages:
|
|
139
|
+
try:
|
|
140
|
+
result, error = LLM.execute_tool_call_sync(params, incomplete_tool_message)
|
|
141
|
+
except LlmRequestParams.ToolDoesNotExistError as e:
|
|
142
|
+
logger.warning(f"{e.message} Skipping this tool call.")
|
|
143
|
+
continue
|
|
144
|
+
yield ToolMessage(
|
|
145
|
+
tool_call_id=incomplete_tool_message.tool_call_id,
|
|
146
|
+
name=incomplete_tool_message.name,
|
|
147
|
+
arguments=incomplete_tool_message.arguments,
|
|
148
|
+
result=result,
|
|
149
|
+
error=error)
|
|
150
|
+
|
|
151
|
+
async def _resolve_tool_calls(self, params: LlmRequestParams, assistant_message: AssistantMessage) -> AsyncGenerator[ToolMessage]:
|
|
152
|
+
if not params.execute_tools: return
|
|
153
|
+
if (incomplete_tool_messages :=
|
|
154
|
+
assistant_message.get_incomplete_tool_messages()) is None:
|
|
155
|
+
return
|
|
156
|
+
for incomplete_tool_message in incomplete_tool_messages:
|
|
157
|
+
try:
|
|
158
|
+
result, error = await LLM.execute_tool_call(params, incomplete_tool_message)
|
|
159
|
+
except LlmRequestParams.ToolDoesNotExistError as e:
|
|
160
|
+
logger.warning(f"{e.message} Skipping this tool call.")
|
|
161
|
+
continue
|
|
162
|
+
yield ToolMessage(
|
|
163
|
+
tool_call_id=incomplete_tool_message.tool_call_id,
|
|
164
|
+
name=incomplete_tool_message.name,
|
|
165
|
+
arguments=incomplete_tool_message.arguments,
|
|
166
|
+
result=result,
|
|
167
|
+
error=error)
|
|
168
|
+
|
|
169
|
+
def list_models(self) -> list[str]:
|
|
170
|
+
provider_config = ProviderConfigManager.get_provider_model_info(
|
|
171
|
+
model=None,
|
|
172
|
+
provider=self.provider,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if provider_config is None:
|
|
176
|
+
raise ValueError(f"The '{self.provider}' provider is not supported to list models.")
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
models = provider_config.get_models(
|
|
180
|
+
api_key=self.api_key,
|
|
181
|
+
api_base=self.base_url
|
|
182
|
+
)
|
|
183
|
+
except Exception as e:
|
|
184
|
+
raise e
|
|
185
|
+
return models
|
|
186
|
+
|
|
187
|
+
def generate_text_sync(self, params: LlmRequestParams) -> GenerateTextResponse:
|
|
188
|
+
response = completion(**self._param_parser.parse_nonstream(params))
|
|
189
|
+
response = cast(LiteLlmModelResponse, response)
|
|
190
|
+
assistant_message = AssistantMessage.from_litellm_message(response)
|
|
191
|
+
result: GenerateTextResponse = [assistant_message]
|
|
192
|
+
for tool_message in self._resolve_tool_calls_sync(params, assistant_message):
|
|
193
|
+
result.append(tool_message)
|
|
194
|
+
return result
|
|
195
|
+
|
|
196
|
+
async def generate_text(self, params: LlmRequestParams) -> GenerateTextResponse:
|
|
197
|
+
response = await acompletion(**self._param_parser.parse_nonstream(params))
|
|
198
|
+
response = cast(LiteLlmModelResponse, response)
|
|
199
|
+
assistant_message = AssistantMessage.from_litellm_message(response)
|
|
200
|
+
result: GenerateTextResponse = [assistant_message]
|
|
201
|
+
async for tool_message in self._resolve_tool_calls(params, assistant_message):
|
|
202
|
+
result.append(tool_message)
|
|
203
|
+
return result
|
|
204
|
+
|
|
205
|
+
def stream_text_sync(self, params: LlmRequestParams) -> StreamTextResponseSync:
|
|
206
|
+
"""
|
|
207
|
+
Returns:
|
|
208
|
+
- stream: Generator yielding `MessageChunk` objects
|
|
209
|
+
- full_message_queue: Queue containing complete `AssistantMessage`, `ToolMessage` (or `None` when done)
|
|
210
|
+
"""
|
|
211
|
+
def stream(response: CustomStreamWrapper) -> Generator[MessageChunk]:
|
|
212
|
+
nonlocal message_collector
|
|
213
|
+
for chunk in response:
|
|
214
|
+
chunk = cast(LiteLlmModelResponseStream, chunk)
|
|
215
|
+
yield from openai_chunk_normalizer(chunk)
|
|
216
|
+
message_collector.collect(chunk)
|
|
217
|
+
|
|
218
|
+
message = message_collector.get_message()
|
|
219
|
+
full_message_queue.put(message)
|
|
220
|
+
|
|
221
|
+
for tool_message in self._resolve_tool_calls_sync(params, message):
|
|
222
|
+
full_message_queue.put(tool_message)
|
|
223
|
+
full_message_queue.put(None)
|
|
224
|
+
|
|
225
|
+
response = completion(**self._param_parser.parse_stream(params))
|
|
226
|
+
message_collector = AssistantMessageCollector()
|
|
227
|
+
returned_stream = stream(cast(CustomStreamWrapper, response))
|
|
228
|
+
full_message_queue = FullMessageQueueSync()
|
|
229
|
+
return returned_stream, full_message_queue
|
|
230
|
+
|
|
231
|
+
async def stream_text(self, params: LlmRequestParams) -> StreamTextResponseAsync:
|
|
232
|
+
"""
|
|
233
|
+
Returns:
|
|
234
|
+
- stream: Generator yielding `MessageChunk` objects
|
|
235
|
+
- full_message_queue: Queue containing complete `AssistantMessage`, `ToolMessage` (or `None` when done)
|
|
236
|
+
"""
|
|
237
|
+
async def stream(response: CustomStreamWrapper) -> AsyncGenerator[MessageChunk]:
|
|
238
|
+
nonlocal message_collector
|
|
239
|
+
async for chunk in response:
|
|
240
|
+
chunk = cast(LiteLlmModelResponseStream, chunk)
|
|
241
|
+
for normalized_chunk in openai_chunk_normalizer(chunk):
|
|
242
|
+
yield normalized_chunk
|
|
243
|
+
message_collector.collect(chunk)
|
|
244
|
+
|
|
245
|
+
message = message_collector.get_message()
|
|
246
|
+
await full_message_queue.put(message)
|
|
247
|
+
async for tool_message in self._resolve_tool_calls(params, message):
|
|
248
|
+
await full_message_queue.put(tool_message)
|
|
249
|
+
await full_message_queue.put(None)
|
|
250
|
+
|
|
251
|
+
response = await acompletion(**self._param_parser.parse_stream(params))
|
|
252
|
+
message_collector = AssistantMessageCollector()
|
|
253
|
+
returned_stream = stream(cast(CustomStreamWrapper, response))
|
|
254
|
+
full_message_queue = FullMessageQueueAsync()
|
|
255
|
+
return returned_stream, full_message_queue
|
|
256
|
+
|
|
257
|
+
__all__ = [
|
|
258
|
+
"LLM",
|
|
259
|
+
"LlmProviders",
|
|
260
|
+
"LlmRequestParams",
|
|
261
|
+
|
|
262
|
+
"Toolset",
|
|
263
|
+
"python_tool",
|
|
264
|
+
"PythonToolset",
|
|
265
|
+
"McpToolset",
|
|
266
|
+
"LocalMcpToolset",
|
|
267
|
+
"RemoteMcpToolset",
|
|
268
|
+
|
|
269
|
+
"McpClient",
|
|
270
|
+
"McpTool",
|
|
271
|
+
"McpToolResult",
|
|
272
|
+
"LocalMcpClient",
|
|
273
|
+
"RemoteMcpClient",
|
|
274
|
+
"LocalServerParams",
|
|
275
|
+
"RemoteServerParams",
|
|
276
|
+
"OAuthParams",
|
|
277
|
+
|
|
278
|
+
"ToolFn",
|
|
279
|
+
"ToolDef",
|
|
280
|
+
"RawToolDef",
|
|
281
|
+
"ToolLike",
|
|
282
|
+
"execute_tool",
|
|
283
|
+
"execute_tool_sync",
|
|
284
|
+
|
|
285
|
+
"ChatMessage",
|
|
286
|
+
"UserMessage",
|
|
287
|
+
"SystemMessage",
|
|
288
|
+
"AssistantMessage",
|
|
289
|
+
"ToolMessage",
|
|
290
|
+
|
|
291
|
+
"MessageChunk",
|
|
292
|
+
"TextChunk",
|
|
293
|
+
"UsageChunk",
|
|
294
|
+
"ReasoningChunk",
|
|
295
|
+
"AudioChunk",
|
|
296
|
+
"ImageChunk",
|
|
297
|
+
"ToolCallChunk",
|
|
298
|
+
|
|
299
|
+
"GenerateTextResponse",
|
|
300
|
+
"StreamTextResponseSync",
|
|
301
|
+
"StreamTextResponseAsync",
|
|
302
|
+
"FullMessageQueueSync",
|
|
303
|
+
"FullMessageQueueAsync",
|
|
304
|
+
|
|
305
|
+
"enable_debugging",
|
|
306
|
+
"enable_logging",
|
|
307
|
+
|
|
308
|
+
# Exceptions
|
|
309
|
+
"AuthenticationError",
|
|
310
|
+
"PermissionDeniedError",
|
|
311
|
+
"RateLimitError",
|
|
312
|
+
"ContextWindowExceededError",
|
|
313
|
+
"BadRequestError",
|
|
314
|
+
"InvalidRequestError",
|
|
315
|
+
"InternalServerError",
|
|
316
|
+
"ServiceUnavailableError",
|
|
317
|
+
"ContentPolicyViolationError",
|
|
318
|
+
"APIError",
|
|
319
|
+
"Timeout",
|
|
320
|
+
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
logger = logging.getLogger("LiteAI-SDK")
|
|
5
|
+
logger.addHandler(logging.NullHandler())
|
|
6
|
+
|
|
7
|
+
def enable_logging(level=logging.INFO):
|
|
8
|
+
"""
|
|
9
|
+
Enable logging for the LiteAI SDK.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
level: The logging level (default: logging.INFO).
|
|
13
|
+
|
|
14
|
+
Common values: logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR
|
|
15
|
+
"""
|
|
16
|
+
logger.setLevel(level)
|
|
17
|
+
|
|
18
|
+
if not any(isinstance(h, logging.StreamHandler) for h in logger.handlers):
|
|
19
|
+
handler = logging.StreamHandler(sys.stderr)
|
|
20
|
+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(message)s")
|
|
21
|
+
handler.setFormatter(formatter)
|
|
22
|
+
logger.addHandler(handler)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .base_mcp_client import McpClient, McpTool, ToolResult as McpToolResult
|
|
2
|
+
from .local_mcp_client import LocalMcpClient, LocalServerParams
|
|
3
|
+
from .remote_mcp_client import RemoteMcpClient, RemoteServerParams, OAuthParams
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"McpClient",
|
|
7
|
+
"McpTool",
|
|
8
|
+
"McpToolResult",
|
|
9
|
+
|
|
10
|
+
"LocalMcpClient",
|
|
11
|
+
"LocalServerParams",
|
|
12
|
+
"RemoteMcpClient",
|
|
13
|
+
"RemoteServerParams",
|
|
14
|
+
"OAuthParams",
|
|
15
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, NamedTuple
|
|
3
|
+
from mcp import Tool as McpTool
|
|
4
|
+
from mcp.types import ContentBlock as ToolResultBlock
|
|
5
|
+
|
|
6
|
+
Tool = McpTool
|
|
7
|
+
|
|
8
|
+
class ToolResult(NamedTuple):
|
|
9
|
+
is_error: bool
|
|
10
|
+
content: list[ToolResultBlock]
|
|
11
|
+
|
|
12
|
+
class McpClient(ABC):
|
|
13
|
+
@property
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def name(self) -> str: ...
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
async def connect(self): ...
|
|
19
|
+
@abstractmethod
|
|
20
|
+
async def disconnect(self): ...
|
|
21
|
+
@abstractmethod
|
|
22
|
+
async def list_tools(self) -> list[Tool]:
|
|
23
|
+
"""
|
|
24
|
+
Raises:
|
|
25
|
+
McpSessionNotEstablishedError: If the session is not established.
|
|
26
|
+
"""
|
|
27
|
+
@abstractmethod
|
|
28
|
+
async def call_tool(
|
|
29
|
+
self, tool_name: str, arguments: dict[str, Any] | None = None
|
|
30
|
+
) -> ToolResult:
|
|
31
|
+
"""
|
|
32
|
+
Raises:
|
|
33
|
+
McpSessionNotEstablishedError: If the session is not established.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
class McpSessionNotEstablishedError(RuntimeError):
|
|
37
|
+
def __init__(self):
|
|
38
|
+
super().__init__("MCP Session not established, please call connect() first")
|