vectorvein 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vectorvein/__init__.py ADDED
File without changes
@@ -0,0 +1,110 @@
1
+ # @Author: Bi Ying
2
+ # @Date: 2024-07-26 14:48:55
3
+ from .base_client import BaseChatClient, BaseAsyncChatClient
4
+
5
+ from .yi_client import YiChatClient, AsyncYiChatClient
6
+ from .groq_client import GroqChatClient, AsyncGroqChatClient
7
+ from .qwen_client import QwenChatClient, AsyncQwenChatClient
8
+ from .local_client import LocalChatClient, AsyncLocalChatClient
9
+ from .gemini_client import GeminiChatClient, AsyncGeminiChatClient
10
+ from .openai_client import OpenAIChatClient, AsyncOpenAIChatClient
11
+ from .zhipuai_client import ZhiPuAIChatClient, AsyncZhiPuAIChatClient
12
+ from .minimax_client import MiniMaxChatClient, AsyncMiniMaxChatClient
13
+ from .mistral_client import MistralChatClient, AsyncMistralChatClient
14
+ from .moonshot_client import MoonshotChatClient, AsyncMoonshotChatClient
15
+ from .deepseek_client import DeepSeekChatClient, AsyncDeepSeekChatClient
16
+
17
+ from ..types import defaults as defs
18
+ from ..types.enums import BackendType, ContextLengthControlType
19
+ from .anthropic_client import AnthropicChatClient, AsyncAnthropicChatClient
20
+ from .utils import format_messages
21
+
22
+
23
+ BackendMap = {
24
+ "sync": {
25
+ BackendType.Anthropic: AnthropicChatClient,
26
+ BackendType.DeepSeek: DeepSeekChatClient,
27
+ BackendType.Gemini: GeminiChatClient,
28
+ BackendType.Groq: GroqChatClient,
29
+ BackendType.Local: LocalChatClient,
30
+ BackendType.MiniMax: MiniMaxChatClient,
31
+ BackendType.Mistral: MistralChatClient,
32
+ BackendType.Moonshot: MoonshotChatClient,
33
+ BackendType.OpenAI: OpenAIChatClient,
34
+ BackendType.Qwen: QwenChatClient,
35
+ BackendType.Yi: YiChatClient,
36
+ BackendType.ZhiPuAI: ZhiPuAIChatClient,
37
+ },
38
+ "async": {
39
+ BackendType.Anthropic: AsyncAnthropicChatClient,
40
+ BackendType.DeepSeek: AsyncDeepSeekChatClient,
41
+ BackendType.Gemini: AsyncGeminiChatClient,
42
+ BackendType.Groq: AsyncGroqChatClient,
43
+ BackendType.Local: AsyncLocalChatClient,
44
+ BackendType.MiniMax: AsyncMiniMaxChatClient,
45
+ BackendType.Mistral: AsyncMistralChatClient,
46
+ BackendType.Moonshot: AsyncMoonshotChatClient,
47
+ BackendType.OpenAI: AsyncOpenAIChatClient,
48
+ BackendType.Qwen: AsyncQwenChatClient,
49
+ BackendType.Yi: AsyncYiChatClient,
50
+ BackendType.ZhiPuAI: AsyncZhiPuAIChatClient,
51
+ },
52
+ }
53
+
54
+
55
+ def create_chat_client(
56
+ backend: BackendType,
57
+ model: str | None = None,
58
+ stream: bool = True,
59
+ temperature: float = 0.7,
60
+ context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
61
+ **kwargs,
62
+ ) -> BaseChatClient:
63
+ if backend.lower() not in BackendMap["sync"]:
64
+ raise ValueError(f"Unsupported backend: {backend}")
65
+ else:
66
+ backend_key = backend.lower()
67
+
68
+ ClientClass = BackendMap["sync"][backend_key]
69
+ if model is None:
70
+ model = ClientClass.DEFAULT_MODEL
71
+ return BackendMap["sync"][backend_key](
72
+ model=model,
73
+ stream=stream,
74
+ temperature=temperature,
75
+ context_length_control=context_length_control,
76
+ **kwargs,
77
+ )
78
+
79
+
80
+ def create_async_chat_client(
81
+ backend: BackendType,
82
+ model: str | None = None,
83
+ stream: bool = True,
84
+ temperature: float = 0.7,
85
+ context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
86
+ **kwargs,
87
+ ) -> BaseAsyncChatClient:
88
+ if backend.lower() not in BackendMap["async"]:
89
+ raise ValueError(f"Unsupported backend: {backend}")
90
+ else:
91
+ backend_key = backend.lower()
92
+
93
+ ClientClass = BackendMap["async"][backend_key]
94
+ if model is None:
95
+ model = ClientClass.DEFAULT_MODEL
96
+ return BackendMap["async"][backend_key](
97
+ model=model,
98
+ stream=stream,
99
+ temperature=temperature,
100
+ context_length_control=context_length_control,
101
+ **kwargs,
102
+ )
103
+
104
+
105
+ __all__ = [
106
+ "create_chat_client",
107
+ "create_async_chat_client",
108
+ "format_messages",
109
+ "BackendType",
110
+ ]
@@ -0,0 +1,450 @@
1
+ # @Author: Bi Ying
2
+ # @Date: 2024-07-26 14:48:55
3
+ import json
4
+ import random
5
+
6
+ from anthropic import Anthropic, AnthropicVertex, AsyncAnthropic, AsyncAnthropicVertex
7
+ from anthropic._types import NotGiven, NOT_GIVEN
8
+ from anthropic.types import (
9
+ TextBlock,
10
+ ToolUseBlock,
11
+ RawMessageDeltaEvent,
12
+ RawMessageStartEvent,
13
+ RawContentBlockStartEvent,
14
+ RawContentBlockDeltaEvent,
15
+ )
16
+ from google.oauth2.credentials import Credentials
17
+ from google.auth.transport.requests import Request
18
+ from google.auth import _helpers
19
+
20
+ from ..settings import settings
21
+ from .utils import cutoff_messages
22
+ from ..types import defaults as defs
23
+ from .base_client import BaseChatClient, BaseAsyncChatClient
24
+ from ..types.enums import ContextLengthControlType, BackendType
25
+
26
+
27
+ def refactor_tool_use_params(tools: list):
28
+ return [
29
+ {
30
+ "name": tool["function"]["name"],
31
+ "description": tool["function"]["description"],
32
+ "input_schema": tool["function"]["parameters"],
33
+ }
34
+ for tool in tools
35
+ ]
36
+
37
+
38
+ def refactor_tool_calls(tool_calls: list):
39
+ return [
40
+ {
41
+ "index": index,
42
+ "id": tool["id"],
43
+ "type": "function",
44
+ "function": {
45
+ "name": tool["name"],
46
+ "arguments": json.dumps(tool["input"], ensure_ascii=False),
47
+ },
48
+ }
49
+ for index, tool in enumerate(tool_calls)
50
+ ]
51
+
52
+
53
+ def format_messages_alternate(messages: list) -> list:
54
+ # messages: roles must alternate between "user" and "assistant", and not multiple "user" roles in a row
55
+ # reformat multiple "user" roles in a row into {"role": "user", "content": [{"type": "text", "text": "Hello, Claude"}, {"type": "text", "text": "How are you?"}]}
56
+ # same for assistant role
57
+ # if not multiple "user" or "assistant" roles in a row, keep it as is
58
+
59
+ formatted_messages = []
60
+ current_role = None
61
+ current_content = []
62
+
63
+ for message in messages:
64
+ role = message["role"]
65
+ content = message["content"]
66
+
67
+ if role != current_role:
68
+ if current_content:
69
+ formatted_messages.append({"role": current_role, "content": current_content})
70
+ current_content = []
71
+ current_role = role
72
+
73
+ if isinstance(content, str):
74
+ current_content.append({"type": "text", "text": content})
75
+ elif isinstance(content, list):
76
+ current_content.extend(content)
77
+ else:
78
+ current_content.append(content)
79
+
80
+ if current_content:
81
+ formatted_messages.append({"role": current_role, "content": current_content})
82
+
83
+ return formatted_messages
84
+
85
+
86
+ class AnthropicChatClient(BaseChatClient):
87
+ DEFAULT_MODEL: str = defs.ANTHROPIC_DEFAULT_MODEL
88
+ BACKEND_NAME: BackendType = BackendType.Anthropic
89
+
90
+ def __init__(
91
+ self,
92
+ model: str = defs.ANTHROPIC_DEFAULT_MODEL,
93
+ stream: bool = True,
94
+ temperature: float = 0.7,
95
+ context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
96
+ random_endpoint: bool = True,
97
+ endpoint_id: str = "",
98
+ **kwargs,
99
+ ):
100
+ super().__init__(
101
+ model,
102
+ stream,
103
+ temperature,
104
+ context_length_control,
105
+ random_endpoint,
106
+ endpoint_id,
107
+ **kwargs,
108
+ )
109
+
110
+ def create_completion(
111
+ self,
112
+ messages: list = list,
113
+ model: str | None = None,
114
+ stream: bool | None = None,
115
+ temperature: float | None = None,
116
+ max_tokens: int = 2000,
117
+ tools: list | NotGiven = NOT_GIVEN,
118
+ tool_choice: str | NotGiven = NOT_GIVEN,
119
+ ):
120
+ if model is not None:
121
+ self.model = model
122
+ if stream is not None:
123
+ self.stream = stream
124
+ if temperature is not None:
125
+ self.temperature = temperature
126
+
127
+ self.model_setting = self.backend_settings.models[self.model]
128
+
129
+ if messages[0].get("role") == "system":
130
+ system_prompt = messages[0]["content"]
131
+ messages = messages[1:]
132
+ else:
133
+ system_prompt = ""
134
+
135
+ if self.context_length_control == ContextLengthControlType.Latest:
136
+ messages = cutoff_messages(
137
+ messages,
138
+ max_count=self.model_setting.context_length,
139
+ backend=self.BACKEND_NAME,
140
+ model=self.model_setting.id,
141
+ )
142
+
143
+ messages = format_messages_alternate(messages)
144
+
145
+ if self.random_endpoint:
146
+ self.random_endpoint = True
147
+ self.endpoint_id = random.choice(self.backend_settings.models[self.model].endpoints)
148
+ self.endpoint = settings.get_endpoint(self.endpoint_id)
149
+
150
+ if self.endpoint.is_vertex:
151
+ self.creds = Credentials(
152
+ token=self.endpoint.credentials.get("token"),
153
+ refresh_token=self.endpoint.credentials.get("refresh_token"),
154
+ token_uri=self.endpoint.credentials.get("token_uri"),
155
+ scopes=None,
156
+ client_id=self.endpoint.credentials.get("client_id"),
157
+ client_secret=self.endpoint.credentials.get("client_secret"),
158
+ quota_project_id=self.endpoint.credentials.get("quota_project_id"),
159
+ expiry=_helpers.utcnow() - _helpers.REFRESH_THRESHOLD,
160
+ rapt_token=self.endpoint.credentials.get("rapt_token"),
161
+ trust_boundary=self.endpoint.credentials.get("trust_boundary"),
162
+ universe_domain=self.endpoint.credentials.get("universe_domain"),
163
+ account=self.endpoint.credentials.get("account", ""),
164
+ )
165
+
166
+ if self.creds.expired and self.creds.refresh_token:
167
+ self.creds.refresh(Request())
168
+
169
+ if self.endpoint.api_base is None:
170
+ base_url = None
171
+ else:
172
+ base_url = f"{self.endpoint.api_base}{self.endpoint.region}-aiplatform/v1"
173
+
174
+ self._client = AnthropicVertex(
175
+ region=self.endpoint.region,
176
+ base_url=base_url,
177
+ project_id=self.endpoint.credentials.get("quota_project_id"),
178
+ access_token=self.creds.token,
179
+ )
180
+ else:
181
+ self._client = Anthropic(
182
+ api_key=self.endpoint.api_key,
183
+ base_url=self.endpoint.api_base,
184
+ )
185
+
186
+ response = self._client.messages.create(
187
+ model=self.model_setting.id,
188
+ messages=messages,
189
+ system=system_prompt,
190
+ stream=self.stream,
191
+ temperature=self.temperature,
192
+ max_tokens=max_tokens,
193
+ tools=refactor_tool_use_params(tools) if tools else tools,
194
+ tool_choice=tool_choice,
195
+ )
196
+
197
+ if self.stream:
198
+
199
+ def generator():
200
+ result = {"content": ""}
201
+ for chunk in response:
202
+ message = {"content": ""}
203
+ if isinstance(chunk, RawMessageStartEvent):
204
+ result["usage"] = {"prompt_tokens": chunk.message.usage.input_tokens}
205
+ continue
206
+ elif isinstance(chunk, RawContentBlockStartEvent):
207
+ if chunk.content_block.type == "tool_use":
208
+ result["tool_calls"] = message["tool_calls"] = [
209
+ {
210
+ "index": 0,
211
+ "id": chunk.content_block.id,
212
+ "function": {
213
+ "arguments": "",
214
+ "name": chunk.content_block.name,
215
+ },
216
+ "type": "function",
217
+ }
218
+ ]
219
+ elif chunk.content_block.type == "text":
220
+ message["content"] = chunk.content_block.text
221
+ yield message
222
+ elif isinstance(chunk, RawContentBlockDeltaEvent):
223
+ if chunk.delta.type == "text_delta":
224
+ message["content"] = chunk.delta.text
225
+ result["content"] += chunk.delta.text
226
+ elif chunk.delta.type == "input_json_delta":
227
+ result["tool_calls"][0]["function"]["arguments"] += chunk.delta.partial_json
228
+ message["tool_calls"] = [
229
+ {
230
+ "index": 0,
231
+ "id": result["tool_calls"][0]["id"],
232
+ "function": {
233
+ "arguments": chunk.delta.partial_json,
234
+ "name": result["tool_calls"][0]["function"]["name"],
235
+ },
236
+ "type": "function",
237
+ }
238
+ ]
239
+ yield message
240
+ elif isinstance(chunk, RawMessageDeltaEvent):
241
+ result["usage"]["completion_tokens"] = chunk.usage.output_tokens
242
+ result["usage"]["total_tokens"] = (
243
+ result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"]
244
+ )
245
+ yield {"usage": result["usage"]}
246
+
247
+ return generator()
248
+ else:
249
+ result = {
250
+ "content": "",
251
+ "usage": {
252
+ "prompt_tokens": response.usage.input_tokens,
253
+ "completion_tokens": response.usage.output_tokens,
254
+ "total_tokens": response.usage.input_tokens + response.usage.output_tokens,
255
+ },
256
+ }
257
+ tool_calls = []
258
+ for content_block in response.content:
259
+ if isinstance(content_block, TextBlock):
260
+ result["content"] += content_block.text
261
+ elif isinstance(content_block, ToolUseBlock):
262
+ tool_calls.append(content_block.model_dump())
263
+
264
+ if tool_calls:
265
+ result["tool_calls"] = refactor_tool_calls(tool_calls)
266
+
267
+ return result
268
+
269
+
270
+ class AsyncAnthropicChatClient(BaseAsyncChatClient):
271
+ DEFAULT_MODEL: str = defs.ANTHROPIC_DEFAULT_MODEL
272
+ BACKEND_NAME: BackendType = BackendType.Anthropic
273
+
274
+ def __init__(
275
+ self,
276
+ model: str = defs.ANTHROPIC_DEFAULT_MODEL,
277
+ stream: bool = True,
278
+ temperature: float = 0.7,
279
+ context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
280
+ random_endpoint: bool = True,
281
+ endpoint_id: str = "",
282
+ **kwargs,
283
+ ):
284
+ super().__init__(
285
+ model,
286
+ stream,
287
+ temperature,
288
+ context_length_control,
289
+ random_endpoint,
290
+ endpoint_id,
291
+ **kwargs,
292
+ )
293
+
294
+ async def create_completion(
295
+ self,
296
+ messages: list = list,
297
+ model: str | None = None,
298
+ stream: bool | None = None,
299
+ temperature: float | None = None,
300
+ max_tokens: int = 2000,
301
+ tools: list | NotGiven = NOT_GIVEN,
302
+ tool_choice: str | NotGiven = NOT_GIVEN,
303
+ ):
304
+ if model is not None:
305
+ self.model = model
306
+ if stream is not None:
307
+ self.stream = stream
308
+ if temperature is not None:
309
+ self.temperature = temperature
310
+
311
+ self.model_setting = self.backend_settings.models[self.model]
312
+
313
+ if messages[0].get("role") == "system":
314
+ system_prompt = messages[0]["content"]
315
+ messages = messages[1:]
316
+ else:
317
+ system_prompt = ""
318
+
319
+ if self.context_length_control == ContextLengthControlType.Latest:
320
+ messages = cutoff_messages(
321
+ messages,
322
+ max_count=self.model_setting.context_length,
323
+ backend=self.BACKEND_NAME,
324
+ model=self.model_setting.id,
325
+ )
326
+
327
+ messages = format_messages_alternate(messages)
328
+
329
+ if self.random_endpoint:
330
+ self.random_endpoint = True
331
+ self.endpoint_id = random.choice(self.backend_settings.models[self.model].endpoints)
332
+ self.endpoint = settings.get_endpoint(self.endpoint_id)
333
+
334
+ if self.endpoint.is_vertex:
335
+ self.creds = Credentials(
336
+ token=self.endpoint.credentials.get("token"),
337
+ refresh_token=self.endpoint.credentials.get("refresh_token"),
338
+ token_uri=self.endpoint.credentials.get("token_uri"),
339
+ scopes=None,
340
+ client_id=self.endpoint.credentials.get("client_id"),
341
+ client_secret=self.endpoint.credentials.get("client_secret"),
342
+ quota_project_id=self.endpoint.credentials.get("quota_project_id"),
343
+ expiry=_helpers.utcnow() - _helpers.REFRESH_THRESHOLD,
344
+ rapt_token=self.endpoint.credentials.get("rapt_token"),
345
+ trust_boundary=self.endpoint.credentials.get("trust_boundary"),
346
+ universe_domain=self.endpoint.credentials.get("universe_domain"),
347
+ account=self.endpoint.credentials.get("account", ""),
348
+ )
349
+
350
+ if self.creds.expired and self.creds.refresh_token:
351
+ self.creds.refresh(Request())
352
+
353
+ if self.endpoint.api_base is None:
354
+ base_url = None
355
+ else:
356
+ base_url = f"{self.endpoint.api_base}{self.endpoint.region}-aiplatform/v1"
357
+
358
+ self._client = AsyncAnthropicVertex(
359
+ region=self.endpoint.region,
360
+ base_url=base_url,
361
+ project_id=self.endpoint.credentials.get("quota_project_id"),
362
+ access_token=self.creds.token,
363
+ )
364
+ else:
365
+ self._client = AsyncAnthropic(
366
+ api_key=self.endpoint.api_key,
367
+ base_url=self.endpoint.api_base,
368
+ )
369
+ response = await self._client.messages.create(
370
+ model=self.model_setting.id,
371
+ messages=messages,
372
+ system=system_prompt,
373
+ stream=self.stream,
374
+ temperature=self.temperature,
375
+ max_tokens=max_tokens,
376
+ tools=refactor_tool_use_params(tools) if tools else tools,
377
+ tool_choice=tool_choice,
378
+ )
379
+
380
+ if self.stream:
381
+
382
+ async def generator():
383
+ result = {"content": ""}
384
+ async for chunk in response:
385
+ message = {"content": ""}
386
+ if isinstance(chunk, RawMessageStartEvent):
387
+ result["usage"] = {"prompt_tokens": chunk.message.usage.input_tokens}
388
+ continue
389
+ elif isinstance(chunk, RawContentBlockStartEvent):
390
+ if chunk.content_block.type == "tool_use":
391
+ result["tool_calls"] = message["tool_calls"] = [
392
+ {
393
+ "index": 0,
394
+ "id": chunk.content_block.id,
395
+ "function": {
396
+ "arguments": "",
397
+ "name": chunk.content_block.name,
398
+ },
399
+ "type": "function",
400
+ }
401
+ ]
402
+ elif chunk.content_block.type == "text":
403
+ message["content"] = chunk.content_block.text
404
+ yield message
405
+ elif isinstance(chunk, RawContentBlockDeltaEvent):
406
+ if chunk.delta.type == "text_delta":
407
+ message["content"] = chunk.delta.text
408
+ result["content"] += chunk.delta.text
409
+ elif chunk.delta.type == "input_json_delta":
410
+ result["tool_calls"][0]["function"]["arguments"] += chunk.delta.partial_json
411
+ message["tool_calls"] = [
412
+ {
413
+ "index": 0,
414
+ "id": result["tool_calls"][0]["id"],
415
+ "function": {
416
+ "arguments": chunk.delta.partial_json,
417
+ "name": result["tool_calls"][0]["function"]["name"],
418
+ },
419
+ "type": "function",
420
+ }
421
+ ]
422
+ yield message
423
+ elif isinstance(chunk, RawMessageDeltaEvent):
424
+ result["usage"]["completion_tokens"] = chunk.usage.output_tokens
425
+ result["usage"]["total_tokens"] = (
426
+ result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"]
427
+ )
428
+ yield {"usage": result["usage"]}
429
+
430
+ return generator()
431
+ else:
432
+ result = {
433
+ "content": "",
434
+ "usage": {
435
+ "prompt_tokens": response.usage.input_tokens,
436
+ "completion_tokens": response.usage.output_tokens,
437
+ "total_tokens": response.usage.input_tokens + response.usage.output_tokens,
438
+ },
439
+ }
440
+ tool_calls = []
441
+ for content_block in response.content:
442
+ if isinstance(content_block, TextBlock):
443
+ result["content"] += content_block.text
444
+ elif isinstance(content_block, ToolUseBlock):
445
+ tool_calls.append(content_block.model_dump())
446
+
447
+ if tool_calls:
448
+ result["tool_calls"] = refactor_tool_calls(tool_calls)
449
+
450
+ return result
@@ -0,0 +1,91 @@
1
+ # @Author: Bi Ying
2
+ # @Date: 2024-07-26 14:48:55
3
+ from abc import ABC, abstractmethod
4
+
5
+ from ..settings import settings
6
+ from ..types import defaults as defs
7
+ from ..types.enums import ContextLengthControlType, BackendType
8
+
9
+
10
+ class BaseChatClient(ABC):
11
+ DEFAULT_MODEL: str | None = None
12
+ BACKEND_NAME: BackendType | None = None
13
+
14
+ def __init__(
15
+ self,
16
+ model: str = "",
17
+ stream: bool = True,
18
+ temperature: float = 0.7,
19
+ context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
20
+ random_endpoint: bool = True,
21
+ endpoint_id: str = "",
22
+ **kwargs,
23
+ ):
24
+ self.model = model or self.DEFAULT_MODEL
25
+ self.stream = stream
26
+ self.temperature = temperature
27
+ self.context_length_control = context_length_control
28
+ self.random_endpoint = random_endpoint
29
+ self.endpoint_id = endpoint_id
30
+
31
+ self.backend_settings = settings.get_backend(self.BACKEND_NAME)
32
+
33
+ if endpoint_id:
34
+ self.endpoint_id = endpoint_id
35
+ self.random_endpoint = False
36
+ self.endpoint = settings.get_endpoint(self.endpoint_id)
37
+
38
+ @abstractmethod
39
+ def create_completion(
40
+ self,
41
+ messages: list,
42
+ model: str | None = None,
43
+ stream: bool = True,
44
+ temperature: float = 0.7,
45
+ max_tokens: int = 2000,
46
+ tools: list | None = None,
47
+ tool_choice: str | None = None,
48
+ ):
49
+ pass
50
+
51
+
52
+ class BaseAsyncChatClient(ABC):
53
+ DEFAULT_MODEL: str | None = None
54
+ BACKEND_NAME: BackendType | None = None
55
+
56
+ def __init__(
57
+ self,
58
+ model: str = "",
59
+ stream: bool = True,
60
+ temperature: float = 0.7,
61
+ context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
62
+ random_endpoint: bool = True,
63
+ endpoint_id: str = "",
64
+ **kwargs,
65
+ ):
66
+ self.model = model or self.DEFAULT_MODEL
67
+ self.stream = stream
68
+ self.temperature = temperature
69
+ self.context_length_control = context_length_control
70
+ self.random_endpoint = random_endpoint
71
+ self.endpoint_id = endpoint_id
72
+
73
+ self.backend_settings = settings.get_backend(self.BACKEND_NAME)
74
+
75
+ if endpoint_id:
76
+ self.endpoint_id = endpoint_id
77
+ self.random_endpoint = False
78
+ self.endpoint = settings.get_endpoint(self.endpoint_id)
79
+
80
+ @abstractmethod
81
+ async def create_completion(
82
+ self,
83
+ messages: list,
84
+ model: str | None = None,
85
+ stream: bool = True,
86
+ temperature: float = 0.7,
87
+ max_tokens: int = 2000,
88
+ tools: list | None = None,
89
+ tool_choice: str | None = None,
90
+ ):
91
+ pass
@@ -0,0 +1,15 @@
1
+ # @Author: Bi Ying
2
+ # @Date: 2024-07-26 14:48:55
3
+ from ..types.enums import BackendType
4
+ from ..types.defaults import DEEPSEEK_DEFAULT_MODEL
5
+ from .openai_compatible_client import OpenAICompatibleChatClient, AsyncOpenAICompatibleChatClient
6
+
7
+
8
+ class DeepSeekChatClient(OpenAICompatibleChatClient):
9
+ DEFAULT_MODEL = DEEPSEEK_DEFAULT_MODEL
10
+ BACKEND_NAME = BackendType.DeepSeek
11
+
12
+
13
+ class AsyncDeepSeekChatClient(AsyncOpenAICompatibleChatClient):
14
+ DEFAULT_MODEL = DEEPSEEK_DEFAULT_MODEL
15
+ BACKEND_NAME = BackendType.DeepSeek