liteai-sdk 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 BHznJNs
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,100 @@
1
+ Metadata-Version: 2.4
2
+ Name: liteai_sdk
3
+ Version: 0.3.2
4
+ Summary: A wrapper of LiteLLM
5
+ Author-email: BHznJNs <bhznjns@outlook.com>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Programming Language :: Python :: 3 :: Only
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ License-File: LICENSE
16
+ Requires-Dist: litellm>=1.80.0
17
+ Requires-Dist: pydantic>=2.0.0
18
+ Requires-Dist: python-dotenv>=1.2.1 ; extra == "dev"
19
+ Requires-Dist: pytest-cov ; extra == "test"
20
+ Requires-Dist: pytest-mock ; extra == "test"
21
+ Requires-Dist: pytest-runner ; extra == "test"
22
+ Requires-Dist: pytest ; extra == "test"
23
+ Requires-Dist: pytest-github-actions-annotate-failures ; extra == "test"
24
+ Project-URL: Source, https://github.com/BHznJNs/liteai
25
+ Project-URL: Tracker, https://github.com/BHznJNs/liteai/issues
26
+ Provides-Extra: dev
27
+ Provides-Extra: test
28
+
29
+ # LiteAI-SDK
30
+
31
+ LiteAI-SDK is a wrapper of LiteLLM which provides a more intuitive API and [AI SDK](https://github.com/vercel/ai) like DX.
32
+
33
+ ## Installation
34
+
35
+ ```
36
+ pip install liteai-sdk
37
+ ```
38
+
39
+ ### Develop with coding agent
40
+
41
+ You can access the complete usage guidance with [llms.txt](https://raw.githubusercontent.com/BHznJNs/liteai/refs/heads/main/llms.txt), just give it to your coding agent to tell it how to use LiteAI-SDK.
42
+
43
+ ## Examples
44
+
45
+ Below is a simple example of just a API call:
46
+
47
+ ```python
48
+ import os
49
+ from dotenv import load_dotenv
50
+ from liteai_sdk import LLM, LlmProviders, LlmRequestParams, UserMessage
51
+
52
+ load_dotenv()
53
+
54
+ llm = LLM(provider=LlmProviders.OPENAI,
55
+ api_key=os.getenv("API_KEY", ""),
56
+ base_url=os.getenv("BASE_URL", ""))
57
+
58
+ response = llm.generate_text_sync( # sync API of generate_text
59
+ LlmRequestParams(
60
+ model="deepseek-v3.1",
61
+ messages=[UserMessage("Hello.")]))
62
+ print(response)
63
+ ```
64
+
65
+ Below is an example that shows the automatically tool call:
66
+
67
+ ```python
68
+ import os
69
+ from dotenv import load_dotenv
70
+ from liteai_sdk import LLM, LlmProviders, LlmRequestParams, UserMessage
71
+
72
+ load_dotenv()
73
+
74
+ def example_tool():
75
+ """
76
+ This is a test tool that is used to test the tool calling functionality.
77
+ """
78
+ print("The example tool is called.")
79
+ return "Hello World"
80
+
81
+ llm = LLM(provider=LlmProviders.OPENAI,
82
+ api_key=os.getenv("API_KEY", ""),
83
+ base_url=os.getenv("BASE_URL", ""))
84
+
85
+ params = LlmRequestParams(
86
+ model="deepseek-v3.1",
87
+ tools=[example_tool],
88
+ execute_tools=True,
89
+ messages=[UserMessage("Please call the tool example_tool.")])
90
+
91
+ print("User: ", "Please call the tool example_tool.")
92
+ messages = llm.generate_text_sync(params)
93
+ for message in messages:
94
+ match message.role:
95
+ case "assistant":
96
+ print("Assistant: ", message.content)
97
+ case "tool":
98
+ print("Tool: ", message.result)
99
+ ```
100
+
@@ -0,0 +1,71 @@
1
+ # LiteAI-SDK
2
+
3
+ LiteAI-SDK is a wrapper of LiteLLM which provides a more intuitive API and [AI SDK](https://github.com/vercel/ai) like DX.
4
+
5
+ ## Installation
6
+
7
+ ```
8
+ pip install liteai-sdk
9
+ ```
10
+
11
+ ### Develop with coding agent
12
+
13
+ You can access the complete usage guidance with [llms.txt](https://raw.githubusercontent.com/BHznJNs/liteai/refs/heads/main/llms.txt), just give it to your coding agent to tell it how to use LiteAI-SDK.
14
+
15
+ ## Examples
16
+
17
+ Below is a simple example of just a API call:
18
+
19
+ ```python
20
+ import os
21
+ from dotenv import load_dotenv
22
+ from liteai_sdk import LLM, LlmProviders, LlmRequestParams, UserMessage
23
+
24
+ load_dotenv()
25
+
26
+ llm = LLM(provider=LlmProviders.OPENAI,
27
+ api_key=os.getenv("API_KEY", ""),
28
+ base_url=os.getenv("BASE_URL", ""))
29
+
30
+ response = llm.generate_text_sync( # sync API of generate_text
31
+ LlmRequestParams(
32
+ model="deepseek-v3.1",
33
+ messages=[UserMessage("Hello.")]))
34
+ print(response)
35
+ ```
36
+
37
+ Below is an example that shows the automatically tool call:
38
+
39
+ ```python
40
+ import os
41
+ from dotenv import load_dotenv
42
+ from liteai_sdk import LLM, LlmProviders, LlmRequestParams, UserMessage
43
+
44
+ load_dotenv()
45
+
46
+ def example_tool():
47
+ """
48
+ This is a test tool that is used to test the tool calling functionality.
49
+ """
50
+ print("The example tool is called.")
51
+ return "Hello World"
52
+
53
+ llm = LLM(provider=LlmProviders.OPENAI,
54
+ api_key=os.getenv("API_KEY", ""),
55
+ base_url=os.getenv("BASE_URL", ""))
56
+
57
+ params = LlmRequestParams(
58
+ model="deepseek-v3.1",
59
+ tools=[example_tool],
60
+ execute_tools=True,
61
+ messages=[UserMessage("Please call the tool example_tool.")])
62
+
63
+ print("User: ", "Please call the tool example_tool.")
64
+ messages = llm.generate_text_sync(params)
65
+ for message in messages:
66
+ match message.role:
67
+ case "assistant":
68
+ print("Assistant: ", message.content)
69
+ case "tool":
70
+ print("Tool: ", message.result)
71
+ ```
@@ -0,0 +1,231 @@
1
+ [build-system]
2
+ requires = ["flit_core >=2,<4"]
3
+ build-backend = "flit_core.buildapi"
4
+
5
+ [project]
6
+ name = "liteai_sdk"
7
+ authors = [{ name = "BHznJNs", email = "bhznjns@outlook.com" }]
8
+ description = "A wrapper of LiteLLM"
9
+ readme = "README.md"
10
+ classifiers = [
11
+ "Development Status :: 3 - Alpha",
12
+ "Intended Audience :: Developers",
13
+ "License :: OSI Approved :: MIT License",
14
+ "Programming Language :: Python :: 3 :: Only",
15
+ "Programming Language :: Python :: 3.10",
16
+ "Programming Language :: Python :: 3.11",
17
+ "Programming Language :: Python :: 3.12",
18
+ ]
19
+ requires-python = ">=3.10"
20
+ version = "0.3.2"
21
+
22
+ dependencies = [
23
+ "litellm>=1.80.0",
24
+ "pydantic>=2.0.0",
25
+ ]
26
+
27
+ [project.optional-dependencies]
28
+ dev = [
29
+ "python-dotenv>=1.2.1"
30
+ ]
31
+ test = [
32
+ "pytest-cov",
33
+ "pytest-mock",
34
+ "pytest-runner",
35
+ "pytest",
36
+ "pytest-github-actions-annotate-failures",
37
+ ]
38
+
39
+ [project.urls]
40
+ Source = "https://github.com/BHznJNs/liteai"
41
+ Tracker = "https://github.com/BHznJNs/liteai/issues"
42
+
43
+ [tool.flit.module]
44
+ name = "liteai_sdk"
45
+
46
+ [tool.bandit]
47
+ exclude_dirs = ["build", "dist", "tests", "scripts"]
48
+ number = 4
49
+ recursive = true
50
+ targets = "src"
51
+
52
+ [tool.black]
53
+ line-length = 120
54
+ fast = true
55
+
56
+ [tool.coverage.run]
57
+ branch = true
58
+
59
+ [tool.coverage.report]
60
+ fail_under = 100
61
+
62
+ [tool.flake8]
63
+ max-line-length = 120
64
+ select = "F,E,W,B,B901,B902,B903"
65
+ exclude = [
66
+ ".eggs",
67
+ ".git",
68
+ ".tox",
69
+ "nssm",
70
+ "obj",
71
+ "out",
72
+ "packages",
73
+ "pywin32",
74
+ "tests",
75
+ "swagger_client",
76
+ ]
77
+ ignore = ["E722", "B001", "W503", "E203"]
78
+
79
+ [tool.pyright]
80
+ include = ["src"]
81
+ exclude = ["**/node_modules", "**/__pycache__"]
82
+ venv = "env37"
83
+
84
+ reportMissingImports = true
85
+ reportMissingTypeStubs = false
86
+
87
+ pythonVersion = "3.10"
88
+ pythonPlatform = "Linux"
89
+
90
+ executionEnvironments = [{ root = "src" }]
91
+
92
+ [tool.pytest.ini_options]
93
+ addopts = "--cov-report xml:coverage.xml --cov src --cov-fail-under 0 --cov-append -m 'not integration'"
94
+ pythonpath = ["src"]
95
+ testpaths = "tests"
96
+ junit_family = "xunit2"
97
+ markers = [
98
+ "integration: marks as integration test",
99
+ "notebooks: marks as notebook test",
100
+ "gpu: marks as gpu test",
101
+ "spark: marks tests which need Spark",
102
+ "slow: marks tests as slow",
103
+ "unit: fast offline tests",
104
+ ]
105
+
106
+ [tool.pylint]
107
+ extension-pkg-whitelist = [
108
+ "numpy",
109
+ "torch",
110
+ "cv2",
111
+ "pyodbc",
112
+ "pydantic",
113
+ "ciso8601",
114
+ "netcdf4",
115
+ "scipy",
116
+ ]
117
+ ignore = "CVS"
118
+ ignore-patterns = "test.*?py,conftest.py"
119
+ init-hook = 'import sys; sys.setrecursionlimit(8 * sys.getrecursionlimit())'
120
+ jobs = 0
121
+ limit-inference-results = 100
122
+ persistent = "yes"
123
+ suggestion-mode = "yes"
124
+ unsafe-load-any-extension = "no"
125
+
126
+ [tool.pylint.'MESSAGES CONTROL']
127
+ enable = "c-extension-no-member"
128
+
129
+ [tool.pylint.'REPORTS']
130
+ evaluation = "10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)"
131
+ output-format = "text"
132
+ reports = "no"
133
+ score = "yes"
134
+
135
+ [tool.pylint.'REFACTORING']
136
+ max-nested-blocks = 5
137
+ never-returning-functions = "sys.exit"
138
+
139
+ [tool.pylint.'BASIC']
140
+ argument-naming-style = "snake_case"
141
+ attr-naming-style = "snake_case"
142
+ bad-names = ["foo", "bar"]
143
+ class-attribute-naming-style = "any"
144
+ class-naming-style = "PascalCase"
145
+ const-naming-style = "UPPER_CASE"
146
+ docstring-min-length = -1
147
+ function-naming-style = "snake_case"
148
+ good-names = ["i", "j", "k", "ex", "Run", "_"]
149
+ include-naming-hint = "yes"
150
+ inlinevar-naming-style = "any"
151
+ method-naming-style = "snake_case"
152
+ module-naming-style = "any"
153
+ no-docstring-rgx = "^_"
154
+ property-classes = "abc.abstractproperty"
155
+ variable-naming-style = "snake_case"
156
+
157
+ [tool.pylint.'FORMAT']
158
+ ignore-long-lines = "^\\s*(# )?.*['\"]?<?https?://\\S+>?"
159
+ indent-after-paren = 4
160
+ indent-string = ' '
161
+ max-line-length = 120
162
+ max-module-lines = 1000
163
+ single-line-class-stmt = "no"
164
+ single-line-if-stmt = "no"
165
+
166
+ [tool.pylint.'LOGGING']
167
+ logging-format-style = "old"
168
+ logging-modules = "logging"
169
+
170
+ [tool.pylint.'MISCELLANEOUS']
171
+ notes = ["FIXME", "XXX", "TODO"]
172
+
173
+ [tool.pylint.'SIMILARITIES']
174
+ ignore-comments = "yes"
175
+ ignore-docstrings = "yes"
176
+ ignore-imports = "yes"
177
+ min-similarity-lines = 7
178
+
179
+ [tool.pylint.'SPELLING']
180
+ max-spelling-suggestions = 4
181
+ spelling-store-unknown-words = "no"
182
+
183
+ [tool.pylint.'STRING']
184
+ check-str-concat-over-line-jumps = "no"
185
+
186
+ [tool.pylint.'TYPECHECK']
187
+ contextmanager-decorators = "contextlib.contextmanager"
188
+ generated-members = "numpy.*,np.*,pyspark.sql.functions,collect_list"
189
+ ignore-mixin-members = "yes"
190
+ ignore-none = "yes"
191
+ ignore-on-opaque-inference = "yes"
192
+ ignored-classes = "optparse.Values,thread._local,_thread._local,numpy,torch,swagger_client"
193
+ ignored-modules = "numpy,torch,swagger_client,netCDF4,scipy"
194
+ missing-member-hint = "yes"
195
+ missing-member-hint-distance = 1
196
+ missing-member-max-choices = 1
197
+
198
+ [tool.pylint.'VARIABLES']
199
+ additional-builtins = "dbutils"
200
+ allow-global-unused-variables = "yes"
201
+ callbacks = ["cb_", "_cb"]
202
+ dummy-variables-rgx = "_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_"
203
+ ignored-argument-names = "_.*|^ignored_|^unused_"
204
+ init-import = "no"
205
+ redefining-builtins-modules = "six.moves,past.builtins,future.builtins,builtins,io"
206
+
207
+ [tool.pylint.'CLASSES']
208
+ defining-attr-methods = ["__init__", "__new__", "setUp", "__post_init__"]
209
+ exclude-protected = ["_asdict", "_fields", "_replace", "_source", "_make"]
210
+ valid-classmethod-first-arg = "cls"
211
+ valid-metaclass-classmethod-first-arg = "cls"
212
+
213
+ [tool.pylint.'DESIGN']
214
+ max-args = 5
215
+ max-attributes = 7
216
+ max-bool-expr = 5
217
+ max-branches = 12
218
+ max-locals = 15
219
+ max-parents = 7
220
+ max-public-methods = 20
221
+ max-returns = 6
222
+ max-statements = 50
223
+ min-public-methods = 2
224
+
225
+ [tool.pylint.'IMPORTS']
226
+ allow-wildcard-with-all = "no"
227
+ analyse-fallback-blocks = "no"
228
+ deprecated-modules = "optparse,tkinter.tix"
229
+
230
+ [tool.pylint.'EXCEPTIONS']
231
+ overgeneral-exceptions = ["BaseException", "Exception"]
@@ -0,0 +1,220 @@
1
+ import asyncio
2
+ import json
3
+ import queue
4
+ from typing import cast
5
+ from collections.abc import AsyncGenerator, Generator
6
+ from litellm import ChatCompletionAssistantToolCall, CustomStreamWrapper, completion, acompletion
7
+ from litellm.utils import get_valid_models
8
+ from litellm.types.utils import LlmProviders,\
9
+ ModelResponse as LiteLlmModelResponse,\
10
+ ModelResponseStream as LiteLlmModelResponseStream,\
11
+ Choices as LiteLlmModelResponseChoices
12
+ from .stream import AssistantMessageCollector
13
+ from .tool import ToolFn, ToolDef, RawToolDef, prepare_tools
14
+ from .tool.execute import execute_tool_sync, execute_tool, parse_arguments
15
+ from .tool.utils import filter_executable_tools, find_tool_by_name
16
+ from .types import LlmRequestParams, GenerateTextResponse, StreamTextResponseSync, StreamTextResponseAsync
17
+ from .types.message import ChatMessage, AssistantMessageChunk, UserMessage, SystemMessage, AssistantMessage, ToolMessage
18
+
19
+ class LLM:
20
+ def __init__(self,
21
+ provider: LlmProviders,
22
+ base_url: str,
23
+ api_key: str):
24
+ self.provider = provider
25
+ self.base_url = base_url
26
+ self.api_key = api_key
27
+
28
+ def _parse_params_nonstream(self, params: LlmRequestParams):
29
+ tools = params.tools and prepare_tools(params.tools)
30
+ return {
31
+ "model": f"{self.provider.value}/{params.model}",
32
+ "messages": [message.to_litellm_message() for message in params.messages],
33
+ "base_url": self.base_url,
34
+ "api_key": self.api_key,
35
+ "tools": tools,
36
+ "tool_choice": params.tool_choice,
37
+ "stream": False,
38
+ "timeout": params.timeout_sec,
39
+ "extra_headers": params.headers,
40
+ **(params.extra_args or {})
41
+ }
42
+
43
+ def _parse_params_stream(self, params: LlmRequestParams):
44
+ tools = params.tools and prepare_tools(params.tools)
45
+ return {
46
+ "model": f"{self.provider.value}/{params.model}",
47
+ "messages": [message.to_litellm_message() for message in params.messages],
48
+ "base_url": self.base_url,
49
+ "api_key": self.api_key,
50
+ "tools": tools,
51
+ "tool_choice": params.tool_choice,
52
+ "stream": True,
53
+ "timeout": params.timeout_sec,
54
+ "extra_headers": params.headers,
55
+ **(params.extra_args or {})
56
+ }
57
+
58
+ @staticmethod
59
+ def _should_resolve_tool_calls(
60
+ params: LlmRequestParams,
61
+ message: AssistantMessage,
62
+ ) -> tuple[list[ToolFn | ToolDef | RawToolDef],
63
+ list[ChatCompletionAssistantToolCall]] | None:
64
+ message.tool_calls
65
+ condition = params.execute_tools and\
66
+ params.tools is not None and\
67
+ message.tool_calls is not None
68
+ if condition:
69
+ assert params.tools is not None
70
+ assert message.tool_calls is not None
71
+ return params.tools, message.tool_calls
72
+ return None
73
+
74
+ @staticmethod
75
+ def _parse_tool_call(tool_call: ChatCompletionAssistantToolCall) -> tuple[str, str, str] | None:
76
+ id = tool_call.get("id")
77
+ function = tool_call.get("function")
78
+ function_name = function.get("name")
79
+ function_arguments = function.get("arguments")
80
+ if id is None or\
81
+ function is None or\
82
+ function_name is None or\
83
+ function_arguments is None: return None
84
+ return id, function_name, function_arguments
85
+
86
+ @staticmethod
87
+ async def _execute_tool_calls(
88
+ tools: list[ToolFn | ToolDef | RawToolDef],
89
+ tool_calls: list[ChatCompletionAssistantToolCall]
90
+ ) -> list[ToolMessage]:
91
+ executable_tools = filter_executable_tools(tools)
92
+ result = []
93
+ for tool_call in tool_calls:
94
+ if (tool_call_data := LLM._parse_tool_call(tool_call)) is None: continue
95
+ id, function_name, function_arguments = tool_call_data
96
+ if (target_tool := find_tool_by_name(cast(list, executable_tools), function_name)) is None: continue
97
+ parsed_arguments = parse_arguments(function_arguments)
98
+ ret = await execute_tool(target_tool, parsed_arguments)
99
+ result.append(ToolMessage(
100
+ id=id,
101
+ name=function_name,
102
+ arguments=parsed_arguments,
103
+ result=ret))
104
+ return result
105
+
106
+ @staticmethod
107
+ def _execute_tool_calls_sync(
108
+ tools: list[ToolFn | ToolDef | RawToolDef],
109
+ tool_calls: list[ChatCompletionAssistantToolCall]
110
+ ) -> list[ToolMessage]:
111
+ executable_tools = filter_executable_tools(tools)
112
+ result = []
113
+ for tool_call in tool_calls:
114
+ if (tool_call_data := LLM._parse_tool_call(tool_call)) is None: continue
115
+ id, function_name, function_arguments = tool_call_data
116
+ if (target_tool := find_tool_by_name(cast(list, executable_tools), function_name)) is None: continue
117
+ parsed_arguments = parse_arguments(function_arguments)
118
+ ret = execute_tool_sync(target_tool, parsed_arguments)
119
+ result.append(ToolMessage(
120
+ id=id,
121
+ name=function_name,
122
+ arguments=parsed_arguments,
123
+ result=ret))
124
+ return result
125
+
126
+ def list_models(self) -> list[str]:
127
+ return get_valid_models(
128
+ custom_llm_provider=self.provider.value,
129
+ check_provider_endpoint=True,
130
+ api_base=self.base_url,
131
+ api_key=self.api_key)
132
+
133
+ def generate_text_sync(self, params: LlmRequestParams):
134
+ response = completion(**self._parse_params_nonstream(params))
135
+ response = cast(LiteLlmModelResponse, response)
136
+ choices = cast(list[LiteLlmModelResponseChoices], response.choices)
137
+ message = choices[0].message
138
+ assistant_message = AssistantMessage.from_litellm_message(message)
139
+ result: GenerateTextResponse = [assistant_message]
140
+ if (tools_and_tool_calls := self._should_resolve_tool_calls(params, assistant_message)):
141
+ tools, tool_calls = tools_and_tool_calls
142
+ result += self._execute_tool_calls_sync(tools, tool_calls)
143
+ return result
144
+
145
+ async def generate_text(self, params: LlmRequestParams) -> GenerateTextResponse:
146
+ response = await acompletion(**self._parse_params_nonstream(params))
147
+ response = cast(LiteLlmModelResponse, response)
148
+ choices = cast(list[LiteLlmModelResponseChoices], response.choices)
149
+ message = choices[0].message
150
+ assistant_message = AssistantMessage.from_litellm_message(message)
151
+ result: GenerateTextResponse = [assistant_message]
152
+ if (tools_and_tool_calls := self._should_resolve_tool_calls(params, assistant_message)):
153
+ tools, tool_calls = tools_and_tool_calls
154
+ result += await self._execute_tool_calls(tools, tool_calls)
155
+ return result
156
+
157
+ def stream_text_sync(self, params: LlmRequestParams) -> StreamTextResponseSync:
158
+ def stream(response: CustomStreamWrapper) -> Generator[AssistantMessageChunk]:
159
+ nonlocal message_collector
160
+ for chunk in response:
161
+ chunk = cast(LiteLlmModelResponseStream, chunk)
162
+ yield AssistantMessageChunk.from_litellm_chunk(chunk)
163
+ message_collector.collect(chunk)
164
+
165
+ message = message_collector.get_message()
166
+ full_message_queue.put(message)
167
+ if (tools_and_tool_calls := self._should_resolve_tool_calls(params, message)):
168
+ tools, tool_calls = tools_and_tool_calls
169
+ tool_messages = self._execute_tool_calls_sync(tools, tool_calls)
170
+ for tool_message in tool_messages:
171
+ full_message_queue.put(tool_message)
172
+ full_message_queue.put(None)
173
+
174
+ response = completion(**self._parse_params_stream(params))
175
+ message_collector = AssistantMessageCollector()
176
+ returned_stream = stream(cast(CustomStreamWrapper, response))
177
+ full_message_queue = queue.Queue[AssistantMessage | ToolMessage | None]()
178
+ return returned_stream, full_message_queue
179
+
180
+ async def stream_text(self, params: LlmRequestParams) -> StreamTextResponseAsync:
181
+ async def stream(response: CustomStreamWrapper) -> AsyncGenerator[AssistantMessageChunk]:
182
+ nonlocal message_collector
183
+ async for chunk in response:
184
+ chunk = cast(LiteLlmModelResponseStream, chunk)
185
+ yield AssistantMessageChunk.from_litellm_chunk(chunk)
186
+ message_collector.collect(chunk)
187
+
188
+ message = message_collector.get_message()
189
+ await full_message_queue.put(message)
190
+ if (tools_and_tool_calls := self._should_resolve_tool_calls(params, message)):
191
+ tools, tool_calls = tools_and_tool_calls
192
+ tool_messages = await self._execute_tool_calls(tools, tool_calls)
193
+ for tool_message in tool_messages:
194
+ await full_message_queue.put(tool_message)
195
+ await full_message_queue.put(None)
196
+
197
+ response = await acompletion(**self._parse_params_stream(params))
198
+ message_collector = AssistantMessageCollector()
199
+ returned_stream = stream(cast(CustomStreamWrapper, response))
200
+ full_message_queue = asyncio.Queue[AssistantMessage | ToolMessage | None]()
201
+ return returned_stream, full_message_queue
202
+
203
+ __all__ = [
204
+ "LLM",
205
+ "LlmRequestParams",
206
+ "ToolFn",
207
+ "ToolDef",
208
+ "RawToolDef",
209
+
210
+ "ChatMessage",
211
+ "UserMessage",
212
+ "SystemMessage",
213
+ "AssistantMessage",
214
+ "ToolMessage",
215
+ "AssistantMessageChunk",
216
+
217
+ "GenerateTextResponse",
218
+ "StreamTextResponseSync",
219
+ "StreamTextResponseAsync"
220
+ ]
@@ -0,0 +1,79 @@
1
+ import dataclasses
2
+ from litellm import ChatCompletionAssistantToolCall
3
+ from litellm.types.utils import ChatCompletionDeltaToolCall,\
4
+ ModelResponseStream as LiteLlmModelResponseStream
5
+ from .types.message import AssistantMessage
6
+
7
+ @dataclasses.dataclass
8
+ class ToolCallTemp:
9
+ id: str | None = None
10
+ name: str = ""
11
+ arguments: str = ""
12
+
13
+ class ToolCallCollector:
14
+ def __init__(self):
15
+ self.tool_call_buf: list[ToolCallTemp] = []
16
+ self._max_index = 0
17
+
18
+ def collect(self, tool_call_chunk: ChatCompletionDeltaToolCall):
19
+ if tool_call_chunk.index >= self._max_index:
20
+ self._max_index = tool_call_chunk.index
21
+ self.tool_call_buf.append(ToolCallTemp())
22
+
23
+ temp_tool_call = self.tool_call_buf[tool_call_chunk.index]
24
+ if tool_call_chunk.get("id"):
25
+ temp_tool_call.id = tool_call_chunk.id
26
+ if tool_call_chunk.function.get("name"):
27
+ assert tool_call_chunk.function.name is not None
28
+ temp_tool_call.name += tool_call_chunk.function.name
29
+ if tool_call_chunk.function.get("arguments"):
30
+ assert tool_call_chunk.function.arguments is not None
31
+ temp_tool_call.arguments += tool_call_chunk.function.arguments
32
+
33
+ def get_tool_calls(self) -> list[ChatCompletionAssistantToolCall]:
34
+ return [{
35
+ "id": tool_call.id,
36
+ "function": {
37
+ "name": tool_call.name,
38
+ "arguments": tool_call.arguments,
39
+ },
40
+ "type": "function"
41
+ } for tool_call in self.tool_call_buf]
42
+
43
+ class AssistantMessageCollector:
44
+ def __init__(self):
45
+ self.message_buf = AssistantMessage(None)
46
+ self.tool_call_collector = ToolCallCollector()
47
+
48
+ def collect(self, chunk: LiteLlmModelResponseStream):
49
+ delta = chunk.choices[0].delta
50
+ if delta.get("content"):
51
+ assert delta.content is not None
52
+ if self.message_buf.content is None:
53
+ self.message_buf.content = ""
54
+ self.message_buf.content += delta.content
55
+
56
+ if delta.get("reasoning_content"):
57
+ assert delta.reasoning_content is not None
58
+ if self.message_buf.reasoning_content is None:
59
+ self.message_buf.reasoning_content = ""
60
+ self.message_buf.reasoning_content += delta.reasoning_content
61
+
62
+ if delta.get("tool_calls"):
63
+ assert delta.tool_calls is not None
64
+ for tool_call_chunk in delta.tool_calls:
65
+ self.tool_call_collector.collect(tool_call_chunk)
66
+
67
+ if delta.get("images"):
68
+ assert delta.images is not None
69
+ if self.message_buf.images is None:
70
+ self.message_buf.images = []
71
+ self.message_buf.images = delta.images
72
+
73
+ if delta.get("audio"):
74
+ assert delta.audio is not None
75
+ self.message_buf.audio = delta.audio
76
+
77
+ def get_message(self) -> AssistantMessage:
78
+ self.message_buf.tool_calls = self.tool_call_collector.get_tool_calls()
79
+ return self.message_buf
@@ -0,0 +1,307 @@
1
+ """
2
+ source: https://github.com/mozilla-ai/any-llm/blob/main/src/any_llm/tools.py
3
+ """
4
+
5
+ import dataclasses
6
+ import enum
7
+ import inspect
8
+ import types as _types
9
+ from collections.abc import Callable, Mapping, Sequence
10
+ from datetime import date, datetime, time
11
+ from typing import Annotated as _Annotated, Literal as _Literal, is_typeddict as _is_typeddict,\
12
+ Any, Awaitable, get_args, get_origin, get_type_hints
13
+ from pydantic import BaseModel as PydanticBaseModel
14
+
15
+ ToolFn = Callable[..., Any] | Callable[..., Awaitable[Any]]
16
+ """
17
+ RawToolDef example:
18
+ {
19
+ "name": "get_current_weather",
20
+ "description": "Get the current weather in a given location",
21
+ "parameters": {
22
+ "type": "object",
23
+ "properties": {
24
+ "location": {
25
+ "type": "string",
26
+ "description": "The city and state, e.g. San Francisco, CA",
27
+ },
28
+ "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
29
+ },
30
+ "required": ["location"],
31
+ }
32
+ }
33
+ """
34
+ RawToolDef = dict[str, Any]
35
+
36
+ @dataclasses.dataclass
37
+ class ToolDef:
38
+ name: str
39
+ description: str
40
+ execute: ToolFn
41
+
42
+ def _python_type_to_json_schema(python_type: Any) -> dict[str, Any]:
43
+ """Convert Python type annotation to a JSON Schema for a parameter.
44
+
45
+ Supported mappings (subset tailored for LLM tool schemas):
46
+ - Primitives: str/int/float/bool -> string/integer/number/boolean
47
+ - bytes -> string with contentEncoding base64
48
+ - datetime/date/time -> string with format date-time/date/time
49
+ - list[T] / Sequence[T] / set[T] / frozenset[T] -> array with items=schema(T)
50
+ - set/frozenset include uniqueItems=true
51
+ - list without type args defaults items to string
52
+ - dict[K,V] / Mapping[K,V] -> object with additionalProperties=schema(V)
53
+ - dict without type args defaults additionalProperties to string
54
+ - tuple[T1, T2, ...] -> array with prefixItems per element and min/maxItems
55
+ - tuple[T, ...] -> array with items=schema(T)
56
+ - Union[X, Y] and X | Y -> oneOf=[schema(X), schema(Y)] (without top-level type)
57
+ - Optional[T] (Union[T, None]) -> schema(T) (nullability not encoded)
58
+ - Literal[...]/Enum -> enum with appropriate type inference when uniform
59
+ - TypedDict -> object with properties/required per annotations
60
+ - dataclass/Pydantic BaseModel -> object with nested properties inferred from fields
61
+ """
62
+ origin = get_origin(python_type)
63
+ args = get_args(python_type)
64
+
65
+ if _Annotated is not None and origin is _Annotated and len(args) >= 1:
66
+ python_type = args[0]
67
+ origin = get_origin(python_type)
68
+ args = get_args(python_type)
69
+
70
+ if python_type is Any:
71
+ return {"type": "string"}
72
+
73
+ primitive_map = {str: "string", int: "integer", float: "number", bool: "boolean"}
74
+ if python_type in primitive_map:
75
+ return {"type": primitive_map[python_type]}
76
+
77
+ if python_type is bytes:
78
+ return {"type": "string", "contentEncoding": "base64"}
79
+ if python_type is datetime:
80
+ return {"type": "string", "format": "date-time"}
81
+ if python_type is date:
82
+ return {"type": "string", "format": "date"}
83
+ if python_type is time:
84
+ return {"type": "string", "format": "time"}
85
+
86
+ if python_type is list:
87
+ return {"type": "array", "items": {"type": "string"}}
88
+ if python_type is dict:
89
+ return {"type": "object", "additionalProperties": {"type": "string"}}
90
+
91
+ if origin is _Literal:
92
+ literal_values = list(args)
93
+ schema_lit: dict[str, Any] = {"enum": literal_values}
94
+ if all(isinstance(v, bool) for v in literal_values):
95
+ schema_lit["type"] = "boolean"
96
+ elif all(isinstance(v, str) for v in literal_values):
97
+ schema_lit["type"] = "string"
98
+ elif all(isinstance(v, int) and not isinstance(v, bool) for v in literal_values):
99
+ schema_lit["type"] = "integer"
100
+ elif all(isinstance(v, int | float) and not isinstance(v, bool) for v in literal_values):
101
+ schema_lit["type"] = "number"
102
+ return schema_lit
103
+
104
+ if inspect.isclass(python_type) and issubclass(python_type, enum.Enum):
105
+ enum_values = [e.value for e in python_type]
106
+ value_types = {type(v) for v in enum_values}
107
+ schema: dict[str, Any] = {"enum": enum_values}
108
+ if value_types == {str}:
109
+ schema["type"] = "string"
110
+ elif value_types == {int}:
111
+ schema["type"] = "integer"
112
+ elif value_types <= {int, float}:
113
+ schema["type"] = "number"
114
+ elif value_types == {bool}:
115
+ schema["type"] = "boolean"
116
+ return schema
117
+
118
+ if _is_typeddict(python_type):
119
+ annotations: dict[str, Any] = getattr(python_type, "__annotations__", {}) or {}
120
+ required_keys = set(getattr(python_type, "__required_keys__", set()))
121
+ td_properties: dict[str, Any] = {}
122
+ td_required: list[str] = []
123
+ for field_name, field_type in annotations.items():
124
+ td_properties[field_name] = _python_type_to_json_schema(field_type)
125
+ if field_name in required_keys:
126
+ td_required.append(field_name)
127
+ schema_td: dict[str, Any] = {
128
+ "type": "object",
129
+ "properties": td_properties,
130
+ }
131
+ if td_required:
132
+ schema_td["required"] = td_required
133
+ return schema_td
134
+
135
+ if inspect.isclass(python_type) and dataclasses.is_dataclass(python_type):
136
+ type_hints = get_type_hints(python_type)
137
+ dc_properties: dict[str, Any] = {}
138
+ dc_required: list[str] = []
139
+ for field in dataclasses.fields(python_type):
140
+ field_type = type_hints.get(field.name, Any)
141
+ dc_properties[field.name] = _python_type_to_json_schema(field_type)
142
+ if (
143
+ field.default is dataclasses.MISSING
144
+ and getattr(field, "default_factory", dataclasses.MISSING) is dataclasses.MISSING
145
+ ):
146
+ dc_required.append(field.name)
147
+ schema_dc: dict[str, Any] = {"type": "object", "properties": dc_properties}
148
+ if dc_required:
149
+ schema_dc["required"] = dc_required
150
+ return schema_dc
151
+
152
+ if inspect.isclass(python_type) and issubclass(python_type, PydanticBaseModel):
153
+ model_type_hints = get_type_hints(python_type)
154
+ pd_properties: dict[str, Any] = {}
155
+ pd_required: list[str] = []
156
+ model_fields = getattr(python_type, "model_fields", {})
157
+ for name, field_info in model_fields.items():
158
+ pd_properties[name] = _python_type_to_json_schema(model_type_hints.get(name, Any))
159
+ is_required = getattr(field_info, "is_required", None)
160
+ if callable(is_required) and is_required():
161
+ pd_required.append(name)
162
+ schema_pd: dict[str, Any] = {"type": "object", "properties": pd_properties}
163
+ if pd_required:
164
+ schema_pd["required"] = pd_required
165
+ return schema_pd
166
+
167
+ if origin in (list, Sequence, set, frozenset):
168
+ item_type = args[0] if args else Any
169
+ item_schema = _python_type_to_json_schema(item_type)
170
+ schema_arr: dict[str, Any] = {"type": "array", "items": item_schema or {"type": "string"}}
171
+ if origin in (set, frozenset):
172
+ schema_arr["uniqueItems"] = True
173
+ return schema_arr
174
+ if origin is tuple:
175
+ if not args:
176
+ return {"type": "array", "items": {"type": "string"}}
177
+ if len(args) == 2 and args[1] is Ellipsis:
178
+ return {"type": "array", "items": _python_type_to_json_schema(args[0])}
179
+ prefix_items = [_python_type_to_json_schema(a) for a in args]
180
+ return {
181
+ "type": "array",
182
+ "prefixItems": prefix_items,
183
+ "minItems": len(prefix_items),
184
+ "maxItems": len(prefix_items),
185
+ }
186
+
187
+ if origin in (dict, Mapping):
188
+ value_type = args[1] if len(args) >= 2 else Any
189
+ value_schema = _python_type_to_json_schema(value_type)
190
+ return {"type": "object", "additionalProperties": value_schema or {"type": "string"}}
191
+
192
+ typing_union = getattr(__import__("typing"), "Union", None)
193
+ if origin in (typing_union, _types.UnionType):
194
+ non_none_args = [a for a in args if a is not type(None)]
195
+ if len(non_none_args) > 1:
196
+ schemas = [_python_type_to_json_schema(arg) for arg in non_none_args]
197
+ return {"oneOf": schemas}
198
+ if non_none_args:
199
+ return _python_type_to_json_schema(non_none_args[0])
200
+ return {"type": "string"}
201
+
202
+ return {"type": "string"}
203
+
204
+ def _parse_callable_properties(func: ToolFn) -> tuple[dict[str, dict[str, Any]], list[str]]:
205
+ sig = inspect.signature(func)
206
+ type_hints = get_type_hints(func)
207
+
208
+ properties: dict[str, dict[str, Any]] = {}
209
+ required: list[str] = []
210
+
211
+ for param_name, param in sig.parameters.items():
212
+ if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
213
+ continue
214
+
215
+ annotated_type = type_hints.get(param_name, str)
216
+ param_schema = _python_type_to_json_schema(annotated_type)
217
+
218
+ type_name = getattr(annotated_type, "__name__", str(annotated_type))
219
+ properties[param_name] = {
220
+ **param_schema,
221
+ "description": f"Parameter {param_name} of type {type_name}",
222
+ }
223
+
224
+ if param.default == inspect.Parameter.empty:
225
+ required.append(param_name)
226
+
227
+ return properties, required
228
+
229
+ def generate_tool_definition_from_callable(func: ToolFn) -> dict[str, Any]:
230
+ """Convert a Python callable to OpenAI tools format.
231
+
232
+ Args:
233
+ func: A Python callable (function) to convert to a tool
234
+
235
+ Returns:
236
+ Dictionary in OpenAI tools format
237
+
238
+ Raises:
239
+ ValueError: If the function doesn't have proper docstring or type annotations
240
+
241
+ Example:
242
+ >>> def get_weather(location: str, unit: str = "celsius") -> str:
243
+ ... '''Get weather information for a location.'''
244
+ ... return f"Weather in {location} is sunny, 25°{unit[0].upper()}"
245
+ >>>
246
+ >>> tool = generate_tool_definition_from_callable(get_weather)
247
+ >>> # Returns OpenAI tools format dict
248
+
249
+ """
250
+ if not func.__doc__:
251
+ msg = f"Function {func.__name__} must have a docstring"
252
+ raise ValueError(msg)
253
+
254
+ properties, required = _parse_callable_properties(func)
255
+ return {
256
+ "type": "function",
257
+ "function": {
258
+ "name": func.__name__,
259
+ "description": func.__doc__.strip(),
260
+ "parameters": {"type": "object", "properties": properties, "required": required},
261
+ },
262
+ }
263
+
264
+ def generate_tool_definition_from_tool_def(tool_def: ToolDef) -> dict[str, Any]:
265
+ """Convert a ToolDef to OpenAI tools format.
266
+
267
+ Args:
268
+ tool_def: A ToolDef to convert to a tool
269
+
270
+ Returns:
271
+ Dictionary in OpenAI tools format
272
+
273
+ Example:
274
+ >>> tool_def = ToolDef(
275
+ ... name="get_weather",
276
+ ... description="Get weather information for a location.",
277
+ ... execute=SomeFunction(),
278
+ ... )
279
+ >>> tool = generate_tool_definition_from_tool_def(tool_def)
280
+ >>> # Returns OpenAI tools format dict
281
+ """
282
+ properties, required = _parse_callable_properties(tool_def.execute)
283
+ return {
284
+ "type": "function",
285
+ "function": {
286
+ "name": tool_def.name,
287
+ "description": tool_def.description,
288
+ "parameters": {"type": "object", "properties": properties, "required": required},
289
+ },
290
+ }
291
+
292
+ def generate_tool_definition_from_raw_tool_def(raw_tool_def: RawToolDef) -> dict[str, Any]:
293
+ return {
294
+ "type": "function",
295
+ "function": raw_tool_def,
296
+ }
297
+
298
+ def prepare_tools(tools: list[ToolFn | ToolDef | RawToolDef]) -> list[dict]:
299
+ tool_defs = []
300
+ for tool in tools:
301
+ if callable(tool):
302
+ tool_defs.append(generate_tool_definition_from_callable(tool))
303
+ elif isinstance(tool, ToolDef):
304
+ tool_defs.append(generate_tool_definition_from_tool_def(tool))
305
+ else:
306
+ tool_defs.append(generate_tool_definition_from_raw_tool_def(tool))
307
+ return tool_defs
@@ -0,0 +1,59 @@
1
+ import asyncio
2
+ import json
3
+ from functools import singledispatch
4
+ from typing import Any, Awaitable, Callable, cast
5
+ from types import FunctionType, CoroutineType
6
+ from . import ToolDef
7
+
8
+ async def _coroutine_wrapper(awaitable: Awaitable[Any]) -> CoroutineType:
9
+ return await awaitable
10
+
11
+ def _arguments_normalizer(arguments: str | dict) -> dict:
12
+ if type(arguments) == str:
13
+ return parse_arguments(arguments)
14
+ elif type(arguments) == dict:
15
+ return arguments
16
+ else:
17
+ raise ValueError(f"Invalid arguments type: {type(arguments)}")
18
+
19
+ def parse_arguments(arguments: str) -> dict:
20
+ args = json.loads(arguments)
21
+ return cast(dict, args)
22
+
23
+ @singledispatch
24
+ def execute_tool_sync(tool, arguments: str | dict) -> Any: pass
25
+
26
+ @execute_tool_sync.register(FunctionType)
27
+ def _(toolfn: Callable, arguments: str | dict) -> Any:
28
+ arguments = _arguments_normalizer(arguments)
29
+ if asyncio.iscoroutinefunction(toolfn):
30
+ return asyncio.run(
31
+ _coroutine_wrapper(
32
+ toolfn(**arguments)))
33
+ return toolfn(**arguments)
34
+
35
+ @execute_tool_sync.register(ToolDef)
36
+ def _(tooldef: ToolDef, arguments: str | dict):
37
+ arguments = _arguments_normalizer(arguments)
38
+ if asyncio.iscoroutinefunction(tooldef.execute):
39
+ return asyncio.run(
40
+ _coroutine_wrapper(
41
+ tooldef.execute(**arguments)))
42
+ return tooldef.execute(**arguments)
43
+
44
+ @singledispatch
45
+ async def execute_tool(tool, arguments: str | dict) -> Any: pass
46
+
47
+ @execute_tool.register(FunctionType)
48
+ async def _(toolfn: Callable, arguments: str | dict) -> Any:
49
+ arguments = _arguments_normalizer(arguments)
50
+ if asyncio.iscoroutinefunction(toolfn):
51
+ return await toolfn(**arguments)
52
+ return toolfn(**arguments)
53
+
54
+ @execute_tool.register(ToolDef)
55
+ async def _(tooldef: ToolDef, arguments: str | dict):
56
+ arguments = _arguments_normalizer(arguments)
57
+ if asyncio.iscoroutinefunction(tooldef.execute):
58
+ return await tooldef.execute(**arguments)
59
+ return tooldef.execute(**arguments)
@@ -0,0 +1,19 @@
1
+ from typing import Any
2
+ from . import ToolFn, ToolDef, RawToolDef
3
+
4
+ def filter_executable_tools(tools: list[ToolFn | ToolDef | RawToolDef]) -> list[ToolFn | ToolDef]:
5
+ """
6
+ Since when we are going to execute the tools,
7
+ we do not care the raw tool definitions, they are usually the built-in tools from the provider.
8
+ """
9
+ return [tool for tool in tools if callable(tool) or isinstance(tool, ToolDef)]
10
+
11
+ def find_tool_by_name(tools: list[ToolFn | ToolDef | RawToolDef], name: str) -> ToolFn | ToolDef | RawToolDef | None:
12
+ for tool in tools:
13
+ if callable(tool) and tool.__name__ == name:
14
+ return tool
15
+ elif isinstance(tool, ToolDef) and tool.name == name:
16
+ return tool
17
+ elif isinstance(tool, dict) and tool.get("function", {}).get("name") == name:
18
+ return tool
19
+ return None
@@ -0,0 +1,32 @@
1
+ import asyncio
2
+ import dataclasses
3
+ import queue
4
+ from typing import Any, Generator, Literal
5
+ from collections.abc import AsyncGenerator, Generator
6
+ from ..tool import ToolFn, ToolDef, RawToolDef
7
+ from .message import AssistantMessageChunk, ChatMessage, AssistantMessage, ToolMessage
8
+
9
+ @dataclasses.dataclass
10
+ class LlmRequestParams:
11
+ model: str
12
+ messages: list[ChatMessage]
13
+ tools: list[ToolFn | ToolDef | RawToolDef] | None = None
14
+ tool_choice: Literal["auto", "required", "none"] = "auto"
15
+ execute_tools: bool = False
16
+
17
+ timeout_sec: float | None = None
18
+ temperature: float | None = None
19
+ max_tokens: int | None = None
20
+ headers: dict[str, str] | None = None
21
+
22
+ extra_args: dict[str, Any] | None = None
23
+
24
+ # --- --- --- --- --- ---
25
+
26
+ GenerateTextResponse = list[AssistantMessage | ToolMessage]
27
+ StreamTextResponseSync = tuple[
28
+ Generator[AssistantMessageChunk],
29
+ queue.Queue[AssistantMessage | ToolMessage | None]]
30
+ StreamTextResponseAsync = tuple[
31
+ AsyncGenerator[AssistantMessageChunk],
32
+ asyncio.Queue[AssistantMessage | ToolMessage | None]]
@@ -0,0 +1,119 @@
1
+ from abc import ABC
2
+ import json
3
+ from typing import Any, Literal
4
+ from litellm.types.utils import Message as LiteLlmMessage,\
5
+ ModelResponseStream as LiteLlmModelResponseStream,\
6
+ ChatCompletionAudioResponse
7
+ from litellm.types.llms.openai import (
8
+ AllMessageValues,
9
+ OpenAIMessageContent,
10
+ ChatCompletionAssistantToolCall,
11
+ ImageURLListItem as ChatCompletionImageURL,
12
+
13
+ ChatCompletionUserMessage,
14
+ ChatCompletionAssistantMessage,
15
+ ChatCompletionToolMessage,
16
+ ChatCompletionSystemMessage,
17
+ )
18
+
19
+ class ChatMessage(ABC):
20
+ def to_litellm_message(self) -> AllMessageValues: ...
21
+
22
+ class UserMessage(ChatMessage):
23
+ role: Literal["user"] = "user"
24
+ def __init__(self, content: OpenAIMessageContent):
25
+ self.content = content
26
+
27
+ def to_litellm_message(self) -> ChatCompletionUserMessage:
28
+ return ChatCompletionUserMessage(role=self.role, content=self.content)
29
+
30
+ class AssistantMessage(ChatMessage):
31
+ role: Literal["assistant"] = "assistant"
32
+ def __init__(self,
33
+ content: str | None,
34
+ reasoning_content: str | None = None,
35
+ tool_calls: list[ChatCompletionAssistantToolCall] | None = None):
36
+ self.content = content
37
+ self.reasoning_content = reasoning_content
38
+ self.tool_calls = tool_calls
39
+ self.audio: ChatCompletionAudioResponse | None = None
40
+ self.images: list[ChatCompletionImageURL] | None = None
41
+
42
+ @staticmethod
43
+ def from_litellm_message(message: LiteLlmMessage) -> "AssistantMessage":
44
+ tool_calls: list[ChatCompletionAssistantToolCall] | None = None
45
+ if message.get("tool_calls"):
46
+ assert message.tool_calls is not None
47
+ tool_calls = [{
48
+ "id": tool_call.id,
49
+ "function": {
50
+ "name": tool_call.function.name,
51
+ "arguments": tool_call.function.arguments,
52
+ },
53
+ "type": "function",
54
+ } for tool_call in message.tool_calls]
55
+
56
+ result = AssistantMessage(
57
+ content=message.get("content"),
58
+ reasoning_content=message.get("reasoning_content"),
59
+ tool_calls=tool_calls)
60
+
61
+ if message.get("audio"):
62
+ result.audio = message.audio
63
+ if message.get("images"):
64
+ result.images = message.images
65
+
66
+ return result
67
+
68
+ def to_litellm_message(self) -> ChatCompletionAssistantMessage:
69
+ return ChatCompletionAssistantMessage(role=self.role,
70
+ content=self.content,
71
+ reasoning_content=self.reasoning_content,
72
+ tool_calls=self.tool_calls)
73
+
74
+ class ToolMessage(ChatMessage):
75
+ role: Literal["tool"] = "tool"
76
+ def __init__(self, id: str, name: str, arguments: dict, result: Any):
77
+ self.id = id
78
+ self.name = name
79
+ self.arguments = arguments
80
+ self.result = result
81
+
82
+ def to_litellm_message(self) -> ChatCompletionToolMessage:
83
+ return ChatCompletionToolMessage(
84
+ role=self.role,
85
+ content=json.dumps(self.result),
86
+ tool_call_id=self.id)
87
+
88
+ class SystemMessage(ChatMessage):
89
+ role: Literal["system"] = "system"
90
+ def __init__(self, content: str):
91
+ self.content = content
92
+
93
+ def to_litellm_message(self) -> ChatCompletionSystemMessage:
94
+ return ChatCompletionSystemMessage(role=self.role, content=self.content)
95
+
96
+ class AssistantMessageChunk:
97
+ def __init__(self,
98
+ content: str | None = None,
99
+ reasoning_content: str | None = None,
100
+ audio: ChatCompletionAudioResponse | None = None,
101
+ images: list[ChatCompletionImageURL] | None = None):
102
+ self.content = content
103
+ self.reasoning_content = reasoning_content
104
+ self.audio = audio
105
+ self.images = images
106
+
107
+ @staticmethod
108
+ def from_litellm_chunk(chunk: LiteLlmModelResponseStream) -> "AssistantMessageChunk":
109
+ delta = chunk.choices[0].delta
110
+ temp_chunk = AssistantMessageChunk()
111
+ if delta.get("content"):
112
+ temp_chunk.content = delta.content
113
+ if delta.get("reasoning_content"):
114
+ temp_chunk.reasoning_content = delta.reasoning_content
115
+ if delta.get("audio"):
116
+ temp_chunk.audio = delta.audio
117
+ if delta.get("images"):
118
+ temp_chunk.images = delta.images
119
+ return temp_chunk