vectorvein 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectorvein-0.1.0/PKG-INFO +16 -0
- vectorvein-0.1.0/README.md +1 -0
- vectorvein-0.1.0/pyproject.toml +29 -0
- vectorvein-0.1.0/src/vectorvein/__init__.py +0 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/__init__.py +110 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/anthropic_client.py +450 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/base_client.py +91 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/deepseek_client.py +15 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/gemini_client.py +317 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/groq_client.py +15 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/local_client.py +14 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/minimax_client.py +315 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/mistral_client.py +15 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/moonshot_client.py +15 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/openai_client.py +15 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/openai_compatible_client.py +291 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/qwen_client.py +15 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/utils.py +635 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/yi_client.py +15 -0
- vectorvein-0.1.0/src/vectorvein/chat_clients/zhipuai_client.py +15 -0
- vectorvein-0.1.0/src/vectorvein/settings/__init__.py +71 -0
- vectorvein-0.1.0/src/vectorvein/types/defaults.py +396 -0
- vectorvein-0.1.0/src/vectorvein/types/enums.py +83 -0
- vectorvein-0.1.0/src/vectorvein/types/llm_parameters.py +69 -0
- vectorvein-0.1.0/src/vectorvein/utilities/media_processing.py +70 -0
- vectorvein-0.1.0/tests/__init__.py +0 -0
- vectorvein-0.1.0/tests/cat.png +0 -0
- vectorvein-0.1.0/tests/sample_settings.py +88 -0
- vectorvein-0.1.0/tests/test_create_chat_client.py +194 -0
- vectorvein-0.1.0/tests/test_format_messages.py +41 -0
- vectorvein-0.1.0/tests/test_image_input_chat_client.py +45 -0
@@ -0,0 +1,16 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: vectorvein
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary: Default template for PDM package
|
5
|
+
Author-Email: Anderson <andersonby@163.com>
|
6
|
+
License: MIT
|
7
|
+
Requires-Python: >=3.8
|
8
|
+
Requires-Dist: openai>=1.37.1
|
9
|
+
Requires-Dist: tiktoken>=0.7.0
|
10
|
+
Requires-Dist: httpx>=0.27.0
|
11
|
+
Requires-Dist: anthropic[vertex]>=0.31.2
|
12
|
+
Requires-Dist: pydantic>=2.8.2
|
13
|
+
Requires-Dist: Pillow>=10.4.0
|
14
|
+
Description-Content-Type: text/markdown
|
15
|
+
|
16
|
+
# vectorvein
|
@@ -0,0 +1 @@
|
|
1
|
+
# vectorvein
|
@@ -0,0 +1,29 @@
|
|
1
|
+
[project]
|
2
|
+
name = "vectorvein"
|
3
|
+
version = "0.1.0"
|
4
|
+
description = "Default template for PDM package"
|
5
|
+
authors = [
|
6
|
+
{ name = "Anderson", email = "andersonby@163.com" },
|
7
|
+
]
|
8
|
+
dependencies = [
|
9
|
+
"openai>=1.37.1",
|
10
|
+
"tiktoken>=0.7.0",
|
11
|
+
"httpx>=0.27.0",
|
12
|
+
"anthropic[vertex]>=0.31.2",
|
13
|
+
"pydantic>=2.8.2",
|
14
|
+
"Pillow>=10.4.0",
|
15
|
+
]
|
16
|
+
requires-python = ">=3.8"
|
17
|
+
readme = "README.md"
|
18
|
+
|
19
|
+
[project.license]
|
20
|
+
text = "MIT"
|
21
|
+
|
22
|
+
[build-system]
|
23
|
+
requires = [
|
24
|
+
"pdm-backend",
|
25
|
+
]
|
26
|
+
build-backend = "pdm.backend"
|
27
|
+
|
28
|
+
[tool.pdm]
|
29
|
+
distribution = true
|
File without changes
|
@@ -0,0 +1,110 @@
|
|
1
|
+
# @Author: Bi Ying
|
2
|
+
# @Date: 2024-07-26 14:48:55
|
3
|
+
from .base_client import BaseChatClient, BaseAsyncChatClient
|
4
|
+
|
5
|
+
from .yi_client import YiChatClient, AsyncYiChatClient
|
6
|
+
from .groq_client import GroqChatClient, AsyncGroqChatClient
|
7
|
+
from .qwen_client import QwenChatClient, AsyncQwenChatClient
|
8
|
+
from .local_client import LocalChatClient, AsyncLocalChatClient
|
9
|
+
from .gemini_client import GeminiChatClient, AsyncGeminiChatClient
|
10
|
+
from .openai_client import OpenAIChatClient, AsyncOpenAIChatClient
|
11
|
+
from .zhipuai_client import ZhiPuAIChatClient, AsyncZhiPuAIChatClient
|
12
|
+
from .minimax_client import MiniMaxChatClient, AsyncMiniMaxChatClient
|
13
|
+
from .mistral_client import MistralChatClient, AsyncMistralChatClient
|
14
|
+
from .moonshot_client import MoonshotChatClient, AsyncMoonshotChatClient
|
15
|
+
from .deepseek_client import DeepSeekChatClient, AsyncDeepSeekChatClient
|
16
|
+
|
17
|
+
from ..types import defaults as defs
|
18
|
+
from ..types.enums import BackendType, ContextLengthControlType
|
19
|
+
from .anthropic_client import AnthropicChatClient, AsyncAnthropicChatClient
|
20
|
+
from .utils import format_messages
|
21
|
+
|
22
|
+
|
23
|
+
BackendMap = {
|
24
|
+
"sync": {
|
25
|
+
BackendType.Anthropic: AnthropicChatClient,
|
26
|
+
BackendType.DeepSeek: DeepSeekChatClient,
|
27
|
+
BackendType.Gemini: GeminiChatClient,
|
28
|
+
BackendType.Groq: GroqChatClient,
|
29
|
+
BackendType.Local: LocalChatClient,
|
30
|
+
BackendType.MiniMax: MiniMaxChatClient,
|
31
|
+
BackendType.Mistral: MistralChatClient,
|
32
|
+
BackendType.Moonshot: MoonshotChatClient,
|
33
|
+
BackendType.OpenAI: OpenAIChatClient,
|
34
|
+
BackendType.Qwen: QwenChatClient,
|
35
|
+
BackendType.Yi: YiChatClient,
|
36
|
+
BackendType.ZhiPuAI: ZhiPuAIChatClient,
|
37
|
+
},
|
38
|
+
"async": {
|
39
|
+
BackendType.Anthropic: AsyncAnthropicChatClient,
|
40
|
+
BackendType.DeepSeek: AsyncDeepSeekChatClient,
|
41
|
+
BackendType.Gemini: AsyncGeminiChatClient,
|
42
|
+
BackendType.Groq: AsyncGroqChatClient,
|
43
|
+
BackendType.Local: AsyncLocalChatClient,
|
44
|
+
BackendType.MiniMax: AsyncMiniMaxChatClient,
|
45
|
+
BackendType.Mistral: AsyncMistralChatClient,
|
46
|
+
BackendType.Moonshot: AsyncMoonshotChatClient,
|
47
|
+
BackendType.OpenAI: AsyncOpenAIChatClient,
|
48
|
+
BackendType.Qwen: AsyncQwenChatClient,
|
49
|
+
BackendType.Yi: AsyncYiChatClient,
|
50
|
+
BackendType.ZhiPuAI: AsyncZhiPuAIChatClient,
|
51
|
+
},
|
52
|
+
}
|
53
|
+
|
54
|
+
|
55
|
+
def create_chat_client(
|
56
|
+
backend: BackendType,
|
57
|
+
model: str | None = None,
|
58
|
+
stream: bool = True,
|
59
|
+
temperature: float = 0.7,
|
60
|
+
context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
|
61
|
+
**kwargs,
|
62
|
+
) -> BaseChatClient:
|
63
|
+
if backend.lower() not in BackendMap["sync"]:
|
64
|
+
raise ValueError(f"Unsupported backend: {backend}")
|
65
|
+
else:
|
66
|
+
backend_key = backend.lower()
|
67
|
+
|
68
|
+
ClientClass = BackendMap["sync"][backend_key]
|
69
|
+
if model is None:
|
70
|
+
model = ClientClass.DEFAULT_MODEL
|
71
|
+
return BackendMap["sync"][backend_key](
|
72
|
+
model=model,
|
73
|
+
stream=stream,
|
74
|
+
temperature=temperature,
|
75
|
+
context_length_control=context_length_control,
|
76
|
+
**kwargs,
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
def create_async_chat_client(
|
81
|
+
backend: BackendType,
|
82
|
+
model: str | None = None,
|
83
|
+
stream: bool = True,
|
84
|
+
temperature: float = 0.7,
|
85
|
+
context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
|
86
|
+
**kwargs,
|
87
|
+
) -> BaseAsyncChatClient:
|
88
|
+
if backend.lower() not in BackendMap["async"]:
|
89
|
+
raise ValueError(f"Unsupported backend: {backend}")
|
90
|
+
else:
|
91
|
+
backend_key = backend.lower()
|
92
|
+
|
93
|
+
ClientClass = BackendMap["async"][backend_key]
|
94
|
+
if model is None:
|
95
|
+
model = ClientClass.DEFAULT_MODEL
|
96
|
+
return BackendMap["async"][backend_key](
|
97
|
+
model=model,
|
98
|
+
stream=stream,
|
99
|
+
temperature=temperature,
|
100
|
+
context_length_control=context_length_control,
|
101
|
+
**kwargs,
|
102
|
+
)
|
103
|
+
|
104
|
+
|
105
|
+
__all__ = [
|
106
|
+
"create_chat_client",
|
107
|
+
"create_async_chat_client",
|
108
|
+
"format_messages",
|
109
|
+
"BackendType",
|
110
|
+
]
|
@@ -0,0 +1,450 @@
|
|
1
|
+
# @Author: Bi Ying
|
2
|
+
# @Date: 2024-07-26 14:48:55
|
3
|
+
import json
|
4
|
+
import random
|
5
|
+
|
6
|
+
from anthropic import Anthropic, AnthropicVertex, AsyncAnthropic, AsyncAnthropicVertex
|
7
|
+
from anthropic._types import NotGiven, NOT_GIVEN
|
8
|
+
from anthropic.types import (
|
9
|
+
TextBlock,
|
10
|
+
ToolUseBlock,
|
11
|
+
RawMessageDeltaEvent,
|
12
|
+
RawMessageStartEvent,
|
13
|
+
RawContentBlockStartEvent,
|
14
|
+
RawContentBlockDeltaEvent,
|
15
|
+
)
|
16
|
+
from google.oauth2.credentials import Credentials
|
17
|
+
from google.auth.transport.requests import Request
|
18
|
+
from google.auth import _helpers
|
19
|
+
|
20
|
+
from ..settings import settings
|
21
|
+
from .utils import cutoff_messages
|
22
|
+
from ..types import defaults as defs
|
23
|
+
from .base_client import BaseChatClient, BaseAsyncChatClient
|
24
|
+
from ..types.enums import ContextLengthControlType, BackendType
|
25
|
+
|
26
|
+
|
27
|
+
def refactor_tool_use_params(tools: list):
|
28
|
+
return [
|
29
|
+
{
|
30
|
+
"name": tool["function"]["name"],
|
31
|
+
"description": tool["function"]["description"],
|
32
|
+
"input_schema": tool["function"]["parameters"],
|
33
|
+
}
|
34
|
+
for tool in tools
|
35
|
+
]
|
36
|
+
|
37
|
+
|
38
|
+
def refactor_tool_calls(tool_calls: list):
|
39
|
+
return [
|
40
|
+
{
|
41
|
+
"index": index,
|
42
|
+
"id": tool["id"],
|
43
|
+
"type": "function",
|
44
|
+
"function": {
|
45
|
+
"name": tool["name"],
|
46
|
+
"arguments": json.dumps(tool["input"], ensure_ascii=False),
|
47
|
+
},
|
48
|
+
}
|
49
|
+
for index, tool in enumerate(tool_calls)
|
50
|
+
]
|
51
|
+
|
52
|
+
|
53
|
+
def format_messages_alternate(messages: list) -> list:
|
54
|
+
# messages: roles must alternate between "user" and "assistant", and not multiple "user" roles in a row
|
55
|
+
# reformat multiple "user" roles in a row into {"role": "user", "content": [{"type": "text", "text": "Hello, Claude"}, {"type": "text", "text": "How are you?"}]}
|
56
|
+
# same for assistant role
|
57
|
+
# if not multiple "user" or "assistant" roles in a row, keep it as is
|
58
|
+
|
59
|
+
formatted_messages = []
|
60
|
+
current_role = None
|
61
|
+
current_content = []
|
62
|
+
|
63
|
+
for message in messages:
|
64
|
+
role = message["role"]
|
65
|
+
content = message["content"]
|
66
|
+
|
67
|
+
if role != current_role:
|
68
|
+
if current_content:
|
69
|
+
formatted_messages.append({"role": current_role, "content": current_content})
|
70
|
+
current_content = []
|
71
|
+
current_role = role
|
72
|
+
|
73
|
+
if isinstance(content, str):
|
74
|
+
current_content.append({"type": "text", "text": content})
|
75
|
+
elif isinstance(content, list):
|
76
|
+
current_content.extend(content)
|
77
|
+
else:
|
78
|
+
current_content.append(content)
|
79
|
+
|
80
|
+
if current_content:
|
81
|
+
formatted_messages.append({"role": current_role, "content": current_content})
|
82
|
+
|
83
|
+
return formatted_messages
|
84
|
+
|
85
|
+
|
86
|
+
class AnthropicChatClient(BaseChatClient):
|
87
|
+
DEFAULT_MODEL: str = defs.ANTHROPIC_DEFAULT_MODEL
|
88
|
+
BACKEND_NAME: BackendType = BackendType.Anthropic
|
89
|
+
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
model: str = defs.ANTHROPIC_DEFAULT_MODEL,
|
93
|
+
stream: bool = True,
|
94
|
+
temperature: float = 0.7,
|
95
|
+
context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
|
96
|
+
random_endpoint: bool = True,
|
97
|
+
endpoint_id: str = "",
|
98
|
+
**kwargs,
|
99
|
+
):
|
100
|
+
super().__init__(
|
101
|
+
model,
|
102
|
+
stream,
|
103
|
+
temperature,
|
104
|
+
context_length_control,
|
105
|
+
random_endpoint,
|
106
|
+
endpoint_id,
|
107
|
+
**kwargs,
|
108
|
+
)
|
109
|
+
|
110
|
+
def create_completion(
|
111
|
+
self,
|
112
|
+
messages: list = list,
|
113
|
+
model: str | None = None,
|
114
|
+
stream: bool | None = None,
|
115
|
+
temperature: float | None = None,
|
116
|
+
max_tokens: int = 2000,
|
117
|
+
tools: list | NotGiven = NOT_GIVEN,
|
118
|
+
tool_choice: str | NotGiven = NOT_GIVEN,
|
119
|
+
):
|
120
|
+
if model is not None:
|
121
|
+
self.model = model
|
122
|
+
if stream is not None:
|
123
|
+
self.stream = stream
|
124
|
+
if temperature is not None:
|
125
|
+
self.temperature = temperature
|
126
|
+
|
127
|
+
self.model_setting = self.backend_settings.models[self.model]
|
128
|
+
|
129
|
+
if messages[0].get("role") == "system":
|
130
|
+
system_prompt = messages[0]["content"]
|
131
|
+
messages = messages[1:]
|
132
|
+
else:
|
133
|
+
system_prompt = ""
|
134
|
+
|
135
|
+
if self.context_length_control == ContextLengthControlType.Latest:
|
136
|
+
messages = cutoff_messages(
|
137
|
+
messages,
|
138
|
+
max_count=self.model_setting.context_length,
|
139
|
+
backend=self.BACKEND_NAME,
|
140
|
+
model=self.model_setting.id,
|
141
|
+
)
|
142
|
+
|
143
|
+
messages = format_messages_alternate(messages)
|
144
|
+
|
145
|
+
if self.random_endpoint:
|
146
|
+
self.random_endpoint = True
|
147
|
+
self.endpoint_id = random.choice(self.backend_settings.models[self.model].endpoints)
|
148
|
+
self.endpoint = settings.get_endpoint(self.endpoint_id)
|
149
|
+
|
150
|
+
if self.endpoint.is_vertex:
|
151
|
+
self.creds = Credentials(
|
152
|
+
token=self.endpoint.credentials.get("token"),
|
153
|
+
refresh_token=self.endpoint.credentials.get("refresh_token"),
|
154
|
+
token_uri=self.endpoint.credentials.get("token_uri"),
|
155
|
+
scopes=None,
|
156
|
+
client_id=self.endpoint.credentials.get("client_id"),
|
157
|
+
client_secret=self.endpoint.credentials.get("client_secret"),
|
158
|
+
quota_project_id=self.endpoint.credentials.get("quota_project_id"),
|
159
|
+
expiry=_helpers.utcnow() - _helpers.REFRESH_THRESHOLD,
|
160
|
+
rapt_token=self.endpoint.credentials.get("rapt_token"),
|
161
|
+
trust_boundary=self.endpoint.credentials.get("trust_boundary"),
|
162
|
+
universe_domain=self.endpoint.credentials.get("universe_domain"),
|
163
|
+
account=self.endpoint.credentials.get("account", ""),
|
164
|
+
)
|
165
|
+
|
166
|
+
if self.creds.expired and self.creds.refresh_token:
|
167
|
+
self.creds.refresh(Request())
|
168
|
+
|
169
|
+
if self.endpoint.api_base is None:
|
170
|
+
base_url = None
|
171
|
+
else:
|
172
|
+
base_url = f"{self.endpoint.api_base}{self.endpoint.region}-aiplatform/v1"
|
173
|
+
|
174
|
+
self._client = AnthropicVertex(
|
175
|
+
region=self.endpoint.region,
|
176
|
+
base_url=base_url,
|
177
|
+
project_id=self.endpoint.credentials.get("quota_project_id"),
|
178
|
+
access_token=self.creds.token,
|
179
|
+
)
|
180
|
+
else:
|
181
|
+
self._client = Anthropic(
|
182
|
+
api_key=self.endpoint.api_key,
|
183
|
+
base_url=self.endpoint.api_base,
|
184
|
+
)
|
185
|
+
|
186
|
+
response = self._client.messages.create(
|
187
|
+
model=self.model_setting.id,
|
188
|
+
messages=messages,
|
189
|
+
system=system_prompt,
|
190
|
+
stream=self.stream,
|
191
|
+
temperature=self.temperature,
|
192
|
+
max_tokens=max_tokens,
|
193
|
+
tools=refactor_tool_use_params(tools) if tools else tools,
|
194
|
+
tool_choice=tool_choice,
|
195
|
+
)
|
196
|
+
|
197
|
+
if self.stream:
|
198
|
+
|
199
|
+
def generator():
|
200
|
+
result = {"content": ""}
|
201
|
+
for chunk in response:
|
202
|
+
message = {"content": ""}
|
203
|
+
if isinstance(chunk, RawMessageStartEvent):
|
204
|
+
result["usage"] = {"prompt_tokens": chunk.message.usage.input_tokens}
|
205
|
+
continue
|
206
|
+
elif isinstance(chunk, RawContentBlockStartEvent):
|
207
|
+
if chunk.content_block.type == "tool_use":
|
208
|
+
result["tool_calls"] = message["tool_calls"] = [
|
209
|
+
{
|
210
|
+
"index": 0,
|
211
|
+
"id": chunk.content_block.id,
|
212
|
+
"function": {
|
213
|
+
"arguments": "",
|
214
|
+
"name": chunk.content_block.name,
|
215
|
+
},
|
216
|
+
"type": "function",
|
217
|
+
}
|
218
|
+
]
|
219
|
+
elif chunk.content_block.type == "text":
|
220
|
+
message["content"] = chunk.content_block.text
|
221
|
+
yield message
|
222
|
+
elif isinstance(chunk, RawContentBlockDeltaEvent):
|
223
|
+
if chunk.delta.type == "text_delta":
|
224
|
+
message["content"] = chunk.delta.text
|
225
|
+
result["content"] += chunk.delta.text
|
226
|
+
elif chunk.delta.type == "input_json_delta":
|
227
|
+
result["tool_calls"][0]["function"]["arguments"] += chunk.delta.partial_json
|
228
|
+
message["tool_calls"] = [
|
229
|
+
{
|
230
|
+
"index": 0,
|
231
|
+
"id": result["tool_calls"][0]["id"],
|
232
|
+
"function": {
|
233
|
+
"arguments": chunk.delta.partial_json,
|
234
|
+
"name": result["tool_calls"][0]["function"]["name"],
|
235
|
+
},
|
236
|
+
"type": "function",
|
237
|
+
}
|
238
|
+
]
|
239
|
+
yield message
|
240
|
+
elif isinstance(chunk, RawMessageDeltaEvent):
|
241
|
+
result["usage"]["completion_tokens"] = chunk.usage.output_tokens
|
242
|
+
result["usage"]["total_tokens"] = (
|
243
|
+
result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"]
|
244
|
+
)
|
245
|
+
yield {"usage": result["usage"]}
|
246
|
+
|
247
|
+
return generator()
|
248
|
+
else:
|
249
|
+
result = {
|
250
|
+
"content": "",
|
251
|
+
"usage": {
|
252
|
+
"prompt_tokens": response.usage.input_tokens,
|
253
|
+
"completion_tokens": response.usage.output_tokens,
|
254
|
+
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
|
255
|
+
},
|
256
|
+
}
|
257
|
+
tool_calls = []
|
258
|
+
for content_block in response.content:
|
259
|
+
if isinstance(content_block, TextBlock):
|
260
|
+
result["content"] += content_block.text
|
261
|
+
elif isinstance(content_block, ToolUseBlock):
|
262
|
+
tool_calls.append(content_block.model_dump())
|
263
|
+
|
264
|
+
if tool_calls:
|
265
|
+
result["tool_calls"] = refactor_tool_calls(tool_calls)
|
266
|
+
|
267
|
+
return result
|
268
|
+
|
269
|
+
|
270
|
+
class AsyncAnthropicChatClient(BaseAsyncChatClient):
|
271
|
+
DEFAULT_MODEL: str = defs.ANTHROPIC_DEFAULT_MODEL
|
272
|
+
BACKEND_NAME: BackendType = BackendType.Anthropic
|
273
|
+
|
274
|
+
def __init__(
|
275
|
+
self,
|
276
|
+
model: str = defs.ANTHROPIC_DEFAULT_MODEL,
|
277
|
+
stream: bool = True,
|
278
|
+
temperature: float = 0.7,
|
279
|
+
context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
|
280
|
+
random_endpoint: bool = True,
|
281
|
+
endpoint_id: str = "",
|
282
|
+
**kwargs,
|
283
|
+
):
|
284
|
+
super().__init__(
|
285
|
+
model,
|
286
|
+
stream,
|
287
|
+
temperature,
|
288
|
+
context_length_control,
|
289
|
+
random_endpoint,
|
290
|
+
endpoint_id,
|
291
|
+
**kwargs,
|
292
|
+
)
|
293
|
+
|
294
|
+
async def create_completion(
|
295
|
+
self,
|
296
|
+
messages: list = list,
|
297
|
+
model: str | None = None,
|
298
|
+
stream: bool | None = None,
|
299
|
+
temperature: float | None = None,
|
300
|
+
max_tokens: int = 2000,
|
301
|
+
tools: list | NotGiven = NOT_GIVEN,
|
302
|
+
tool_choice: str | NotGiven = NOT_GIVEN,
|
303
|
+
):
|
304
|
+
if model is not None:
|
305
|
+
self.model = model
|
306
|
+
if stream is not None:
|
307
|
+
self.stream = stream
|
308
|
+
if temperature is not None:
|
309
|
+
self.temperature = temperature
|
310
|
+
|
311
|
+
self.model_setting = self.backend_settings.models[self.model]
|
312
|
+
|
313
|
+
if messages[0].get("role") == "system":
|
314
|
+
system_prompt = messages[0]["content"]
|
315
|
+
messages = messages[1:]
|
316
|
+
else:
|
317
|
+
system_prompt = ""
|
318
|
+
|
319
|
+
if self.context_length_control == ContextLengthControlType.Latest:
|
320
|
+
messages = cutoff_messages(
|
321
|
+
messages,
|
322
|
+
max_count=self.model_setting.context_length,
|
323
|
+
backend=self.BACKEND_NAME,
|
324
|
+
model=self.model_setting.id,
|
325
|
+
)
|
326
|
+
|
327
|
+
messages = format_messages_alternate(messages)
|
328
|
+
|
329
|
+
if self.random_endpoint:
|
330
|
+
self.random_endpoint = True
|
331
|
+
self.endpoint_id = random.choice(self.backend_settings.models[self.model].endpoints)
|
332
|
+
self.endpoint = settings.get_endpoint(self.endpoint_id)
|
333
|
+
|
334
|
+
if self.endpoint.is_vertex:
|
335
|
+
self.creds = Credentials(
|
336
|
+
token=self.endpoint.credentials.get("token"),
|
337
|
+
refresh_token=self.endpoint.credentials.get("refresh_token"),
|
338
|
+
token_uri=self.endpoint.credentials.get("token_uri"),
|
339
|
+
scopes=None,
|
340
|
+
client_id=self.endpoint.credentials.get("client_id"),
|
341
|
+
client_secret=self.endpoint.credentials.get("client_secret"),
|
342
|
+
quota_project_id=self.endpoint.credentials.get("quota_project_id"),
|
343
|
+
expiry=_helpers.utcnow() - _helpers.REFRESH_THRESHOLD,
|
344
|
+
rapt_token=self.endpoint.credentials.get("rapt_token"),
|
345
|
+
trust_boundary=self.endpoint.credentials.get("trust_boundary"),
|
346
|
+
universe_domain=self.endpoint.credentials.get("universe_domain"),
|
347
|
+
account=self.endpoint.credentials.get("account", ""),
|
348
|
+
)
|
349
|
+
|
350
|
+
if self.creds.expired and self.creds.refresh_token:
|
351
|
+
self.creds.refresh(Request())
|
352
|
+
|
353
|
+
if self.endpoint.api_base is None:
|
354
|
+
base_url = None
|
355
|
+
else:
|
356
|
+
base_url = f"{self.endpoint.api_base}{self.endpoint.region}-aiplatform/v1"
|
357
|
+
|
358
|
+
self._client = AsyncAnthropicVertex(
|
359
|
+
region=self.endpoint.region,
|
360
|
+
base_url=base_url,
|
361
|
+
project_id=self.endpoint.credentials.get("quota_project_id"),
|
362
|
+
access_token=self.creds.token,
|
363
|
+
)
|
364
|
+
else:
|
365
|
+
self._client = AsyncAnthropic(
|
366
|
+
api_key=self.endpoint.api_key,
|
367
|
+
base_url=self.endpoint.api_base,
|
368
|
+
)
|
369
|
+
response = await self._client.messages.create(
|
370
|
+
model=self.model_setting.id,
|
371
|
+
messages=messages,
|
372
|
+
system=system_prompt,
|
373
|
+
stream=self.stream,
|
374
|
+
temperature=self.temperature,
|
375
|
+
max_tokens=max_tokens,
|
376
|
+
tools=refactor_tool_use_params(tools) if tools else tools,
|
377
|
+
tool_choice=tool_choice,
|
378
|
+
)
|
379
|
+
|
380
|
+
if self.stream:
|
381
|
+
|
382
|
+
async def generator():
|
383
|
+
result = {"content": ""}
|
384
|
+
async for chunk in response:
|
385
|
+
message = {"content": ""}
|
386
|
+
if isinstance(chunk, RawMessageStartEvent):
|
387
|
+
result["usage"] = {"prompt_tokens": chunk.message.usage.input_tokens}
|
388
|
+
continue
|
389
|
+
elif isinstance(chunk, RawContentBlockStartEvent):
|
390
|
+
if chunk.content_block.type == "tool_use":
|
391
|
+
result["tool_calls"] = message["tool_calls"] = [
|
392
|
+
{
|
393
|
+
"index": 0,
|
394
|
+
"id": chunk.content_block.id,
|
395
|
+
"function": {
|
396
|
+
"arguments": "",
|
397
|
+
"name": chunk.content_block.name,
|
398
|
+
},
|
399
|
+
"type": "function",
|
400
|
+
}
|
401
|
+
]
|
402
|
+
elif chunk.content_block.type == "text":
|
403
|
+
message["content"] = chunk.content_block.text
|
404
|
+
yield message
|
405
|
+
elif isinstance(chunk, RawContentBlockDeltaEvent):
|
406
|
+
if chunk.delta.type == "text_delta":
|
407
|
+
message["content"] = chunk.delta.text
|
408
|
+
result["content"] += chunk.delta.text
|
409
|
+
elif chunk.delta.type == "input_json_delta":
|
410
|
+
result["tool_calls"][0]["function"]["arguments"] += chunk.delta.partial_json
|
411
|
+
message["tool_calls"] = [
|
412
|
+
{
|
413
|
+
"index": 0,
|
414
|
+
"id": result["tool_calls"][0]["id"],
|
415
|
+
"function": {
|
416
|
+
"arguments": chunk.delta.partial_json,
|
417
|
+
"name": result["tool_calls"][0]["function"]["name"],
|
418
|
+
},
|
419
|
+
"type": "function",
|
420
|
+
}
|
421
|
+
]
|
422
|
+
yield message
|
423
|
+
elif isinstance(chunk, RawMessageDeltaEvent):
|
424
|
+
result["usage"]["completion_tokens"] = chunk.usage.output_tokens
|
425
|
+
result["usage"]["total_tokens"] = (
|
426
|
+
result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"]
|
427
|
+
)
|
428
|
+
yield {"usage": result["usage"]}
|
429
|
+
|
430
|
+
return generator()
|
431
|
+
else:
|
432
|
+
result = {
|
433
|
+
"content": "",
|
434
|
+
"usage": {
|
435
|
+
"prompt_tokens": response.usage.input_tokens,
|
436
|
+
"completion_tokens": response.usage.output_tokens,
|
437
|
+
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
|
438
|
+
},
|
439
|
+
}
|
440
|
+
tool_calls = []
|
441
|
+
for content_block in response.content:
|
442
|
+
if isinstance(content_block, TextBlock):
|
443
|
+
result["content"] += content_block.text
|
444
|
+
elif isinstance(content_block, ToolUseBlock):
|
445
|
+
tool_calls.append(content_block.model_dump())
|
446
|
+
|
447
|
+
if tool_calls:
|
448
|
+
result["tool_calls"] = refactor_tool_calls(tool_calls)
|
449
|
+
|
450
|
+
return result
|