aient 1.0.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/__init__.py +1 -0
- aient/core/.git +1 -0
- aient/core/__init__.py +1 -0
- aient/core/log_config.py +6 -0
- aient/core/models.py +227 -0
- aient/core/request.py +1361 -0
- aient/core/response.py +531 -0
- aient/core/test/test_base_api.py +17 -0
- aient/core/test/test_image.py +15 -0
- aient/core/test/test_payload.py +92 -0
- aient/core/utils.py +655 -0
- aient/models/__init__.py +9 -0
- aient/models/audio.py +63 -0
- aient/models/base.py +270 -0
- aient/models/chatgpt.py +856 -0
- aient/models/claude.py +640 -0
- aient/models/duckduckgo.py +241 -0
- aient/models/gemini.py +357 -0
- aient/models/groq.py +268 -0
- aient/models/vertex.py +420 -0
- aient/plugins/__init__.py +32 -0
- aient/plugins/arXiv.py +48 -0
- aient/plugins/config.py +178 -0
- aient/plugins/image.py +72 -0
- aient/plugins/registry.py +116 -0
- aient/plugins/run_python.py +156 -0
- aient/plugins/today.py +19 -0
- aient/plugins/websearch.py +393 -0
- aient/utils/__init__.py +0 -0
- aient/utils/prompt.py +143 -0
- aient/utils/scripts.py +235 -0
- aient-1.0.29.dist-info/METADATA +119 -0
- aient-1.0.29.dist-info/RECORD +36 -0
- aient-1.0.29.dist-info/WHEEL +5 -0
- aient-1.0.29.dist-info/licenses/LICENSE +7 -0
- aient-1.0.29.dist-info/top_level.txt +1 -0
aient/models/groq.py
ADDED
@@ -0,0 +1,268 @@
|
|
1
|
+
import os
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
import tiktoken
|
5
|
+
|
6
|
+
from .base import BaseLLM
|
7
|
+
|
8
|
+
class groq(BaseLLM):
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
api_key: str = None,
|
12
|
+
engine: str = os.environ.get("GPT_ENGINE") or "llama3-70b-8192",
|
13
|
+
api_url: str = "https://api.groq.com/openai/v1/chat/completions",
|
14
|
+
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
15
|
+
temperature: float = 0.5,
|
16
|
+
top_p: float = 1,
|
17
|
+
timeout: float = 20,
|
18
|
+
):
|
19
|
+
super().__init__(api_key, engine, api_url, system_prompt, timeout=timeout, temperature=temperature, top_p=top_p)
|
20
|
+
self.api_url = api_url
|
21
|
+
|
22
|
+
def add_to_conversation(
|
23
|
+
self,
|
24
|
+
message: str,
|
25
|
+
role: str,
|
26
|
+
convo_id: str = "default",
|
27
|
+
pass_history: int = 9999,
|
28
|
+
total_tokens: int = 0,
|
29
|
+
) -> None:
|
30
|
+
"""
|
31
|
+
Add a message to the conversation
|
32
|
+
"""
|
33
|
+
if convo_id not in self.conversation or pass_history <= 2:
|
34
|
+
self.reset(convo_id=convo_id)
|
35
|
+
self.conversation[convo_id].append({"role": role, "content": message})
|
36
|
+
|
37
|
+
history_len = len(self.conversation[convo_id])
|
38
|
+
history = pass_history
|
39
|
+
if pass_history < 2:
|
40
|
+
history = 2
|
41
|
+
while history_len > history:
|
42
|
+
self.conversation[convo_id].pop(1)
|
43
|
+
history_len = history_len - 1
|
44
|
+
|
45
|
+
if total_tokens:
|
46
|
+
self.tokens_usage[convo_id] += total_tokens
|
47
|
+
|
48
|
+
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
|
49
|
+
"""
|
50
|
+
Reset the conversation
|
51
|
+
"""
|
52
|
+
self.conversation[convo_id] = list()
|
53
|
+
self.system_prompt = system_prompt or self.system_prompt
|
54
|
+
|
55
|
+
def __truncate_conversation(self, convo_id: str = "default") -> None:
|
56
|
+
"""
|
57
|
+
Truncate the conversation
|
58
|
+
"""
|
59
|
+
while True:
|
60
|
+
if (
|
61
|
+
self.get_token_count(convo_id) > self.truncate_limit
|
62
|
+
and len(self.conversation[convo_id]) > 1
|
63
|
+
):
|
64
|
+
# Don't remove the first message
|
65
|
+
self.conversation[convo_id].pop(1)
|
66
|
+
else:
|
67
|
+
break
|
68
|
+
|
69
|
+
def get_token_count(self, convo_id: str = "default") -> int:
|
70
|
+
"""
|
71
|
+
Get token count
|
72
|
+
"""
|
73
|
+
# tiktoken.model.MODEL_TO_ENCODING["mixtral-8x7b-32768"] = "cl100k_base"
|
74
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
75
|
+
|
76
|
+
num_tokens = 0
|
77
|
+
for message in self.conversation[convo_id]:
|
78
|
+
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
79
|
+
num_tokens += 5
|
80
|
+
for key, value in message.items():
|
81
|
+
if value:
|
82
|
+
num_tokens += len(encoding.encode(value))
|
83
|
+
if key == "name": # if there's a name, the role is omitted
|
84
|
+
num_tokens += 5 # role is always required and always 1 token
|
85
|
+
num_tokens += 5 # every reply is primed with <im_start>assistant
|
86
|
+
return num_tokens
|
87
|
+
|
88
|
+
def ask_stream(
|
89
|
+
self,
|
90
|
+
prompt: str,
|
91
|
+
role: str = "user",
|
92
|
+
convo_id: str = "default",
|
93
|
+
model: str = "",
|
94
|
+
pass_history: int = 9999,
|
95
|
+
model_max_tokens: int = 1024,
|
96
|
+
system_prompt: str = None,
|
97
|
+
**kwargs,
|
98
|
+
):
|
99
|
+
self.system_prompt = system_prompt or self.system_prompt
|
100
|
+
if convo_id not in self.conversation or pass_history <= 2:
|
101
|
+
self.reset(convo_id=convo_id)
|
102
|
+
self.add_to_conversation(prompt, role, convo_id=convo_id, pass_history=pass_history)
|
103
|
+
# self.__truncate_conversation(convo_id=convo_id)
|
104
|
+
# print(self.conversation[convo_id])
|
105
|
+
|
106
|
+
url = self.api_url
|
107
|
+
headers = {
|
108
|
+
"Authorization": f"Bearer {kwargs.get('GROQ_API_KEY', self.api_key)}",
|
109
|
+
"Content-Type": "application/json",
|
110
|
+
}
|
111
|
+
|
112
|
+
self.conversation[convo_id][0] = {"role": "system","content": self.system_prompt}
|
113
|
+
json_post = {
|
114
|
+
"messages": self.conversation[convo_id] if pass_history else [{
|
115
|
+
"role": "user",
|
116
|
+
"content": prompt
|
117
|
+
}],
|
118
|
+
"model": model or self.engine,
|
119
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
120
|
+
"max_tokens": model_max_tokens,
|
121
|
+
"top_p": kwargs.get("top_p", self.top_p),
|
122
|
+
"stop": None,
|
123
|
+
"stream": True,
|
124
|
+
}
|
125
|
+
# print("json_post", json_post)
|
126
|
+
# print(os.environ.get("GPT_ENGINE"), model, self.engine)
|
127
|
+
|
128
|
+
try:
|
129
|
+
response = self.session.post(
|
130
|
+
url,
|
131
|
+
headers=headers,
|
132
|
+
json=json_post,
|
133
|
+
timeout=kwargs.get("timeout", self.timeout),
|
134
|
+
stream=True,
|
135
|
+
)
|
136
|
+
except ConnectionError:
|
137
|
+
print("连接错误,请检查服务器状态或网络连接。")
|
138
|
+
return
|
139
|
+
except requests.exceptions.ReadTimeout:
|
140
|
+
print("请求超时,请检查网络连接或增加超时时间。{e}")
|
141
|
+
return
|
142
|
+
except Exception as e:
|
143
|
+
print(f"发生了未预料的错误: {e}")
|
144
|
+
return
|
145
|
+
|
146
|
+
if response.status_code != 200:
|
147
|
+
print(response.text)
|
148
|
+
raise BaseException(f"{response.status_code} {response.reason} {response.text}")
|
149
|
+
response_role: str = "assistant"
|
150
|
+
full_response: str = ""
|
151
|
+
for line in response.iter_lines():
|
152
|
+
if not line:
|
153
|
+
continue
|
154
|
+
# Remove "data: "
|
155
|
+
# print(line.decode("utf-8"))
|
156
|
+
if line.decode("utf-8")[:6] == "data: ":
|
157
|
+
line = line.decode("utf-8")[6:]
|
158
|
+
else:
|
159
|
+
print(line.decode("utf-8"))
|
160
|
+
full_response = json.loads(line.decode("utf-8"))["choices"][0]["message"]["content"]
|
161
|
+
yield full_response
|
162
|
+
break
|
163
|
+
if line == "[DONE]":
|
164
|
+
break
|
165
|
+
resp: dict = json.loads(line)
|
166
|
+
# print("resp", resp)
|
167
|
+
choices = resp.get("choices")
|
168
|
+
if not choices:
|
169
|
+
continue
|
170
|
+
delta = choices[0].get("delta")
|
171
|
+
if not delta:
|
172
|
+
continue
|
173
|
+
if "role" in delta:
|
174
|
+
response_role = delta["role"]
|
175
|
+
if "content" in delta and delta["content"]:
|
176
|
+
content = delta["content"]
|
177
|
+
full_response += content
|
178
|
+
yield content
|
179
|
+
self.add_to_conversation(full_response, response_role, convo_id=convo_id, pass_history=pass_history)
|
180
|
+
|
181
|
+
async def ask_stream_async(
|
182
|
+
self,
|
183
|
+
prompt: str,
|
184
|
+
role: str = "user",
|
185
|
+
convo_id: str = "default",
|
186
|
+
model: str = "",
|
187
|
+
pass_history: int = 9999,
|
188
|
+
model_max_tokens: int = 1024,
|
189
|
+
system_prompt: str = None,
|
190
|
+
**kwargs,
|
191
|
+
):
|
192
|
+
self.system_prompt = system_prompt or self.system_prompt
|
193
|
+
if convo_id not in self.conversation or pass_history <= 2:
|
194
|
+
self.reset(convo_id=convo_id)
|
195
|
+
self.add_to_conversation(prompt, role, convo_id=convo_id, pass_history=pass_history)
|
196
|
+
# self.__truncate_conversation(convo_id=convo_id)
|
197
|
+
# print(self.conversation[convo_id])
|
198
|
+
|
199
|
+
url = self.api_url
|
200
|
+
headers = {
|
201
|
+
"Authorization": f"Bearer {os.environ.get('GROQ_API_KEY', self.api_key) or kwargs.get('api_key')}",
|
202
|
+
"Content-Type": "application/json",
|
203
|
+
}
|
204
|
+
|
205
|
+
self.conversation[convo_id][0] = {"role": "system","content": self.system_prompt}
|
206
|
+
json_post = {
|
207
|
+
"messages": self.conversation[convo_id] if pass_history else [{
|
208
|
+
"role": "user",
|
209
|
+
"content": prompt
|
210
|
+
}],
|
211
|
+
"model": model or self.engine,
|
212
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
213
|
+
"max_tokens": model_max_tokens,
|
214
|
+
"top_p": kwargs.get("top_p", self.top_p),
|
215
|
+
"stop": None,
|
216
|
+
"stream": True,
|
217
|
+
}
|
218
|
+
# print("json_post", json_post)
|
219
|
+
# print(os.environ.get("GPT_ENGINE"), model, self.engine)
|
220
|
+
|
221
|
+
response_role: str = "assistant"
|
222
|
+
full_response: str = ""
|
223
|
+
try:
|
224
|
+
async with self.aclient.stream(
|
225
|
+
"post",
|
226
|
+
url,
|
227
|
+
headers=headers,
|
228
|
+
json=json_post,
|
229
|
+
timeout=kwargs.get("timeout", self.timeout),
|
230
|
+
) as response:
|
231
|
+
if response.status_code != 200:
|
232
|
+
await response.aread()
|
233
|
+
print(response.text)
|
234
|
+
raise BaseException(f"{response.status_code} {response.reason} {response.text}")
|
235
|
+
async for line in response.aiter_lines():
|
236
|
+
if not line:
|
237
|
+
continue
|
238
|
+
# Remove "data: "
|
239
|
+
# print(line)
|
240
|
+
if line[:6] == "data: ":
|
241
|
+
line = line.lstrip("data: ")
|
242
|
+
else:
|
243
|
+
full_response = json.loads(line)["choices"][0]["message"]["content"]
|
244
|
+
yield full_response
|
245
|
+
break
|
246
|
+
if line == "[DONE]":
|
247
|
+
break
|
248
|
+
resp: dict = json.loads(line)
|
249
|
+
# print("resp", resp)
|
250
|
+
choices = resp.get("choices")
|
251
|
+
if not choices:
|
252
|
+
continue
|
253
|
+
delta = choices[0].get("delta")
|
254
|
+
if not delta:
|
255
|
+
continue
|
256
|
+
if "role" in delta:
|
257
|
+
response_role = delta["role"]
|
258
|
+
if "content" in delta and delta["content"]:
|
259
|
+
content = delta["content"]
|
260
|
+
full_response += content
|
261
|
+
yield content
|
262
|
+
except Exception as e:
|
263
|
+
print(f"发生了未预料的错误: {e}")
|
264
|
+
import traceback
|
265
|
+
traceback.print_exc()
|
266
|
+
return
|
267
|
+
|
268
|
+
self.add_to_conversation(full_response, response_role, convo_id=convo_id, pass_history=pass_history)
|
aient/models/vertex.py
ADDED
@@ -0,0 +1,420 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
import json
|
4
|
+
import requests
|
5
|
+
|
6
|
+
|
7
|
+
from .base import BaseLLM
|
8
|
+
from ..core.utils import BaseAPI
|
9
|
+
|
10
|
+
import copy
|
11
|
+
from ..plugins import PLUGINS, get_tools_result_async, function_call_list
|
12
|
+
from ..utils.scripts import safe_get
|
13
|
+
|
14
|
+
import time
|
15
|
+
import httpx
|
16
|
+
import base64
|
17
|
+
from cryptography.hazmat.primitives import hashes
|
18
|
+
from cryptography.hazmat.primitives.asymmetric import padding
|
19
|
+
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
20
|
+
|
21
|
+
def create_jwt(client_email, private_key):
|
22
|
+
# JWT Header
|
23
|
+
header = json.dumps({
|
24
|
+
"alg": "RS256",
|
25
|
+
"typ": "JWT"
|
26
|
+
}).encode()
|
27
|
+
|
28
|
+
# JWT Payload
|
29
|
+
now = int(time.time())
|
30
|
+
payload = json.dumps({
|
31
|
+
"iss": client_email,
|
32
|
+
"scope": "https://www.googleapis.com/auth/cloud-platform",
|
33
|
+
"aud": "https://oauth2.googleapis.com/token",
|
34
|
+
"exp": now + 3600,
|
35
|
+
"iat": now
|
36
|
+
}).encode()
|
37
|
+
|
38
|
+
# Encode header and payload
|
39
|
+
segments = [
|
40
|
+
base64.urlsafe_b64encode(header).rstrip(b'='),
|
41
|
+
base64.urlsafe_b64encode(payload).rstrip(b'=')
|
42
|
+
]
|
43
|
+
|
44
|
+
# Create signature
|
45
|
+
signing_input = b'.'.join(segments)
|
46
|
+
private_key = load_pem_private_key(private_key.encode(), password=None)
|
47
|
+
signature = private_key.sign(
|
48
|
+
signing_input,
|
49
|
+
padding.PKCS1v15(),
|
50
|
+
hashes.SHA256()
|
51
|
+
)
|
52
|
+
|
53
|
+
segments.append(base64.urlsafe_b64encode(signature).rstrip(b'='))
|
54
|
+
return b'.'.join(segments).decode()
|
55
|
+
|
56
|
+
def get_access_token(client_email, private_key):
|
57
|
+
jwt = create_jwt(client_email, private_key)
|
58
|
+
|
59
|
+
with httpx.Client() as client:
|
60
|
+
response = client.post(
|
61
|
+
"https://oauth2.googleapis.com/token",
|
62
|
+
data={
|
63
|
+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
64
|
+
"assertion": jwt
|
65
|
+
},
|
66
|
+
headers={'Content-Type': "application/x-www-form-urlencoded"}
|
67
|
+
)
|
68
|
+
response.raise_for_status()
|
69
|
+
return response.json()["access_token"]
|
70
|
+
|
71
|
+
# https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#python
|
72
|
+
class vertex(BaseLLM):
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
api_key: str = None,
|
76
|
+
engine: str = os.environ.get("GPT_ENGINE") or "gemini-1.5-pro-latest",
|
77
|
+
api_url: str = "https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_ID}:{stream}",
|
78
|
+
system_prompt: str = "You are Gemini, a large language model trained by Google. Respond conversationally",
|
79
|
+
project_id: str = os.environ.get("VERTEX_PROJECT_ID", None),
|
80
|
+
temperature: float = 0.5,
|
81
|
+
top_p: float = 0.7,
|
82
|
+
timeout: float = 20,
|
83
|
+
use_plugins: bool = True,
|
84
|
+
print_log: bool = False,
|
85
|
+
):
|
86
|
+
url = api_url.format(PROJECT_ID=os.environ.get("VERTEX_PROJECT_ID", project_id), MODEL_ID=engine, stream="streamGenerateContent")
|
87
|
+
super().__init__(api_key, engine, url, system_prompt=system_prompt, timeout=timeout, temperature=temperature, top_p=top_p, use_plugins=use_plugins, print_log=print_log)
|
88
|
+
self.conversation: dict[str, list[dict]] = {
|
89
|
+
"default": [],
|
90
|
+
}
|
91
|
+
|
92
|
+
def add_to_conversation(
|
93
|
+
self,
|
94
|
+
message: str,
|
95
|
+
role: str,
|
96
|
+
convo_id: str = "default",
|
97
|
+
pass_history: int = 9999,
|
98
|
+
total_tokens: int = 0,
|
99
|
+
function_arguments: str = "",
|
100
|
+
) -> None:
|
101
|
+
"""
|
102
|
+
Add a message to the conversation
|
103
|
+
"""
|
104
|
+
|
105
|
+
if convo_id not in self.conversation or pass_history <= 2:
|
106
|
+
self.reset(convo_id=convo_id)
|
107
|
+
# print("message", message)
|
108
|
+
|
109
|
+
if function_arguments:
|
110
|
+
self.conversation[convo_id].append(
|
111
|
+
{
|
112
|
+
"role": "model",
|
113
|
+
"parts": [function_arguments]
|
114
|
+
}
|
115
|
+
)
|
116
|
+
function_call_name = function_arguments["functionCall"]["name"]
|
117
|
+
self.conversation[convo_id].append(
|
118
|
+
{
|
119
|
+
"role": "function",
|
120
|
+
"parts": [{
|
121
|
+
"functionResponse": {
|
122
|
+
"name": function_call_name,
|
123
|
+
"response": {
|
124
|
+
"name": function_call_name,
|
125
|
+
"content": {
|
126
|
+
"result": message,
|
127
|
+
}
|
128
|
+
}
|
129
|
+
}
|
130
|
+
}]
|
131
|
+
}
|
132
|
+
)
|
133
|
+
|
134
|
+
else:
|
135
|
+
if isinstance(message, str):
|
136
|
+
message = [{"text": message}]
|
137
|
+
self.conversation[convo_id].append({"role": role, "parts": message})
|
138
|
+
|
139
|
+
history_len = len(self.conversation[convo_id])
|
140
|
+
history = pass_history
|
141
|
+
if pass_history < 2:
|
142
|
+
history = 2
|
143
|
+
while history_len > history:
|
144
|
+
mess_body = self.conversation[convo_id].pop(1)
|
145
|
+
history_len = history_len - 1
|
146
|
+
if mess_body.get("role") == "user":
|
147
|
+
mess_body = self.conversation[convo_id].pop(1)
|
148
|
+
history_len = history_len - 1
|
149
|
+
if safe_get(mess_body, "parts", 0, "functionCall"):
|
150
|
+
self.conversation[convo_id].pop(1)
|
151
|
+
history_len = history_len - 1
|
152
|
+
|
153
|
+
if total_tokens:
|
154
|
+
self.tokens_usage[convo_id] += total_tokens
|
155
|
+
|
156
|
+
def reset(self, convo_id: str = "default", system_prompt: str = "You are Gemini, a large language model trained by Google. Respond conversationally") -> None:
|
157
|
+
"""
|
158
|
+
Reset the conversation
|
159
|
+
"""
|
160
|
+
self.system_prompt = system_prompt or self.system_prompt
|
161
|
+
self.conversation[convo_id] = list()
|
162
|
+
|
163
|
+
def ask_stream(
|
164
|
+
self,
|
165
|
+
prompt: str,
|
166
|
+
role: str = "user",
|
167
|
+
convo_id: str = "default",
|
168
|
+
model: str = "",
|
169
|
+
pass_history: int = 9999,
|
170
|
+
model_max_tokens: int = 4096,
|
171
|
+
systemprompt: str = None,
|
172
|
+
**kwargs,
|
173
|
+
):
|
174
|
+
self.system_prompt = systemprompt or self.system_prompt
|
175
|
+
if convo_id not in self.conversation or pass_history <= 2:
|
176
|
+
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
|
177
|
+
self.add_to_conversation(prompt, role, convo_id=convo_id, pass_history=pass_history)
|
178
|
+
# print(self.conversation[convo_id])
|
179
|
+
|
180
|
+
headers = {
|
181
|
+
"Content-Type": "application/json",
|
182
|
+
}
|
183
|
+
|
184
|
+
json_post = {
|
185
|
+
"contents": self.conversation[convo_id] if pass_history else [{
|
186
|
+
"role": "user",
|
187
|
+
"content": prompt
|
188
|
+
}],
|
189
|
+
"systemInstruction": {"parts": [{"text": self.system_prompt}]},
|
190
|
+
"safetySettings": [
|
191
|
+
{
|
192
|
+
"category": "HARM_CATEGORY_HARASSMENT",
|
193
|
+
"threshold": "BLOCK_NONE"
|
194
|
+
},
|
195
|
+
{
|
196
|
+
"category": "HARM_CATEGORY_HATE_SPEECH",
|
197
|
+
"threshold": "BLOCK_NONE"
|
198
|
+
},
|
199
|
+
{
|
200
|
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
201
|
+
"threshold": "BLOCK_NONE"
|
202
|
+
},
|
203
|
+
{
|
204
|
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
205
|
+
"threshold": "BLOCK_NONE"
|
206
|
+
}
|
207
|
+
],
|
208
|
+
}
|
209
|
+
if self.print_log:
|
210
|
+
replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
|
211
|
+
print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
|
212
|
+
|
213
|
+
url = self.api_url.format(model=model or self.engine, stream="streamGenerateContent", api_key=self.api_key)
|
214
|
+
|
215
|
+
try:
|
216
|
+
response = self.session.post(
|
217
|
+
url,
|
218
|
+
headers=headers,
|
219
|
+
json=json_post,
|
220
|
+
timeout=kwargs.get("timeout", self.timeout),
|
221
|
+
stream=True,
|
222
|
+
)
|
223
|
+
except ConnectionError:
|
224
|
+
print("连接错误,请检查服务器状态或网络连接。")
|
225
|
+
return
|
226
|
+
except requests.exceptions.ReadTimeout:
|
227
|
+
print("请求超时,请检查网络连接或增加超时时间。{e}")
|
228
|
+
return
|
229
|
+
except Exception as e:
|
230
|
+
print(f"发生了未预料的错误: {e}")
|
231
|
+
return
|
232
|
+
|
233
|
+
if response.status_code != 200:
|
234
|
+
print(response.text)
|
235
|
+
raise BaseException(f"{response.status_code} {response.reason} {response.text}")
|
236
|
+
response_role: str = "model"
|
237
|
+
full_response: str = ""
|
238
|
+
try:
|
239
|
+
for line in response.iter_lines():
|
240
|
+
if not line:
|
241
|
+
continue
|
242
|
+
line = line.decode("utf-8")
|
243
|
+
if line and '\"text\": \"' in line:
|
244
|
+
content = line.split('\"text\": \"')[1][:-1]
|
245
|
+
content = "\n".join(content.split("\\n"))
|
246
|
+
full_response += content
|
247
|
+
yield content
|
248
|
+
except requests.exceptions.ChunkedEncodingError as e:
|
249
|
+
print("Chunked Encoding Error occurred:", e)
|
250
|
+
except Exception as e:
|
251
|
+
print("An error occurred:", e)
|
252
|
+
|
253
|
+
self.add_to_conversation([{"text": full_response}], response_role, convo_id=convo_id, pass_history=pass_history)
|
254
|
+
|
255
|
+
async def ask_stream_async(
|
256
|
+
self,
|
257
|
+
prompt: str,
|
258
|
+
role: str = "user",
|
259
|
+
convo_id: str = "default",
|
260
|
+
model: str = "",
|
261
|
+
pass_history: int = 9999,
|
262
|
+
systemprompt: str = None,
|
263
|
+
language: str = "English",
|
264
|
+
function_arguments: str = "",
|
265
|
+
total_tokens: int = 0,
|
266
|
+
**kwargs,
|
267
|
+
):
|
268
|
+
self.system_prompt = systemprompt or self.system_prompt
|
269
|
+
if convo_id not in self.conversation or pass_history <= 2:
|
270
|
+
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
|
271
|
+
self.add_to_conversation(prompt, role, convo_id=convo_id, total_tokens=total_tokens, function_arguments=function_arguments, pass_history=pass_history)
|
272
|
+
# print(self.conversation[convo_id])
|
273
|
+
|
274
|
+
client_email = os.environ.get("VERTEX_CLIENT_EMAIL")
|
275
|
+
private_key = os.environ.get("VERTEX_PRIVATE_KEY")
|
276
|
+
access_token = get_access_token(client_email, private_key)
|
277
|
+
headers = {
|
278
|
+
'Authorization': f"Bearer {access_token}",
|
279
|
+
"Content-Type": "application/json",
|
280
|
+
}
|
281
|
+
|
282
|
+
json_post = {
|
283
|
+
"contents": self.conversation[convo_id] if pass_history else [{
|
284
|
+
"role": "user",
|
285
|
+
"content": prompt
|
286
|
+
}],
|
287
|
+
"system_instruction": {"parts": [{"text": self.system_prompt}]},
|
288
|
+
# "safety_settings": [
|
289
|
+
# {
|
290
|
+
# "category": "HARM_CATEGORY_HARASSMENT",
|
291
|
+
# "threshold": "BLOCK_NONE"
|
292
|
+
# },
|
293
|
+
# {
|
294
|
+
# "category": "HARM_CATEGORY_HATE_SPEECH",
|
295
|
+
# "threshold": "BLOCK_NONE"
|
296
|
+
# },
|
297
|
+
# {
|
298
|
+
# "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
299
|
+
# "threshold": "BLOCK_NONE"
|
300
|
+
# },
|
301
|
+
# {
|
302
|
+
# "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
303
|
+
# "threshold": "BLOCK_NONE"
|
304
|
+
# }
|
305
|
+
# ],
|
306
|
+
"generationConfig": {
|
307
|
+
"temperature": self.temperature,
|
308
|
+
"max_output_tokens": 8192,
|
309
|
+
"top_k": 40,
|
310
|
+
"top_p": 0.95
|
311
|
+
},
|
312
|
+
}
|
313
|
+
|
314
|
+
plugins = kwargs.get("plugins", PLUGINS)
|
315
|
+
if all(value == False for value in plugins.values()) == False and self.use_plugins:
|
316
|
+
tools = {
|
317
|
+
"tools": [
|
318
|
+
{
|
319
|
+
"function_declarations": [
|
320
|
+
|
321
|
+
]
|
322
|
+
}
|
323
|
+
],
|
324
|
+
"tool_config": {
|
325
|
+
"function_calling_config": {
|
326
|
+
"mode": "AUTO",
|
327
|
+
},
|
328
|
+
},
|
329
|
+
}
|
330
|
+
json_post.update(copy.deepcopy(tools))
|
331
|
+
for item in plugins.keys():
|
332
|
+
try:
|
333
|
+
if plugins[item]:
|
334
|
+
json_post["tools"][0]["function_declarations"].append(function_call_list[item])
|
335
|
+
except:
|
336
|
+
pass
|
337
|
+
|
338
|
+
if self.print_log:
|
339
|
+
replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
|
340
|
+
print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
|
341
|
+
|
342
|
+
url = "https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_ID}:{stream}".format(PROJECT_ID=os.environ.get("VERTEX_PROJECT_ID"), MODEL_ID=model, stream="streamGenerateContent")
|
343
|
+
self.api_url = BaseAPI(url)
|
344
|
+
url = self.api_url.source_api_url
|
345
|
+
|
346
|
+
response_role: str = "model"
|
347
|
+
full_response: str = ""
|
348
|
+
function_full_response: str = "{"
|
349
|
+
need_function_call = False
|
350
|
+
revicing_function_call = False
|
351
|
+
total_tokens = 0
|
352
|
+
try:
|
353
|
+
async with self.aclient.stream(
|
354
|
+
"post",
|
355
|
+
url,
|
356
|
+
headers=headers,
|
357
|
+
json=json_post,
|
358
|
+
timeout=kwargs.get("timeout", self.timeout),
|
359
|
+
) as response:
|
360
|
+
if response.status_code != 200:
|
361
|
+
error_content = await response.aread()
|
362
|
+
error_message = error_content.decode('utf-8')
|
363
|
+
raise BaseException(f"{response.status_code}: {error_message}")
|
364
|
+
try:
|
365
|
+
async for line in response.aiter_lines():
|
366
|
+
if not line:
|
367
|
+
continue
|
368
|
+
# print(line)
|
369
|
+
if line and '\"text\": \"' in line:
|
370
|
+
content = line.split('\"text\": \"')[1][:-1]
|
371
|
+
content = "\n".join(content.split("\\n"))
|
372
|
+
full_response += content
|
373
|
+
yield content
|
374
|
+
|
375
|
+
if line and '\"totalTokenCount\": ' in line:
|
376
|
+
content = int(line.split('\"totalTokenCount\": ')[1])
|
377
|
+
total_tokens = content
|
378
|
+
|
379
|
+
if line and ('\"functionCall\": {' in line or revicing_function_call):
|
380
|
+
revicing_function_call = True
|
381
|
+
need_function_call = True
|
382
|
+
if ']' in line:
|
383
|
+
revicing_function_call = False
|
384
|
+
continue
|
385
|
+
|
386
|
+
function_full_response += line
|
387
|
+
|
388
|
+
except requests.exceptions.ChunkedEncodingError as e:
|
389
|
+
print("Chunked Encoding Error occurred:", e)
|
390
|
+
except Exception as e:
|
391
|
+
print("An error occurred:", e)
|
392
|
+
|
393
|
+
except Exception as e:
|
394
|
+
print(f"发生了未预料的错误: {e}")
|
395
|
+
return
|
396
|
+
|
397
|
+
if response.status_code != 200:
|
398
|
+
await response.aread()
|
399
|
+
print(response.text)
|
400
|
+
raise BaseException(f"{response.status_code} {response.reason} {response.text}")
|
401
|
+
if self.print_log:
|
402
|
+
print("\n\rtotal_tokens", total_tokens)
|
403
|
+
if need_function_call:
|
404
|
+
# print(function_full_response)
|
405
|
+
function_call = json.loads(function_full_response)
|
406
|
+
print(json.dumps(function_call, indent=4, ensure_ascii=False))
|
407
|
+
function_call_name = function_call["functionCall"]["name"]
|
408
|
+
function_full_response = json.dumps(function_call["functionCall"]["args"])
|
409
|
+
function_call_max_tokens = 32000
|
410
|
+
print("\033[32m function_call", function_call_name, "max token:", function_call_max_tokens, "\033[0m")
|
411
|
+
async for chunk in get_tools_result_async(function_call_name, function_full_response, function_call_max_tokens, model or self.engine, vertex, kwargs.get('api_key', self.api_key), self.api_url, use_plugins=False, model=model or self.engine, add_message=self.add_to_conversation, convo_id=convo_id, language=language):
|
412
|
+
if "function_response:" in chunk:
|
413
|
+
function_response = chunk.replace("function_response:", "")
|
414
|
+
else:
|
415
|
+
yield chunk
|
416
|
+
response_role = "model"
|
417
|
+
async for chunk in self.ask_stream_async(function_response, response_role, convo_id=convo_id, function_name=function_call_name, total_tokens=total_tokens, model=model or self.engine, function_arguments=function_call, api_key=kwargs.get('api_key', self.api_key), plugins=kwargs.get("plugins", PLUGINS)):
|
418
|
+
yield chunk
|
419
|
+
else:
|
420
|
+
self.add_to_conversation([{"text": full_response}], response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)
|