liteai-sdk 0.3.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liteai_sdk/__init__.py +264 -0
- liteai_sdk/debug.py +4 -0
- liteai_sdk/logger.py +22 -0
- liteai_sdk/param_parser.py +48 -0
- liteai_sdk/stream.py +82 -0
- liteai_sdk/tool/__init__.py +310 -0
- liteai_sdk/tool/execute.py +65 -0
- liteai_sdk/tool/utils.py +11 -0
- liteai_sdk/types/__init__.py +32 -0
- liteai_sdk/types/exceptions.py +27 -0
- liteai_sdk/types/message.py +242 -0
- liteai_sdk-0.3.21.dist-info/METADATA +100 -0
- liteai_sdk-0.3.21.dist-info/RECORD +15 -0
- liteai_sdk-0.3.21.dist-info/WHEEL +4 -0
- liteai_sdk-0.3.21.dist-info/licenses/LICENSE +21 -0
liteai_sdk/__init__.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import queue
|
|
3
|
+
from typing import cast
|
|
4
|
+
from collections.abc import AsyncGenerator, Generator
|
|
5
|
+
from litellm import CustomStreamWrapper, completion, acompletion
|
|
6
|
+
from litellm.exceptions import (
|
|
7
|
+
AuthenticationError,
|
|
8
|
+
PermissionDeniedError,
|
|
9
|
+
RateLimitError,
|
|
10
|
+
ContextWindowExceededError,
|
|
11
|
+
BadRequestError,
|
|
12
|
+
InvalidRequestError,
|
|
13
|
+
InternalServerError,
|
|
14
|
+
ServiceUnavailableError,
|
|
15
|
+
ContentPolicyViolationError,
|
|
16
|
+
APIError,
|
|
17
|
+
Timeout,
|
|
18
|
+
)
|
|
19
|
+
from litellm.utils import get_valid_models
|
|
20
|
+
from litellm.types.utils import LlmProviders,\
|
|
21
|
+
ModelResponse as LiteLlmModelResponse,\
|
|
22
|
+
ModelResponseStream as LiteLlmModelResponseStream,\
|
|
23
|
+
Choices as LiteLlmModelResponseChoices
|
|
24
|
+
from .debug import enable_debugging
|
|
25
|
+
from .param_parser import ParamParser
|
|
26
|
+
from .stream import AssistantMessageCollector
|
|
27
|
+
from .tool import ToolFn, ToolDef, RawToolDef, ToolLike
|
|
28
|
+
from .tool.execute import execute_tool_sync, execute_tool
|
|
29
|
+
from .tool.utils import find_tool_by_name
|
|
30
|
+
from .types import LlmRequestParams, GenerateTextResponse, StreamTextResponseSync, StreamTextResponseAsync
|
|
31
|
+
from .types.exceptions import *
|
|
32
|
+
from .types.message import ChatMessage, UserMessage, SystemMessage, AssistantMessage, ToolMessage,\
|
|
33
|
+
MessageChunk, TextChunk, ReasoningChunk, AudioChunk, ImageChunk, ToolCallChunk,\
|
|
34
|
+
ToolCallTuple, openai_chunk_normalizer
|
|
35
|
+
from .logger import logger, enable_logging
|
|
36
|
+
|
|
37
|
+
class LLM:
|
|
38
|
+
"""
|
|
39
|
+
The `generate_text` and `stream_text` API will return ToolMessage in the returned sequence
|
|
40
|
+
only if `params.execute_tools` is True.
|
|
41
|
+
|
|
42
|
+
- - -
|
|
43
|
+
|
|
44
|
+
Possible exceptions raises for `generate_text` and `stream_text`:
|
|
45
|
+
- AuthenticationError
|
|
46
|
+
- PermissionDeniedError
|
|
47
|
+
- RateLimitError
|
|
48
|
+
- ContextWindowExceededError
|
|
49
|
+
- BadRequestError
|
|
50
|
+
- InvalidRequestError
|
|
51
|
+
- InternalServerError
|
|
52
|
+
- ServiceUnavailableError
|
|
53
|
+
- ContentPolicyViolationError
|
|
54
|
+
- APIError
|
|
55
|
+
- Timeout
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self,
|
|
59
|
+
provider: LlmProviders,
|
|
60
|
+
base_url: str,
|
|
61
|
+
api_key: str):
|
|
62
|
+
self.provider = provider
|
|
63
|
+
self.base_url = base_url
|
|
64
|
+
self.api_key = api_key
|
|
65
|
+
self._param_parser = ParamParser(self.provider, self.base_url, self.api_key)
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def _should_resolve_tool_calls(
|
|
69
|
+
params: LlmRequestParams,
|
|
70
|
+
message: AssistantMessage,
|
|
71
|
+
) -> tuple[list[ToolLike],
|
|
72
|
+
list[ToolCallTuple]] | None:
|
|
73
|
+
parsed_tool_calls = message.parse_tool_calls()
|
|
74
|
+
condition = params.execute_tools and\
|
|
75
|
+
params.tools is not None and\
|
|
76
|
+
parsed_tool_calls is not None
|
|
77
|
+
if condition:
|
|
78
|
+
assert params.tools is not None
|
|
79
|
+
assert parsed_tool_calls is not None
|
|
80
|
+
return params.tools, parsed_tool_calls
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
async def _execute_tool_calls(
|
|
85
|
+
tools: list[ToolLike],
|
|
86
|
+
tool_call_tuples: list[ToolCallTuple]
|
|
87
|
+
) -> list[ToolMessage]:
|
|
88
|
+
results = []
|
|
89
|
+
for tool_call_tuple in tool_call_tuples:
|
|
90
|
+
id, function_name, function_arguments = tool_call_tuple
|
|
91
|
+
if (target_tool := find_tool_by_name(tools, function_name)) is None:
|
|
92
|
+
logger.warning(f"Tool \"{function_name}\" not found, skipping execution.")
|
|
93
|
+
continue
|
|
94
|
+
if isinstance(target_tool, dict):
|
|
95
|
+
logger.warning(f"Tool \"{function_name}\" is a raw tool, skipping execution.")
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
result, error = None, None
|
|
99
|
+
try:
|
|
100
|
+
result = await execute_tool(target_tool, function_arguments)
|
|
101
|
+
except Exception as e:
|
|
102
|
+
error = f"{type(e).__name__}: {str(e)}"
|
|
103
|
+
results.append(ToolMessage(
|
|
104
|
+
id=id,
|
|
105
|
+
name=function_name,
|
|
106
|
+
arguments=function_arguments,
|
|
107
|
+
result=result,
|
|
108
|
+
error=error).with_tool_def(target_tool))
|
|
109
|
+
return results
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def _execute_tool_calls_sync(
|
|
113
|
+
tools: list[ToolLike],
|
|
114
|
+
tool_call_tuples: list[ToolCallTuple]
|
|
115
|
+
) -> list[ToolMessage]:
|
|
116
|
+
results = []
|
|
117
|
+
for tool_call_tuple in tool_call_tuples:
|
|
118
|
+
id, function_name, function_arguments = tool_call_tuple
|
|
119
|
+
if (target_tool := find_tool_by_name(tools, function_name)) is None:
|
|
120
|
+
logger.warning(f"Tool \"{function_name}\" not found, skipping execution.")
|
|
121
|
+
continue
|
|
122
|
+
if isinstance(target_tool, dict):
|
|
123
|
+
logger.warning(f"Tool \"{function_name}\" is a raw tool, skipping execution.")
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
result, error = None, None
|
|
127
|
+
try:
|
|
128
|
+
result = execute_tool_sync(target_tool, function_arguments)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
error = f"{type(e).__name__}: {str(e)}"
|
|
131
|
+
results.append(ToolMessage(
|
|
132
|
+
id=id,
|
|
133
|
+
name=function_name,
|
|
134
|
+
arguments=function_arguments,
|
|
135
|
+
result=result,
|
|
136
|
+
error=error).with_tool_def(target_tool))
|
|
137
|
+
return results
|
|
138
|
+
|
|
139
|
+
def list_models(self) -> list[str]:
|
|
140
|
+
return get_valid_models(
|
|
141
|
+
custom_llm_provider=self.provider.value,
|
|
142
|
+
check_provider_endpoint=True,
|
|
143
|
+
api_base=self.base_url,
|
|
144
|
+
api_key=self.api_key)
|
|
145
|
+
|
|
146
|
+
def generate_text_sync(self, params: LlmRequestParams) -> GenerateTextResponse:
|
|
147
|
+
response = completion(**self._param_parser.parse_nonstream(params))
|
|
148
|
+
response = cast(LiteLlmModelResponse, response)
|
|
149
|
+
choices = cast(list[LiteLlmModelResponseChoices], response.choices)
|
|
150
|
+
message = choices[0].message
|
|
151
|
+
assistant_message = AssistantMessage\
|
|
152
|
+
.from_litellm_message(message)\
|
|
153
|
+
.with_request_params(params)
|
|
154
|
+
result: GenerateTextResponse = [assistant_message]
|
|
155
|
+
if (tools_and_tool_calls := self._should_resolve_tool_calls(params, assistant_message)):
|
|
156
|
+
tools, tool_calls = tools_and_tool_calls
|
|
157
|
+
result += self._execute_tool_calls_sync(tools, tool_calls)
|
|
158
|
+
return result
|
|
159
|
+
|
|
160
|
+
async def generate_text(self, params: LlmRequestParams) -> GenerateTextResponse:
|
|
161
|
+
response = await acompletion(**self._param_parser.parse_nonstream(params))
|
|
162
|
+
response = cast(LiteLlmModelResponse, response)
|
|
163
|
+
choices = cast(list[LiteLlmModelResponseChoices], response.choices)
|
|
164
|
+
message = choices[0].message
|
|
165
|
+
assistant_message = AssistantMessage\
|
|
166
|
+
.from_litellm_message(message)\
|
|
167
|
+
.with_request_params(params)
|
|
168
|
+
result: GenerateTextResponse = [assistant_message]
|
|
169
|
+
if (tools_and_tool_calls := self._should_resolve_tool_calls(params, assistant_message)):
|
|
170
|
+
tools, tool_calls = tools_and_tool_calls
|
|
171
|
+
result += await self._execute_tool_calls(tools, tool_calls)
|
|
172
|
+
return result
|
|
173
|
+
|
|
174
|
+
def stream_text_sync(self, params: LlmRequestParams) -> StreamTextResponseSync:
|
|
175
|
+
def stream(response: CustomStreamWrapper) -> Generator[MessageChunk]:
|
|
176
|
+
nonlocal message_collector
|
|
177
|
+
for chunk in response:
|
|
178
|
+
chunk = cast(LiteLlmModelResponseStream, chunk)
|
|
179
|
+
yield from openai_chunk_normalizer(chunk)
|
|
180
|
+
message_collector.collect(chunk)
|
|
181
|
+
|
|
182
|
+
message = message_collector.get_message().with_request_params(params)
|
|
183
|
+
full_message_queue.put(message)
|
|
184
|
+
if (tools_and_tool_calls := self._should_resolve_tool_calls(params, message)):
|
|
185
|
+
tools, tool_calls = tools_and_tool_calls
|
|
186
|
+
tool_messages = self._execute_tool_calls_sync(tools, tool_calls)
|
|
187
|
+
for tool_message in tool_messages:
|
|
188
|
+
full_message_queue.put(tool_message)
|
|
189
|
+
full_message_queue.put(None)
|
|
190
|
+
|
|
191
|
+
response = completion(**self._param_parser.parse_stream(params))
|
|
192
|
+
message_collector = AssistantMessageCollector()
|
|
193
|
+
returned_stream = stream(cast(CustomStreamWrapper, response))
|
|
194
|
+
full_message_queue = queue.Queue[AssistantMessage | ToolMessage | None]()
|
|
195
|
+
return returned_stream, full_message_queue
|
|
196
|
+
|
|
197
|
+
async def stream_text(self, params: LlmRequestParams) -> StreamTextResponseAsync:
|
|
198
|
+
async def stream(response: CustomStreamWrapper) -> AsyncGenerator[TextChunk | ReasoningChunk | AudioChunk | ImageChunk | ToolCallChunk]:
|
|
199
|
+
nonlocal message_collector
|
|
200
|
+
async for chunk in response:
|
|
201
|
+
chunk = cast(LiteLlmModelResponseStream, chunk)
|
|
202
|
+
for normalized_chunk in openai_chunk_normalizer(chunk):
|
|
203
|
+
yield normalized_chunk
|
|
204
|
+
message_collector.collect(chunk)
|
|
205
|
+
|
|
206
|
+
message = message_collector.get_message().with_request_params(params)
|
|
207
|
+
await full_message_queue.put(message)
|
|
208
|
+
if (tools_and_tool_calls := self._should_resolve_tool_calls(params, message)):
|
|
209
|
+
tools, tool_calls = tools_and_tool_calls
|
|
210
|
+
tool_messages = await self._execute_tool_calls(tools, tool_calls)
|
|
211
|
+
for tool_message in tool_messages:
|
|
212
|
+
await full_message_queue.put(tool_message)
|
|
213
|
+
await full_message_queue.put(None)
|
|
214
|
+
|
|
215
|
+
response = await acompletion(**self._param_parser.parse_stream(params))
|
|
216
|
+
message_collector = AssistantMessageCollector()
|
|
217
|
+
returned_stream = stream(cast(CustomStreamWrapper, response))
|
|
218
|
+
full_message_queue = asyncio.Queue[AssistantMessage | ToolMessage | None]()
|
|
219
|
+
return returned_stream, full_message_queue
|
|
220
|
+
|
|
221
|
+
__all__ = [
|
|
222
|
+
# Exceptions
|
|
223
|
+
"AuthenticationError",
|
|
224
|
+
"PermissionDeniedError",
|
|
225
|
+
"RateLimitError",
|
|
226
|
+
"ContextWindowExceededError",
|
|
227
|
+
"BadRequestError",
|
|
228
|
+
"InvalidRequestError",
|
|
229
|
+
"InternalServerError",
|
|
230
|
+
"ServiceUnavailableError",
|
|
231
|
+
"ContentPolicyViolationError",
|
|
232
|
+
"APIError",
|
|
233
|
+
"Timeout",
|
|
234
|
+
|
|
235
|
+
"LLM",
|
|
236
|
+
"LlmRequestParams",
|
|
237
|
+
|
|
238
|
+
"ToolFn",
|
|
239
|
+
"ToolDef",
|
|
240
|
+
"RawToolDef",
|
|
241
|
+
"ToolLike",
|
|
242
|
+
"execute_tool",
|
|
243
|
+
"execute_tool_sync",
|
|
244
|
+
|
|
245
|
+
"ChatMessage",
|
|
246
|
+
"UserMessage",
|
|
247
|
+
"SystemMessage",
|
|
248
|
+
"AssistantMessage",
|
|
249
|
+
"ToolMessage",
|
|
250
|
+
|
|
251
|
+
"MessageChunk",
|
|
252
|
+
"TextChunk",
|
|
253
|
+
"ReasoningChunk",
|
|
254
|
+
"AudioChunk",
|
|
255
|
+
"ImageChunk",
|
|
256
|
+
"ToolCallChunk",
|
|
257
|
+
|
|
258
|
+
"GenerateTextResponse",
|
|
259
|
+
"StreamTextResponseSync",
|
|
260
|
+
"StreamTextResponseAsync",
|
|
261
|
+
|
|
262
|
+
"enable_debugging",
|
|
263
|
+
"enable_logging",
|
|
264
|
+
]
|
liteai_sdk/debug.py
ADDED
liteai_sdk/logger.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
logger = logging.getLogger("LiteAI-SDK")
|
|
5
|
+
logger.addHandler(logging.NullHandler())
|
|
6
|
+
|
|
7
|
+
def enable_logging(level=logging.INFO):
|
|
8
|
+
"""
|
|
9
|
+
Enable logging for the LiteAI SDK.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
level: The logging level (default: logging.INFO).
|
|
13
|
+
|
|
14
|
+
Common values: logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR
|
|
15
|
+
"""
|
|
16
|
+
logger.setLevel(level)
|
|
17
|
+
|
|
18
|
+
if not any(isinstance(h, logging.StreamHandler) for h in logger.handlers):
|
|
19
|
+
handler = logging.StreamHandler(sys.stderr)
|
|
20
|
+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(message)s")
|
|
21
|
+
handler.setFormatter(formatter)
|
|
22
|
+
logger.addHandler(handler)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from litellm.types.utils import LlmProviders
|
|
3
|
+
from .tool import prepare_tools
|
|
4
|
+
from .types import LlmRequestParams, ToolMessage
|
|
5
|
+
|
|
6
|
+
ParsedParams = dict[str, Any]
|
|
7
|
+
|
|
8
|
+
class ParamParser:
|
|
9
|
+
def __init__(self,
|
|
10
|
+
provider: LlmProviders,
|
|
11
|
+
base_url: str,
|
|
12
|
+
api_key: str):
|
|
13
|
+
self._provider = provider
|
|
14
|
+
self._base_url = base_url
|
|
15
|
+
self._api_key = api_key
|
|
16
|
+
|
|
17
|
+
def _parse(self, params: LlmRequestParams) -> ParsedParams:
|
|
18
|
+
tools = params.tools and prepare_tools(params.tools)
|
|
19
|
+
transformed_messages = []
|
|
20
|
+
for message in params.messages:
|
|
21
|
+
if type(message) is ToolMessage and\
|
|
22
|
+
message.result is None and\
|
|
23
|
+
message.error is None:
|
|
24
|
+
# skip ToolMessage that is not resolved
|
|
25
|
+
continue
|
|
26
|
+
transformed_messages.append(message.to_litellm_message())
|
|
27
|
+
|
|
28
|
+
return {
|
|
29
|
+
"model": f"{self._provider.value}/{params.model}",
|
|
30
|
+
"messages": transformed_messages,
|
|
31
|
+
"base_url": self._base_url,
|
|
32
|
+
"api_key": self._api_key,
|
|
33
|
+
"tools": tools,
|
|
34
|
+
"tool_choice": params.tool_choice,
|
|
35
|
+
"timeout": params.timeout_sec,
|
|
36
|
+
"extra_headers": params.headers,
|
|
37
|
+
**(params.extra_args or {})
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
def parse_nonstream(self, params: LlmRequestParams) -> ParsedParams:
|
|
41
|
+
parsed = self._parse(params)
|
|
42
|
+
parsed["stream"] = False
|
|
43
|
+
return parsed
|
|
44
|
+
|
|
45
|
+
def parse_stream(self, params: LlmRequestParams) -> ParsedParams:
|
|
46
|
+
parsed = self._parse(params)
|
|
47
|
+
parsed["stream"] = True
|
|
48
|
+
return parsed
|
liteai_sdk/stream.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from litellm import ChatCompletionAssistantToolCall
|
|
3
|
+
from litellm.types.utils import ChatCompletionDeltaToolCall,\
|
|
4
|
+
ModelResponseStream as LiteLlmModelResponseStream
|
|
5
|
+
from .types.message import AssistantMessage
|
|
6
|
+
|
|
7
|
+
@dataclasses.dataclass
|
|
8
|
+
class ToolCallTemp:
|
|
9
|
+
id: str | None = None
|
|
10
|
+
name: str = ""
|
|
11
|
+
arguments: str = ""
|
|
12
|
+
|
|
13
|
+
class ToolCallCollector:
|
|
14
|
+
def __init__(self):
|
|
15
|
+
self.tool_call_map: dict[int, ToolCallTemp] = {}
|
|
16
|
+
|
|
17
|
+
def collect(self, tool_call_chunk: ChatCompletionDeltaToolCall):
|
|
18
|
+
if tool_call_chunk.index not in self.tool_call_map:
|
|
19
|
+
self.tool_call_map[tool_call_chunk.index] = ToolCallTemp()
|
|
20
|
+
|
|
21
|
+
temp_tool_call = self.tool_call_map[tool_call_chunk.index]
|
|
22
|
+
if tool_call_chunk.get("id"):
|
|
23
|
+
temp_tool_call.id = tool_call_chunk.id
|
|
24
|
+
if tool_call_chunk.function.get("name"):
|
|
25
|
+
assert tool_call_chunk.function.name is not None
|
|
26
|
+
temp_tool_call.name += tool_call_chunk.function.name
|
|
27
|
+
if tool_call_chunk.function.get("arguments"):
|
|
28
|
+
assert tool_call_chunk.function.arguments is not None
|
|
29
|
+
temp_tool_call.arguments += tool_call_chunk.function.arguments
|
|
30
|
+
|
|
31
|
+
def get_tool_calls(self) -> list[ChatCompletionAssistantToolCall]:
|
|
32
|
+
return [{
|
|
33
|
+
"id": tool_call.id,
|
|
34
|
+
"function": {
|
|
35
|
+
"name": tool_call.name,
|
|
36
|
+
"arguments": tool_call.arguments,
|
|
37
|
+
},
|
|
38
|
+
"type": "function"
|
|
39
|
+
} for tool_call in self.tool_call_map.values()]
|
|
40
|
+
|
|
41
|
+
class AssistantMessageCollector:
|
|
42
|
+
def __init__(self):
|
|
43
|
+
self.message_buf = AssistantMessage(content=None)
|
|
44
|
+
self.tool_call_collector = ToolCallCollector()
|
|
45
|
+
|
|
46
|
+
def collect(self, chunk: LiteLlmModelResponseStream):
|
|
47
|
+
delta = chunk.choices[0].delta
|
|
48
|
+
if delta.get("content"):
|
|
49
|
+
assert delta.content is not None
|
|
50
|
+
if self.message_buf.content is None:
|
|
51
|
+
self.message_buf.content = ""
|
|
52
|
+
self.message_buf.content += delta.content
|
|
53
|
+
|
|
54
|
+
if delta.get("reasoning_content"):
|
|
55
|
+
assert delta.reasoning_content is not None
|
|
56
|
+
if self.message_buf.reasoning_content is None:
|
|
57
|
+
self.message_buf.reasoning_content = ""
|
|
58
|
+
self.message_buf.reasoning_content += delta.reasoning_content
|
|
59
|
+
|
|
60
|
+
if delta.get("tool_calls"):
|
|
61
|
+
assert delta.tool_calls is not None
|
|
62
|
+
for tool_call_chunk in delta.tool_calls:
|
|
63
|
+
self.tool_call_collector.collect(tool_call_chunk)
|
|
64
|
+
|
|
65
|
+
if delta.get("images"):
|
|
66
|
+
assert delta.images is not None
|
|
67
|
+
if self.message_buf.images is None:
|
|
68
|
+
self.message_buf.images = []
|
|
69
|
+
self.message_buf.images = delta.images
|
|
70
|
+
|
|
71
|
+
if delta.get("audio"):
|
|
72
|
+
assert delta.audio is not None
|
|
73
|
+
self.message_buf.audio = delta.audio
|
|
74
|
+
|
|
75
|
+
def get_message(self) -> AssistantMessage:
|
|
76
|
+
self.message_buf.tool_calls = self.tool_call_collector.get_tool_calls()
|
|
77
|
+
return self.message_buf
|
|
78
|
+
|
|
79
|
+
def clear(self):
|
|
80
|
+
"""
|
|
81
|
+
This class will be created for each stream, so we don't need to clear it.
|
|
82
|
+
"""
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
"""
|
|
2
|
+
source: https://github.com/mozilla-ai/any-llm/blob/main/src/any_llm/tools.py
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import dataclasses
|
|
6
|
+
import enum
|
|
7
|
+
import inspect
|
|
8
|
+
import types as _types
|
|
9
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
10
|
+
from datetime import date, datetime, time
|
|
11
|
+
from typing import Annotated as _Annotated, Literal as _Literal, is_typeddict as _is_typeddict,\
|
|
12
|
+
Any, Awaitable, get_args, get_origin, get_type_hints
|
|
13
|
+
from pydantic import BaseModel as PydanticBaseModel
|
|
14
|
+
|
|
15
|
+
ToolFn = Callable[..., Any] | Callable[..., Awaitable[Any]]
|
|
16
|
+
|
|
17
|
+
"""
|
|
18
|
+
RawToolDef example:
|
|
19
|
+
{
|
|
20
|
+
"name": "get_current_weather",
|
|
21
|
+
"description": "Get the current weather in a given location",
|
|
22
|
+
"parameters": {
|
|
23
|
+
"type": "object",
|
|
24
|
+
"properties": {
|
|
25
|
+
"location": {
|
|
26
|
+
"type": "string",
|
|
27
|
+
"description": "The city and state, e.g. San Francisco, CA",
|
|
28
|
+
},
|
|
29
|
+
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
|
30
|
+
},
|
|
31
|
+
"required": ["location"],
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
"""
|
|
35
|
+
RawToolDef = dict[str, Any]
|
|
36
|
+
|
|
37
|
+
@dataclasses.dataclass
|
|
38
|
+
class ToolDef:
|
|
39
|
+
name: str
|
|
40
|
+
description: str
|
|
41
|
+
execute: ToolFn
|
|
42
|
+
|
|
43
|
+
ToolLike = ToolDef | RawToolDef | ToolFn
|
|
44
|
+
|
|
45
|
+
def _python_type_to_json_schema(python_type: Any) -> dict[str, Any]:
|
|
46
|
+
"""Convert Python type annotation to a JSON Schema for a parameter.
|
|
47
|
+
|
|
48
|
+
Supported mappings (subset tailored for LLM tool schemas):
|
|
49
|
+
- Primitives: str/int/float/bool -> string/integer/number/boolean
|
|
50
|
+
- bytes -> string with contentEncoding base64
|
|
51
|
+
- datetime/date/time -> string with format date-time/date/time
|
|
52
|
+
- list[T] / Sequence[T] / set[T] / frozenset[T] -> array with items=schema(T)
|
|
53
|
+
- set/frozenset include uniqueItems=true
|
|
54
|
+
- list without type args defaults items to string
|
|
55
|
+
- dict[K,V] / Mapping[K,V] -> object with additionalProperties=schema(V)
|
|
56
|
+
- dict without type args defaults additionalProperties to string
|
|
57
|
+
- tuple[T1, T2, ...] -> array with prefixItems per element and min/maxItems
|
|
58
|
+
- tuple[T, ...] -> array with items=schema(T)
|
|
59
|
+
- Union[X, Y] and X | Y -> oneOf=[schema(X), schema(Y)] (without top-level type)
|
|
60
|
+
- Optional[T] (Union[T, None]) -> schema(T) (nullability not encoded)
|
|
61
|
+
- Literal[...]/Enum -> enum with appropriate type inference when uniform
|
|
62
|
+
- TypedDict -> object with properties/required per annotations
|
|
63
|
+
- dataclass/Pydantic BaseModel -> object with nested properties inferred from fields
|
|
64
|
+
"""
|
|
65
|
+
origin = get_origin(python_type)
|
|
66
|
+
args = get_args(python_type)
|
|
67
|
+
|
|
68
|
+
if _Annotated is not None and origin is _Annotated and len(args) >= 1:
|
|
69
|
+
python_type = args[0]
|
|
70
|
+
origin = get_origin(python_type)
|
|
71
|
+
args = get_args(python_type)
|
|
72
|
+
|
|
73
|
+
if python_type is Any:
|
|
74
|
+
return {"type": "string"}
|
|
75
|
+
|
|
76
|
+
primitive_map = {str: "string", int: "integer", float: "number", bool: "boolean"}
|
|
77
|
+
if python_type in primitive_map:
|
|
78
|
+
return {"type": primitive_map[python_type]}
|
|
79
|
+
|
|
80
|
+
if python_type is bytes:
|
|
81
|
+
return {"type": "string", "contentEncoding": "base64"}
|
|
82
|
+
if python_type is datetime:
|
|
83
|
+
return {"type": "string", "format": "date-time"}
|
|
84
|
+
if python_type is date:
|
|
85
|
+
return {"type": "string", "format": "date"}
|
|
86
|
+
if python_type is time:
|
|
87
|
+
return {"type": "string", "format": "time"}
|
|
88
|
+
|
|
89
|
+
if python_type is list:
|
|
90
|
+
return {"type": "array", "items": {"type": "string"}}
|
|
91
|
+
if python_type is dict:
|
|
92
|
+
return {"type": "object", "additionalProperties": {"type": "string"}}
|
|
93
|
+
|
|
94
|
+
if origin is _Literal:
|
|
95
|
+
literal_values = list(args)
|
|
96
|
+
schema_lit: dict[str, Any] = {"enum": literal_values}
|
|
97
|
+
if all(isinstance(v, bool) for v in literal_values):
|
|
98
|
+
schema_lit["type"] = "boolean"
|
|
99
|
+
elif all(isinstance(v, str) for v in literal_values):
|
|
100
|
+
schema_lit["type"] = "string"
|
|
101
|
+
elif all(isinstance(v, int) and not isinstance(v, bool) for v in literal_values):
|
|
102
|
+
schema_lit["type"] = "integer"
|
|
103
|
+
elif all(isinstance(v, int | float) and not isinstance(v, bool) for v in literal_values):
|
|
104
|
+
schema_lit["type"] = "number"
|
|
105
|
+
return schema_lit
|
|
106
|
+
|
|
107
|
+
if inspect.isclass(python_type) and issubclass(python_type, enum.Enum):
|
|
108
|
+
enum_values = [e.value for e in python_type]
|
|
109
|
+
value_types = {type(v) for v in enum_values}
|
|
110
|
+
schema: dict[str, Any] = {"enum": enum_values}
|
|
111
|
+
if value_types == {str}:
|
|
112
|
+
schema["type"] = "string"
|
|
113
|
+
elif value_types == {int}:
|
|
114
|
+
schema["type"] = "integer"
|
|
115
|
+
elif value_types <= {int, float}:
|
|
116
|
+
schema["type"] = "number"
|
|
117
|
+
elif value_types == {bool}:
|
|
118
|
+
schema["type"] = "boolean"
|
|
119
|
+
return schema
|
|
120
|
+
|
|
121
|
+
if _is_typeddict(python_type):
|
|
122
|
+
annotations: dict[str, Any] = getattr(python_type, "__annotations__", {}) or {}
|
|
123
|
+
required_keys = set(getattr(python_type, "__required_keys__", set()))
|
|
124
|
+
td_properties: dict[str, Any] = {}
|
|
125
|
+
td_required: list[str] = []
|
|
126
|
+
for field_name, field_type in annotations.items():
|
|
127
|
+
td_properties[field_name] = _python_type_to_json_schema(field_type)
|
|
128
|
+
if field_name in required_keys:
|
|
129
|
+
td_required.append(field_name)
|
|
130
|
+
schema_td: dict[str, Any] = {
|
|
131
|
+
"type": "object",
|
|
132
|
+
"properties": td_properties,
|
|
133
|
+
}
|
|
134
|
+
if td_required:
|
|
135
|
+
schema_td["required"] = td_required
|
|
136
|
+
return schema_td
|
|
137
|
+
|
|
138
|
+
if inspect.isclass(python_type) and dataclasses.is_dataclass(python_type):
|
|
139
|
+
type_hints = get_type_hints(python_type)
|
|
140
|
+
dc_properties: dict[str, Any] = {}
|
|
141
|
+
dc_required: list[str] = []
|
|
142
|
+
for field in dataclasses.fields(python_type):
|
|
143
|
+
field_type = type_hints.get(field.name, Any)
|
|
144
|
+
dc_properties[field.name] = _python_type_to_json_schema(field_type)
|
|
145
|
+
if (
|
|
146
|
+
field.default is dataclasses.MISSING
|
|
147
|
+
and getattr(field, "default_factory", dataclasses.MISSING) is dataclasses.MISSING
|
|
148
|
+
):
|
|
149
|
+
dc_required.append(field.name)
|
|
150
|
+
schema_dc: dict[str, Any] = {"type": "object", "properties": dc_properties}
|
|
151
|
+
if dc_required:
|
|
152
|
+
schema_dc["required"] = dc_required
|
|
153
|
+
return schema_dc
|
|
154
|
+
|
|
155
|
+
if inspect.isclass(python_type) and issubclass(python_type, PydanticBaseModel):
|
|
156
|
+
model_type_hints = get_type_hints(python_type)
|
|
157
|
+
pd_properties: dict[str, Any] = {}
|
|
158
|
+
pd_required: list[str] = []
|
|
159
|
+
model_fields = getattr(python_type, "model_fields", {})
|
|
160
|
+
for name, field_info in model_fields.items():
|
|
161
|
+
pd_properties[name] = _python_type_to_json_schema(model_type_hints.get(name, Any))
|
|
162
|
+
is_required = getattr(field_info, "is_required", None)
|
|
163
|
+
if callable(is_required) and is_required():
|
|
164
|
+
pd_required.append(name)
|
|
165
|
+
schema_pd: dict[str, Any] = {"type": "object", "properties": pd_properties}
|
|
166
|
+
if pd_required:
|
|
167
|
+
schema_pd["required"] = pd_required
|
|
168
|
+
return schema_pd
|
|
169
|
+
|
|
170
|
+
if origin in (list, Sequence, set, frozenset):
|
|
171
|
+
item_type = args[0] if args else Any
|
|
172
|
+
item_schema = _python_type_to_json_schema(item_type)
|
|
173
|
+
schema_arr: dict[str, Any] = {"type": "array", "items": item_schema or {"type": "string"}}
|
|
174
|
+
if origin in (set, frozenset):
|
|
175
|
+
schema_arr["uniqueItems"] = True
|
|
176
|
+
return schema_arr
|
|
177
|
+
if origin is tuple:
|
|
178
|
+
if not args:
|
|
179
|
+
return {"type": "array", "items": {"type": "string"}}
|
|
180
|
+
if len(args) == 2 and args[1] is Ellipsis:
|
|
181
|
+
return {"type": "array", "items": _python_type_to_json_schema(args[0])}
|
|
182
|
+
prefix_items = [_python_type_to_json_schema(a) for a in args]
|
|
183
|
+
return {
|
|
184
|
+
"type": "array",
|
|
185
|
+
"prefixItems": prefix_items,
|
|
186
|
+
"minItems": len(prefix_items),
|
|
187
|
+
"maxItems": len(prefix_items),
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
if origin in (dict, Mapping):
|
|
191
|
+
value_type = args[1] if len(args) >= 2 else Any
|
|
192
|
+
value_schema = _python_type_to_json_schema(value_type)
|
|
193
|
+
return {"type": "object", "additionalProperties": value_schema or {"type": "string"}}
|
|
194
|
+
|
|
195
|
+
typing_union = getattr(__import__("typing"), "Union", None)
|
|
196
|
+
if origin in (typing_union, _types.UnionType):
|
|
197
|
+
non_none_args = [a for a in args if a is not type(None)]
|
|
198
|
+
if len(non_none_args) > 1:
|
|
199
|
+
schemas = [_python_type_to_json_schema(arg) for arg in non_none_args]
|
|
200
|
+
return {"oneOf": schemas}
|
|
201
|
+
if non_none_args:
|
|
202
|
+
return _python_type_to_json_schema(non_none_args[0])
|
|
203
|
+
return {"type": "string"}
|
|
204
|
+
|
|
205
|
+
return {"type": "string"}
|
|
206
|
+
|
|
207
|
+
def _parse_callable_properties(func: ToolFn) -> tuple[dict[str, dict[str, Any]], list[str]]:
|
|
208
|
+
sig = inspect.signature(func)
|
|
209
|
+
type_hints = get_type_hints(func)
|
|
210
|
+
|
|
211
|
+
properties: dict[str, dict[str, Any]] = {}
|
|
212
|
+
required: list[str] = []
|
|
213
|
+
|
|
214
|
+
for param_name, param in sig.parameters.items():
|
|
215
|
+
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
|
|
216
|
+
continue
|
|
217
|
+
|
|
218
|
+
annotated_type = type_hints.get(param_name, str)
|
|
219
|
+
param_schema = _python_type_to_json_schema(annotated_type)
|
|
220
|
+
|
|
221
|
+
type_name = getattr(annotated_type, "__name__", str(annotated_type))
|
|
222
|
+
properties[param_name] = {
|
|
223
|
+
**param_schema,
|
|
224
|
+
"description": f"Parameter {param_name} of type {type_name}",
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
if param.default == inspect.Parameter.empty:
|
|
228
|
+
required.append(param_name)
|
|
229
|
+
|
|
230
|
+
return properties, required
|
|
231
|
+
|
|
232
|
+
def generate_tool_definition_from_callable(func: ToolFn) -> dict[str, Any]:
|
|
233
|
+
"""Convert a Python callable to OpenAI tools format.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
func: A Python callable (function) to convert to a tool
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
Dictionary in OpenAI tools format
|
|
240
|
+
|
|
241
|
+
Raises:
|
|
242
|
+
ValueError: If the function doesn't have proper docstring or type annotations
|
|
243
|
+
|
|
244
|
+
Example:
|
|
245
|
+
>>> def get_weather(location: str, unit: str = "celsius") -> str:
|
|
246
|
+
... '''Get weather information for a location.'''
|
|
247
|
+
... return f"Weather in {location} is sunny, 25°{unit[0].upper()}"
|
|
248
|
+
>>>
|
|
249
|
+
>>> tool = generate_tool_definition_from_callable(get_weather)
|
|
250
|
+
>>> # Returns OpenAI tools format dict
|
|
251
|
+
|
|
252
|
+
"""
|
|
253
|
+
if not func.__doc__:
|
|
254
|
+
msg = f"Function {func.__name__} must have a docstring"
|
|
255
|
+
raise ValueError(msg)
|
|
256
|
+
|
|
257
|
+
properties, required = _parse_callable_properties(func)
|
|
258
|
+
return {
|
|
259
|
+
"type": "function",
|
|
260
|
+
"function": {
|
|
261
|
+
"name": func.__name__,
|
|
262
|
+
"description": func.__doc__.strip(),
|
|
263
|
+
"parameters": {"type": "object", "properties": properties, "required": required},
|
|
264
|
+
},
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
def generate_tool_definition_from_tool_def(tool_def: ToolDef) -> dict[str, Any]:
|
|
268
|
+
"""Convert a ToolDef to OpenAI tools format.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
tool_def: A ToolDef to convert to a tool
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
Dictionary in OpenAI tools format
|
|
275
|
+
|
|
276
|
+
Example:
|
|
277
|
+
>>> tool_def = ToolDef(
|
|
278
|
+
... name="get_weather",
|
|
279
|
+
... description="Get weather information for a location.",
|
|
280
|
+
... execute=SomeFunction(),
|
|
281
|
+
... )
|
|
282
|
+
>>> tool = generate_tool_definition_from_tool_def(tool_def)
|
|
283
|
+
>>> # Returns OpenAI tools format dict
|
|
284
|
+
"""
|
|
285
|
+
properties, required = _parse_callable_properties(tool_def.execute)
|
|
286
|
+
return {
|
|
287
|
+
"type": "function",
|
|
288
|
+
"function": {
|
|
289
|
+
"name": tool_def.name,
|
|
290
|
+
"description": tool_def.description,
|
|
291
|
+
"parameters": {"type": "object", "properties": properties, "required": required},
|
|
292
|
+
},
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
def generate_tool_definition_from_raw_tool_def(raw_tool_def: RawToolDef) -> dict[str, Any]:
|
|
296
|
+
return {
|
|
297
|
+
"type": "function",
|
|
298
|
+
"function": raw_tool_def,
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
def prepare_tools(tools: list[ToolLike]) -> list[dict]:
|
|
302
|
+
tool_defs = []
|
|
303
|
+
for tool in tools:
|
|
304
|
+
if callable(tool):
|
|
305
|
+
tool_defs.append(generate_tool_definition_from_callable(tool))
|
|
306
|
+
elif isinstance(tool, ToolDef):
|
|
307
|
+
tool_defs.append(generate_tool_definition_from_tool_def(tool))
|
|
308
|
+
else:
|
|
309
|
+
tool_defs.append(generate_tool_definition_from_raw_tool_def(tool))
|
|
310
|
+
return tool_defs
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
from functools import singledispatch
|
|
4
|
+
from typing import Any, Awaitable, Callable, cast
|
|
5
|
+
from types import FunctionType, MethodType, CoroutineType
|
|
6
|
+
from . import ToolDef
|
|
7
|
+
|
|
8
|
+
async def _coroutine_wrapper(awaitable: Awaitable[Any]) -> CoroutineType:
|
|
9
|
+
return await awaitable
|
|
10
|
+
|
|
11
|
+
def _arguments_normalizer(arguments: str | dict) -> dict:
|
|
12
|
+
if isinstance(arguments, str):
|
|
13
|
+
parsed = json.loads(arguments)
|
|
14
|
+
return cast(dict, parsed)
|
|
15
|
+
elif isinstance(arguments, dict):
|
|
16
|
+
return arguments
|
|
17
|
+
else:
|
|
18
|
+
raise ValueError(f"Invalid arguments type: {type(arguments)}")
|
|
19
|
+
|
|
20
|
+
def _result_normalizer(result: Any) -> str:
|
|
21
|
+
if isinstance(result, str):
|
|
22
|
+
return result
|
|
23
|
+
return json.dumps(result, ensure_ascii=False)
|
|
24
|
+
|
|
25
|
+
@singledispatch
|
|
26
|
+
def execute_tool_sync(tool, arguments: str | dict) -> str:
|
|
27
|
+
raise ValueError(f"Invalid tool type: {type(tool)}")
|
|
28
|
+
|
|
29
|
+
@execute_tool_sync.register(FunctionType)
|
|
30
|
+
@execute_tool_sync.register(MethodType)
|
|
31
|
+
def _(toolfn: Callable, arguments: str | dict) -> str:
|
|
32
|
+
arguments = _arguments_normalizer(arguments)
|
|
33
|
+
result = asyncio.run(_coroutine_wrapper(toolfn(**arguments)))\
|
|
34
|
+
if asyncio.iscoroutinefunction(toolfn)\
|
|
35
|
+
else toolfn(**arguments)
|
|
36
|
+
return _result_normalizer(result)
|
|
37
|
+
|
|
38
|
+
@execute_tool_sync.register(ToolDef)
|
|
39
|
+
def _(tooldef: ToolDef, arguments: str | dict) -> str:
|
|
40
|
+
arguments = _arguments_normalizer(arguments)
|
|
41
|
+
result = asyncio.run(_coroutine_wrapper(tooldef.execute(**arguments)))\
|
|
42
|
+
if asyncio.iscoroutinefunction(tooldef.execute)\
|
|
43
|
+
else tooldef.execute(**arguments)
|
|
44
|
+
return _result_normalizer(result)
|
|
45
|
+
|
|
46
|
+
@singledispatch
|
|
47
|
+
async def execute_tool(tool, arguments: str | dict) -> str:
|
|
48
|
+
raise ValueError(f"Invalid tool type: {type(tool)}")
|
|
49
|
+
|
|
50
|
+
@execute_tool.register(FunctionType)
|
|
51
|
+
@execute_tool.register(MethodType)
|
|
52
|
+
async def _(toolfn: Callable, arguments: str | dict) -> str:
|
|
53
|
+
arguments = _arguments_normalizer(arguments)
|
|
54
|
+
result = await toolfn(**arguments)\
|
|
55
|
+
if asyncio.iscoroutinefunction(toolfn)\
|
|
56
|
+
else toolfn(**arguments)
|
|
57
|
+
return _result_normalizer(result)
|
|
58
|
+
|
|
59
|
+
@execute_tool.register(ToolDef)
|
|
60
|
+
async def _(tooldef: ToolDef, arguments: str | dict) -> str:
|
|
61
|
+
arguments = _arguments_normalizer(arguments)
|
|
62
|
+
result = await tooldef.execute(**arguments)\
|
|
63
|
+
if asyncio.iscoroutinefunction(tooldef.execute)\
|
|
64
|
+
else tooldef.execute(**arguments)
|
|
65
|
+
return _result_normalizer(result)
|
liteai_sdk/tool/utils.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from . import ToolFn, ToolDef, ToolLike
|
|
2
|
+
|
|
3
|
+
def find_tool_by_name(tools: list[ToolLike], name: str) -> ToolLike | None:
|
|
4
|
+
for tool in tools:
|
|
5
|
+
if callable(tool) and tool.__name__ == name:
|
|
6
|
+
return tool
|
|
7
|
+
elif isinstance(tool, ToolDef) and tool.name == name:
|
|
8
|
+
return tool
|
|
9
|
+
elif isinstance(tool, dict) and tool.get("name") == name:
|
|
10
|
+
return tool
|
|
11
|
+
return None
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import dataclasses
|
|
3
|
+
import queue
|
|
4
|
+
from typing import Any, Generator, Literal
|
|
5
|
+
from collections.abc import AsyncGenerator, Generator
|
|
6
|
+
from ..tool import ToolLike
|
|
7
|
+
from .message import ChatMessage, AssistantMessage, ToolMessage, MessageChunk
|
|
8
|
+
|
|
9
|
+
@dataclasses.dataclass
|
|
10
|
+
class LlmRequestParams:
|
|
11
|
+
model: str
|
|
12
|
+
messages: list[ChatMessage]
|
|
13
|
+
tools: list[ToolLike] | None = None
|
|
14
|
+
tool_choice: Literal["auto", "required", "none"] = "auto"
|
|
15
|
+
execute_tools: bool = False
|
|
16
|
+
|
|
17
|
+
timeout_sec: float | None = None
|
|
18
|
+
temperature: float | None = None
|
|
19
|
+
max_tokens: int | None = None
|
|
20
|
+
headers: dict[str, str] | None = None
|
|
21
|
+
|
|
22
|
+
extra_args: dict[str, Any] | None = None
|
|
23
|
+
|
|
24
|
+
# --- --- --- --- --- ---
|
|
25
|
+
|
|
26
|
+
GenerateTextResponse = list[AssistantMessage | ToolMessage]
|
|
27
|
+
StreamTextResponseSync = tuple[
|
|
28
|
+
Generator[MessageChunk],
|
|
29
|
+
queue.Queue[AssistantMessage | ToolMessage | None]]
|
|
30
|
+
StreamTextResponseAsync = tuple[
|
|
31
|
+
AsyncGenerator[MessageChunk],
|
|
32
|
+
asyncio.Queue[AssistantMessage | ToolMessage | None]]
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from litellm.exceptions import (
|
|
2
|
+
AuthenticationError,
|
|
3
|
+
PermissionDeniedError,
|
|
4
|
+
RateLimitError,
|
|
5
|
+
ContextWindowExceededError,
|
|
6
|
+
BadRequestError,
|
|
7
|
+
InvalidRequestError,
|
|
8
|
+
InternalServerError,
|
|
9
|
+
ServiceUnavailableError,
|
|
10
|
+
ContentPolicyViolationError,
|
|
11
|
+
APIError,
|
|
12
|
+
Timeout,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"AuthenticationError",
|
|
17
|
+
"PermissionDeniedError",
|
|
18
|
+
"RateLimitError",
|
|
19
|
+
"ContextWindowExceededError",
|
|
20
|
+
"BadRequestError",
|
|
21
|
+
"InvalidRequestError",
|
|
22
|
+
"InternalServerError",
|
|
23
|
+
"ServiceUnavailableError",
|
|
24
|
+
"ContentPolicyViolationError",
|
|
25
|
+
"APIError",
|
|
26
|
+
"Timeout",
|
|
27
|
+
]
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import dataclasses
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, cast
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator
|
|
8
|
+
from litellm.types.utils import Message as LiteLlmMessage,\
|
|
9
|
+
ModelResponseStream as LiteLlmModelResponseStream,\
|
|
10
|
+
ChatCompletionAudioResponse,\
|
|
11
|
+
ChatCompletionMessageToolCall,\
|
|
12
|
+
ChatCompletionDeltaToolCall
|
|
13
|
+
from litellm.types.llms.openai import (
|
|
14
|
+
AllMessageValues,
|
|
15
|
+
OpenAIMessageContent,
|
|
16
|
+
ChatCompletionAssistantToolCall,
|
|
17
|
+
ImageURLListItem as ChatCompletionImageURL,
|
|
18
|
+
|
|
19
|
+
ChatCompletionUserMessage,
|
|
20
|
+
ChatCompletionAssistantMessage,
|
|
21
|
+
ChatCompletionToolMessage,
|
|
22
|
+
ChatCompletionSystemMessage,
|
|
23
|
+
)
|
|
24
|
+
from ..tool import ToolLike
|
|
25
|
+
from ..tool.utils import find_tool_by_name
|
|
26
|
+
from ..logger import logger
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from . import LlmRequestParams
|
|
30
|
+
|
|
31
|
+
class ChatMessage(BaseModel, ABC):
|
|
32
|
+
model_config = ConfigDict(
|
|
33
|
+
arbitrary_types_allowed=True,
|
|
34
|
+
validate_assignment=True,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def to_litellm_message(self) -> AllMessageValues: ...
|
|
39
|
+
|
|
40
|
+
class UserMessage(ChatMessage):
|
|
41
|
+
content: OpenAIMessageContent
|
|
42
|
+
role: Literal["user"] = "user"
|
|
43
|
+
|
|
44
|
+
def to_litellm_message(self) -> ChatCompletionUserMessage:
|
|
45
|
+
return ChatCompletionUserMessage(role=self.role, content=self.content)
|
|
46
|
+
|
|
47
|
+
class ToolMessage(ChatMessage):
|
|
48
|
+
"""
|
|
49
|
+
The `tool_def` field is ref to the target tool of the tool call, and
|
|
50
|
+
it will only be None when the target tool is not found
|
|
51
|
+
"""
|
|
52
|
+
id: str
|
|
53
|
+
name: str
|
|
54
|
+
arguments: str
|
|
55
|
+
result: str | None = None
|
|
56
|
+
error: str | None = None
|
|
57
|
+
role: Literal["tool"] = "tool"
|
|
58
|
+
|
|
59
|
+
_tool_def: ToolLike | None = PrivateAttr(default=None)
|
|
60
|
+
|
|
61
|
+
@field_validator("result", mode="before")
|
|
62
|
+
def validate_result(cls, v: Any) -> Any:
|
|
63
|
+
if v is None: return v
|
|
64
|
+
if isinstance(v, str): return v
|
|
65
|
+
return json.dumps(v, ensure_ascii=False)
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def tool_def(self) -> ToolLike | None:
|
|
69
|
+
return self._tool_def
|
|
70
|
+
|
|
71
|
+
def with_tool_def(self, tool_def: ToolLike) -> "ToolMessage":
|
|
72
|
+
self._tool_def = tool_def
|
|
73
|
+
return self
|
|
74
|
+
|
|
75
|
+
def to_litellm_message(self) -> ChatCompletionToolMessage:
|
|
76
|
+
if self.result is None and self.error is None:
|
|
77
|
+
raise ValueError(f"ToolMessage({self.id}, {self.name}) is incomplete, "
|
|
78
|
+
"result and error cannot be both None")
|
|
79
|
+
|
|
80
|
+
if self.error is not None:
|
|
81
|
+
content = json.dumps({"error": self.error}, ensure_ascii=False)
|
|
82
|
+
else:
|
|
83
|
+
assert self.result is not None
|
|
84
|
+
content = self.result
|
|
85
|
+
|
|
86
|
+
return ChatCompletionToolMessage(
|
|
87
|
+
role=self.role,
|
|
88
|
+
content=content,
|
|
89
|
+
tool_call_id=self.id)
|
|
90
|
+
|
|
91
|
+
ToolCallTuple = tuple[str, str, str]
|
|
92
|
+
class AssistantMessage(ChatMessage):
|
|
93
|
+
content: str | None = None
|
|
94
|
+
reasoning_content: str | None = None
|
|
95
|
+
tool_calls: list[ChatCompletionAssistantToolCall] | None = None
|
|
96
|
+
audio: ChatCompletionAudioResponse | None = None
|
|
97
|
+
images: list[ChatCompletionImageURL] | None = None
|
|
98
|
+
role: Literal["assistant"] = "assistant"
|
|
99
|
+
|
|
100
|
+
_request_params_ref: LlmRequestParams | None = PrivateAttr(default=None)
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def from_litellm_message(cls, message: LiteLlmMessage) -> "AssistantMessage":
|
|
104
|
+
tool_calls: list[ChatCompletionAssistantToolCall] | None = None
|
|
105
|
+
if (message_tool_calls := message.get("tool_calls")) is not None:
|
|
106
|
+
tool_calls = [ChatCompletionAssistantToolCall(
|
|
107
|
+
id=tool_call.id,
|
|
108
|
+
function={
|
|
109
|
+
"name": tool_call.function.name,
|
|
110
|
+
"arguments": tool_call.function.arguments,
|
|
111
|
+
},
|
|
112
|
+
type="function",
|
|
113
|
+
) for tool_call in cast(list[ChatCompletionMessageToolCall], message_tool_calls)]
|
|
114
|
+
|
|
115
|
+
return cls.model_construct(
|
|
116
|
+
content=message.get("content"),
|
|
117
|
+
reasoning_content=message.get("reasoning_content"),
|
|
118
|
+
tool_calls=tool_calls,
|
|
119
|
+
audio=message.get("audio"),
|
|
120
|
+
images=message.get("images"),
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def with_request_params(self, request_params: LlmRequestParams) -> "AssistantMessage":
|
|
124
|
+
self._request_params_ref = request_params
|
|
125
|
+
return self
|
|
126
|
+
|
|
127
|
+
def to_litellm_message(self) -> ChatCompletionAssistantMessage:
|
|
128
|
+
return ChatCompletionAssistantMessage(role=self.role,
|
|
129
|
+
content=self.content or "",
|
|
130
|
+
reasoning_content=self.reasoning_content,
|
|
131
|
+
tool_calls=self.tool_calls)
|
|
132
|
+
|
|
133
|
+
def parse_tool_calls(self) -> list[ToolCallTuple] | None:
|
|
134
|
+
if self.tool_calls is None: return None
|
|
135
|
+
results = []
|
|
136
|
+
for tool_call in self.tool_calls:
|
|
137
|
+
id = tool_call.get("id")
|
|
138
|
+
function = tool_call.get("function") # this can not be None
|
|
139
|
+
function_name = function.get("name")
|
|
140
|
+
function_arguments = function.get("arguments")
|
|
141
|
+
if id is None or\
|
|
142
|
+
function is None or\
|
|
143
|
+
function_name is None or\
|
|
144
|
+
function_arguments is None:
|
|
145
|
+
return None
|
|
146
|
+
results.append((id, function_name, function_arguments))
|
|
147
|
+
return results
|
|
148
|
+
|
|
149
|
+
def get_partial_tool_messages(self) -> list[ToolMessage] | None:
|
|
150
|
+
"""
|
|
151
|
+
Get a partial tool message from the assistant message.
|
|
152
|
+
The returned tool message is not complete,
|
|
153
|
+
it only contains the tool call id, name and arguments.
|
|
154
|
+
Returns None if there is no tool call in the assistant message.
|
|
155
|
+
"""
|
|
156
|
+
has_tool_def = self._request_params_ref is not None and\
|
|
157
|
+
self._request_params_ref.tools is not None
|
|
158
|
+
if not has_tool_def:
|
|
159
|
+
logger.warning("AssistantMessage.get_partial_tool_messages() called without request params. "
|
|
160
|
+
"Call with_request_params() first to enable auto tool_def attachment feature.")
|
|
161
|
+
|
|
162
|
+
parsed_tool_calls = self.parse_tool_calls()
|
|
163
|
+
if parsed_tool_calls is None: return None
|
|
164
|
+
|
|
165
|
+
results = []
|
|
166
|
+
for tool_call in parsed_tool_calls:
|
|
167
|
+
id, name, arguments = tool_call
|
|
168
|
+
|
|
169
|
+
tool_message = ToolMessage(
|
|
170
|
+
id=id,
|
|
171
|
+
name=name,
|
|
172
|
+
arguments=arguments,
|
|
173
|
+
result=None,
|
|
174
|
+
error=None)
|
|
175
|
+
|
|
176
|
+
if has_tool_def:
|
|
177
|
+
assert self._request_params_ref and self._request_params_ref.tools
|
|
178
|
+
target_tool = find_tool_by_name(self._request_params_ref.tools, name)
|
|
179
|
+
if target_tool:
|
|
180
|
+
tool_message = tool_message.with_tool_def(target_tool)
|
|
181
|
+
else:
|
|
182
|
+
logger.warning(f"Tool {name} not found in request params, "
|
|
183
|
+
"tool_def will not be attached to the tool message")
|
|
184
|
+
|
|
185
|
+
results.append(tool_message)
|
|
186
|
+
return results
|
|
187
|
+
|
|
188
|
+
class SystemMessage(ChatMessage):
|
|
189
|
+
content: str
|
|
190
|
+
role: Literal["system"] = "system"
|
|
191
|
+
|
|
192
|
+
def to_litellm_message(self) -> ChatCompletionSystemMessage:
|
|
193
|
+
return ChatCompletionSystemMessage(role=self.role, content=self.content)
|
|
194
|
+
|
|
195
|
+
@dataclasses.dataclass
|
|
196
|
+
class TextChunk:
|
|
197
|
+
content: str
|
|
198
|
+
|
|
199
|
+
@dataclasses.dataclass
|
|
200
|
+
class ReasoningChunk:
|
|
201
|
+
content: str
|
|
202
|
+
|
|
203
|
+
@dataclasses.dataclass
|
|
204
|
+
class AudioChunk:
|
|
205
|
+
data: ChatCompletionAudioResponse
|
|
206
|
+
|
|
207
|
+
@dataclasses.dataclass
|
|
208
|
+
class ImageChunk:
|
|
209
|
+
data: list[ChatCompletionImageURL]
|
|
210
|
+
|
|
211
|
+
@dataclasses.dataclass
|
|
212
|
+
class ToolCallChunk:
|
|
213
|
+
id: str | None
|
|
214
|
+
name: str | None
|
|
215
|
+
arguments: str
|
|
216
|
+
index: int
|
|
217
|
+
|
|
218
|
+
MessageChunk = TextChunk | ReasoningChunk | AudioChunk | ImageChunk | ToolCallChunk
|
|
219
|
+
|
|
220
|
+
def openai_chunk_normalizer(
|
|
221
|
+
chunk: LiteLlmModelResponseStream
|
|
222
|
+
) -> list[MessageChunk]:
|
|
223
|
+
if len(chunk.choices) == 0: return []
|
|
224
|
+
|
|
225
|
+
result = []
|
|
226
|
+
delta = chunk.choices[0].delta
|
|
227
|
+
if delta.get("content"):
|
|
228
|
+
result.append(TextChunk(cast(str, delta.content)))
|
|
229
|
+
if delta.get("reasoning_content"):
|
|
230
|
+
result.append(ReasoningChunk(cast(str, delta.reasoning_content)))
|
|
231
|
+
if delta.get("audio"):
|
|
232
|
+
result.append(AudioChunk(cast(ChatCompletionAudioResponse, delta.audio)))
|
|
233
|
+
if delta.get("images"):
|
|
234
|
+
result.append(ImageChunk(cast(list[ChatCompletionImageURL], delta.images)))
|
|
235
|
+
if delta.get("tool_calls"):
|
|
236
|
+
for tool_call in cast(list[ChatCompletionDeltaToolCall], delta.tool_calls):
|
|
237
|
+
result.append(ToolCallChunk(
|
|
238
|
+
tool_call.id,
|
|
239
|
+
tool_call.function.name,
|
|
240
|
+
tool_call.function.arguments,
|
|
241
|
+
tool_call.index))
|
|
242
|
+
return result
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: liteai_sdk
|
|
3
|
+
Version: 0.3.21
|
|
4
|
+
Summary: A wrapper of LiteLLM
|
|
5
|
+
Author-email: BHznJNs <bhznjns@outlook.com>
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Classifier: Development Status :: 3 - Alpha
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
Requires-Dist: litellm>=1.80.0
|
|
17
|
+
Requires-Dist: pydantic>=2.0.0
|
|
18
|
+
Requires-Dist: python-dotenv>=1.2.1 ; extra == "dev"
|
|
19
|
+
Requires-Dist: pytest-cov ; extra == "test"
|
|
20
|
+
Requires-Dist: pytest-mock ; extra == "test"
|
|
21
|
+
Requires-Dist: pytest-runner ; extra == "test"
|
|
22
|
+
Requires-Dist: pytest ; extra == "test"
|
|
23
|
+
Requires-Dist: pytest-github-actions-annotate-failures ; extra == "test"
|
|
24
|
+
Project-URL: Source, https://github.com/BHznJNs/liteai
|
|
25
|
+
Project-URL: Tracker, https://github.com/BHznJNs/liteai/issues
|
|
26
|
+
Provides-Extra: dev
|
|
27
|
+
Provides-Extra: test
|
|
28
|
+
|
|
29
|
+
# LiteAI-SDK
|
|
30
|
+
|
|
31
|
+
LiteAI-SDK is a wrapper of LiteLLM which provides a more intuitive API and [AI SDK](https://github.com/vercel/ai) like DX.
|
|
32
|
+
|
|
33
|
+
## Installation
|
|
34
|
+
|
|
35
|
+
```
|
|
36
|
+
pip install liteai-sdk
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
### Develop with coding agent
|
|
40
|
+
|
|
41
|
+
You can access the complete usage guidance with [llms.txt](https://raw.githubusercontent.com/BHznJNs/liteai/refs/heads/main/llms.txt), just give it to your coding agent to tell it how to use LiteAI-SDK.
|
|
42
|
+
|
|
43
|
+
## Examples
|
|
44
|
+
|
|
45
|
+
Below is a simple example of just a API call:
|
|
46
|
+
|
|
47
|
+
```python
|
|
48
|
+
import os
|
|
49
|
+
from dotenv import load_dotenv
|
|
50
|
+
from liteai_sdk import LLM, LlmProviders, LlmRequestParams, UserMessage
|
|
51
|
+
|
|
52
|
+
load_dotenv()
|
|
53
|
+
|
|
54
|
+
llm = LLM(provider=LlmProviders.OPENAI,
|
|
55
|
+
api_key=os.getenv("API_KEY", ""),
|
|
56
|
+
base_url=os.getenv("BASE_URL", ""))
|
|
57
|
+
|
|
58
|
+
response = llm.generate_text_sync( # sync API of generate_text
|
|
59
|
+
LlmRequestParams(
|
|
60
|
+
model="deepseek-v3.1",
|
|
61
|
+
messages=[UserMessage(content="Hello.")]))
|
|
62
|
+
print(response)
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
Below is an example that shows the automatically tool call:
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
import os
|
|
69
|
+
from dotenv import load_dotenv
|
|
70
|
+
from liteai_sdk import LLM, LlmProviders, LlmRequestParams, UserMessage
|
|
71
|
+
|
|
72
|
+
load_dotenv()
|
|
73
|
+
|
|
74
|
+
def example_tool():
|
|
75
|
+
"""
|
|
76
|
+
This is a test tool that is used to test the tool calling functionality.
|
|
77
|
+
"""
|
|
78
|
+
print("The example tool is called.")
|
|
79
|
+
return "Hello World"
|
|
80
|
+
|
|
81
|
+
llm = LLM(provider=LlmProviders.OPENAI,
|
|
82
|
+
api_key=os.getenv("API_KEY", ""),
|
|
83
|
+
base_url=os.getenv("BASE_URL", ""))
|
|
84
|
+
|
|
85
|
+
params = LlmRequestParams(
|
|
86
|
+
model="deepseek-v3.1",
|
|
87
|
+
tools=[example_tool],
|
|
88
|
+
execute_tools=True,
|
|
89
|
+
messages=[UserMessage(content="Please call the tool example_tool.")])
|
|
90
|
+
|
|
91
|
+
print("User: ", "Please call the tool example_tool.")
|
|
92
|
+
messages = llm.generate_text_sync(params)
|
|
93
|
+
for message in messages:
|
|
94
|
+
match message.role:
|
|
95
|
+
case "assistant":
|
|
96
|
+
print("Assistant: ", message.content)
|
|
97
|
+
case "tool":
|
|
98
|
+
print("Tool: ", message.result)
|
|
99
|
+
```
|
|
100
|
+
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
liteai_sdk/__init__.py,sha256=KTHDeLyGGtVA-H8nJhJMBr_rKFV1h5cD5voZ8oPXI00,10608
|
|
2
|
+
liteai_sdk/debug.py,sha256=T7qIy1BeeUGlF40l9JCMMVn8pvvMJAEQeG4adQbOydA,69
|
|
3
|
+
liteai_sdk/logger.py,sha256=99vJAQRKcu4CuHgZYAJ2zDQtGea6Bn3vJJrS-mtza7c,677
|
|
4
|
+
liteai_sdk/param_parser.py,sha256=xykvUesZzwZNf4-n1j4JfVk0L2y_wvnSWSsHo5vjBU8,1655
|
|
5
|
+
liteai_sdk/stream.py,sha256=T9MLmgPC8te6qvSkBOh7vkl-I4OGCKuW1kEN6RkiCe0,3176
|
|
6
|
+
liteai_sdk/tool/__init__.py,sha256=c1qJaEpoYlgOCtAjFODhrSR73ZW17OuamsO__yeYAkY,12150
|
|
7
|
+
liteai_sdk/tool/execute.py,sha256=1CfRlJZgqoev42fDH4vygXyEtCEEBPcRfbqaP77jxu4,2462
|
|
8
|
+
liteai_sdk/tool/utils.py,sha256=Djd1-EoLPfIqgPbWWvOreozQ76NHX4FZ6OXc1evKqPM,409
|
|
9
|
+
liteai_sdk/types/__init__.py,sha256=CMmweIGMgreZlbvBtRTKfvdcC7war2ApLNf-9Fz0yzc,1006
|
|
10
|
+
liteai_sdk/types/exceptions.py,sha256=hIGu06htOJxfEBAHx7KTvLQr0Y8GYnBLFJFlr_IGpDs,602
|
|
11
|
+
liteai_sdk/types/message.py,sha256=AnhJ5wKKcWuAt0lW3mPXpIyvUBy3u-iFLa1dpeUTp18,8785
|
|
12
|
+
liteai_sdk-0.3.21.dist-info/licenses/LICENSE,sha256=cTeVgQVJJcRdm1boa2P1FBnOeXfA_egV6s4PouyrCxg,1064
|
|
13
|
+
liteai_sdk-0.3.21.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
|
14
|
+
liteai_sdk-0.3.21.dist-info/METADATA,sha256=uUYWHL4MKkSTsqokLBEN2EcJd60HkinQthNXRRTabzU,3024
|
|
15
|
+
liteai_sdk-0.3.21.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 BHznJNs
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|