aient 1.0.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/__init__.py +1 -0
- aient/core/.git +1 -0
- aient/core/__init__.py +1 -0
- aient/core/log_config.py +6 -0
- aient/core/models.py +227 -0
- aient/core/request.py +1361 -0
- aient/core/response.py +531 -0
- aient/core/test/test_base_api.py +17 -0
- aient/core/test/test_image.py +15 -0
- aient/core/test/test_payload.py +92 -0
- aient/core/utils.py +655 -0
- aient/models/__init__.py +9 -0
- aient/models/audio.py +63 -0
- aient/models/base.py +270 -0
- aient/models/chatgpt.py +856 -0
- aient/models/claude.py +640 -0
- aient/models/duckduckgo.py +241 -0
- aient/models/gemini.py +357 -0
- aient/models/groq.py +268 -0
- aient/models/vertex.py +420 -0
- aient/plugins/__init__.py +32 -0
- aient/plugins/arXiv.py +48 -0
- aient/plugins/config.py +178 -0
- aient/plugins/image.py +72 -0
- aient/plugins/registry.py +116 -0
- aient/plugins/run_python.py +156 -0
- aient/plugins/today.py +19 -0
- aient/plugins/websearch.py +393 -0
- aient/utils/__init__.py +0 -0
- aient/utils/prompt.py +143 -0
- aient/utils/scripts.py +235 -0
- aient-1.0.29.dist-info/METADATA +119 -0
- aient-1.0.29.dist-info/RECORD +36 -0
- aient-1.0.29.dist-info/WHEEL +5 -0
- aient-1.0.29.dist-info/licenses/LICENSE +7 -0
- aient-1.0.29.dist-info/top_level.txt +1 -0
aient/models/chatgpt.py
ADDED
@@ -0,0 +1,856 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
import json
|
4
|
+
import copy
|
5
|
+
import httpx
|
6
|
+
import asyncio
|
7
|
+
import requests
|
8
|
+
from typing import Set
|
9
|
+
from typing import Union, Optional, Callable
|
10
|
+
from pathlib import Path
|
11
|
+
|
12
|
+
|
13
|
+
from .base import BaseLLM
|
14
|
+
from ..plugins import PLUGINS, get_tools_result_async, function_call_list, update_tools_config
|
15
|
+
from ..utils.scripts import check_json, safe_get, async_generator_to_sync
|
16
|
+
from ..core.request import prepare_request_payload
|
17
|
+
from ..core.response import fetch_response_stream
|
18
|
+
|
19
|
+
def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]:
|
20
|
+
"""
|
21
|
+
Get filtered list of object variable names.
|
22
|
+
:param keys: List of keys to include. If the first key is "not", the remaining keys will be removed from the class keys.
|
23
|
+
:return: List of class keys.
|
24
|
+
"""
|
25
|
+
class_keys = obj.__dict__.keys()
|
26
|
+
if not keys:
|
27
|
+
return set(class_keys)
|
28
|
+
|
29
|
+
# Remove the passed keys from the class keys.
|
30
|
+
if keys[0] == "not":
|
31
|
+
return {key for key in class_keys if key not in keys[1:]}
|
32
|
+
# Check if all passed keys are valid
|
33
|
+
if invalid_keys := set(keys) - class_keys:
|
34
|
+
raise ValueError(
|
35
|
+
f"Invalid keys: {invalid_keys}",
|
36
|
+
)
|
37
|
+
# Only return specified keys that are in class_keys
|
38
|
+
return {key for key in keys if key in class_keys}
|
39
|
+
|
40
|
+
class chatgpt(BaseLLM):
|
41
|
+
"""
|
42
|
+
Official ChatGPT API
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
api_key: str = None,
|
48
|
+
engine: str = os.environ.get("GPT_ENGINE") or "gpt-4o",
|
49
|
+
api_url: str = (os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions"),
|
50
|
+
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
51
|
+
proxy: str = None,
|
52
|
+
timeout: float = 600,
|
53
|
+
max_tokens: int = None,
|
54
|
+
temperature: float = 0.5,
|
55
|
+
top_p: float = 1.0,
|
56
|
+
presence_penalty: float = 0.0,
|
57
|
+
frequency_penalty: float = 0.0,
|
58
|
+
reply_count: int = 1,
|
59
|
+
truncate_limit: int = None,
|
60
|
+
use_plugins: bool = True,
|
61
|
+
print_log: bool = False,
|
62
|
+
tools: Optional[Union[list, str, Callable]] = [],
|
63
|
+
) -> None:
|
64
|
+
"""
|
65
|
+
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
|
66
|
+
"""
|
67
|
+
super().__init__(api_key, engine, api_url, system_prompt, proxy, timeout, max_tokens, temperature, top_p, presence_penalty, frequency_penalty, reply_count, truncate_limit, use_plugins=use_plugins, print_log=print_log)
|
68
|
+
self.conversation: dict[str, list[dict]] = {
|
69
|
+
"default": [
|
70
|
+
{
|
71
|
+
"role": "system",
|
72
|
+
"content": self.system_prompt,
|
73
|
+
},
|
74
|
+
],
|
75
|
+
}
|
76
|
+
self.function_calls_counter = {}
|
77
|
+
self.function_call_max_loop = 3
|
78
|
+
|
79
|
+
|
80
|
+
# 注册和处理传入的工具
|
81
|
+
self._register_tools(tools)
|
82
|
+
|
83
|
+
if self.tokens_usage["default"] > self.max_tokens:
|
84
|
+
raise Exception("System prompt is too long")
|
85
|
+
|
86
|
+
def _register_tools(self, tools):
|
87
|
+
"""动态注册工具函数并更新配置"""
|
88
|
+
|
89
|
+
self.plugins = copy.deepcopy(PLUGINS)
|
90
|
+
self.function_call_list = copy.deepcopy(function_call_list)
|
91
|
+
# 如果有新工具,需要注册到registry并更新配置
|
92
|
+
self.plugins, self.function_call_list, _ = update_tools_config()
|
93
|
+
|
94
|
+
if isinstance(tools, list):
|
95
|
+
self.tools = tools if tools else []
|
96
|
+
else:
|
97
|
+
self.tools = [tools] if tools else []
|
98
|
+
|
99
|
+
for tool in self.tools:
|
100
|
+
tool_name = tool.__name__ if callable(tool) else str(tool)
|
101
|
+
if tool_name in self.plugins:
|
102
|
+
self.plugins[tool_name] = True
|
103
|
+
else:
|
104
|
+
raise ValueError(f"Tool {tool_name} not found in plugins")
|
105
|
+
|
106
|
+
def add_to_conversation(
|
107
|
+
self,
|
108
|
+
message: Union[str, list],
|
109
|
+
role: str,
|
110
|
+
convo_id: str = "default",
|
111
|
+
function_name: str = "",
|
112
|
+
total_tokens: int = 0,
|
113
|
+
function_arguments: str = "",
|
114
|
+
pass_history: int = 9999,
|
115
|
+
function_call_id: str = "",
|
116
|
+
) -> None:
|
117
|
+
"""
|
118
|
+
Add a message to the conversation
|
119
|
+
"""
|
120
|
+
# print("role", role, "function_name", function_name, "message", message)
|
121
|
+
if convo_id not in self.conversation:
|
122
|
+
self.reset(convo_id=convo_id)
|
123
|
+
if function_name == "" and message and message != None:
|
124
|
+
self.conversation[convo_id].append({"role": role, "content": message})
|
125
|
+
elif function_name != "" and message and message != None:
|
126
|
+
self.conversation[convo_id].append({
|
127
|
+
"role": "assistant",
|
128
|
+
"tool_calls": [
|
129
|
+
{
|
130
|
+
"id": function_call_id,
|
131
|
+
"type": "function",
|
132
|
+
"function": {
|
133
|
+
"name": function_name,
|
134
|
+
"arguments": function_arguments,
|
135
|
+
},
|
136
|
+
}
|
137
|
+
],
|
138
|
+
})
|
139
|
+
self.conversation[convo_id].append({"role": role, "tool_call_id": function_call_id, "content": message})
|
140
|
+
else:
|
141
|
+
print('\033[31m')
|
142
|
+
print("error: add_to_conversation message is None or empty")
|
143
|
+
print("role", role, "function_name", function_name, "message", message)
|
144
|
+
print('\033[0m')
|
145
|
+
|
146
|
+
conversation_len = len(self.conversation[convo_id]) - 1
|
147
|
+
message_index = 0
|
148
|
+
# print(json.dumps(self.conversation[convo_id], indent=4, ensure_ascii=False))
|
149
|
+
while message_index < conversation_len:
|
150
|
+
if self.conversation[convo_id][message_index]["role"] == self.conversation[convo_id][message_index + 1]["role"]:
|
151
|
+
if self.conversation[convo_id][message_index].get("content") and self.conversation[convo_id][message_index + 1].get("content"):
|
152
|
+
if type(self.conversation[convo_id][message_index + 1]["content"]) == str \
|
153
|
+
and type(self.conversation[convo_id][message_index]["content"]) == list:
|
154
|
+
self.conversation[convo_id][message_index + 1]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index + 1]["content"]}]
|
155
|
+
if type(self.conversation[convo_id][message_index]["content"]) == str \
|
156
|
+
and type(self.conversation[convo_id][message_index + 1]["content"]) == list:
|
157
|
+
self.conversation[convo_id][message_index]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index]["content"]}]
|
158
|
+
if type(self.conversation[convo_id][message_index]["content"]) == dict \
|
159
|
+
and type(self.conversation[convo_id][message_index + 1]["content"]) == str:
|
160
|
+
self.conversation[convo_id][message_index]["content"] = [self.conversation[convo_id][message_index]["content"]]
|
161
|
+
self.conversation[convo_id][message_index + 1]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index + 1]["content"]}]
|
162
|
+
self.conversation[convo_id][message_index]["content"] += self.conversation[convo_id][message_index + 1]["content"]
|
163
|
+
self.conversation[convo_id].pop(message_index + 1)
|
164
|
+
conversation_len = conversation_len - 1
|
165
|
+
else:
|
166
|
+
message_index = message_index + 1
|
167
|
+
|
168
|
+
history_len = len(self.conversation[convo_id])
|
169
|
+
|
170
|
+
history = pass_history
|
171
|
+
if pass_history < 2:
|
172
|
+
history = 2
|
173
|
+
while history_len > history:
|
174
|
+
mess_body = self.conversation[convo_id].pop(1)
|
175
|
+
history_len = history_len - 1
|
176
|
+
if mess_body.get("role") == "user":
|
177
|
+
assistant_body = self.conversation[convo_id].pop(1)
|
178
|
+
history_len = history_len - 1
|
179
|
+
if assistant_body.get("tool_calls"):
|
180
|
+
self.conversation[convo_id].pop(1)
|
181
|
+
history_len = history_len - 1
|
182
|
+
|
183
|
+
if total_tokens:
|
184
|
+
self.tokens_usage[convo_id] += total_tokens
|
185
|
+
|
186
|
+
def truncate_conversation(self, convo_id: str = "default") -> None:
|
187
|
+
"""
|
188
|
+
Truncate the conversation
|
189
|
+
"""
|
190
|
+
while True:
|
191
|
+
if (
|
192
|
+
self.tokens_usage[convo_id] > self.truncate_limit
|
193
|
+
and len(self.conversation[convo_id]) > 1
|
194
|
+
):
|
195
|
+
# Don't remove the first message
|
196
|
+
mess = self.conversation[convo_id].pop(1)
|
197
|
+
print("Truncate message:", mess)
|
198
|
+
else:
|
199
|
+
break
|
200
|
+
|
201
|
+
def extract_values(self, obj):
|
202
|
+
if isinstance(obj, dict):
|
203
|
+
for value in obj.values():
|
204
|
+
yield from self.extract_values(value)
|
205
|
+
elif isinstance(obj, list):
|
206
|
+
for item in obj:
|
207
|
+
yield from self.extract_values(item)
|
208
|
+
else:
|
209
|
+
yield obj
|
210
|
+
|
211
|
+
async def get_post_body(
|
212
|
+
self,
|
213
|
+
prompt: str,
|
214
|
+
role: str = "user",
|
215
|
+
convo_id: str = "default",
|
216
|
+
model: str = "",
|
217
|
+
pass_history: int = 9999,
|
218
|
+
**kwargs,
|
219
|
+
):
|
220
|
+
self.conversation[convo_id][0] = {"role": "system","content": self.system_prompt}
|
221
|
+
|
222
|
+
# 构造 provider 信息
|
223
|
+
provider = {
|
224
|
+
"provider": "openai",
|
225
|
+
"base_url": kwargs.get('api_url', self.api_url.chat_url),
|
226
|
+
"api": kwargs.get('api_key', self.api_key),
|
227
|
+
"model": [model or self.engine],
|
228
|
+
"tools": True if self.use_plugins else False,
|
229
|
+
"image": True
|
230
|
+
}
|
231
|
+
|
232
|
+
# 构造请求数据
|
233
|
+
request_data = {
|
234
|
+
"model": model or self.engine,
|
235
|
+
"messages": copy.deepcopy(self.conversation[convo_id]) if pass_history else [
|
236
|
+
{"role": "system","content": self.system_prompt},
|
237
|
+
{"role": role, "content": prompt}
|
238
|
+
],
|
239
|
+
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
240
|
+
"stream": True,
|
241
|
+
"stream_options": {
|
242
|
+
"include_usage": True
|
243
|
+
},
|
244
|
+
"temperature": kwargs.get("temperature", self.temperature)
|
245
|
+
}
|
246
|
+
|
247
|
+
# 添加工具相关信息
|
248
|
+
plugins_status = kwargs.get("plugins", self.plugins)
|
249
|
+
if not (all(value == False for value in plugins_status.values()) or self.use_plugins == False):
|
250
|
+
tools_request_body = []
|
251
|
+
for item in plugins_status.keys():
|
252
|
+
try:
|
253
|
+
if plugins_status[item]:
|
254
|
+
tools_request_body.append({"type": "function", "function": self.function_call_list[item]})
|
255
|
+
except:
|
256
|
+
pass
|
257
|
+
if tools_request_body:
|
258
|
+
request_data["tools"] = tools_request_body
|
259
|
+
request_data["tool_choice"] = "auto"
|
260
|
+
|
261
|
+
# print("request_data", json.dumps(request_data, indent=4, ensure_ascii=False))
|
262
|
+
|
263
|
+
# 调用核心模块的 prepare_request_payload 函数
|
264
|
+
url, headers, json_post_body, engine_type = await prepare_request_payload(provider, request_data)
|
265
|
+
|
266
|
+
return url, headers, json_post_body, engine_type
|
267
|
+
|
268
|
+
async def _process_stream_response(
|
269
|
+
self,
|
270
|
+
response_gen,
|
271
|
+
convo_id="default",
|
272
|
+
function_name="",
|
273
|
+
total_tokens=0,
|
274
|
+
function_arguments="",
|
275
|
+
function_call_id="",
|
276
|
+
model="",
|
277
|
+
language="English",
|
278
|
+
system_prompt=None,
|
279
|
+
pass_history=9999,
|
280
|
+
is_async=False,
|
281
|
+
**kwargs
|
282
|
+
):
|
283
|
+
"""
|
284
|
+
处理流式响应的共用逻辑
|
285
|
+
|
286
|
+
:param response_gen: 响应生成器(同步或异步)
|
287
|
+
:param is_async: 是否使用异步模式
|
288
|
+
"""
|
289
|
+
response_role = None
|
290
|
+
full_response = ""
|
291
|
+
function_full_response = ""
|
292
|
+
function_call_name = ""
|
293
|
+
need_function_call = False
|
294
|
+
|
295
|
+
# 处理单行数据的公共逻辑
|
296
|
+
def process_line(line):
|
297
|
+
nonlocal response_role, full_response, function_full_response, function_call_name, need_function_call, total_tokens, function_call_id
|
298
|
+
|
299
|
+
if not line or (isinstance(line, str) and line.startswith(':')):
|
300
|
+
return None
|
301
|
+
|
302
|
+
if isinstance(line, str) and line.startswith('data:'):
|
303
|
+
line = line.lstrip("data: ")
|
304
|
+
if line == "[DONE]":
|
305
|
+
return "DONE"
|
306
|
+
elif isinstance(line, (dict, list)):
|
307
|
+
if isinstance(line, dict) and safe_get(line, "choices", 0, "message", "content"):
|
308
|
+
full_response = line["choices"][0]["message"]["content"]
|
309
|
+
return full_response
|
310
|
+
else:
|
311
|
+
return str(line)
|
312
|
+
else:
|
313
|
+
try:
|
314
|
+
if isinstance(line, str):
|
315
|
+
line = json.loads(line)
|
316
|
+
if safe_get(line, "choices", 0, "message", "content"):
|
317
|
+
full_response = line["choices"][0]["message"]["content"]
|
318
|
+
return full_response
|
319
|
+
else:
|
320
|
+
return str(line)
|
321
|
+
except:
|
322
|
+
print("json.loads error:", repr(line))
|
323
|
+
return None
|
324
|
+
|
325
|
+
resp = json.loads(line) if isinstance(line, str) else line
|
326
|
+
if "error" in resp:
|
327
|
+
raise Exception(f"{resp}")
|
328
|
+
|
329
|
+
total_tokens = total_tokens or safe_get(resp, "usage", "total_tokens", default=0)
|
330
|
+
delta = safe_get(resp, "choices", 0, "delta")
|
331
|
+
if not delta:
|
332
|
+
return None
|
333
|
+
|
334
|
+
response_role = response_role or safe_get(delta, "role")
|
335
|
+
if safe_get(delta, "content"):
|
336
|
+
need_function_call = False
|
337
|
+
content = delta["content"]
|
338
|
+
full_response += content
|
339
|
+
return content
|
340
|
+
|
341
|
+
if safe_get(delta, "tool_calls"):
|
342
|
+
need_function_call = True
|
343
|
+
function_call_name = function_call_name or safe_get(delta, "tool_calls", 0, "function", "name")
|
344
|
+
function_full_response += safe_get(delta, "tool_calls", 0, "function", "arguments", default="")
|
345
|
+
function_call_id = function_call_id or safe_get(delta, "tool_calls", 0, "id")
|
346
|
+
return None
|
347
|
+
|
348
|
+
# 处理流式响应
|
349
|
+
async def process_async():
|
350
|
+
nonlocal response_role, full_response, function_full_response, function_call_name, need_function_call, total_tokens, function_call_id
|
351
|
+
|
352
|
+
async for line in response_gen:
|
353
|
+
line = line.strip() if isinstance(line, str) else line
|
354
|
+
result = process_line(line)
|
355
|
+
if result == "DONE":
|
356
|
+
break
|
357
|
+
elif result:
|
358
|
+
yield result
|
359
|
+
|
360
|
+
def process_sync():
|
361
|
+
nonlocal response_role, full_response, function_full_response, function_call_name, need_function_call, total_tokens, function_call_id
|
362
|
+
|
363
|
+
for line in response_gen:
|
364
|
+
line = line.decode("utf-8") if hasattr(line, "decode") else line
|
365
|
+
result = process_line(line)
|
366
|
+
if result == "DONE":
|
367
|
+
break
|
368
|
+
elif result:
|
369
|
+
yield result
|
370
|
+
|
371
|
+
# 使用同步或异步处理器处理响应
|
372
|
+
if is_async:
|
373
|
+
async for chunk in process_async():
|
374
|
+
yield chunk
|
375
|
+
else:
|
376
|
+
for chunk in process_sync():
|
377
|
+
yield chunk
|
378
|
+
|
379
|
+
if self.print_log:
|
380
|
+
print("\n\rtotal_tokens", total_tokens)
|
381
|
+
|
382
|
+
if response_role is None:
|
383
|
+
response_role = "assistant"
|
384
|
+
|
385
|
+
# 处理函数调用
|
386
|
+
if need_function_call:
|
387
|
+
if self.print_log:
|
388
|
+
print("function_full_response", function_full_response)
|
389
|
+
function_full_response = check_json(function_full_response)
|
390
|
+
function_response = ""
|
391
|
+
|
392
|
+
if not self.function_calls_counter.get(function_call_name):
|
393
|
+
self.function_calls_counter[function_call_name] = 1
|
394
|
+
else:
|
395
|
+
self.function_calls_counter[function_call_name] += 1
|
396
|
+
|
397
|
+
has_args = safe_get(self.function_call_list, function_call_name, "parameters", "required", default=False)
|
398
|
+
if self.function_calls_counter[function_call_name] <= self.function_call_max_loop and (function_full_response != "{}" or not has_args):
|
399
|
+
function_call_max_tokens = self.truncate_limit - 1000
|
400
|
+
if function_call_max_tokens <= 0:
|
401
|
+
function_call_max_tokens = int(self.truncate_limit / 2)
|
402
|
+
if self.print_log:
|
403
|
+
print("\033[32m function_call", function_call_name, "max token:", function_call_max_tokens, "\033[0m")
|
404
|
+
|
405
|
+
# 处理函数调用结果
|
406
|
+
if is_async:
|
407
|
+
async for chunk in get_tools_result_async(
|
408
|
+
function_call_name, function_full_response, function_call_max_tokens,
|
409
|
+
model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
|
410
|
+
self.api_url, use_plugins=False, model=model or self.engine,
|
411
|
+
add_message=self.add_to_conversation, convo_id=convo_id, language=language
|
412
|
+
):
|
413
|
+
if "function_response:" in chunk:
|
414
|
+
function_response = chunk.replace("function_response:", "")
|
415
|
+
else:
|
416
|
+
yield chunk
|
417
|
+
else:
|
418
|
+
async def run_async():
|
419
|
+
nonlocal function_response
|
420
|
+
async for chunk in get_tools_result_async(
|
421
|
+
function_call_name, function_full_response, function_call_max_tokens,
|
422
|
+
model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
|
423
|
+
self.api_url, use_plugins=False, model=model or self.engine,
|
424
|
+
add_message=self.add_to_conversation, convo_id=convo_id, language=language
|
425
|
+
):
|
426
|
+
if "function_response:" in chunk:
|
427
|
+
function_response = chunk.replace("function_response:", "")
|
428
|
+
else:
|
429
|
+
yield chunk
|
430
|
+
|
431
|
+
for chunk in async_generator_to_sync(run_async()):
|
432
|
+
yield chunk
|
433
|
+
else:
|
434
|
+
function_response = "无法找到相关信息,停止使用 tools"
|
435
|
+
|
436
|
+
response_role = "tool"
|
437
|
+
|
438
|
+
# 递归处理函数调用响应
|
439
|
+
if is_async:
|
440
|
+
async for chunk in self.ask_stream_async(
|
441
|
+
function_response, response_role, convo_id=convo_id,
|
442
|
+
function_name=function_call_name, total_tokens=total_tokens,
|
443
|
+
model=model or self.engine, function_arguments=function_full_response,
|
444
|
+
function_call_id=function_call_id, api_key=kwargs.get('api_key', self.api_key),
|
445
|
+
plugins=kwargs.get("plugins", self.plugins), system_prompt=system_prompt
|
446
|
+
):
|
447
|
+
yield chunk
|
448
|
+
else:
|
449
|
+
for chunk in self.ask_stream(
|
450
|
+
function_response, response_role, convo_id=convo_id,
|
451
|
+
function_name=function_call_name, total_tokens=total_tokens,
|
452
|
+
model=model or self.engine, function_arguments=function_full_response,
|
453
|
+
function_call_id=function_call_id, api_key=kwargs.get('api_key', self.api_key),
|
454
|
+
plugins=kwargs.get("plugins", self.plugins), system_prompt=system_prompt
|
455
|
+
):
|
456
|
+
yield chunk
|
457
|
+
else:
|
458
|
+
# 添加响应到对话历史
|
459
|
+
self.add_to_conversation(full_response, response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)
|
460
|
+
self.function_calls_counter = {}
|
461
|
+
|
462
|
+
# 清理翻译引擎相关的历史记录
|
463
|
+
if pass_history <= 2 and len(self.conversation[convo_id]) >= 2 \
|
464
|
+
and (
|
465
|
+
"You are a translation engine" in self.conversation[convo_id][-2]["content"] \
|
466
|
+
or "You are a translation engine" in safe_get(self.conversation, convo_id, -2, "content", 0, "text", default="") \
|
467
|
+
or "你是一位精通简体中文的专业翻译" in self.conversation[convo_id][-2]["content"] \
|
468
|
+
or "你是一位精通简体中文的专业翻译" in safe_get(self.conversation, convo_id, -2, "content", 0, "text", default="")
|
469
|
+
):
|
470
|
+
self.conversation[convo_id].pop(-1)
|
471
|
+
self.conversation[convo_id].pop(-1)
|
472
|
+
|
473
|
+
def ask_stream(
|
474
|
+
self,
|
475
|
+
prompt: list,
|
476
|
+
role: str = "user",
|
477
|
+
convo_id: str = "default",
|
478
|
+
model: str = "",
|
479
|
+
pass_history: int = 9999,
|
480
|
+
function_name: str = "",
|
481
|
+
total_tokens: int = 0,
|
482
|
+
function_arguments: str = "",
|
483
|
+
function_call_id: str = "",
|
484
|
+
language: str = "English",
|
485
|
+
system_prompt: str = None,
|
486
|
+
**kwargs,
|
487
|
+
):
|
488
|
+
"""
|
489
|
+
Ask a question (同步流式响应)
|
490
|
+
"""
|
491
|
+
# 准备会话
|
492
|
+
self.system_prompt = system_prompt or self.system_prompt
|
493
|
+
if convo_id not in self.conversation or pass_history <= 2:
|
494
|
+
self.reset(convo_id=convo_id, system_prompt=system_prompt)
|
495
|
+
self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name, total_tokens=total_tokens, function_arguments=function_arguments, function_call_id=function_call_id, pass_history=pass_history)
|
496
|
+
|
497
|
+
# 获取请求体
|
498
|
+
json_post = None
|
499
|
+
async def get_post_body_async():
|
500
|
+
nonlocal json_post
|
501
|
+
url, headers, json_post, engine_type = await self.get_post_body(prompt, role, convo_id, model, pass_history, **kwargs)
|
502
|
+
return url, headers, json_post, engine_type
|
503
|
+
|
504
|
+
# 替换原来的获取请求体的代码
|
505
|
+
# json_post = next(async_generator_to_sync(get_post_body_async()))
|
506
|
+
try:
|
507
|
+
url, headers, json_post, engine_type = asyncio.run(get_post_body_async())
|
508
|
+
except RuntimeError:
|
509
|
+
# 如果已经在事件循环中,则使用不同的方法
|
510
|
+
loop = asyncio.get_event_loop()
|
511
|
+
url, headers, json_post, engine_type = loop.run_until_complete(get_post_body_async())
|
512
|
+
|
513
|
+
self.truncate_conversation(convo_id=convo_id)
|
514
|
+
|
515
|
+
# 打印日志
|
516
|
+
if self.print_log:
|
517
|
+
print("api_url", kwargs.get('api_url', self.api_url.chat_url), url)
|
518
|
+
print("api_key", kwargs.get('api_key', self.api_key))
|
519
|
+
|
520
|
+
# 发送请求并处理响应
|
521
|
+
for _ in range(3):
|
522
|
+
if self.print_log:
|
523
|
+
replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
|
524
|
+
print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
|
525
|
+
|
526
|
+
try:
|
527
|
+
# 改进处理方式,创建一个内部异步函数来处理异步调用
|
528
|
+
async def process_async():
|
529
|
+
# 异步调用 fetch_response_stream
|
530
|
+
async_generator = fetch_response_stream(
|
531
|
+
self.aclient,
|
532
|
+
url,
|
533
|
+
headers,
|
534
|
+
json_post,
|
535
|
+
engine_type,
|
536
|
+
model or self.engine,
|
537
|
+
)
|
538
|
+
# 异步处理响应流
|
539
|
+
async for chunk in self._process_stream_response(
|
540
|
+
async_generator,
|
541
|
+
convo_id=convo_id,
|
542
|
+
function_name=function_name,
|
543
|
+
total_tokens=total_tokens,
|
544
|
+
function_arguments=function_arguments,
|
545
|
+
function_call_id=function_call_id,
|
546
|
+
model=model,
|
547
|
+
language=language,
|
548
|
+
system_prompt=system_prompt,
|
549
|
+
pass_history=pass_history,
|
550
|
+
is_async=True,
|
551
|
+
**kwargs
|
552
|
+
):
|
553
|
+
yield chunk
|
554
|
+
|
555
|
+
# 将异步函数转换为同步生成器
|
556
|
+
return async_generator_to_sync(process_async())
|
557
|
+
except ConnectionError:
|
558
|
+
print("连接错误,请检查服务器状态或网络连接。")
|
559
|
+
return
|
560
|
+
except requests.exceptions.ReadTimeout:
|
561
|
+
print("请求超时,请检查网络连接或增加超时时间。")
|
562
|
+
return
|
563
|
+
except Exception as e:
|
564
|
+
print(f"发生了未预料的错误:{e}")
|
565
|
+
if "Invalid URL" in str(e):
|
566
|
+
e = "您输入了无效的API URL,请使用正确的URL并使用`/start`命令重新设置API URL。具体错误如下:\n\n" + str(e)
|
567
|
+
raise Exception(f"{e}")
|
568
|
+
# 最后一次重试失败,向上抛出异常
|
569
|
+
if _ == 2:
|
570
|
+
raise Exception(f"{e}")
|
571
|
+
|
572
|
+
async def ask_stream_async(
|
573
|
+
self,
|
574
|
+
prompt: list,
|
575
|
+
role: str = "user",
|
576
|
+
convo_id: str = "default",
|
577
|
+
model: str = "",
|
578
|
+
pass_history: int = 9999,
|
579
|
+
function_name: str = "",
|
580
|
+
total_tokens: int = 0,
|
581
|
+
function_arguments: str = "",
|
582
|
+
function_call_id: str = "",
|
583
|
+
language: str = "English",
|
584
|
+
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
585
|
+
**kwargs,
|
586
|
+
):
|
587
|
+
"""
|
588
|
+
Ask a question (异步流式响应)
|
589
|
+
"""
|
590
|
+
# 准备会话
|
591
|
+
self.system_prompt = system_prompt or self.system_prompt
|
592
|
+
if convo_id not in self.conversation or pass_history <= 2:
|
593
|
+
self.reset(convo_id=convo_id, system_prompt=system_prompt)
|
594
|
+
self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name, total_tokens=total_tokens, function_arguments=function_arguments, pass_history=pass_history, function_call_id=function_call_id)
|
595
|
+
|
596
|
+
# 获取请求体
|
597
|
+
url, headers, json_post, engine_type = await self.get_post_body(prompt, role, convo_id, model, pass_history, **kwargs)
|
598
|
+
self.truncate_conversation(convo_id=convo_id)
|
599
|
+
|
600
|
+
# 打印日志
|
601
|
+
if self.print_log:
|
602
|
+
print("api_url", kwargs.get('api_url', self.api_url.chat_url) == url)
|
603
|
+
print("api_url", kwargs.get('api_url', self.api_url.chat_url))
|
604
|
+
print("api_url", url)
|
605
|
+
# print("headers", headers)
|
606
|
+
print("api_key", kwargs.get('api_key', self.api_key))
|
607
|
+
|
608
|
+
# 发送请求并处理响应
|
609
|
+
for _ in range(3):
|
610
|
+
if self.print_log:
|
611
|
+
replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
|
612
|
+
print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
|
613
|
+
|
614
|
+
try:
|
615
|
+
# 使用fetch_response_stream处理响应
|
616
|
+
generator = fetch_response_stream(
|
617
|
+
self.aclient,
|
618
|
+
url,
|
619
|
+
headers,
|
620
|
+
json_post,
|
621
|
+
engine_type,
|
622
|
+
model or self.engine,
|
623
|
+
)
|
624
|
+
# if isinstance(chunk, dict) and "error" in chunk:
|
625
|
+
# # 处理错误响应
|
626
|
+
# if chunk["status_code"] in (400, 422, 503):
|
627
|
+
# json_post, should_retry = await self._handle_response_error(
|
628
|
+
# type('Response', (), {'status_code': chunk["status_code"], 'text': json.dumps(chunk["details"]), 'aread': lambda: asyncio.sleep(0)}),
|
629
|
+
# json_post
|
630
|
+
# )
|
631
|
+
# if should_retry:
|
632
|
+
# break # 跳出内部循环,继续外部循环重试
|
633
|
+
# raise Exception(f"{chunk['status_code']} {chunk['error']} {chunk['details']}")
|
634
|
+
|
635
|
+
# 处理正常响应
|
636
|
+
async for processed_chunk in self._process_stream_response(
|
637
|
+
generator,
|
638
|
+
convo_id=convo_id,
|
639
|
+
function_name=function_name,
|
640
|
+
total_tokens=total_tokens,
|
641
|
+
function_arguments=function_arguments,
|
642
|
+
function_call_id=function_call_id,
|
643
|
+
model=model,
|
644
|
+
language=language,
|
645
|
+
system_prompt=system_prompt,
|
646
|
+
pass_history=pass_history,
|
647
|
+
is_async=True,
|
648
|
+
**kwargs
|
649
|
+
):
|
650
|
+
yield processed_chunk
|
651
|
+
|
652
|
+
# 成功处理,跳出重试循环
|
653
|
+
break
|
654
|
+
except Exception as e:
|
655
|
+
print(f"发生了未预料的错误:{e}")
|
656
|
+
import traceback
|
657
|
+
traceback.print_exc()
|
658
|
+
if "Invalid URL" in str(e):
|
659
|
+
e = "您输入了无效的API URL,请使用正确的URL并使用`/start`命令重新设置API URL。具体错误如下:\n\n" + str(e)
|
660
|
+
raise Exception(f"{e}")
|
661
|
+
# 最后一次重试失败,向上抛出异常
|
662
|
+
if _ == 2:
|
663
|
+
raise Exception(f"{e}")
|
664
|
+
|
665
|
+
async def ask_async(
|
666
|
+
self,
|
667
|
+
prompt: str,
|
668
|
+
role: str = "user",
|
669
|
+
convo_id: str = "default",
|
670
|
+
model: str = "",
|
671
|
+
pass_history: int = 9999,
|
672
|
+
**kwargs,
|
673
|
+
) -> str:
|
674
|
+
"""
|
675
|
+
Non-streaming ask
|
676
|
+
"""
|
677
|
+
response = self.ask_stream_async(
|
678
|
+
prompt=prompt,
|
679
|
+
role=role,
|
680
|
+
convo_id=convo_id,
|
681
|
+
pass_history=pass_history,
|
682
|
+
model=model or self.engine,
|
683
|
+
**kwargs,
|
684
|
+
)
|
685
|
+
full_response: str = "".join([r async for r in response])
|
686
|
+
return full_response
|
687
|
+
|
688
|
+
def rollback(self, n: int = 1, convo_id: str = "default") -> None:
|
689
|
+
"""
|
690
|
+
Rollback the conversation
|
691
|
+
"""
|
692
|
+
for _ in range(n):
|
693
|
+
self.conversation[convo_id].pop()
|
694
|
+
|
695
|
+
def reset(self, convo_id: str = "default", system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally") -> None:
|
696
|
+
"""
|
697
|
+
Reset the conversation
|
698
|
+
"""
|
699
|
+
self.system_prompt = system_prompt or self.system_prompt
|
700
|
+
self.conversation[convo_id] = [
|
701
|
+
{"role": "system", "content": self.system_prompt},
|
702
|
+
]
|
703
|
+
self.tokens_usage[convo_id] = 0
|
704
|
+
|
705
|
+
def save(self, file: str, *keys: str) -> None:
|
706
|
+
"""
|
707
|
+
Save the Chatbot configuration to a JSON file
|
708
|
+
"""
|
709
|
+
with open(file, "w", encoding="utf-8") as f:
|
710
|
+
data = {
|
711
|
+
key: self.__dict__[key]
|
712
|
+
for key in get_filtered_keys_from_object(self, *keys)
|
713
|
+
}
|
714
|
+
# saves session.proxies dict as session
|
715
|
+
# leave this here for compatibility
|
716
|
+
data["session"] = data["proxy"]
|
717
|
+
del data["aclient"]
|
718
|
+
json.dump(
|
719
|
+
data,
|
720
|
+
f,
|
721
|
+
indent=2,
|
722
|
+
)
|
723
|
+
|
724
|
+
def load(self, file: Path, *keys_: str) -> None:
|
725
|
+
"""
|
726
|
+
Load the Chatbot configuration from a JSON file
|
727
|
+
"""
|
728
|
+
with open(file, encoding="utf-8") as f:
|
729
|
+
# load json, if session is in keys, load proxies
|
730
|
+
loaded_config = json.load(f)
|
731
|
+
keys = get_filtered_keys_from_object(self, *keys_)
|
732
|
+
|
733
|
+
if (
|
734
|
+
"session" in keys
|
735
|
+
and loaded_config["session"]
|
736
|
+
or "proxy" in keys
|
737
|
+
and loaded_config["proxy"]
|
738
|
+
):
|
739
|
+
self.proxy = loaded_config.get("session", loaded_config["proxy"])
|
740
|
+
self.session = httpx.Client(
|
741
|
+
follow_redirects=True,
|
742
|
+
proxies=self.proxy,
|
743
|
+
timeout=self.timeout,
|
744
|
+
cookies=self.session.cookies,
|
745
|
+
headers=self.session.headers,
|
746
|
+
)
|
747
|
+
self.aclient = httpx.AsyncClient(
|
748
|
+
follow_redirects=True,
|
749
|
+
proxies=self.proxy,
|
750
|
+
timeout=self.timeout,
|
751
|
+
cookies=self.session.cookies,
|
752
|
+
headers=self.session.headers,
|
753
|
+
)
|
754
|
+
if "session" in keys:
|
755
|
+
keys.remove("session")
|
756
|
+
if "aclient" in keys:
|
757
|
+
keys.remove("aclient")
|
758
|
+
self.__dict__.update({key: loaded_config[key] for key in keys})
|
759
|
+
|
760
|
+
def _handle_response_error_common(self, response_text, json_post):
|
761
|
+
"""通用的响应错误处理逻辑,适用于同步和异步场景"""
|
762
|
+
try:
|
763
|
+
# 检查内容审核失败
|
764
|
+
if "Content did not pass the moral check" in response_text:
|
765
|
+
return json_post, False, f"内容未通过道德检查:{response_text[:400]}"
|
766
|
+
|
767
|
+
# 处理函数调用相关错误
|
768
|
+
if "function calling" in response_text:
|
769
|
+
if "tools" in json_post:
|
770
|
+
del json_post["tools"]
|
771
|
+
if "tool_choice" in json_post:
|
772
|
+
del json_post["tool_choice"]
|
773
|
+
return json_post, True, None
|
774
|
+
|
775
|
+
# 处理请求格式错误
|
776
|
+
elif "invalid_request_error" in response_text:
|
777
|
+
for index, mess in enumerate(json_post["messages"]):
|
778
|
+
if type(mess["content"]) == list and "text" in mess["content"][0]:
|
779
|
+
json_post["messages"][index] = {
|
780
|
+
"role": mess["role"],
|
781
|
+
"content": mess["content"][0]["text"]
|
782
|
+
}
|
783
|
+
return json_post, True, None
|
784
|
+
|
785
|
+
# 处理角色不允许错误
|
786
|
+
elif "'function' is not an allowed role" in response_text:
|
787
|
+
if json_post["messages"][-1]["role"] == "tool":
|
788
|
+
mess = json_post["messages"][-1]
|
789
|
+
json_post["messages"][-1] = {
|
790
|
+
"role": "assistant",
|
791
|
+
"name": mess["name"],
|
792
|
+
"content": mess["content"]
|
793
|
+
}
|
794
|
+
return json_post, True, None
|
795
|
+
|
796
|
+
# 处理服务器繁忙错误
|
797
|
+
elif "Sorry, server is busy" in response_text:
|
798
|
+
for index, mess in enumerate(json_post["messages"]):
|
799
|
+
if type(mess["content"]) == list and "text" in mess["content"][0]:
|
800
|
+
json_post["messages"][index] = {
|
801
|
+
"role": mess["role"],
|
802
|
+
"content": mess["content"][0]["text"]
|
803
|
+
}
|
804
|
+
return json_post, True, None
|
805
|
+
|
806
|
+
# 处理token超限错误
|
807
|
+
elif "is not possible because the prompts occupy" in response_text:
|
808
|
+
max_tokens = re.findall(r"only\s(\d+)\stokens", response_text)
|
809
|
+
if max_tokens:
|
810
|
+
json_post["max_tokens"] = int(max_tokens[0])
|
811
|
+
return json_post, True, None
|
812
|
+
|
813
|
+
# 默认移除工具相关设置
|
814
|
+
else:
|
815
|
+
if "tools" in json_post:
|
816
|
+
del json_post["tools"]
|
817
|
+
if "tool_choice" in json_post:
|
818
|
+
del json_post["tool_choice"]
|
819
|
+
return json_post, True, None
|
820
|
+
|
821
|
+
except Exception as e:
|
822
|
+
print(f"处理响应错误时出现异常: {e}")
|
823
|
+
return json_post, False, str(e)
|
824
|
+
|
825
|
+
def _handle_response_error_sync(self, response, json_post):
|
826
|
+
"""处理API响应错误并相应地修改请求体(同步版本)"""
|
827
|
+
response_text = response.text
|
828
|
+
|
829
|
+
# 处理空响应
|
830
|
+
if response.status_code == 200 and response_text == "":
|
831
|
+
for index, mess in enumerate(json_post["messages"]):
|
832
|
+
if type(mess["content"]) == list and "text" in mess["content"][0]:
|
833
|
+
json_post["messages"][index] = {
|
834
|
+
"role": mess["role"],
|
835
|
+
"content": mess["content"][0]["text"]
|
836
|
+
}
|
837
|
+
return json_post, True
|
838
|
+
|
839
|
+
json_post, should_retry, error_msg = self._handle_response_error_common(response_text, json_post)
|
840
|
+
|
841
|
+
if error_msg:
|
842
|
+
raise Exception(f"{response.status_code} {response.reason} {error_msg}")
|
843
|
+
|
844
|
+
return json_post, should_retry
|
845
|
+
|
846
|
+
async def _handle_response_error(self, response, json_post):
|
847
|
+
"""处理API响应错误并相应地修改请求体(异步版本)"""
|
848
|
+
await response.aread()
|
849
|
+
response_text = response.text
|
850
|
+
|
851
|
+
json_post, should_retry, error_msg = self._handle_response_error_common(response_text, json_post)
|
852
|
+
|
853
|
+
if error_msg:
|
854
|
+
raise Exception(f"{response.status_code} {response.reason_phrase} {error_msg}")
|
855
|
+
|
856
|
+
return json_post, should_retry
|