vectorvein 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectorvein/__init__.py +0 -0
- vectorvein/chat_clients/__init__.py +110 -0
- vectorvein/chat_clients/anthropic_client.py +450 -0
- vectorvein/chat_clients/base_client.py +91 -0
- vectorvein/chat_clients/deepseek_client.py +15 -0
- vectorvein/chat_clients/gemini_client.py +317 -0
- vectorvein/chat_clients/groq_client.py +15 -0
- vectorvein/chat_clients/local_client.py +14 -0
- vectorvein/chat_clients/minimax_client.py +315 -0
- vectorvein/chat_clients/mistral_client.py +15 -0
- vectorvein/chat_clients/moonshot_client.py +15 -0
- vectorvein/chat_clients/openai_client.py +15 -0
- vectorvein/chat_clients/openai_compatible_client.py +291 -0
- vectorvein/chat_clients/qwen_client.py +15 -0
- vectorvein/chat_clients/utils.py +635 -0
- vectorvein/chat_clients/yi_client.py +15 -0
- vectorvein/chat_clients/zhipuai_client.py +15 -0
- vectorvein/settings/__init__.py +71 -0
- vectorvein/types/defaults.py +396 -0
- vectorvein/types/enums.py +83 -0
- vectorvein/types/llm_parameters.py +69 -0
- vectorvein/utilities/media_processing.py +70 -0
- vectorvein-0.1.0.dist-info/METADATA +16 -0
- vectorvein-0.1.0.dist-info/RECORD +25 -0
- vectorvein-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,317 @@
|
|
1
|
+
# @Author: Bi Ying
|
2
|
+
# @Date: 2024-06-17 23:47:49
|
3
|
+
import json
|
4
|
+
import random
|
5
|
+
|
6
|
+
import httpx
|
7
|
+
|
8
|
+
from ..settings import settings
|
9
|
+
from .utils import cutoff_messages
|
10
|
+
from ..types import defaults as defs
|
11
|
+
from .base_client import BaseChatClient, BaseAsyncChatClient
|
12
|
+
from ..types.enums import ContextLengthControlType, BackendType
|
13
|
+
|
14
|
+
|
15
|
+
class GeminiChatClient(BaseChatClient):
|
16
|
+
DEFAULT_MODEL: str = defs.GEMINI_DEFAULT_MODEL
|
17
|
+
BACKEND_NAME: BackendType = BackendType.Gemini
|
18
|
+
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
model: str = defs.GEMINI_DEFAULT_MODEL,
|
22
|
+
stream: bool = True,
|
23
|
+
temperature: float = 0.7,
|
24
|
+
context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
|
25
|
+
random_endpoint: bool = True,
|
26
|
+
endpoint_id: str = "",
|
27
|
+
**kwargs,
|
28
|
+
):
|
29
|
+
super().__init__(
|
30
|
+
model,
|
31
|
+
stream,
|
32
|
+
temperature,
|
33
|
+
context_length_control,
|
34
|
+
random_endpoint,
|
35
|
+
endpoint_id,
|
36
|
+
**kwargs,
|
37
|
+
)
|
38
|
+
|
39
|
+
def create_completion(
|
40
|
+
self,
|
41
|
+
messages: list = list,
|
42
|
+
model: str | None = None,
|
43
|
+
stream: bool | None = None,
|
44
|
+
temperature: float | None = None,
|
45
|
+
max_tokens: int = 2000,
|
46
|
+
tools: list | None = None,
|
47
|
+
tool_choice: str | None = None,
|
48
|
+
):
|
49
|
+
if model is not None:
|
50
|
+
self.model = model
|
51
|
+
if stream is not None:
|
52
|
+
self.stream = stream
|
53
|
+
if temperature is not None:
|
54
|
+
self.temperature = temperature
|
55
|
+
|
56
|
+
self.model_setting = self.backend_settings.models[self.model]
|
57
|
+
|
58
|
+
if messages[0].get("role") == "system":
|
59
|
+
system_prompt = messages[0]["content"]
|
60
|
+
messages = messages[1:]
|
61
|
+
else:
|
62
|
+
system_prompt = ""
|
63
|
+
|
64
|
+
if self.context_length_control == ContextLengthControlType.Latest:
|
65
|
+
messages = cutoff_messages(
|
66
|
+
messages,
|
67
|
+
max_count=self.model_setting.context_length,
|
68
|
+
backend=self.BACKEND_NAME,
|
69
|
+
model=self.model_setting.id,
|
70
|
+
)
|
71
|
+
|
72
|
+
if tools:
|
73
|
+
tools_params = {"tools": [{"function_declarations": [tool["function"] for tool in tools]}]}
|
74
|
+
else:
|
75
|
+
tools_params = {}
|
76
|
+
|
77
|
+
if self.random_endpoint:
|
78
|
+
self.random_endpoint = True
|
79
|
+
self.endpoint_id = random.choice(self.backend_settings.models[self.model].endpoints)
|
80
|
+
self.endpoint = settings.get_endpoint(self.endpoint_id)
|
81
|
+
|
82
|
+
request_body = {
|
83
|
+
"contents": messages,
|
84
|
+
"safetySettings": [
|
85
|
+
{
|
86
|
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
87
|
+
"threshold": "BLOCK_ONLY_HIGH",
|
88
|
+
}
|
89
|
+
],
|
90
|
+
"generationConfig": {
|
91
|
+
"temperature": self.temperature,
|
92
|
+
"maxOutputTokens": max_tokens,
|
93
|
+
},
|
94
|
+
**tools_params,
|
95
|
+
}
|
96
|
+
if system_prompt:
|
97
|
+
request_body["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
98
|
+
|
99
|
+
headers = {"Content-Type": "application/json"}
|
100
|
+
|
101
|
+
params = {"key": self.endpoint.api_key}
|
102
|
+
|
103
|
+
if self.stream:
|
104
|
+
url = f"{self.endpoint.api_base}/models/{self.model_setting.id}:streamGenerateContent"
|
105
|
+
params["alt"] = "sse"
|
106
|
+
|
107
|
+
def generator():
|
108
|
+
result = {"content": ""}
|
109
|
+
with httpx.stream("POST", url, headers=headers, params=params, json=request_body) as response:
|
110
|
+
for chunk in response.iter_lines():
|
111
|
+
message = {"content": ""}
|
112
|
+
if not chunk.startswith("data:"):
|
113
|
+
continue
|
114
|
+
data = json.loads(chunk[5:])
|
115
|
+
chunk_content = data["candidates"][0]["content"]["parts"][0]
|
116
|
+
if "text" in chunk_content:
|
117
|
+
message["content"] = chunk_content["text"]
|
118
|
+
result["content"] += message["content"]
|
119
|
+
elif "functionCall" in chunk_content:
|
120
|
+
message["tool_calls"] = [
|
121
|
+
{
|
122
|
+
"index": 0,
|
123
|
+
"id": 0,
|
124
|
+
"function": {
|
125
|
+
"arguments": json.dumps(
|
126
|
+
chunk_content["functionCall"]["args"], ensure_ascii=False
|
127
|
+
),
|
128
|
+
"name": chunk_content["functionCall"]["name"],
|
129
|
+
},
|
130
|
+
"type": "function",
|
131
|
+
}
|
132
|
+
]
|
133
|
+
|
134
|
+
result["usage"] = message["usage"] = {
|
135
|
+
"prompt_tokens": data["usageMetadata"]["promptTokenCount"],
|
136
|
+
"completion_tokens": data["usageMetadata"]["candidatesTokenCount"],
|
137
|
+
"total_tokens": data["usageMetadata"]["totalTokenCount"],
|
138
|
+
}
|
139
|
+
yield message
|
140
|
+
|
141
|
+
return generator()
|
142
|
+
else:
|
143
|
+
url = f"{self.endpoint.api_base}/models/{self.model_setting.id}:generateContent"
|
144
|
+
response = httpx.post(url, json=request_body, headers=headers, params=params, timeout=None).json()
|
145
|
+
result = {
|
146
|
+
"content": "",
|
147
|
+
"usage": {
|
148
|
+
"prompt_tokens": response["usageMetadata"]["promptTokenCount"],
|
149
|
+
"completion_tokens": response["usageMetadata"]["candidatesTokenCount"],
|
150
|
+
"total_tokens": response["usageMetadata"]["totalTokenCount"],
|
151
|
+
},
|
152
|
+
}
|
153
|
+
tool_calls = []
|
154
|
+
for part in response["candidates"][0]["content"]["parts"]:
|
155
|
+
if "text" in part:
|
156
|
+
result["content"] += part["text"]
|
157
|
+
elif "functionCall" in part:
|
158
|
+
tool_calls.append(part["functionCall"])
|
159
|
+
|
160
|
+
if tool_calls:
|
161
|
+
result["tool_calls"] = tool_calls
|
162
|
+
|
163
|
+
return result
|
164
|
+
|
165
|
+
|
166
|
+
class AsyncGeminiChatClient(BaseAsyncChatClient):
|
167
|
+
DEFAULT_MODEL: str = defs.GEMINI_DEFAULT_MODEL
|
168
|
+
BACKEND_NAME: BackendType = BackendType.Gemini
|
169
|
+
|
170
|
+
def __init__(
|
171
|
+
self,
|
172
|
+
model: str = defs.GEMINI_DEFAULT_MODEL,
|
173
|
+
stream: bool = True,
|
174
|
+
temperature: float = 0.7,
|
175
|
+
context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
|
176
|
+
random_endpoint: bool = True,
|
177
|
+
endpoint_id: str = "",
|
178
|
+
**kwargs,
|
179
|
+
):
|
180
|
+
super().__init__(
|
181
|
+
model,
|
182
|
+
stream,
|
183
|
+
temperature,
|
184
|
+
context_length_control,
|
185
|
+
random_endpoint,
|
186
|
+
endpoint_id,
|
187
|
+
**kwargs,
|
188
|
+
)
|
189
|
+
|
190
|
+
async def create_completion(
|
191
|
+
self,
|
192
|
+
messages: list = list,
|
193
|
+
model: str | None = None,
|
194
|
+
stream: bool | None = None,
|
195
|
+
temperature: float | None = None,
|
196
|
+
max_tokens: int = 2000,
|
197
|
+
tools: list | None = None,
|
198
|
+
tool_choice: str | None = None,
|
199
|
+
):
|
200
|
+
if model is not None:
|
201
|
+
self.model = model
|
202
|
+
if stream is not None:
|
203
|
+
self.stream = stream
|
204
|
+
if temperature is not None:
|
205
|
+
self.temperature = temperature
|
206
|
+
|
207
|
+
self.model_setting = self.backend_settings.models[self.model]
|
208
|
+
|
209
|
+
if messages[0].get("role") == "system":
|
210
|
+
system_prompt = messages[0]["content"]
|
211
|
+
messages = messages[1:]
|
212
|
+
else:
|
213
|
+
system_prompt = ""
|
214
|
+
|
215
|
+
if self.context_length_control == ContextLengthControlType.Latest:
|
216
|
+
messages = cutoff_messages(
|
217
|
+
messages,
|
218
|
+
max_count=self.model_setting.context_length,
|
219
|
+
backend=self.BACKEND_NAME,
|
220
|
+
model=self.model_setting.id,
|
221
|
+
)
|
222
|
+
|
223
|
+
if tools:
|
224
|
+
tools_params = {"tools": [{"function_declarations": [tool["function"] for tool in tools]}]}
|
225
|
+
else:
|
226
|
+
tools_params = {}
|
227
|
+
|
228
|
+
if self.random_endpoint:
|
229
|
+
self.random_endpoint = True
|
230
|
+
self.endpoint_id = random.choice(self.backend_settings.models[self.model].endpoints)
|
231
|
+
self.endpoint = settings.get_endpoint(self.endpoint_id)
|
232
|
+
|
233
|
+
request_body = {
|
234
|
+
"contents": messages,
|
235
|
+
"safetySettings": [
|
236
|
+
{
|
237
|
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
238
|
+
"threshold": "BLOCK_ONLY_HIGH",
|
239
|
+
}
|
240
|
+
],
|
241
|
+
"generationConfig": {
|
242
|
+
"temperature": self.temperature,
|
243
|
+
"maxOutputTokens": max_tokens,
|
244
|
+
},
|
245
|
+
**tools_params,
|
246
|
+
}
|
247
|
+
if system_prompt:
|
248
|
+
request_body["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
249
|
+
|
250
|
+
headers = {"Content-Type": "application/json"}
|
251
|
+
|
252
|
+
params = {"key": self.endpoint.api_key}
|
253
|
+
|
254
|
+
if self.stream:
|
255
|
+
url = f"{self.endpoint.api_base}/models/{self.model_setting.id}:streamGenerateContent"
|
256
|
+
params["alt"] = "sse"
|
257
|
+
|
258
|
+
async def generator():
|
259
|
+
result = {"content": ""}
|
260
|
+
client = httpx.AsyncClient()
|
261
|
+
async with client.stream("POST", url, headers=headers, params=params, json=request_body) as response:
|
262
|
+
async for chunk in response.aiter_lines():
|
263
|
+
message = {"content": ""}
|
264
|
+
if not chunk.startswith("data:"):
|
265
|
+
continue
|
266
|
+
data = json.loads(chunk[5:])
|
267
|
+
chunk_content = data["candidates"][0]["content"]["parts"][0]
|
268
|
+
if "text" in chunk_content:
|
269
|
+
message["content"] = chunk_content["text"]
|
270
|
+
result["content"] += message["content"]
|
271
|
+
elif "functionCall" in chunk_content:
|
272
|
+
message["tool_calls"] = [
|
273
|
+
{
|
274
|
+
"index": 0,
|
275
|
+
"id": 0,
|
276
|
+
"function": {
|
277
|
+
"arguments": json.dumps(
|
278
|
+
chunk_content["functionCall"]["args"], ensure_ascii=False
|
279
|
+
),
|
280
|
+
"name": chunk_content["functionCall"]["name"],
|
281
|
+
},
|
282
|
+
"type": "function",
|
283
|
+
}
|
284
|
+
]
|
285
|
+
|
286
|
+
result["usage"] = message["usage"] = {
|
287
|
+
"prompt_tokens": data["usageMetadata"]["promptTokenCount"],
|
288
|
+
"completion_tokens": data["usageMetadata"]["candidatesTokenCount"],
|
289
|
+
"total_tokens": data["usageMetadata"]["totalTokenCount"],
|
290
|
+
}
|
291
|
+
yield message
|
292
|
+
|
293
|
+
return generator()
|
294
|
+
else:
|
295
|
+
url = f"{self.endpoint.api_base}/models/{self.model_setting.id}:generateContent"
|
296
|
+
async with httpx.AsyncClient(headers=headers, params=params, timeout=None) as client:
|
297
|
+
response = await client.post(url, json=request_body)
|
298
|
+
response = response.json()
|
299
|
+
result = {
|
300
|
+
"content": "",
|
301
|
+
"usage": {
|
302
|
+
"prompt_tokens": response["usageMetadata"]["promptTokenCount"],
|
303
|
+
"completion_tokens": response["usageMetadata"]["candidatesTokenCount"],
|
304
|
+
"total_tokens": response["usageMetadata"]["totalTokenCount"],
|
305
|
+
},
|
306
|
+
}
|
307
|
+
tool_calls = []
|
308
|
+
for part in response["candidates"][0]["content"]["parts"]:
|
309
|
+
if "text" in part:
|
310
|
+
result["content"] += part["text"]
|
311
|
+
elif "functionCall" in part:
|
312
|
+
tool_calls.append(part["functionCall"])
|
313
|
+
|
314
|
+
if tool_calls:
|
315
|
+
result["tool_calls"] = tool_calls
|
316
|
+
|
317
|
+
return result
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# @Author: Bi Ying
|
2
|
+
# @Date: 2024-07-26 14:48:55
|
3
|
+
from ..types.enums import BackendType
|
4
|
+
from ..types.defaults import GROQ_DEFAULT_MODEL
|
5
|
+
from .openai_compatible_client import OpenAICompatibleChatClient, AsyncOpenAICompatibleChatClient
|
6
|
+
|
7
|
+
|
8
|
+
class GroqChatClient(OpenAICompatibleChatClient):
|
9
|
+
DEFAULT_MODEL = GROQ_DEFAULT_MODEL
|
10
|
+
BACKEND_NAME = BackendType.Groq
|
11
|
+
|
12
|
+
|
13
|
+
class AsyncGroqChatClient(AsyncOpenAICompatibleChatClient):
|
14
|
+
DEFAULT_MODEL = GROQ_DEFAULT_MODEL
|
15
|
+
BACKEND_NAME = BackendType.Groq
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# @Author: Bi Ying
|
2
|
+
# @Date: 2024-07-26 14:48:55
|
3
|
+
from ..types.enums import BackendType
|
4
|
+
from .openai_compatible_client import OpenAICompatibleChatClient, AsyncOpenAICompatibleChatClient
|
5
|
+
|
6
|
+
|
7
|
+
class LocalChatClient(OpenAICompatibleChatClient):
|
8
|
+
DEFAULT_MODEL = ""
|
9
|
+
BACKEND_NAME = BackendType.Local
|
10
|
+
|
11
|
+
|
12
|
+
class AsyncLocalChatClient(AsyncOpenAICompatibleChatClient):
|
13
|
+
DEFAULT_MODEL = ""
|
14
|
+
BACKEND_NAME = BackendType.Local
|
@@ -0,0 +1,315 @@
|
|
1
|
+
# @Author: Bi Ying
|
2
|
+
# @Date: 2024-07-26 14:48:55
|
3
|
+
import json
|
4
|
+
import random
|
5
|
+
|
6
|
+
import httpx
|
7
|
+
|
8
|
+
from ..settings import settings
|
9
|
+
from .utils import cutoff_messages
|
10
|
+
from ..types import defaults as defs
|
11
|
+
from .base_client import BaseChatClient, BaseAsyncChatClient
|
12
|
+
from ..types.enums import ContextLengthControlType, BackendType
|
13
|
+
|
14
|
+
|
15
|
+
def extract_tool_calls(response):
|
16
|
+
try:
|
17
|
+
message = response["choices"][0].get("delta") or response["choices"][0].get("message", {})
|
18
|
+
tool_calls = message.get("tool_calls")
|
19
|
+
if tool_calls:
|
20
|
+
return {
|
21
|
+
"tool_calls": [
|
22
|
+
{
|
23
|
+
"index": index,
|
24
|
+
"id": tool_call["id"],
|
25
|
+
"function": tool_call["function"],
|
26
|
+
"type": tool_call["type"],
|
27
|
+
}
|
28
|
+
for index, tool_call in enumerate(tool_calls)
|
29
|
+
]
|
30
|
+
}
|
31
|
+
else:
|
32
|
+
return {}
|
33
|
+
except Exception:
|
34
|
+
return {}
|
35
|
+
|
36
|
+
|
37
|
+
class MiniMaxChatClient(BaseChatClient):
|
38
|
+
DEFAULT_MODEL: str = defs.MINIMAX_DEFAULT_MODEL
|
39
|
+
BACKEND_NAME: BackendType = BackendType.MiniMax
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
model: str = defs.MINIMAX_DEFAULT_MODEL,
|
44
|
+
stream: bool = True,
|
45
|
+
temperature: float = 0.7,
|
46
|
+
context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
|
47
|
+
random_endpoint: bool = True,
|
48
|
+
endpoint_id: str = "",
|
49
|
+
**kwargs,
|
50
|
+
):
|
51
|
+
super().__init__(
|
52
|
+
model,
|
53
|
+
stream,
|
54
|
+
temperature,
|
55
|
+
context_length_control,
|
56
|
+
random_endpoint,
|
57
|
+
endpoint_id,
|
58
|
+
**kwargs,
|
59
|
+
)
|
60
|
+
|
61
|
+
def create_completion(
|
62
|
+
self,
|
63
|
+
messages: list = list,
|
64
|
+
model: str | None = None,
|
65
|
+
stream: bool | None = None,
|
66
|
+
temperature: float | None = None,
|
67
|
+
max_tokens: int = 2048,
|
68
|
+
tools: list | None = None,
|
69
|
+
tool_choice: str = "auto",
|
70
|
+
):
|
71
|
+
if model is not None:
|
72
|
+
self.model = model
|
73
|
+
if stream is not None:
|
74
|
+
self.stream = stream
|
75
|
+
if temperature is not None:
|
76
|
+
self.temperature = temperature
|
77
|
+
|
78
|
+
self.model_setting = self.backend_settings.models[self.model]
|
79
|
+
if self.random_endpoint:
|
80
|
+
self.random_endpoint = True
|
81
|
+
self.endpoint_id = random.choice(self.backend_settings.models[self.model].endpoints)
|
82
|
+
self.endpoint = settings.get_endpoint(self.endpoint_id)
|
83
|
+
|
84
|
+
if self.context_length_control == ContextLengthControlType.Latest:
|
85
|
+
messages = cutoff_messages(
|
86
|
+
messages,
|
87
|
+
max_count=self.model_setting.context_length,
|
88
|
+
backend=self.BACKEND_NAME,
|
89
|
+
model=self.model_setting.id,
|
90
|
+
)
|
91
|
+
|
92
|
+
if tools is not None:
|
93
|
+
tools_params = {
|
94
|
+
"tools": [
|
95
|
+
{
|
96
|
+
"type": "function",
|
97
|
+
"function": {
|
98
|
+
"name": tool["function"]["name"],
|
99
|
+
"description": tool["function"].get("description", ""),
|
100
|
+
"parameters": json.dumps(
|
101
|
+
tool["function"].get("parameters", {})
|
102
|
+
), # 非要搞不同,parameters 是个字符串
|
103
|
+
},
|
104
|
+
}
|
105
|
+
for tool in tools
|
106
|
+
],
|
107
|
+
"tool_choice": tool_choice,
|
108
|
+
}
|
109
|
+
else:
|
110
|
+
tools_params = {}
|
111
|
+
|
112
|
+
self.url = self.endpoint.api_base
|
113
|
+
self.headers = {"Authorization": f"Bearer {self.endpoint.api_key}", "Content-Type": "application/json"}
|
114
|
+
|
115
|
+
request_body = {
|
116
|
+
"model": self.model,
|
117
|
+
"messages": messages,
|
118
|
+
"max_tokens": max_tokens,
|
119
|
+
"temperature": self.temperature,
|
120
|
+
"stream": self.stream,
|
121
|
+
"mask_sensitive_info": False,
|
122
|
+
**tools_params,
|
123
|
+
}
|
124
|
+
|
125
|
+
response = httpx.post(
|
126
|
+
url=self.url,
|
127
|
+
headers=self.headers,
|
128
|
+
json=request_body,
|
129
|
+
timeout=60,
|
130
|
+
)
|
131
|
+
|
132
|
+
if self.stream:
|
133
|
+
|
134
|
+
def generator():
|
135
|
+
for chunk in response.iter_lines():
|
136
|
+
if chunk:
|
137
|
+
chunk_data = json.loads(chunk[6:])
|
138
|
+
tool_calls_params = extract_tool_calls(chunk_data)
|
139
|
+
has_tool_calls = True if tool_calls_params else False
|
140
|
+
if has_tool_calls:
|
141
|
+
if "usage" not in chunk_data:
|
142
|
+
continue
|
143
|
+
else:
|
144
|
+
yield {
|
145
|
+
"content": chunk_data["choices"][0]["message"].get("content"),
|
146
|
+
"role": "assistant",
|
147
|
+
**tool_calls_params,
|
148
|
+
}
|
149
|
+
else:
|
150
|
+
if "usage" in chunk_data:
|
151
|
+
continue
|
152
|
+
yield {
|
153
|
+
"content": chunk_data["choices"][0]["delta"]["content"],
|
154
|
+
"role": "assistant",
|
155
|
+
}
|
156
|
+
|
157
|
+
return generator()
|
158
|
+
else:
|
159
|
+
result = response.json()
|
160
|
+
tool_calls_params = extract_tool_calls(result)
|
161
|
+
return {
|
162
|
+
"content": result["choices"][0]["message"].get("content"),
|
163
|
+
"usage": {
|
164
|
+
"prompt_tokens": 0,
|
165
|
+
"completion_tokens": result["usage"]["total_tokens"],
|
166
|
+
"total_tokens": result["usage"]["total_tokens"],
|
167
|
+
},
|
168
|
+
"role": "assistant",
|
169
|
+
**tool_calls_params,
|
170
|
+
}
|
171
|
+
|
172
|
+
|
173
|
+
class AsyncMiniMaxChatClient(BaseAsyncChatClient):
|
174
|
+
DEFAULT_MODEL: str = defs.MINIMAX_DEFAULT_MODEL
|
175
|
+
BACKEND_NAME: BackendType = BackendType.MiniMax
|
176
|
+
|
177
|
+
def __init__(
|
178
|
+
self,
|
179
|
+
model: str = defs.MINIMAX_DEFAULT_MODEL,
|
180
|
+
stream: bool = True,
|
181
|
+
temperature: float = 0.7,
|
182
|
+
context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
|
183
|
+
random_endpoint: bool = True,
|
184
|
+
endpoint_id: str = "",
|
185
|
+
**kwargs,
|
186
|
+
):
|
187
|
+
super().__init__(
|
188
|
+
model,
|
189
|
+
stream,
|
190
|
+
temperature,
|
191
|
+
context_length_control,
|
192
|
+
random_endpoint,
|
193
|
+
endpoint_id,
|
194
|
+
**kwargs,
|
195
|
+
)
|
196
|
+
self.http_client = httpx.AsyncClient()
|
197
|
+
|
198
|
+
async def create_completion(
|
199
|
+
self,
|
200
|
+
messages: list = list,
|
201
|
+
model: str | None = None,
|
202
|
+
stream: bool | None = None,
|
203
|
+
temperature: float | None = None,
|
204
|
+
max_tokens: int = 2048,
|
205
|
+
tools: list | None = None,
|
206
|
+
tool_choice: str = "auto",
|
207
|
+
):
|
208
|
+
if model is not None:
|
209
|
+
self.model = model
|
210
|
+
if stream is not None:
|
211
|
+
self.stream = stream
|
212
|
+
if temperature is not None:
|
213
|
+
self.temperature = temperature
|
214
|
+
|
215
|
+
self.model_setting = self.backend_settings.models[self.model]
|
216
|
+
if self.random_endpoint:
|
217
|
+
self.random_endpoint = True
|
218
|
+
self.endpoint_id = random.choice(self.backend_settings.models[self.model].endpoints)
|
219
|
+
self.endpoint = settings.get_endpoint(self.endpoint_id)
|
220
|
+
|
221
|
+
if self.context_length_control == ContextLengthControlType.Latest:
|
222
|
+
messages = cutoff_messages(
|
223
|
+
messages,
|
224
|
+
max_count=self.model_setting.context_length,
|
225
|
+
backend=self.BACKEND_NAME,
|
226
|
+
model=self.model_setting.id,
|
227
|
+
)
|
228
|
+
|
229
|
+
if tools is not None:
|
230
|
+
tools_params = {
|
231
|
+
"tools": [
|
232
|
+
{
|
233
|
+
"type": "function",
|
234
|
+
"function": {
|
235
|
+
"name": tool["function"]["name"],
|
236
|
+
"description": tool["function"].get("description", ""),
|
237
|
+
"parameters": json.dumps(tool["function"].get("parameters", {})),
|
238
|
+
},
|
239
|
+
}
|
240
|
+
for tool in tools
|
241
|
+
],
|
242
|
+
"tool_choice": tool_choice,
|
243
|
+
}
|
244
|
+
else:
|
245
|
+
tools_params = {}
|
246
|
+
|
247
|
+
self.url = self.endpoint.api_base
|
248
|
+
self.headers = {"Authorization": f"Bearer {self.endpoint.api_key}", "Content-Type": "application/json"}
|
249
|
+
|
250
|
+
request_body = {
|
251
|
+
"model": self.model,
|
252
|
+
"messages": messages,
|
253
|
+
"max_tokens": max_tokens,
|
254
|
+
"temperature": self.temperature,
|
255
|
+
"stream": self.stream,
|
256
|
+
"mask_sensitive_info": False,
|
257
|
+
**tools_params,
|
258
|
+
}
|
259
|
+
|
260
|
+
if self.stream:
|
261
|
+
|
262
|
+
async def generator():
|
263
|
+
async with self.http_client.stream(
|
264
|
+
"POST",
|
265
|
+
url=self.url,
|
266
|
+
headers=self.headers,
|
267
|
+
json=request_body,
|
268
|
+
timeout=60,
|
269
|
+
) as response:
|
270
|
+
has_tool_calls = False
|
271
|
+
async for chunk in response.aiter_lines():
|
272
|
+
if chunk:
|
273
|
+
chunk_data = json.loads(chunk[6:])
|
274
|
+
tool_calls_params = extract_tool_calls(chunk_data)
|
275
|
+
has_tool_calls = True if tool_calls_params else False
|
276
|
+
if has_tool_calls:
|
277
|
+
if "usage" not in chunk_data:
|
278
|
+
continue
|
279
|
+
else:
|
280
|
+
yield {
|
281
|
+
"content": chunk_data["choices"][0]["message"].get("content"),
|
282
|
+
"role": "assistant",
|
283
|
+
**tool_calls_params,
|
284
|
+
}
|
285
|
+
else:
|
286
|
+
if "usage" in chunk_data:
|
287
|
+
continue
|
288
|
+
yield {
|
289
|
+
"content": chunk_data["choices"][0]["delta"]["content"],
|
290
|
+
"role": "assistant",
|
291
|
+
}
|
292
|
+
|
293
|
+
return generator()
|
294
|
+
else:
|
295
|
+
response = await self.http_client.post(
|
296
|
+
url=self.url,
|
297
|
+
headers=self.headers,
|
298
|
+
json=request_body,
|
299
|
+
timeout=60,
|
300
|
+
)
|
301
|
+
result = response.json()
|
302
|
+
tool_calls_params = extract_tool_calls(result)
|
303
|
+
return {
|
304
|
+
"content": result["choices"][0]["message"].get("content"),
|
305
|
+
"usage": {
|
306
|
+
"prompt_tokens": 0,
|
307
|
+
"completion_tokens": result["usage"]["total_tokens"],
|
308
|
+
"total_tokens": result["usage"]["total_tokens"],
|
309
|
+
},
|
310
|
+
"role": "assistant",
|
311
|
+
**tool_calls_params,
|
312
|
+
}
|
313
|
+
|
314
|
+
async def __aexit__(self, exc_type, exc, tb):
|
315
|
+
await self.http_client.aclose()
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# @Author: Bi Ying
|
2
|
+
# @Date: 2024-07-26 14:48:55
|
3
|
+
from ..types.enums import BackendType
|
4
|
+
from ..types.defaults import MISTRAL_DEFAULT_MODEL
|
5
|
+
from .openai_compatible_client import OpenAICompatibleChatClient, AsyncOpenAICompatibleChatClient
|
6
|
+
|
7
|
+
|
8
|
+
class MistralChatClient(OpenAICompatibleChatClient):
|
9
|
+
DEFAULT_MODEL = MISTRAL_DEFAULT_MODEL
|
10
|
+
BACKEND_NAME = BackendType.Mistral
|
11
|
+
|
12
|
+
|
13
|
+
class AsyncMistralChatClient(AsyncOpenAICompatibleChatClient):
|
14
|
+
DEFAULT_MODEL = MISTRAL_DEFAULT_MODEL
|
15
|
+
BACKEND_NAME = BackendType.Mistral
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# @Author: Bi Ying
|
2
|
+
# @Date: 2024-07-26 14:48:55
|
3
|
+
from ..types.enums import BackendType
|
4
|
+
from ..types.defaults import MOONSHOT_DEFAULT_MODEL
|
5
|
+
from .openai_compatible_client import OpenAICompatibleChatClient, AsyncOpenAICompatibleChatClient
|
6
|
+
|
7
|
+
|
8
|
+
class MoonshotChatClient(OpenAICompatibleChatClient):
|
9
|
+
DEFAULT_MODEL = MOONSHOT_DEFAULT_MODEL
|
10
|
+
BACKEND_NAME = BackendType.Moonshot
|
11
|
+
|
12
|
+
|
13
|
+
class AsyncMoonshotChatClient(AsyncOpenAICompatibleChatClient):
|
14
|
+
DEFAULT_MODEL = MOONSHOT_DEFAULT_MODEL
|
15
|
+
BACKEND_NAME = BackendType.Moonshot
|