vectorvein 0.1.23__tar.gz → 0.1.25__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vectorvein-0.1.23 → vectorvein-0.1.25}/PKG-INFO +1 -1
- {vectorvein-0.1.23 → vectorvein-0.1.25}/pyproject.toml +1 -1
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/anthropic_client.py +175 -56
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/base_client.py +92 -15
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/gemini_client.py +84 -15
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/minimax_client.py +82 -13
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/openai_compatible_client.py +136 -36
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/utils.py +45 -17
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/types/defaults.py +57 -1
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/types/llm_parameters.py +24 -3
- {vectorvein-0.1.23 → vectorvein-0.1.25}/README.md +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/__init__.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/__init__.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/baichuan_client.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/deepseek_client.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/groq_client.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/local_client.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/mistral_client.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/moonshot_client.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/openai_client.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/qwen_client.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/yi_client.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/chat_clients/zhipuai_client.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/settings/__init__.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/types/enums.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/utilities/media_processing.py +0 -0
- {vectorvein-0.1.23 → vectorvein-0.1.25}/src/vectorvein/utilities/retry.py +0 -0
@@ -2,11 +2,13 @@
|
|
2
2
|
# @Date: 2024-07-26 14:48:55
|
3
3
|
import json
|
4
4
|
import random
|
5
|
+
from functools import cached_property
|
6
|
+
from typing import overload, Generator, AsyncGenerator, Any, Literal, Iterable
|
5
7
|
|
6
8
|
import httpx
|
7
9
|
from openai._types import NotGiven as OpenAINotGiven
|
8
10
|
from anthropic import Anthropic, AnthropicVertex, AsyncAnthropic, AsyncAnthropicVertex
|
9
|
-
from anthropic._types import
|
11
|
+
from anthropic._types import NOT_GIVEN
|
10
12
|
from anthropic.types import (
|
11
13
|
TextBlock,
|
12
14
|
ToolUseBlock,
|
@@ -24,15 +26,25 @@ from ..types import defaults as defs
|
|
24
26
|
from .utils import cutoff_messages, get_message_token_counts
|
25
27
|
from .base_client import BaseChatClient, BaseAsyncChatClient
|
26
28
|
from ..types.enums import ContextLengthControlType, BackendType
|
27
|
-
from ..types.llm_parameters import
|
29
|
+
from ..types.llm_parameters import (
|
30
|
+
Usage,
|
31
|
+
NotGiven,
|
32
|
+
ToolParam,
|
33
|
+
ToolChoice,
|
34
|
+
AnthropicToolParam,
|
35
|
+
AnthropicToolChoice,
|
36
|
+
ChatCompletionMessage,
|
37
|
+
ChatCompletionToolParam,
|
38
|
+
ChatCompletionDeltaMessage,
|
39
|
+
)
|
28
40
|
|
29
41
|
|
30
|
-
def refactor_tool_use_params(tools: list
|
42
|
+
def refactor_tool_use_params(tools: Iterable[ChatCompletionToolParam]) -> list[AnthropicToolParam]:
|
31
43
|
return [
|
32
44
|
{
|
33
45
|
"name": tool["function"]["name"],
|
34
|
-
"description": tool["function"]
|
35
|
-
"input_schema": tool["function"]
|
46
|
+
"description": tool["function"].get("description", ""),
|
47
|
+
"input_schema": tool["function"].get("parameters", {}),
|
36
48
|
}
|
37
49
|
for tool in tools
|
38
50
|
]
|
@@ -53,6 +65,17 @@ def refactor_tool_calls(tool_calls: list):
|
|
53
65
|
]
|
54
66
|
|
55
67
|
|
68
|
+
def refactor_tool_choice(tool_choice: ToolChoice) -> AnthropicToolChoice:
|
69
|
+
if isinstance(tool_choice, str):
|
70
|
+
if tool_choice == "auto":
|
71
|
+
return {"type": "auto"}
|
72
|
+
elif tool_choice == "required":
|
73
|
+
return {"type": "any"}
|
74
|
+
elif isinstance(tool_choice, dict) and "function" in tool_choice:
|
75
|
+
return {"type": "tool", "name": tool_choice["function"]["name"]}
|
76
|
+
return {"type": "auto"}
|
77
|
+
|
78
|
+
|
56
79
|
def format_messages_alternate(messages: list) -> list:
|
57
80
|
# messages: roles must alternate between "user" and "assistant", and not multiple "user" roles in a row
|
58
81
|
# reformat multiple "user" roles in a row into {"role": "user", "content": [{"type": "text", "text": "Hello, Claude"}, {"type": "text", "text": "How are you?"}]}
|
@@ -87,7 +110,7 @@ def format_messages_alternate(messages: list) -> list:
|
|
87
110
|
|
88
111
|
|
89
112
|
class AnthropicChatClient(BaseChatClient):
|
90
|
-
DEFAULT_MODEL: str = defs.ANTHROPIC_DEFAULT_MODEL
|
113
|
+
DEFAULT_MODEL: str | None = defs.ANTHROPIC_DEFAULT_MODEL
|
91
114
|
BACKEND_NAME: BackendType = BackendType.Anthropic
|
92
115
|
|
93
116
|
def __init__(
|
@@ -112,7 +135,7 @@ class AnthropicChatClient(BaseChatClient):
|
|
112
135
|
**kwargs,
|
113
136
|
)
|
114
137
|
|
115
|
-
@
|
138
|
+
@cached_property
|
116
139
|
def raw_client(self):
|
117
140
|
if self.random_endpoint:
|
118
141
|
self.random_endpoint = True
|
@@ -120,6 +143,8 @@ class AnthropicChatClient(BaseChatClient):
|
|
120
143
|
self.endpoint = settings.get_endpoint(self.endpoint_id)
|
121
144
|
|
122
145
|
if self.endpoint.is_vertex:
|
146
|
+
if self.endpoint.credentials is None:
|
147
|
+
raise ValueError("Anthropic Vertex endpoint requires credentials")
|
123
148
|
self.creds = Credentials(
|
124
149
|
token=self.endpoint.credentials.get("token"),
|
125
150
|
refresh_token=self.endpoint.credentials.get("refresh_token"),
|
@@ -131,7 +156,7 @@ class AnthropicChatClient(BaseChatClient):
|
|
131
156
|
expiry=_helpers.utcnow() - _helpers.REFRESH_THRESHOLD,
|
132
157
|
rapt_token=self.endpoint.credentials.get("rapt_token"),
|
133
158
|
trust_boundary=self.endpoint.credentials.get("trust_boundary"),
|
134
|
-
universe_domain=self.endpoint.credentials.get("universe_domain"),
|
159
|
+
universe_domain=self.endpoint.credentials.get("universe_domain", "googleapis.com"),
|
135
160
|
account=self.endpoint.credentials.get("account", ""),
|
136
161
|
)
|
137
162
|
|
@@ -143,10 +168,11 @@ class AnthropicChatClient(BaseChatClient):
|
|
143
168
|
else:
|
144
169
|
base_url = f"{self.endpoint.api_base}{self.endpoint.region}-aiplatform/v1"
|
145
170
|
|
171
|
+
region = NOT_GIVEN if self.endpoint.region is None else self.endpoint.region
|
146
172
|
return AnthropicVertex(
|
147
|
-
region=
|
173
|
+
region=region,
|
148
174
|
base_url=base_url,
|
149
|
-
project_id=self.endpoint.credentials.get("quota_project_id"),
|
175
|
+
project_id=self.endpoint.credentials.get("quota_project_id", NOT_GIVEN),
|
150
176
|
access_token=self.creds.token,
|
151
177
|
http_client=self.http_client,
|
152
178
|
)
|
@@ -157,15 +183,46 @@ class AnthropicChatClient(BaseChatClient):
|
|
157
183
|
http_client=self.http_client,
|
158
184
|
)
|
159
185
|
|
186
|
+
@overload
|
187
|
+
def create_completion(
|
188
|
+
self,
|
189
|
+
messages: list,
|
190
|
+
model: str | None = None,
|
191
|
+
stream: Literal[False] = False,
|
192
|
+
temperature: float | None = None,
|
193
|
+
max_tokens: int | None = None,
|
194
|
+
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
|
195
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
196
|
+
response_format: dict | None = None,
|
197
|
+
**kwargs,
|
198
|
+
) -> ChatCompletionMessage:
|
199
|
+
pass
|
200
|
+
|
201
|
+
@overload
|
202
|
+
def create_completion(
|
203
|
+
self,
|
204
|
+
messages: list,
|
205
|
+
model: str | None = None,
|
206
|
+
stream: Literal[True] = True,
|
207
|
+
temperature: float | None = None,
|
208
|
+
max_tokens: int | None = None,
|
209
|
+
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
|
210
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
211
|
+
response_format: dict | None = None,
|
212
|
+
**kwargs,
|
213
|
+
) -> Generator[ChatCompletionDeltaMessage, None, None]:
|
214
|
+
pass
|
215
|
+
|
160
216
|
def create_completion(
|
161
217
|
self,
|
162
|
-
messages: list
|
218
|
+
messages: list,
|
163
219
|
model: str | None = None,
|
164
220
|
stream: bool | None = None,
|
165
221
|
temperature: float | None = None,
|
166
222
|
max_tokens: int | None = None,
|
167
|
-
tools:
|
168
|
-
tool_choice:
|
223
|
+
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
|
224
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
225
|
+
response_format: dict | None = None,
|
169
226
|
**kwargs,
|
170
227
|
):
|
171
228
|
if model is not None:
|
@@ -182,7 +239,7 @@ class AnthropicChatClient(BaseChatClient):
|
|
182
239
|
self.model_setting = self.backend_settings.models[self.model]
|
183
240
|
|
184
241
|
if messages[0].get("role") == "system":
|
185
|
-
system_prompt = messages[0]["content"]
|
242
|
+
system_prompt: str = messages[0]["content"]
|
186
243
|
messages = messages[1:]
|
187
244
|
else:
|
188
245
|
system_prompt = ""
|
@@ -197,7 +254,10 @@ class AnthropicChatClient(BaseChatClient):
|
|
197
254
|
|
198
255
|
messages = format_messages_alternate(messages)
|
199
256
|
|
200
|
-
tools_params = refactor_tool_use_params(tools) if tools else
|
257
|
+
tools_params: list[AnthropicToolParam] | NotGiven = refactor_tool_use_params(tools) if tools else NOT_GIVEN
|
258
|
+
tool_choice_param = NOT_GIVEN
|
259
|
+
if tool_choice:
|
260
|
+
tool_choice_param = refactor_tool_choice(tool_choice)
|
201
261
|
|
202
262
|
if max_tokens is None:
|
203
263
|
max_output_tokens = self.model_setting.max_output_tokens
|
@@ -208,24 +268,23 @@ class AnthropicChatClient(BaseChatClient):
|
|
208
268
|
else:
|
209
269
|
max_tokens = self.model_setting.context_length - token_counts
|
210
270
|
|
211
|
-
response = self.raw_client.messages.create(
|
212
|
-
model=self.model_setting.id,
|
213
|
-
messages=messages,
|
214
|
-
system=system_prompt,
|
215
|
-
stream=self.stream,
|
216
|
-
temperature=self.temperature,
|
217
|
-
max_tokens=max_tokens,
|
218
|
-
tools=tools_params,
|
219
|
-
tool_choice=tool_choice,
|
220
|
-
**kwargs,
|
221
|
-
)
|
222
|
-
|
223
271
|
if self.stream:
|
272
|
+
stream_response = self.raw_client.messages.create(
|
273
|
+
model=self.model_setting.id,
|
274
|
+
messages=messages,
|
275
|
+
system=system_prompt,
|
276
|
+
stream=True,
|
277
|
+
temperature=self.temperature,
|
278
|
+
max_tokens=max_tokens,
|
279
|
+
tools=tools_params,
|
280
|
+
tool_choice=tool_choice_param,
|
281
|
+
**kwargs,
|
282
|
+
)
|
224
283
|
|
225
284
|
def generator():
|
226
|
-
result = {"content": ""}
|
227
|
-
for chunk in
|
228
|
-
message = {"content": ""}
|
285
|
+
result = {"content": "", "usage": {}, "tool_calls": []}
|
286
|
+
for chunk in stream_response:
|
287
|
+
message = {"content": "", "tool_calls": []}
|
229
288
|
if isinstance(chunk, RawMessageStartEvent):
|
230
289
|
result["usage"] = {"prompt_tokens": chunk.message.usage.input_tokens}
|
231
290
|
continue
|
@@ -268,10 +327,22 @@ class AnthropicChatClient(BaseChatClient):
|
|
268
327
|
result["usage"]["total_tokens"] = (
|
269
328
|
result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"]
|
270
329
|
)
|
271
|
-
yield ChatCompletionDeltaMessage(usage=result["usage"])
|
330
|
+
yield ChatCompletionDeltaMessage(usage=Usage(**result["usage"]))
|
272
331
|
|
273
332
|
return generator()
|
274
333
|
else:
|
334
|
+
response = self.raw_client.messages.create(
|
335
|
+
model=self.model_setting.id,
|
336
|
+
messages=messages,
|
337
|
+
system=system_prompt,
|
338
|
+
stream=False,
|
339
|
+
temperature=self.temperature,
|
340
|
+
max_tokens=max_tokens,
|
341
|
+
tools=tools_params,
|
342
|
+
tool_choice=tool_choice_param,
|
343
|
+
**kwargs,
|
344
|
+
)
|
345
|
+
|
275
346
|
result = {
|
276
347
|
"content": "",
|
277
348
|
"usage": {
|
@@ -294,7 +365,7 @@ class AnthropicChatClient(BaseChatClient):
|
|
294
365
|
|
295
366
|
|
296
367
|
class AsyncAnthropicChatClient(BaseAsyncChatClient):
|
297
|
-
DEFAULT_MODEL: str = defs.ANTHROPIC_DEFAULT_MODEL
|
368
|
+
DEFAULT_MODEL: str | None = defs.ANTHROPIC_DEFAULT_MODEL
|
298
369
|
BACKEND_NAME: BackendType = BackendType.Anthropic
|
299
370
|
|
300
371
|
def __init__(
|
@@ -319,7 +390,7 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
|
|
319
390
|
**kwargs,
|
320
391
|
)
|
321
392
|
|
322
|
-
@
|
393
|
+
@cached_property
|
323
394
|
def raw_client(self):
|
324
395
|
if self.random_endpoint:
|
325
396
|
self.random_endpoint = True
|
@@ -327,6 +398,8 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
|
|
327
398
|
self.endpoint = settings.get_endpoint(self.endpoint_id)
|
328
399
|
|
329
400
|
if self.endpoint.is_vertex:
|
401
|
+
if self.endpoint.credentials is None:
|
402
|
+
raise ValueError("Anthropic Vertex endpoint requires credentials")
|
330
403
|
self.creds = Credentials(
|
331
404
|
token=self.endpoint.credentials.get("token"),
|
332
405
|
refresh_token=self.endpoint.credentials.get("refresh_token"),
|
@@ -338,7 +411,7 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
|
|
338
411
|
expiry=_helpers.utcnow() - _helpers.REFRESH_THRESHOLD,
|
339
412
|
rapt_token=self.endpoint.credentials.get("rapt_token"),
|
340
413
|
trust_boundary=self.endpoint.credentials.get("trust_boundary"),
|
341
|
-
universe_domain=self.endpoint.credentials.get("universe_domain"),
|
414
|
+
universe_domain=self.endpoint.credentials.get("universe_domain", "googleapis.com"),
|
342
415
|
account=self.endpoint.credentials.get("account", ""),
|
343
416
|
)
|
344
417
|
|
@@ -350,10 +423,11 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
|
|
350
423
|
else:
|
351
424
|
base_url = f"{self.endpoint.api_base}{self.endpoint.region}-aiplatform/v1"
|
352
425
|
|
426
|
+
region = NOT_GIVEN if self.endpoint.region is None else self.endpoint.region
|
353
427
|
return AsyncAnthropicVertex(
|
354
|
-
region=
|
428
|
+
region=region,
|
355
429
|
base_url=base_url,
|
356
|
-
project_id=self.endpoint.credentials.get("quota_project_id"),
|
430
|
+
project_id=self.endpoint.credentials.get("quota_project_id", NOT_GIVEN),
|
357
431
|
access_token=self.creds.token,
|
358
432
|
http_client=self.http_client,
|
359
433
|
)
|
@@ -364,15 +438,46 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
|
|
364
438
|
http_client=self.http_client,
|
365
439
|
)
|
366
440
|
|
441
|
+
@overload
|
442
|
+
async def create_completion(
|
443
|
+
self,
|
444
|
+
messages: list,
|
445
|
+
model: str | None = None,
|
446
|
+
stream: Literal[False] = False,
|
447
|
+
temperature: float | None = None,
|
448
|
+
max_tokens: int | None = None,
|
449
|
+
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
|
450
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
451
|
+
response_format: dict | None = None,
|
452
|
+
**kwargs,
|
453
|
+
) -> ChatCompletionMessage:
|
454
|
+
pass
|
455
|
+
|
456
|
+
@overload
|
457
|
+
async def create_completion(
|
458
|
+
self,
|
459
|
+
messages: list,
|
460
|
+
model: str | None = None,
|
461
|
+
stream: Literal[True] = True,
|
462
|
+
temperature: float | None = None,
|
463
|
+
max_tokens: int | None = None,
|
464
|
+
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
|
465
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
466
|
+
response_format: dict | None = None,
|
467
|
+
**kwargs,
|
468
|
+
) -> AsyncGenerator[ChatCompletionDeltaMessage, Any]:
|
469
|
+
pass
|
470
|
+
|
367
471
|
async def create_completion(
|
368
472
|
self,
|
369
|
-
messages: list
|
473
|
+
messages: list,
|
370
474
|
model: str | None = None,
|
371
475
|
stream: bool | None = None,
|
372
476
|
temperature: float | None = None,
|
373
477
|
max_tokens: int | None = None,
|
374
|
-
tools:
|
375
|
-
tool_choice:
|
478
|
+
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
|
479
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
480
|
+
response_format: dict | None = None,
|
376
481
|
**kwargs,
|
377
482
|
):
|
378
483
|
if model is not None:
|
@@ -404,7 +509,10 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
|
|
404
509
|
|
405
510
|
messages = format_messages_alternate(messages)
|
406
511
|
|
407
|
-
tools_params = refactor_tool_use_params(tools) if tools else
|
512
|
+
tools_params: list[AnthropicToolParam] | NotGiven = refactor_tool_use_params(tools) if tools else NOT_GIVEN
|
513
|
+
tool_choice_param = NOT_GIVEN
|
514
|
+
if tool_choice:
|
515
|
+
tool_choice_param = refactor_tool_choice(tool_choice)
|
408
516
|
|
409
517
|
if max_tokens is None:
|
410
518
|
max_output_tokens = self.model_setting.max_output_tokens
|
@@ -415,24 +523,23 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
|
|
415
523
|
else:
|
416
524
|
max_tokens = self.model_setting.context_length - token_counts
|
417
525
|
|
418
|
-
response = await self.raw_client.messages.create(
|
419
|
-
model=self.model_setting.id,
|
420
|
-
messages=messages,
|
421
|
-
system=system_prompt,
|
422
|
-
stream=self.stream,
|
423
|
-
temperature=self.temperature,
|
424
|
-
max_tokens=max_tokens,
|
425
|
-
tools=tools_params,
|
426
|
-
tool_choice=tool_choice,
|
427
|
-
**kwargs,
|
428
|
-
)
|
429
|
-
|
430
526
|
if self.stream:
|
527
|
+
stream_response = await self.raw_client.messages.create(
|
528
|
+
model=self.model_setting.id,
|
529
|
+
messages=messages,
|
530
|
+
system=system_prompt,
|
531
|
+
stream=True,
|
532
|
+
temperature=self.temperature,
|
533
|
+
max_tokens=max_tokens,
|
534
|
+
tools=tools_params,
|
535
|
+
tool_choice=tool_choice_param,
|
536
|
+
**kwargs,
|
537
|
+
)
|
431
538
|
|
432
539
|
async def generator():
|
433
|
-
result = {"content": ""}
|
434
|
-
async for chunk in
|
435
|
-
message = {"content": ""}
|
540
|
+
result = {"content": "", "usage": {}, "tool_calls": []}
|
541
|
+
async for chunk in stream_response:
|
542
|
+
message = {"content": "", "tool_calls": []}
|
436
543
|
if isinstance(chunk, RawMessageStartEvent):
|
437
544
|
result["usage"] = {"prompt_tokens": chunk.message.usage.input_tokens}
|
438
545
|
continue
|
@@ -475,10 +582,22 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
|
|
475
582
|
result["usage"]["total_tokens"] = (
|
476
583
|
result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"]
|
477
584
|
)
|
478
|
-
yield ChatCompletionDeltaMessage(usage=result["usage"])
|
585
|
+
yield ChatCompletionDeltaMessage(usage=Usage(**result["usage"]))
|
479
586
|
|
480
587
|
return generator()
|
481
588
|
else:
|
589
|
+
response = await self.raw_client.messages.create(
|
590
|
+
model=self.model_setting.id,
|
591
|
+
messages=messages,
|
592
|
+
system=system_prompt,
|
593
|
+
stream=False,
|
594
|
+
temperature=self.temperature,
|
595
|
+
max_tokens=max_tokens,
|
596
|
+
tools=tools_params,
|
597
|
+
tool_choice=tool_choice_param,
|
598
|
+
**kwargs,
|
599
|
+
)
|
600
|
+
|
482
601
|
result = {
|
483
602
|
"content": "",
|
484
603
|
"usage": {
|
@@ -1,22 +1,29 @@
|
|
1
1
|
# @Author: Bi Ying
|
2
2
|
# @Date: 2024-07-26 14:48:55
|
3
3
|
from abc import ABC, abstractmethod
|
4
|
-
from
|
4
|
+
from functools import cached_property
|
5
|
+
from typing import Generator, AsyncGenerator, Any, overload, Literal, Iterable
|
5
6
|
|
6
7
|
import httpx
|
7
|
-
from openai._types import NotGiven, NOT_GIVEN
|
8
8
|
from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
|
9
9
|
from anthropic import Anthropic, AnthropicVertex, AsyncAnthropic, AsyncAnthropicVertex
|
10
10
|
|
11
11
|
from ..settings import settings
|
12
12
|
from ..types import defaults as defs
|
13
13
|
from ..types.enums import ContextLengthControlType, BackendType
|
14
|
-
from ..types.llm_parameters import
|
14
|
+
from ..types.llm_parameters import (
|
15
|
+
NotGiven,
|
16
|
+
NOT_GIVEN,
|
17
|
+
ToolParam,
|
18
|
+
ToolChoice,
|
19
|
+
ChatCompletionMessage,
|
20
|
+
ChatCompletionDeltaMessage,
|
21
|
+
)
|
15
22
|
|
16
23
|
|
17
24
|
class BaseChatClient(ABC):
|
18
25
|
DEFAULT_MODEL: str | None = None
|
19
|
-
BACKEND_NAME: BackendType
|
26
|
+
BACKEND_NAME: BackendType
|
20
27
|
|
21
28
|
def __init__(
|
22
29
|
self,
|
@@ -44,9 +51,41 @@ class BaseChatClient(ABC):
|
|
44
51
|
self.random_endpoint = False
|
45
52
|
self.endpoint = settings.get_endpoint(self.endpoint_id)
|
46
53
|
|
47
|
-
@
|
54
|
+
@cached_property
|
48
55
|
@abstractmethod
|
49
|
-
def raw_client(self) -> OpenAI | AzureOpenAI | Anthropic | AnthropicVertex:
|
56
|
+
def raw_client(self) -> OpenAI | AzureOpenAI | Anthropic | AnthropicVertex | httpx.Client | None:
|
57
|
+
pass
|
58
|
+
|
59
|
+
@overload
|
60
|
+
@abstractmethod
|
61
|
+
def create_completion(
|
62
|
+
self,
|
63
|
+
messages: list,
|
64
|
+
model: str | None = None,
|
65
|
+
stream: Literal[False] = False,
|
66
|
+
temperature: float = 0.7,
|
67
|
+
max_tokens: int | None = None,
|
68
|
+
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
|
69
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
70
|
+
response_format: dict | None = None,
|
71
|
+
**kwargs,
|
72
|
+
) -> ChatCompletionMessage:
|
73
|
+
pass
|
74
|
+
|
75
|
+
@overload
|
76
|
+
@abstractmethod
|
77
|
+
def create_completion(
|
78
|
+
self,
|
79
|
+
messages: list,
|
80
|
+
model: str | None = None,
|
81
|
+
stream: Literal[True] = True,
|
82
|
+
temperature: float = 0.7,
|
83
|
+
max_tokens: int | None = None,
|
84
|
+
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
|
85
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
86
|
+
response_format: dict | None = None,
|
87
|
+
**kwargs,
|
88
|
+
) -> Generator[ChatCompletionDeltaMessage, Any, None]:
|
50
89
|
pass
|
51
90
|
|
52
91
|
@abstractmethod
|
@@ -57,8 +96,8 @@ class BaseChatClient(ABC):
|
|
57
96
|
stream: bool = False,
|
58
97
|
temperature: float = 0.7,
|
59
98
|
max_tokens: int | None = None,
|
60
|
-
tools:
|
61
|
-
tool_choice:
|
99
|
+
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
|
100
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
62
101
|
response_format: dict | None = None,
|
63
102
|
**kwargs,
|
64
103
|
) -> ChatCompletionMessage | Generator[ChatCompletionDeltaMessage, Any, None]:
|
@@ -70,8 +109,9 @@ class BaseChatClient(ABC):
|
|
70
109
|
model: str | None = None,
|
71
110
|
temperature: float = 0.7,
|
72
111
|
max_tokens: int | None = None,
|
73
|
-
tools:
|
74
|
-
tool_choice:
|
112
|
+
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
|
113
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
114
|
+
response_format: dict | None = None,
|
75
115
|
**kwargs,
|
76
116
|
) -> Generator[ChatCompletionDeltaMessage, Any, None]:
|
77
117
|
return self.create_completion(
|
@@ -82,13 +122,14 @@ class BaseChatClient(ABC):
|
|
82
122
|
max_tokens=max_tokens,
|
83
123
|
tools=tools,
|
84
124
|
tool_choice=tool_choice,
|
125
|
+
response_format=response_format,
|
85
126
|
**kwargs,
|
86
127
|
)
|
87
128
|
|
88
129
|
|
89
130
|
class BaseAsyncChatClient(ABC):
|
90
131
|
DEFAULT_MODEL: str | None = None
|
91
|
-
BACKEND_NAME: BackendType
|
132
|
+
BACKEND_NAME: BackendType
|
92
133
|
|
93
134
|
def __init__(
|
94
135
|
self,
|
@@ -116,9 +157,43 @@ class BaseAsyncChatClient(ABC):
|
|
116
157
|
self.random_endpoint = False
|
117
158
|
self.endpoint = settings.get_endpoint(self.endpoint_id)
|
118
159
|
|
119
|
-
@
|
160
|
+
@cached_property
|
120
161
|
@abstractmethod
|
121
|
-
def raw_client(
|
162
|
+
def raw_client(
|
163
|
+
self,
|
164
|
+
) -> AsyncOpenAI | AsyncAzureOpenAI | AsyncAnthropic | AsyncAnthropicVertex | httpx.AsyncClient | None:
|
165
|
+
pass
|
166
|
+
|
167
|
+
@overload
|
168
|
+
@abstractmethod
|
169
|
+
async def create_completion(
|
170
|
+
self,
|
171
|
+
messages: list,
|
172
|
+
model: str | None = None,
|
173
|
+
stream: Literal[False] = False,
|
174
|
+
temperature: float = 0.7,
|
175
|
+
max_tokens: int | None = None,
|
176
|
+
tools: list | NotGiven = NOT_GIVEN,
|
177
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
178
|
+
response_format: dict | None = None,
|
179
|
+
**kwargs,
|
180
|
+
) -> ChatCompletionMessage:
|
181
|
+
pass
|
182
|
+
|
183
|
+
@overload
|
184
|
+
@abstractmethod
|
185
|
+
async def create_completion(
|
186
|
+
self,
|
187
|
+
messages: list,
|
188
|
+
model: str | None = None,
|
189
|
+
stream: Literal[True] = True,
|
190
|
+
temperature: float = 0.7,
|
191
|
+
max_tokens: int | None = None,
|
192
|
+
tools: list | NotGiven = NOT_GIVEN,
|
193
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
194
|
+
response_format: dict | None = None,
|
195
|
+
**kwargs,
|
196
|
+
) -> AsyncGenerator[ChatCompletionDeltaMessage, None]:
|
122
197
|
pass
|
123
198
|
|
124
199
|
@abstractmethod
|
@@ -130,7 +205,7 @@ class BaseAsyncChatClient(ABC):
|
|
130
205
|
temperature: float = 0.7,
|
131
206
|
max_tokens: int | None = None,
|
132
207
|
tools: list | NotGiven = NOT_GIVEN,
|
133
|
-
tool_choice:
|
208
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
134
209
|
response_format: dict | None = None,
|
135
210
|
**kwargs,
|
136
211
|
) -> ChatCompletionMessage | AsyncGenerator[ChatCompletionDeltaMessage, None]:
|
@@ -143,7 +218,8 @@ class BaseAsyncChatClient(ABC):
|
|
143
218
|
temperature: float = 0.7,
|
144
219
|
max_tokens: int | None = None,
|
145
220
|
tools: list | NotGiven = NOT_GIVEN,
|
146
|
-
tool_choice:
|
221
|
+
tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
|
222
|
+
response_format: dict | None = None,
|
147
223
|
**kwargs,
|
148
224
|
) -> AsyncGenerator[ChatCompletionDeltaMessage, None]:
|
149
225
|
return await self.create_completion(
|
@@ -154,5 +230,6 @@ class BaseAsyncChatClient(ABC):
|
|
154
230
|
max_tokens=max_tokens,
|
155
231
|
tools=tools,
|
156
232
|
tool_choice=tool_choice,
|
233
|
+
response_format=response_format,
|
157
234
|
**kwargs,
|
158
235
|
)
|