aient 1.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,856 @@
1
+ import os
2
+ import re
3
+ import json
4
+ import copy
5
+ import httpx
6
+ import asyncio
7
+ import requests
8
+ from typing import Set
9
+ from typing import Union, Optional, Callable
10
+ from pathlib import Path
11
+
12
+
13
+ from .base import BaseLLM
14
+ from ..plugins import PLUGINS, get_tools_result_async, function_call_list, update_tools_config
15
+ from ..utils.scripts import check_json, safe_get, async_generator_to_sync
16
+ from ..core.request import prepare_request_payload
17
+ from ..core.response import fetch_response_stream
18
+
19
+ def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]:
20
+ """
21
+ Get filtered list of object variable names.
22
+ :param keys: List of keys to include. If the first key is "not", the remaining keys will be removed from the class keys.
23
+ :return: List of class keys.
24
+ """
25
+ class_keys = obj.__dict__.keys()
26
+ if not keys:
27
+ return set(class_keys)
28
+
29
+ # Remove the passed keys from the class keys.
30
+ if keys[0] == "not":
31
+ return {key for key in class_keys if key not in keys[1:]}
32
+ # Check if all passed keys are valid
33
+ if invalid_keys := set(keys) - class_keys:
34
+ raise ValueError(
35
+ f"Invalid keys: {invalid_keys}",
36
+ )
37
+ # Only return specified keys that are in class_keys
38
+ return {key for key in keys if key in class_keys}
39
+
40
+ class chatgpt(BaseLLM):
41
+ """
42
+ Official ChatGPT API
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ api_key: str = None,
48
+ engine: str = os.environ.get("GPT_ENGINE") or "gpt-4o",
49
+ api_url: str = (os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions"),
50
+ system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
51
+ proxy: str = None,
52
+ timeout: float = 600,
53
+ max_tokens: int = None,
54
+ temperature: float = 0.5,
55
+ top_p: float = 1.0,
56
+ presence_penalty: float = 0.0,
57
+ frequency_penalty: float = 0.0,
58
+ reply_count: int = 1,
59
+ truncate_limit: int = None,
60
+ use_plugins: bool = True,
61
+ print_log: bool = False,
62
+ tools: Optional[Union[list, str, Callable]] = [],
63
+ ) -> None:
64
+ """
65
+ Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
66
+ """
67
+ super().__init__(api_key, engine, api_url, system_prompt, proxy, timeout, max_tokens, temperature, top_p, presence_penalty, frequency_penalty, reply_count, truncate_limit, use_plugins=use_plugins, print_log=print_log)
68
+ self.conversation: dict[str, list[dict]] = {
69
+ "default": [
70
+ {
71
+ "role": "system",
72
+ "content": self.system_prompt,
73
+ },
74
+ ],
75
+ }
76
+ self.function_calls_counter = {}
77
+ self.function_call_max_loop = 3
78
+
79
+
80
+ # 注册和处理传入的工具
81
+ self._register_tools(tools)
82
+
83
+ if self.tokens_usage["default"] > self.max_tokens:
84
+ raise Exception("System prompt is too long")
85
+
86
+ def _register_tools(self, tools):
87
+ """动态注册工具函数并更新配置"""
88
+
89
+ self.plugins = copy.deepcopy(PLUGINS)
90
+ self.function_call_list = copy.deepcopy(function_call_list)
91
+ # 如果有新工具,需要注册到registry并更新配置
92
+ self.plugins, self.function_call_list, _ = update_tools_config()
93
+
94
+ if isinstance(tools, list):
95
+ self.tools = tools if tools else []
96
+ else:
97
+ self.tools = [tools] if tools else []
98
+
99
+ for tool in self.tools:
100
+ tool_name = tool.__name__ if callable(tool) else str(tool)
101
+ if tool_name in self.plugins:
102
+ self.plugins[tool_name] = True
103
+ else:
104
+ raise ValueError(f"Tool {tool_name} not found in plugins")
105
+
106
+ def add_to_conversation(
107
+ self,
108
+ message: Union[str, list],
109
+ role: str,
110
+ convo_id: str = "default",
111
+ function_name: str = "",
112
+ total_tokens: int = 0,
113
+ function_arguments: str = "",
114
+ pass_history: int = 9999,
115
+ function_call_id: str = "",
116
+ ) -> None:
117
+ """
118
+ Add a message to the conversation
119
+ """
120
+ # print("role", role, "function_name", function_name, "message", message)
121
+ if convo_id not in self.conversation:
122
+ self.reset(convo_id=convo_id)
123
+ if function_name == "" and message and message != None:
124
+ self.conversation[convo_id].append({"role": role, "content": message})
125
+ elif function_name != "" and message and message != None:
126
+ self.conversation[convo_id].append({
127
+ "role": "assistant",
128
+ "tool_calls": [
129
+ {
130
+ "id": function_call_id,
131
+ "type": "function",
132
+ "function": {
133
+ "name": function_name,
134
+ "arguments": function_arguments,
135
+ },
136
+ }
137
+ ],
138
+ })
139
+ self.conversation[convo_id].append({"role": role, "tool_call_id": function_call_id, "content": message})
140
+ else:
141
+ print('\033[31m')
142
+ print("error: add_to_conversation message is None or empty")
143
+ print("role", role, "function_name", function_name, "message", message)
144
+ print('\033[0m')
145
+
146
+ conversation_len = len(self.conversation[convo_id]) - 1
147
+ message_index = 0
148
+ # print(json.dumps(self.conversation[convo_id], indent=4, ensure_ascii=False))
149
+ while message_index < conversation_len:
150
+ if self.conversation[convo_id][message_index]["role"] == self.conversation[convo_id][message_index + 1]["role"]:
151
+ if self.conversation[convo_id][message_index].get("content") and self.conversation[convo_id][message_index + 1].get("content"):
152
+ if type(self.conversation[convo_id][message_index + 1]["content"]) == str \
153
+ and type(self.conversation[convo_id][message_index]["content"]) == list:
154
+ self.conversation[convo_id][message_index + 1]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index + 1]["content"]}]
155
+ if type(self.conversation[convo_id][message_index]["content"]) == str \
156
+ and type(self.conversation[convo_id][message_index + 1]["content"]) == list:
157
+ self.conversation[convo_id][message_index]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index]["content"]}]
158
+ if type(self.conversation[convo_id][message_index]["content"]) == dict \
159
+ and type(self.conversation[convo_id][message_index + 1]["content"]) == str:
160
+ self.conversation[convo_id][message_index]["content"] = [self.conversation[convo_id][message_index]["content"]]
161
+ self.conversation[convo_id][message_index + 1]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index + 1]["content"]}]
162
+ self.conversation[convo_id][message_index]["content"] += self.conversation[convo_id][message_index + 1]["content"]
163
+ self.conversation[convo_id].pop(message_index + 1)
164
+ conversation_len = conversation_len - 1
165
+ else:
166
+ message_index = message_index + 1
167
+
168
+ history_len = len(self.conversation[convo_id])
169
+
170
+ history = pass_history
171
+ if pass_history < 2:
172
+ history = 2
173
+ while history_len > history:
174
+ mess_body = self.conversation[convo_id].pop(1)
175
+ history_len = history_len - 1
176
+ if mess_body.get("role") == "user":
177
+ assistant_body = self.conversation[convo_id].pop(1)
178
+ history_len = history_len - 1
179
+ if assistant_body.get("tool_calls"):
180
+ self.conversation[convo_id].pop(1)
181
+ history_len = history_len - 1
182
+
183
+ if total_tokens:
184
+ self.tokens_usage[convo_id] += total_tokens
185
+
186
+ def truncate_conversation(self, convo_id: str = "default") -> None:
187
+ """
188
+ Truncate the conversation
189
+ """
190
+ while True:
191
+ if (
192
+ self.tokens_usage[convo_id] > self.truncate_limit
193
+ and len(self.conversation[convo_id]) > 1
194
+ ):
195
+ # Don't remove the first message
196
+ mess = self.conversation[convo_id].pop(1)
197
+ print("Truncate message:", mess)
198
+ else:
199
+ break
200
+
201
+ def extract_values(self, obj):
202
+ if isinstance(obj, dict):
203
+ for value in obj.values():
204
+ yield from self.extract_values(value)
205
+ elif isinstance(obj, list):
206
+ for item in obj:
207
+ yield from self.extract_values(item)
208
+ else:
209
+ yield obj
210
+
211
+ async def get_post_body(
212
+ self,
213
+ prompt: str,
214
+ role: str = "user",
215
+ convo_id: str = "default",
216
+ model: str = "",
217
+ pass_history: int = 9999,
218
+ **kwargs,
219
+ ):
220
+ self.conversation[convo_id][0] = {"role": "system","content": self.system_prompt}
221
+
222
+ # 构造 provider 信息
223
+ provider = {
224
+ "provider": "openai",
225
+ "base_url": kwargs.get('api_url', self.api_url.chat_url),
226
+ "api": kwargs.get('api_key', self.api_key),
227
+ "model": [model or self.engine],
228
+ "tools": True if self.use_plugins else False,
229
+ "image": True
230
+ }
231
+
232
+ # 构造请求数据
233
+ request_data = {
234
+ "model": model or self.engine,
235
+ "messages": copy.deepcopy(self.conversation[convo_id]) if pass_history else [
236
+ {"role": "system","content": self.system_prompt},
237
+ {"role": role, "content": prompt}
238
+ ],
239
+ "max_tokens": kwargs.get("max_tokens", self.max_tokens),
240
+ "stream": True,
241
+ "stream_options": {
242
+ "include_usage": True
243
+ },
244
+ "temperature": kwargs.get("temperature", self.temperature)
245
+ }
246
+
247
+ # 添加工具相关信息
248
+ plugins_status = kwargs.get("plugins", self.plugins)
249
+ if not (all(value == False for value in plugins_status.values()) or self.use_plugins == False):
250
+ tools_request_body = []
251
+ for item in plugins_status.keys():
252
+ try:
253
+ if plugins_status[item]:
254
+ tools_request_body.append({"type": "function", "function": self.function_call_list[item]})
255
+ except:
256
+ pass
257
+ if tools_request_body:
258
+ request_data["tools"] = tools_request_body
259
+ request_data["tool_choice"] = "auto"
260
+
261
+ # print("request_data", json.dumps(request_data, indent=4, ensure_ascii=False))
262
+
263
+ # 调用核心模块的 prepare_request_payload 函数
264
+ url, headers, json_post_body, engine_type = await prepare_request_payload(provider, request_data)
265
+
266
+ return url, headers, json_post_body, engine_type
267
+
268
+ async def _process_stream_response(
269
+ self,
270
+ response_gen,
271
+ convo_id="default",
272
+ function_name="",
273
+ total_tokens=0,
274
+ function_arguments="",
275
+ function_call_id="",
276
+ model="",
277
+ language="English",
278
+ system_prompt=None,
279
+ pass_history=9999,
280
+ is_async=False,
281
+ **kwargs
282
+ ):
283
+ """
284
+ 处理流式响应的共用逻辑
285
+
286
+ :param response_gen: 响应生成器(同步或异步)
287
+ :param is_async: 是否使用异步模式
288
+ """
289
+ response_role = None
290
+ full_response = ""
291
+ function_full_response = ""
292
+ function_call_name = ""
293
+ need_function_call = False
294
+
295
+ # 处理单行数据的公共逻辑
296
+ def process_line(line):
297
+ nonlocal response_role, full_response, function_full_response, function_call_name, need_function_call, total_tokens, function_call_id
298
+
299
+ if not line or (isinstance(line, str) and line.startswith(':')):
300
+ return None
301
+
302
+ if isinstance(line, str) and line.startswith('data:'):
303
+ line = line.lstrip("data: ")
304
+ if line == "[DONE]":
305
+ return "DONE"
306
+ elif isinstance(line, (dict, list)):
307
+ if isinstance(line, dict) and safe_get(line, "choices", 0, "message", "content"):
308
+ full_response = line["choices"][0]["message"]["content"]
309
+ return full_response
310
+ else:
311
+ return str(line)
312
+ else:
313
+ try:
314
+ if isinstance(line, str):
315
+ line = json.loads(line)
316
+ if safe_get(line, "choices", 0, "message", "content"):
317
+ full_response = line["choices"][0]["message"]["content"]
318
+ return full_response
319
+ else:
320
+ return str(line)
321
+ except:
322
+ print("json.loads error:", repr(line))
323
+ return None
324
+
325
+ resp = json.loads(line) if isinstance(line, str) else line
326
+ if "error" in resp:
327
+ raise Exception(f"{resp}")
328
+
329
+ total_tokens = total_tokens or safe_get(resp, "usage", "total_tokens", default=0)
330
+ delta = safe_get(resp, "choices", 0, "delta")
331
+ if not delta:
332
+ return None
333
+
334
+ response_role = response_role or safe_get(delta, "role")
335
+ if safe_get(delta, "content"):
336
+ need_function_call = False
337
+ content = delta["content"]
338
+ full_response += content
339
+ return content
340
+
341
+ if safe_get(delta, "tool_calls"):
342
+ need_function_call = True
343
+ function_call_name = function_call_name or safe_get(delta, "tool_calls", 0, "function", "name")
344
+ function_full_response += safe_get(delta, "tool_calls", 0, "function", "arguments", default="")
345
+ function_call_id = function_call_id or safe_get(delta, "tool_calls", 0, "id")
346
+ return None
347
+
348
+ # 处理流式响应
349
+ async def process_async():
350
+ nonlocal response_role, full_response, function_full_response, function_call_name, need_function_call, total_tokens, function_call_id
351
+
352
+ async for line in response_gen:
353
+ line = line.strip() if isinstance(line, str) else line
354
+ result = process_line(line)
355
+ if result == "DONE":
356
+ break
357
+ elif result:
358
+ yield result
359
+
360
+ def process_sync():
361
+ nonlocal response_role, full_response, function_full_response, function_call_name, need_function_call, total_tokens, function_call_id
362
+
363
+ for line in response_gen:
364
+ line = line.decode("utf-8") if hasattr(line, "decode") else line
365
+ result = process_line(line)
366
+ if result == "DONE":
367
+ break
368
+ elif result:
369
+ yield result
370
+
371
+ # 使用同步或异步处理器处理响应
372
+ if is_async:
373
+ async for chunk in process_async():
374
+ yield chunk
375
+ else:
376
+ for chunk in process_sync():
377
+ yield chunk
378
+
379
+ if self.print_log:
380
+ print("\n\rtotal_tokens", total_tokens)
381
+
382
+ if response_role is None:
383
+ response_role = "assistant"
384
+
385
+ # 处理函数调用
386
+ if need_function_call:
387
+ if self.print_log:
388
+ print("function_full_response", function_full_response)
389
+ function_full_response = check_json(function_full_response)
390
+ function_response = ""
391
+
392
+ if not self.function_calls_counter.get(function_call_name):
393
+ self.function_calls_counter[function_call_name] = 1
394
+ else:
395
+ self.function_calls_counter[function_call_name] += 1
396
+
397
+ has_args = safe_get(self.function_call_list, function_call_name, "parameters", "required", default=False)
398
+ if self.function_calls_counter[function_call_name] <= self.function_call_max_loop and (function_full_response != "{}" or not has_args):
399
+ function_call_max_tokens = self.truncate_limit - 1000
400
+ if function_call_max_tokens <= 0:
401
+ function_call_max_tokens = int(self.truncate_limit / 2)
402
+ if self.print_log:
403
+ print("\033[32m function_call", function_call_name, "max token:", function_call_max_tokens, "\033[0m")
404
+
405
+ # 处理函数调用结果
406
+ if is_async:
407
+ async for chunk in get_tools_result_async(
408
+ function_call_name, function_full_response, function_call_max_tokens,
409
+ model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
410
+ self.api_url, use_plugins=False, model=model or self.engine,
411
+ add_message=self.add_to_conversation, convo_id=convo_id, language=language
412
+ ):
413
+ if "function_response:" in chunk:
414
+ function_response = chunk.replace("function_response:", "")
415
+ else:
416
+ yield chunk
417
+ else:
418
+ async def run_async():
419
+ nonlocal function_response
420
+ async for chunk in get_tools_result_async(
421
+ function_call_name, function_full_response, function_call_max_tokens,
422
+ model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
423
+ self.api_url, use_plugins=False, model=model or self.engine,
424
+ add_message=self.add_to_conversation, convo_id=convo_id, language=language
425
+ ):
426
+ if "function_response:" in chunk:
427
+ function_response = chunk.replace("function_response:", "")
428
+ else:
429
+ yield chunk
430
+
431
+ for chunk in async_generator_to_sync(run_async()):
432
+ yield chunk
433
+ else:
434
+ function_response = "无法找到相关信息,停止使用 tools"
435
+
436
+ response_role = "tool"
437
+
438
+ # 递归处理函数调用响应
439
+ if is_async:
440
+ async for chunk in self.ask_stream_async(
441
+ function_response, response_role, convo_id=convo_id,
442
+ function_name=function_call_name, total_tokens=total_tokens,
443
+ model=model or self.engine, function_arguments=function_full_response,
444
+ function_call_id=function_call_id, api_key=kwargs.get('api_key', self.api_key),
445
+ plugins=kwargs.get("plugins", self.plugins), system_prompt=system_prompt
446
+ ):
447
+ yield chunk
448
+ else:
449
+ for chunk in self.ask_stream(
450
+ function_response, response_role, convo_id=convo_id,
451
+ function_name=function_call_name, total_tokens=total_tokens,
452
+ model=model or self.engine, function_arguments=function_full_response,
453
+ function_call_id=function_call_id, api_key=kwargs.get('api_key', self.api_key),
454
+ plugins=kwargs.get("plugins", self.plugins), system_prompt=system_prompt
455
+ ):
456
+ yield chunk
457
+ else:
458
+ # 添加响应到对话历史
459
+ self.add_to_conversation(full_response, response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)
460
+ self.function_calls_counter = {}
461
+
462
+ # 清理翻译引擎相关的历史记录
463
+ if pass_history <= 2 and len(self.conversation[convo_id]) >= 2 \
464
+ and (
465
+ "You are a translation engine" in self.conversation[convo_id][-2]["content"] \
466
+ or "You are a translation engine" in safe_get(self.conversation, convo_id, -2, "content", 0, "text", default="") \
467
+ or "你是一位精通简体中文的专业翻译" in self.conversation[convo_id][-2]["content"] \
468
+ or "你是一位精通简体中文的专业翻译" in safe_get(self.conversation, convo_id, -2, "content", 0, "text", default="")
469
+ ):
470
+ self.conversation[convo_id].pop(-1)
471
+ self.conversation[convo_id].pop(-1)
472
+
473
+ def ask_stream(
474
+ self,
475
+ prompt: list,
476
+ role: str = "user",
477
+ convo_id: str = "default",
478
+ model: str = "",
479
+ pass_history: int = 9999,
480
+ function_name: str = "",
481
+ total_tokens: int = 0,
482
+ function_arguments: str = "",
483
+ function_call_id: str = "",
484
+ language: str = "English",
485
+ system_prompt: str = None,
486
+ **kwargs,
487
+ ):
488
+ """
489
+ Ask a question (同步流式响应)
490
+ """
491
+ # 准备会话
492
+ self.system_prompt = system_prompt or self.system_prompt
493
+ if convo_id not in self.conversation or pass_history <= 2:
494
+ self.reset(convo_id=convo_id, system_prompt=system_prompt)
495
+ self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name, total_tokens=total_tokens, function_arguments=function_arguments, function_call_id=function_call_id, pass_history=pass_history)
496
+
497
+ # 获取请求体
498
+ json_post = None
499
+ async def get_post_body_async():
500
+ nonlocal json_post
501
+ url, headers, json_post, engine_type = await self.get_post_body(prompt, role, convo_id, model, pass_history, **kwargs)
502
+ return url, headers, json_post, engine_type
503
+
504
+ # 替换原来的获取请求体的代码
505
+ # json_post = next(async_generator_to_sync(get_post_body_async()))
506
+ try:
507
+ url, headers, json_post, engine_type = asyncio.run(get_post_body_async())
508
+ except RuntimeError:
509
+ # 如果已经在事件循环中,则使用不同的方法
510
+ loop = asyncio.get_event_loop()
511
+ url, headers, json_post, engine_type = loop.run_until_complete(get_post_body_async())
512
+
513
+ self.truncate_conversation(convo_id=convo_id)
514
+
515
+ # 打印日志
516
+ if self.print_log:
517
+ print("api_url", kwargs.get('api_url', self.api_url.chat_url), url)
518
+ print("api_key", kwargs.get('api_key', self.api_key))
519
+
520
+ # 发送请求并处理响应
521
+ for _ in range(3):
522
+ if self.print_log:
523
+ replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
524
+ print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
525
+
526
+ try:
527
+ # 改进处理方式,创建一个内部异步函数来处理异步调用
528
+ async def process_async():
529
+ # 异步调用 fetch_response_stream
530
+ async_generator = fetch_response_stream(
531
+ self.aclient,
532
+ url,
533
+ headers,
534
+ json_post,
535
+ engine_type,
536
+ model or self.engine,
537
+ )
538
+ # 异步处理响应流
539
+ async for chunk in self._process_stream_response(
540
+ async_generator,
541
+ convo_id=convo_id,
542
+ function_name=function_name,
543
+ total_tokens=total_tokens,
544
+ function_arguments=function_arguments,
545
+ function_call_id=function_call_id,
546
+ model=model,
547
+ language=language,
548
+ system_prompt=system_prompt,
549
+ pass_history=pass_history,
550
+ is_async=True,
551
+ **kwargs
552
+ ):
553
+ yield chunk
554
+
555
+ # 将异步函数转换为同步生成器
556
+ return async_generator_to_sync(process_async())
557
+ except ConnectionError:
558
+ print("连接错误,请检查服务器状态或网络连接。")
559
+ return
560
+ except requests.exceptions.ReadTimeout:
561
+ print("请求超时,请检查网络连接或增加超时时间。")
562
+ return
563
+ except Exception as e:
564
+ print(f"发生了未预料的错误:{e}")
565
+ if "Invalid URL" in str(e):
566
+ e = "您输入了无效的API URL,请使用正确的URL并使用`/start`命令重新设置API URL。具体错误如下:\n\n" + str(e)
567
+ raise Exception(f"{e}")
568
+ # 最后一次重试失败,向上抛出异常
569
+ if _ == 2:
570
+ raise Exception(f"{e}")
571
+
572
+ async def ask_stream_async(
573
+ self,
574
+ prompt: list,
575
+ role: str = "user",
576
+ convo_id: str = "default",
577
+ model: str = "",
578
+ pass_history: int = 9999,
579
+ function_name: str = "",
580
+ total_tokens: int = 0,
581
+ function_arguments: str = "",
582
+ function_call_id: str = "",
583
+ language: str = "English",
584
+ system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
585
+ **kwargs,
586
+ ):
587
+ """
588
+ Ask a question (异步流式响应)
589
+ """
590
+ # 准备会话
591
+ self.system_prompt = system_prompt or self.system_prompt
592
+ if convo_id not in self.conversation or pass_history <= 2:
593
+ self.reset(convo_id=convo_id, system_prompt=system_prompt)
594
+ self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name, total_tokens=total_tokens, function_arguments=function_arguments, pass_history=pass_history, function_call_id=function_call_id)
595
+
596
+ # 获取请求体
597
+ url, headers, json_post, engine_type = await self.get_post_body(prompt, role, convo_id, model, pass_history, **kwargs)
598
+ self.truncate_conversation(convo_id=convo_id)
599
+
600
+ # 打印日志
601
+ if self.print_log:
602
+ print("api_url", kwargs.get('api_url', self.api_url.chat_url) == url)
603
+ print("api_url", kwargs.get('api_url', self.api_url.chat_url))
604
+ print("api_url", url)
605
+ # print("headers", headers)
606
+ print("api_key", kwargs.get('api_key', self.api_key))
607
+
608
+ # 发送请求并处理响应
609
+ for _ in range(3):
610
+ if self.print_log:
611
+ replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
612
+ print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
613
+
614
+ try:
615
+ # 使用fetch_response_stream处理响应
616
+ generator = fetch_response_stream(
617
+ self.aclient,
618
+ url,
619
+ headers,
620
+ json_post,
621
+ engine_type,
622
+ model or self.engine,
623
+ )
624
+ # if isinstance(chunk, dict) and "error" in chunk:
625
+ # # 处理错误响应
626
+ # if chunk["status_code"] in (400, 422, 503):
627
+ # json_post, should_retry = await self._handle_response_error(
628
+ # type('Response', (), {'status_code': chunk["status_code"], 'text': json.dumps(chunk["details"]), 'aread': lambda: asyncio.sleep(0)}),
629
+ # json_post
630
+ # )
631
+ # if should_retry:
632
+ # break # 跳出内部循环,继续外部循环重试
633
+ # raise Exception(f"{chunk['status_code']} {chunk['error']} {chunk['details']}")
634
+
635
+ # 处理正常响应
636
+ async for processed_chunk in self._process_stream_response(
637
+ generator,
638
+ convo_id=convo_id,
639
+ function_name=function_name,
640
+ total_tokens=total_tokens,
641
+ function_arguments=function_arguments,
642
+ function_call_id=function_call_id,
643
+ model=model,
644
+ language=language,
645
+ system_prompt=system_prompt,
646
+ pass_history=pass_history,
647
+ is_async=True,
648
+ **kwargs
649
+ ):
650
+ yield processed_chunk
651
+
652
+ # 成功处理,跳出重试循环
653
+ break
654
+ except Exception as e:
655
+ print(f"发生了未预料的错误:{e}")
656
+ import traceback
657
+ traceback.print_exc()
658
+ if "Invalid URL" in str(e):
659
+ e = "您输入了无效的API URL,请使用正确的URL并使用`/start`命令重新设置API URL。具体错误如下:\n\n" + str(e)
660
+ raise Exception(f"{e}")
661
+ # 最后一次重试失败,向上抛出异常
662
+ if _ == 2:
663
+ raise Exception(f"{e}")
664
+
665
+ async def ask_async(
666
+ self,
667
+ prompt: str,
668
+ role: str = "user",
669
+ convo_id: str = "default",
670
+ model: str = "",
671
+ pass_history: int = 9999,
672
+ **kwargs,
673
+ ) -> str:
674
+ """
675
+ Non-streaming ask
676
+ """
677
+ response = self.ask_stream_async(
678
+ prompt=prompt,
679
+ role=role,
680
+ convo_id=convo_id,
681
+ pass_history=pass_history,
682
+ model=model or self.engine,
683
+ **kwargs,
684
+ )
685
+ full_response: str = "".join([r async for r in response])
686
+ return full_response
687
+
688
+ def rollback(self, n: int = 1, convo_id: str = "default") -> None:
689
+ """
690
+ Rollback the conversation
691
+ """
692
+ for _ in range(n):
693
+ self.conversation[convo_id].pop()
694
+
695
+ def reset(self, convo_id: str = "default", system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally") -> None:
696
+ """
697
+ Reset the conversation
698
+ """
699
+ self.system_prompt = system_prompt or self.system_prompt
700
+ self.conversation[convo_id] = [
701
+ {"role": "system", "content": self.system_prompt},
702
+ ]
703
+ self.tokens_usage[convo_id] = 0
704
+
705
+ def save(self, file: str, *keys: str) -> None:
706
+ """
707
+ Save the Chatbot configuration to a JSON file
708
+ """
709
+ with open(file, "w", encoding="utf-8") as f:
710
+ data = {
711
+ key: self.__dict__[key]
712
+ for key in get_filtered_keys_from_object(self, *keys)
713
+ }
714
+ # saves session.proxies dict as session
715
+ # leave this here for compatibility
716
+ data["session"] = data["proxy"]
717
+ del data["aclient"]
718
+ json.dump(
719
+ data,
720
+ f,
721
+ indent=2,
722
+ )
723
+
724
+ def load(self, file: Path, *keys_: str) -> None:
725
+ """
726
+ Load the Chatbot configuration from a JSON file
727
+ """
728
+ with open(file, encoding="utf-8") as f:
729
+ # load json, if session is in keys, load proxies
730
+ loaded_config = json.load(f)
731
+ keys = get_filtered_keys_from_object(self, *keys_)
732
+
733
+ if (
734
+ "session" in keys
735
+ and loaded_config["session"]
736
+ or "proxy" in keys
737
+ and loaded_config["proxy"]
738
+ ):
739
+ self.proxy = loaded_config.get("session", loaded_config["proxy"])
740
+ self.session = httpx.Client(
741
+ follow_redirects=True,
742
+ proxies=self.proxy,
743
+ timeout=self.timeout,
744
+ cookies=self.session.cookies,
745
+ headers=self.session.headers,
746
+ )
747
+ self.aclient = httpx.AsyncClient(
748
+ follow_redirects=True,
749
+ proxies=self.proxy,
750
+ timeout=self.timeout,
751
+ cookies=self.session.cookies,
752
+ headers=self.session.headers,
753
+ )
754
+ if "session" in keys:
755
+ keys.remove("session")
756
+ if "aclient" in keys:
757
+ keys.remove("aclient")
758
+ self.__dict__.update({key: loaded_config[key] for key in keys})
759
+
760
+ def _handle_response_error_common(self, response_text, json_post):
761
+ """通用的响应错误处理逻辑,适用于同步和异步场景"""
762
+ try:
763
+ # 检查内容审核失败
764
+ if "Content did not pass the moral check" in response_text:
765
+ return json_post, False, f"内容未通过道德检查:{response_text[:400]}"
766
+
767
+ # 处理函数调用相关错误
768
+ if "function calling" in response_text:
769
+ if "tools" in json_post:
770
+ del json_post["tools"]
771
+ if "tool_choice" in json_post:
772
+ del json_post["tool_choice"]
773
+ return json_post, True, None
774
+
775
+ # 处理请求格式错误
776
+ elif "invalid_request_error" in response_text:
777
+ for index, mess in enumerate(json_post["messages"]):
778
+ if type(mess["content"]) == list and "text" in mess["content"][0]:
779
+ json_post["messages"][index] = {
780
+ "role": mess["role"],
781
+ "content": mess["content"][0]["text"]
782
+ }
783
+ return json_post, True, None
784
+
785
+ # 处理角色不允许错误
786
+ elif "'function' is not an allowed role" in response_text:
787
+ if json_post["messages"][-1]["role"] == "tool":
788
+ mess = json_post["messages"][-1]
789
+ json_post["messages"][-1] = {
790
+ "role": "assistant",
791
+ "name": mess["name"],
792
+ "content": mess["content"]
793
+ }
794
+ return json_post, True, None
795
+
796
+ # 处理服务器繁忙错误
797
+ elif "Sorry, server is busy" in response_text:
798
+ for index, mess in enumerate(json_post["messages"]):
799
+ if type(mess["content"]) == list and "text" in mess["content"][0]:
800
+ json_post["messages"][index] = {
801
+ "role": mess["role"],
802
+ "content": mess["content"][0]["text"]
803
+ }
804
+ return json_post, True, None
805
+
806
+ # 处理token超限错误
807
+ elif "is not possible because the prompts occupy" in response_text:
808
+ max_tokens = re.findall(r"only\s(\d+)\stokens", response_text)
809
+ if max_tokens:
810
+ json_post["max_tokens"] = int(max_tokens[0])
811
+ return json_post, True, None
812
+
813
+ # 默认移除工具相关设置
814
+ else:
815
+ if "tools" in json_post:
816
+ del json_post["tools"]
817
+ if "tool_choice" in json_post:
818
+ del json_post["tool_choice"]
819
+ return json_post, True, None
820
+
821
+ except Exception as e:
822
+ print(f"处理响应错误时出现异常: {e}")
823
+ return json_post, False, str(e)
824
+
825
+ def _handle_response_error_sync(self, response, json_post):
826
+ """处理API响应错误并相应地修改请求体(同步版本)"""
827
+ response_text = response.text
828
+
829
+ # 处理空响应
830
+ if response.status_code == 200 and response_text == "":
831
+ for index, mess in enumerate(json_post["messages"]):
832
+ if type(mess["content"]) == list and "text" in mess["content"][0]:
833
+ json_post["messages"][index] = {
834
+ "role": mess["role"],
835
+ "content": mess["content"][0]["text"]
836
+ }
837
+ return json_post, True
838
+
839
+ json_post, should_retry, error_msg = self._handle_response_error_common(response_text, json_post)
840
+
841
+ if error_msg:
842
+ raise Exception(f"{response.status_code} {response.reason} {error_msg}")
843
+
844
+ return json_post, should_retry
845
+
846
+ async def _handle_response_error(self, response, json_post):
847
+ """处理API响应错误并相应地修改请求体(异步版本)"""
848
+ await response.aread()
849
+ response_text = response.text
850
+
851
+ json_post, should_retry, error_msg = self._handle_response_error_common(response_text, json_post)
852
+
853
+ if error_msg:
854
+ raise Exception(f"{response.status_code} {response.reason_phrase} {error_msg}")
855
+
856
+ return json_post, should_retry