aient 1.0.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/__init__.py +1 -0
- aient/core/.git +1 -0
- aient/core/__init__.py +1 -0
- aient/core/log_config.py +6 -0
- aient/core/models.py +227 -0
- aient/core/request.py +1361 -0
- aient/core/response.py +531 -0
- aient/core/test/test_base_api.py +17 -0
- aient/core/test/test_image.py +15 -0
- aient/core/test/test_payload.py +92 -0
- aient/core/utils.py +655 -0
- aient/models/__init__.py +9 -0
- aient/models/audio.py +63 -0
- aient/models/base.py +270 -0
- aient/models/chatgpt.py +856 -0
- aient/models/claude.py +640 -0
- aient/models/duckduckgo.py +241 -0
- aient/models/gemini.py +357 -0
- aient/models/groq.py +268 -0
- aient/models/vertex.py +420 -0
- aient/plugins/__init__.py +32 -0
- aient/plugins/arXiv.py +48 -0
- aient/plugins/config.py +178 -0
- aient/plugins/image.py +72 -0
- aient/plugins/registry.py +116 -0
- aient/plugins/run_python.py +156 -0
- aient/plugins/today.py +19 -0
- aient/plugins/websearch.py +393 -0
- aient/utils/__init__.py +0 -0
- aient/utils/prompt.py +143 -0
- aient/utils/scripts.py +235 -0
- aient-1.0.29.dist-info/METADATA +119 -0
- aient-1.0.29.dist-info/RECORD +36 -0
- aient-1.0.29.dist-info/WHEEL +5 -0
- aient-1.0.29.dist-info/licenses/LICENSE +7 -0
- aient-1.0.29.dist-info/top_level.txt +1 -0
aient/models/claude.py
ADDED
@@ -0,0 +1,640 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
import json
|
4
|
+
import copy
|
5
|
+
import tiktoken
|
6
|
+
import requests
|
7
|
+
|
8
|
+
from .base import BaseLLM
|
9
|
+
from ..plugins import PLUGINS, get_tools_result_async, claude_tools_list
|
10
|
+
from ..utils.scripts import check_json, safe_get, async_generator_to_sync
|
11
|
+
|
12
|
+
class claudeConversation(dict):
|
13
|
+
def Conversation(self, index):
|
14
|
+
conversation_list = super().__getitem__(index)
|
15
|
+
return "\n\n" + "\n\n".join([f"{item['role']}:{item['content']}" for item in conversation_list]) + "\n\nAssistant:"
|
16
|
+
|
17
|
+
class claude(BaseLLM):
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
api_key: str,
|
21
|
+
engine: str = os.environ.get("GPT_ENGINE") or "claude-2.1",
|
22
|
+
api_url: str = "https://api.anthropic.com/v1/complete",
|
23
|
+
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
24
|
+
temperature: float = 0.5,
|
25
|
+
top_p: float = 0.7,
|
26
|
+
timeout: float = 20,
|
27
|
+
use_plugins: bool = True,
|
28
|
+
print_log: bool = False,
|
29
|
+
):
|
30
|
+
super().__init__(api_key, engine, api_url, system_prompt, timeout=timeout, temperature=temperature, top_p=top_p, use_plugins=use_plugins, print_log=print_log)
|
31
|
+
# self.api_url = api_url
|
32
|
+
self.conversation = claudeConversation()
|
33
|
+
|
34
|
+
def add_to_conversation(
|
35
|
+
self,
|
36
|
+
message: str,
|
37
|
+
role: str,
|
38
|
+
convo_id: str = "default",
|
39
|
+
pass_history: int = 9999,
|
40
|
+
total_tokens: int = 0,
|
41
|
+
) -> None:
|
42
|
+
"""
|
43
|
+
Add a message to the conversation
|
44
|
+
"""
|
45
|
+
|
46
|
+
if convo_id not in self.conversation or pass_history <= 2:
|
47
|
+
self.reset(convo_id=convo_id)
|
48
|
+
self.conversation[convo_id].append({"role": role, "content": message})
|
49
|
+
|
50
|
+
history_len = len(self.conversation[convo_id])
|
51
|
+
history = pass_history
|
52
|
+
if pass_history < 2:
|
53
|
+
history = 2
|
54
|
+
while history_len > history:
|
55
|
+
self.conversation[convo_id].pop(1)
|
56
|
+
history_len = history_len - 1
|
57
|
+
|
58
|
+
if total_tokens:
|
59
|
+
self.tokens_usage[convo_id] += total_tokens
|
60
|
+
|
61
|
+
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
|
62
|
+
"""
|
63
|
+
Reset the conversation
|
64
|
+
"""
|
65
|
+
self.conversation[convo_id] = claudeConversation()
|
66
|
+
self.system_prompt = system_prompt or self.system_prompt
|
67
|
+
|
68
|
+
def __truncate_conversation(self, convo_id: str = "default") -> None:
|
69
|
+
"""
|
70
|
+
Truncate the conversation
|
71
|
+
"""
|
72
|
+
while True:
|
73
|
+
if (
|
74
|
+
self.get_token_count(convo_id) > self.truncate_limit
|
75
|
+
and len(self.conversation[convo_id]) > 1
|
76
|
+
):
|
77
|
+
# Don't remove the first message
|
78
|
+
self.conversation[convo_id].pop(1)
|
79
|
+
else:
|
80
|
+
break
|
81
|
+
|
82
|
+
def get_token_count(self, convo_id: str = "default") -> int:
|
83
|
+
"""
|
84
|
+
Get token count
|
85
|
+
"""
|
86
|
+
tiktoken.model.MODEL_TO_ENCODING["claude-2.1"] = "cl100k_base"
|
87
|
+
encoding = tiktoken.encoding_for_model(self.engine)
|
88
|
+
|
89
|
+
num_tokens = 0
|
90
|
+
for message in self.conversation[convo_id]:
|
91
|
+
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
92
|
+
num_tokens += 5
|
93
|
+
for key, value in message.items():
|
94
|
+
if value:
|
95
|
+
num_tokens += len(encoding.encode(value))
|
96
|
+
if key == "name": # if there's a name, the role is omitted
|
97
|
+
num_tokens += 5 # role is always required and always 1 token
|
98
|
+
num_tokens += 5 # every reply is primed with <im_start>assistant
|
99
|
+
return num_tokens
|
100
|
+
|
101
|
+
def ask_stream(
|
102
|
+
self,
|
103
|
+
prompt: str,
|
104
|
+
role: str = "Human",
|
105
|
+
convo_id: str = "default",
|
106
|
+
model: str = "",
|
107
|
+
pass_history: int = 9999,
|
108
|
+
model_max_tokens: int = 4096,
|
109
|
+
**kwargs,
|
110
|
+
):
|
111
|
+
if convo_id not in self.conversation or pass_history <= 2:
|
112
|
+
self.reset(convo_id=convo_id)
|
113
|
+
self.add_to_conversation(prompt, role, convo_id=convo_id, pass_history=pass_history)
|
114
|
+
# self.__truncate_conversation(convo_id=convo_id)
|
115
|
+
# print(self.conversation[convo_id])
|
116
|
+
|
117
|
+
url = self.api_url
|
118
|
+
headers = {
|
119
|
+
"accept": "application/json",
|
120
|
+
"anthropic-version": "2023-06-01",
|
121
|
+
"content-type": "application/json",
|
122
|
+
"x-api-key": f"{kwargs.get('api_key', self.api_key)}",
|
123
|
+
}
|
124
|
+
|
125
|
+
json_post = {
|
126
|
+
"model": model or self.engine,
|
127
|
+
"prompt": self.conversation.Conversation(convo_id) if pass_history else f"\n\nHuman:{prompt}\n\nAssistant:",
|
128
|
+
"stream": True,
|
129
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
130
|
+
"top_p": kwargs.get("top_p", self.top_p),
|
131
|
+
"max_tokens_to_sample": model_max_tokens,
|
132
|
+
}
|
133
|
+
|
134
|
+
try:
|
135
|
+
response = self.session.post(
|
136
|
+
url,
|
137
|
+
headers=headers,
|
138
|
+
json=json_post,
|
139
|
+
timeout=kwargs.get("timeout", self.timeout),
|
140
|
+
stream=True,
|
141
|
+
)
|
142
|
+
except ConnectionError:
|
143
|
+
print("连接错误,请检查服务器状态或网络连接。")
|
144
|
+
return
|
145
|
+
except requests.exceptions.ReadTimeout:
|
146
|
+
print("请求超时,请检查网络连接或增加超时时间。{e}")
|
147
|
+
return
|
148
|
+
except Exception as e:
|
149
|
+
print(f"发生了未预料的错误: {e}")
|
150
|
+
return
|
151
|
+
|
152
|
+
if response.status_code != 200:
|
153
|
+
print(response.text)
|
154
|
+
raise BaseException(f"{response.status_code} {response.reason} {response.text}")
|
155
|
+
response_role: str = "Assistant"
|
156
|
+
full_response: str = ""
|
157
|
+
for line in response.iter_lines():
|
158
|
+
if not line or line.decode("utf-8") == "event: completion" or line.decode("utf-8") == "event: ping" or line.decode("utf-8") == "data: {}":
|
159
|
+
continue
|
160
|
+
line = line.decode("utf-8")[6:]
|
161
|
+
# print(line)
|
162
|
+
resp: dict = json.loads(line)
|
163
|
+
content = resp.get("completion")
|
164
|
+
if content:
|
165
|
+
full_response += content
|
166
|
+
yield content
|
167
|
+
self.add_to_conversation(full_response, response_role, convo_id=convo_id, pass_history=pass_history)
|
168
|
+
|
169
|
+
class claude3(BaseLLM):
|
170
|
+
def __init__(
|
171
|
+
self,
|
172
|
+
api_key: str = None,
|
173
|
+
engine: str = os.environ.get("GPT_ENGINE") or "claude-3-5-sonnet-20241022",
|
174
|
+
api_url: str = (os.environ.get("CLAUDE_API_URL") or "https://api.anthropic.com/v1/messages"),
|
175
|
+
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
176
|
+
temperature: float = 0.5,
|
177
|
+
timeout: float = 20,
|
178
|
+
top_p: float = 0.7,
|
179
|
+
use_plugins: bool = True,
|
180
|
+
print_log: bool = False,
|
181
|
+
):
|
182
|
+
super().__init__(api_key, engine, api_url, system_prompt, timeout=timeout, temperature=temperature, top_p=top_p, use_plugins=use_plugins, print_log=print_log)
|
183
|
+
self.conversation: dict[str, list[dict]] = {
|
184
|
+
"default": [],
|
185
|
+
}
|
186
|
+
|
187
|
+
def add_to_conversation(
|
188
|
+
self,
|
189
|
+
message: str,
|
190
|
+
role: str,
|
191
|
+
convo_id: str = "default",
|
192
|
+
pass_history: int = 9999,
|
193
|
+
total_tokens: int = 0,
|
194
|
+
tools_id= "",
|
195
|
+
function_name: str = "",
|
196
|
+
function_full_response: str = "",
|
197
|
+
) -> None:
|
198
|
+
"""
|
199
|
+
Add a message to the conversation
|
200
|
+
"""
|
201
|
+
|
202
|
+
if convo_id not in self.conversation or pass_history <= 2:
|
203
|
+
self.reset(convo_id=convo_id)
|
204
|
+
if role == "user" or (role == "assistant" and function_full_response == ""):
|
205
|
+
if type(message) == list:
|
206
|
+
self.conversation[convo_id].append({
|
207
|
+
"role": role,
|
208
|
+
"content": message
|
209
|
+
})
|
210
|
+
if type(message) == str:
|
211
|
+
self.conversation[convo_id].append({
|
212
|
+
"role": role,
|
213
|
+
"content": [{
|
214
|
+
"type": "text",
|
215
|
+
"text": message
|
216
|
+
}]
|
217
|
+
})
|
218
|
+
elif role == "assistant" and function_full_response:
|
219
|
+
print("function_full_response", function_full_response)
|
220
|
+
function_dict = {
|
221
|
+
"type": "tool_use",
|
222
|
+
"id": f"{tools_id}",
|
223
|
+
"name": f"{function_name}",
|
224
|
+
"input": json.loads(function_full_response)
|
225
|
+
# "input": json.dumps(function_full_response, ensure_ascii=False)
|
226
|
+
}
|
227
|
+
self.conversation[convo_id].append({"role": role, "content": [function_dict]})
|
228
|
+
function_dict = {
|
229
|
+
"type": "tool_result",
|
230
|
+
"tool_use_id": f"{tools_id}",
|
231
|
+
"content": f"{message}",
|
232
|
+
# "is_error": true
|
233
|
+
}
|
234
|
+
self.conversation[convo_id].append({"role": "user", "content": [function_dict]})
|
235
|
+
|
236
|
+
conversation_len = len(self.conversation[convo_id]) - 1
|
237
|
+
message_index = 0
|
238
|
+
while message_index < conversation_len:
|
239
|
+
if self.conversation[convo_id][message_index]["role"] == self.conversation[convo_id][message_index + 1]["role"]:
|
240
|
+
self.conversation[convo_id][message_index]["content"] += self.conversation[convo_id][message_index + 1]["content"]
|
241
|
+
self.conversation[convo_id].pop(message_index + 1)
|
242
|
+
conversation_len = conversation_len - 1
|
243
|
+
else:
|
244
|
+
message_index = message_index + 1
|
245
|
+
|
246
|
+
history_len = len(self.conversation[convo_id])
|
247
|
+
history = pass_history
|
248
|
+
if pass_history < 2:
|
249
|
+
history = 2
|
250
|
+
while history_len > history:
|
251
|
+
mess_body = self.conversation[convo_id].pop(1)
|
252
|
+
history_len = history_len - 1
|
253
|
+
if mess_body.get("role") == "user":
|
254
|
+
mess_body = self.conversation[convo_id].pop(1)
|
255
|
+
history_len = history_len - 1
|
256
|
+
if safe_get(mess_body, "content", 0, "type") == "tool_use":
|
257
|
+
self.conversation[convo_id].pop(1)
|
258
|
+
history_len = history_len - 1
|
259
|
+
|
260
|
+
if total_tokens:
|
261
|
+
self.tokens_usage[convo_id] += total_tokens
|
262
|
+
|
263
|
+
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
|
264
|
+
"""
|
265
|
+
Reset the conversation
|
266
|
+
"""
|
267
|
+
self.conversation[convo_id] = list()
|
268
|
+
self.system_prompt = system_prompt or self.system_prompt
|
269
|
+
|
270
|
+
def __truncate_conversation(self, convo_id: str = "default") -> None:
|
271
|
+
"""
|
272
|
+
Truncate the conversation
|
273
|
+
"""
|
274
|
+
while True:
|
275
|
+
if (
|
276
|
+
self.get_token_count(convo_id) > self.truncate_limit
|
277
|
+
and len(self.conversation[convo_id]) > 1
|
278
|
+
):
|
279
|
+
# Don't remove the first message
|
280
|
+
self.conversation[convo_id].pop(1)
|
281
|
+
else:
|
282
|
+
break
|
283
|
+
|
284
|
+
def get_token_count(self, convo_id: str = "default") -> int:
|
285
|
+
"""
|
286
|
+
Get token count
|
287
|
+
"""
|
288
|
+
tiktoken.model.MODEL_TO_ENCODING["claude-2.1"] = "cl100k_base"
|
289
|
+
encoding = tiktoken.encoding_for_model(self.engine)
|
290
|
+
|
291
|
+
num_tokens = 0
|
292
|
+
for message in self.conversation[convo_id]:
|
293
|
+
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
294
|
+
num_tokens += 5
|
295
|
+
for key, value in message.items():
|
296
|
+
if value:
|
297
|
+
num_tokens += len(encoding.encode(value))
|
298
|
+
if key == "name": # if there's a name, the role is omitted
|
299
|
+
num_tokens += 5 # role is always required and always 1 token
|
300
|
+
num_tokens += 5 # every reply is primed with <im_start>assistant
|
301
|
+
return num_tokens
|
302
|
+
|
303
|
+
def ask_stream(
|
304
|
+
self,
|
305
|
+
prompt: str,
|
306
|
+
role: str = "user",
|
307
|
+
convo_id: str = "default",
|
308
|
+
model: str = "",
|
309
|
+
pass_history: int = 9999,
|
310
|
+
model_max_tokens: int = 4096,
|
311
|
+
tools_id: str = "",
|
312
|
+
total_tokens: int = 0,
|
313
|
+
function_name: str = "",
|
314
|
+
function_full_response: str = "",
|
315
|
+
language: str = "English",
|
316
|
+
system_prompt: str = None,
|
317
|
+
**kwargs,
|
318
|
+
):
|
319
|
+
self.add_to_conversation(prompt, role, convo_id=convo_id, tools_id=tools_id, total_tokens=total_tokens, function_name=function_name, function_full_response=function_full_response, pass_history=pass_history)
|
320
|
+
# self.__truncate_conversation(convo_id=convo_id)
|
321
|
+
# print(self.conversation[convo_id])
|
322
|
+
|
323
|
+
url = self.api_url.source_api_url
|
324
|
+
now_model = model or self.engine
|
325
|
+
headers = {
|
326
|
+
"content-type": "application/json",
|
327
|
+
"x-api-key": f"{kwargs.get('api_key', self.api_key)}",
|
328
|
+
"anthropic-version": "2023-06-01",
|
329
|
+
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in now_model else "tools-2024-05-16",
|
330
|
+
}
|
331
|
+
|
332
|
+
json_post = {
|
333
|
+
"model": now_model,
|
334
|
+
"messages": self.conversation[convo_id] if pass_history else [{
|
335
|
+
"role": "user",
|
336
|
+
"content": prompt
|
337
|
+
}],
|
338
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
339
|
+
"top_p": kwargs.get("top_p", self.top_p),
|
340
|
+
"max_tokens": 8192 if "claude-3-5-sonnet" in now_model else model_max_tokens,
|
341
|
+
"stream": True,
|
342
|
+
}
|
343
|
+
json_post["system"] = system_prompt or self.system_prompt
|
344
|
+
plugins = kwargs.get("plugins", PLUGINS)
|
345
|
+
if all(value == False for value in plugins.values()) == False and self.use_plugins:
|
346
|
+
json_post.update(copy.deepcopy(claude_tools_list["base"]))
|
347
|
+
for item in plugins.keys():
|
348
|
+
try:
|
349
|
+
if plugins[item]:
|
350
|
+
json_post["tools"].append(claude_tools_list[item])
|
351
|
+
except:
|
352
|
+
pass
|
353
|
+
|
354
|
+
if self.print_log:
|
355
|
+
replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
|
356
|
+
print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
|
357
|
+
|
358
|
+
try:
|
359
|
+
response = self.session.post(
|
360
|
+
url,
|
361
|
+
headers=headers,
|
362
|
+
json=json_post,
|
363
|
+
timeout=kwargs.get("timeout", self.timeout),
|
364
|
+
stream=True,
|
365
|
+
)
|
366
|
+
except ConnectionError:
|
367
|
+
print("连接错误,请检查服务器状态或网络连接。")
|
368
|
+
return
|
369
|
+
except requests.exceptions.ReadTimeout:
|
370
|
+
print("请求超时,请检查网络连接或增加超时时间。{e}")
|
371
|
+
return
|
372
|
+
except Exception as e:
|
373
|
+
print(f"发生了未预料的错误: {e}")
|
374
|
+
return
|
375
|
+
|
376
|
+
if response.status_code != 200:
|
377
|
+
print(response.text)
|
378
|
+
raise BaseException(f"{response.status_code} {response.reason} {response.text}")
|
379
|
+
response_role: str = "assistant"
|
380
|
+
full_response: str = ""
|
381
|
+
need_function_call: bool = False
|
382
|
+
function_call_name: str = ""
|
383
|
+
function_full_response: str = ""
|
384
|
+
total_tokens = 0
|
385
|
+
tools_id = ""
|
386
|
+
for line in response.iter_lines():
|
387
|
+
if not line or line.decode("utf-8")[:6] == "event:" or line.decode("utf-8") == "data: {}":
|
388
|
+
continue
|
389
|
+
# print(line.decode("utf-8"))
|
390
|
+
# if "tool_use" in line.decode("utf-8"):
|
391
|
+
# tool_input = json.loads(line.decode("utf-8")["content"][1]["input"])
|
392
|
+
# else:
|
393
|
+
# line = line.decode("utf-8")[6:]
|
394
|
+
line = line.decode("utf-8")
|
395
|
+
line = line.lstrip("data: ")
|
396
|
+
# print(line)
|
397
|
+
resp: dict = json.loads(line)
|
398
|
+
if resp.get("error"):
|
399
|
+
print("error:", resp["error"])
|
400
|
+
raise BaseException(f"{resp['error']}")
|
401
|
+
|
402
|
+
message = resp.get("message")
|
403
|
+
if message:
|
404
|
+
usage = message.get("usage")
|
405
|
+
input_tokens = usage.get("input_tokens", 0)
|
406
|
+
# output_tokens = usage.get("output_tokens", 0)
|
407
|
+
output_tokens = 0
|
408
|
+
total_tokens = total_tokens + input_tokens + output_tokens
|
409
|
+
|
410
|
+
usage = resp.get("usage")
|
411
|
+
if usage:
|
412
|
+
input_tokens = usage.get("input_tokens", 0)
|
413
|
+
output_tokens = usage.get("output_tokens", 0)
|
414
|
+
total_tokens = total_tokens + input_tokens + output_tokens
|
415
|
+
|
416
|
+
# print("\n\rtotal_tokens", total_tokens)
|
417
|
+
|
418
|
+
tool_use = resp.get("content_block")
|
419
|
+
if tool_use and "tool_use" == tool_use['type']:
|
420
|
+
# print("tool_use", tool_use)
|
421
|
+
tools_id = tool_use["id"]
|
422
|
+
need_function_call = True
|
423
|
+
if "name" in tool_use:
|
424
|
+
function_call_name = tool_use["name"]
|
425
|
+
delta = resp.get("delta")
|
426
|
+
# print("delta", delta)
|
427
|
+
if not delta:
|
428
|
+
continue
|
429
|
+
if "text" in delta:
|
430
|
+
content = delta["text"]
|
431
|
+
full_response += content
|
432
|
+
yield content
|
433
|
+
if "partial_json" in delta:
|
434
|
+
function_call_content = delta["partial_json"]
|
435
|
+
function_full_response += function_call_content
|
436
|
+
|
437
|
+
# print("function_full_response", function_full_response)
|
438
|
+
# print("function_call_name", function_call_name)
|
439
|
+
# print("need_function_call", need_function_call)
|
440
|
+
if self.print_log:
|
441
|
+
print("\n\rtotal_tokens", total_tokens)
|
442
|
+
if need_function_call:
|
443
|
+
function_full_response = check_json(function_full_response)
|
444
|
+
print("function_full_response", function_full_response)
|
445
|
+
function_response = ""
|
446
|
+
function_call_max_tokens = int(self.truncate_limit / 2)
|
447
|
+
|
448
|
+
# function_response = yield from get_tools_result(function_call_name, function_full_response, function_call_max_tokens, self.engine, claude3, kwargs.get('api_key', self.api_key), self.api_url, use_plugins=False, model=model or self.engine, add_message=self.add_to_conversation, convo_id=convo_id, language=language)
|
449
|
+
|
450
|
+
async def run_async():
|
451
|
+
nonlocal function_response
|
452
|
+
async for chunk in get_tools_result_async(
|
453
|
+
function_call_name, function_full_response, function_call_max_tokens,
|
454
|
+
model or self.engine, claude3, kwargs.get('api_key', self.api_key),
|
455
|
+
self.api_url, use_plugins=False, model=model or self.engine,
|
456
|
+
add_message=self.add_to_conversation, convo_id=convo_id, language=language
|
457
|
+
):
|
458
|
+
if "function_response:" in chunk:
|
459
|
+
function_response = chunk.replace("function_response:", "")
|
460
|
+
else:
|
461
|
+
yield chunk
|
462
|
+
|
463
|
+
# 使用封装后的函数
|
464
|
+
for chunk in async_generator_to_sync(run_async()):
|
465
|
+
yield chunk
|
466
|
+
|
467
|
+
response_role = "assistant"
|
468
|
+
if self.conversation[convo_id][-1]["role"] == "function" and self.conversation[convo_id][-1]["name"] == "get_search_results":
|
469
|
+
mess = self.conversation[convo_id].pop(-1)
|
470
|
+
yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name, total_tokens=total_tokens, model=model or self.engine, tools_id=tools_id, function_full_response=function_full_response, api_key=kwargs.get('api_key', self.api_key), plugins=kwargs.get("plugins", PLUGINS), system_prompt=system_prompt)
|
471
|
+
else:
|
472
|
+
if self.conversation[convo_id][-1]["role"] == "function" and self.conversation[convo_id][-1]["name"] == "get_search_results":
|
473
|
+
mess = self.conversation[convo_id].pop(-1)
|
474
|
+
self.add_to_conversation(full_response, response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)
|
475
|
+
self.function_calls_counter = {}
|
476
|
+
if pass_history <= 2 and len(self.conversation[convo_id]) >= 2 and ("You are a translation engine" in self.conversation[convo_id][-2]["content"] or (type(self.conversation[convo_id][-2]["content"]) == list and "You are a translation engine" in self.conversation[convo_id][-2]["content"][0]["text"])):
|
477
|
+
self.conversation[convo_id].pop(-1)
|
478
|
+
self.conversation[convo_id].pop(-1)
|
479
|
+
|
480
|
+
async def ask_stream_async(
|
481
|
+
self,
|
482
|
+
prompt: str,
|
483
|
+
role: str = "user",
|
484
|
+
convo_id: str = "default",
|
485
|
+
model: str = "",
|
486
|
+
pass_history: int = 9999,
|
487
|
+
model_max_tokens: int = 4096,
|
488
|
+
tools_id: str = "",
|
489
|
+
total_tokens: int = 0,
|
490
|
+
function_name: str = "",
|
491
|
+
function_full_response: str = "",
|
492
|
+
language: str = "English",
|
493
|
+
system_prompt: str = None,
|
494
|
+
**kwargs,
|
495
|
+
):
|
496
|
+
self.add_to_conversation(prompt, role, convo_id=convo_id, tools_id=tools_id, total_tokens=total_tokens, function_name=function_name, function_full_response=function_full_response, pass_history=pass_history)
|
497
|
+
# self.__truncate_conversation(convo_id=convo_id)
|
498
|
+
# print(self.conversation[convo_id])
|
499
|
+
|
500
|
+
url = self.api_url.source_api_url
|
501
|
+
now_model = model or self.engine
|
502
|
+
headers = {
|
503
|
+
"content-type": "application/json",
|
504
|
+
"x-api-key": f"{kwargs.get('api_key', self.api_key)}",
|
505
|
+
"anthropic-version": "2023-06-01",
|
506
|
+
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in now_model else "tools-2024-05-16",
|
507
|
+
}
|
508
|
+
|
509
|
+
json_post = {
|
510
|
+
"model": now_model,
|
511
|
+
"messages": self.conversation[convo_id] if pass_history else [{
|
512
|
+
"role": "user",
|
513
|
+
"content": prompt
|
514
|
+
}],
|
515
|
+
"temperature": kwargs.get("temperature", self.temperature),
|
516
|
+
"top_p": kwargs.get("top_p", self.top_p),
|
517
|
+
"max_tokens": 8192 if "claude-3-5-sonnet" in now_model else model_max_tokens,
|
518
|
+
"stream": True,
|
519
|
+
}
|
520
|
+
json_post["system"] = system_prompt or self.system_prompt
|
521
|
+
plugins = kwargs.get("plugins", PLUGINS)
|
522
|
+
if all(value == False for value in plugins.values()) == False and self.use_plugins:
|
523
|
+
json_post.update(copy.deepcopy(claude_tools_list["base"]))
|
524
|
+
for item in plugins.keys():
|
525
|
+
try:
|
526
|
+
if plugins[item]:
|
527
|
+
json_post["tools"].append(claude_tools_list[item])
|
528
|
+
except:
|
529
|
+
pass
|
530
|
+
|
531
|
+
if self.print_log:
|
532
|
+
replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
|
533
|
+
print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
|
534
|
+
|
535
|
+
try:
|
536
|
+
response = self.session.post(
|
537
|
+
url,
|
538
|
+
headers=headers,
|
539
|
+
json=json_post,
|
540
|
+
timeout=kwargs.get("timeout", self.timeout),
|
541
|
+
stream=True,
|
542
|
+
)
|
543
|
+
except ConnectionError:
|
544
|
+
print("连接错误,请检查服务器状态或网络连接。")
|
545
|
+
return
|
546
|
+
except requests.exceptions.ReadTimeout:
|
547
|
+
print("请求超时,请检查网络连接或增加超时时间。{e}")
|
548
|
+
return
|
549
|
+
except Exception as e:
|
550
|
+
print(f"发生了未预料的错误: {e}")
|
551
|
+
return
|
552
|
+
|
553
|
+
if response.status_code != 200:
|
554
|
+
print(response.text)
|
555
|
+
raise BaseException(f"{response.status_code} {response.reason} {response.text}")
|
556
|
+
response_role: str = "assistant"
|
557
|
+
full_response: str = ""
|
558
|
+
need_function_call: bool = False
|
559
|
+
function_call_name: str = ""
|
560
|
+
function_full_response: str = ""
|
561
|
+
total_tokens = 0
|
562
|
+
tools_id = ""
|
563
|
+
for line in response.iter_lines():
|
564
|
+
if not line or line.decode("utf-8")[:6] == "event:" or line.decode("utf-8") == "data: {}":
|
565
|
+
continue
|
566
|
+
# print(line.decode("utf-8"))
|
567
|
+
# if "tool_use" in line.decode("utf-8"):
|
568
|
+
# tool_input = json.loads(line.decode("utf-8")["content"][1]["input"])
|
569
|
+
# else:
|
570
|
+
# line = line.decode("utf-8")[6:]
|
571
|
+
line = line.decode("utf-8")[5:]
|
572
|
+
if line.startswith(" "):
|
573
|
+
line = line[1:]
|
574
|
+
# print(line)
|
575
|
+
resp: dict = json.loads(line)
|
576
|
+
if resp.get("error"):
|
577
|
+
print("error:", resp["error"])
|
578
|
+
raise BaseException(f"{resp['error']}")
|
579
|
+
|
580
|
+
message = resp.get("message")
|
581
|
+
if message:
|
582
|
+
usage = message.get("usage")
|
583
|
+
input_tokens = usage.get("input_tokens", 0)
|
584
|
+
# output_tokens = usage.get("output_tokens", 0)
|
585
|
+
output_tokens = 0
|
586
|
+
total_tokens = total_tokens + input_tokens + output_tokens
|
587
|
+
|
588
|
+
usage = resp.get("usage")
|
589
|
+
if usage:
|
590
|
+
input_tokens = usage.get("input_tokens", 0)
|
591
|
+
output_tokens = usage.get("output_tokens", 0)
|
592
|
+
total_tokens = total_tokens + input_tokens + output_tokens
|
593
|
+
if self.print_log:
|
594
|
+
print("\n\rtotal_tokens", total_tokens)
|
595
|
+
|
596
|
+
tool_use = resp.get("content_block")
|
597
|
+
if tool_use and "tool_use" == tool_use['type']:
|
598
|
+
# print("tool_use", tool_use)
|
599
|
+
tools_id = tool_use["id"]
|
600
|
+
need_function_call = True
|
601
|
+
if "name" in tool_use:
|
602
|
+
function_call_name = tool_use["name"]
|
603
|
+
delta = resp.get("delta")
|
604
|
+
# print("delta", delta)
|
605
|
+
if not delta:
|
606
|
+
continue
|
607
|
+
if "text" in delta:
|
608
|
+
content = delta["text"]
|
609
|
+
full_response += content
|
610
|
+
yield content
|
611
|
+
if "partial_json" in delta:
|
612
|
+
function_call_content = delta["partial_json"]
|
613
|
+
function_full_response += function_call_content
|
614
|
+
# print("function_full_response", function_full_response)
|
615
|
+
# print("function_call_name", function_call_name)
|
616
|
+
# print("need_function_call", need_function_call)
|
617
|
+
if need_function_call:
|
618
|
+
function_full_response = check_json(function_full_response)
|
619
|
+
print("function_full_response", function_full_response)
|
620
|
+
function_response = ""
|
621
|
+
function_call_max_tokens = int(self.truncate_limit / 2)
|
622
|
+
async for chunk in get_tools_result_async(function_call_name, function_full_response, function_call_max_tokens, self.engine, claude3, kwargs.get('api_key', self.api_key), self.api_url, use_plugins=False, model=model or self.engine, add_message=self.add_to_conversation, convo_id=convo_id, language=language):
|
623
|
+
if "function_response:" in chunk:
|
624
|
+
function_response = chunk.replace("function_response:", "")
|
625
|
+
else:
|
626
|
+
yield chunk
|
627
|
+
response_role = "assistant"
|
628
|
+
if self.conversation[convo_id][-1]["role"] == "function" and self.conversation[convo_id][-1]["name"] == "get_search_results":
|
629
|
+
mess = self.conversation[convo_id].pop(-1)
|
630
|
+
async for chunk in self.ask_stream_async(function_response, response_role, convo_id=convo_id, function_name=function_call_name, total_tokens=total_tokens, model=model or self.engine, tools_id=tools_id, function_full_response=function_full_response, api_key=kwargs.get('api_key', self.api_key), plugins=kwargs.get("plugins", PLUGINS), system_prompt=system_prompt):
|
631
|
+
yield chunk
|
632
|
+
# yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name, total_tokens=total_tokens, tools_id=tools_id, function_full_response=function_full_response)
|
633
|
+
else:
|
634
|
+
if self.conversation[convo_id][-1]["role"] == "function" and self.conversation[convo_id][-1]["name"] == "get_search_results":
|
635
|
+
mess = self.conversation[convo_id].pop(-1)
|
636
|
+
self.add_to_conversation(full_response, response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)
|
637
|
+
self.function_calls_counter = {}
|
638
|
+
if pass_history <= 2 and len(self.conversation[convo_id]) >= 2 and ("You are a translation engine" in self.conversation[convo_id][-2]["content"] or (type(self.conversation[convo_id][-2]["content"]) == list and "You are a translation engine" in self.conversation[convo_id][-2]["content"][0]["text"])):
|
639
|
+
self.conversation[convo_id].pop(-1)
|
640
|
+
self.conversation[convo_id].pop(-1)
|