beswarm 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. beswarm/aient/main.py +50 -0
  2. beswarm/aient/setup.py +15 -0
  3. beswarm/aient/src/aient/__init__.py +1 -0
  4. beswarm/aient/src/aient/core/__init__.py +1 -0
  5. beswarm/aient/src/aient/core/log_config.py +6 -0
  6. beswarm/aient/src/aient/core/models.py +232 -0
  7. beswarm/aient/src/aient/core/request.py +1665 -0
  8. beswarm/aient/src/aient/core/response.py +617 -0
  9. beswarm/aient/src/aient/core/test/test_base_api.py +18 -0
  10. beswarm/aient/src/aient/core/test/test_image.py +15 -0
  11. beswarm/aient/src/aient/core/test/test_payload.py +92 -0
  12. beswarm/aient/src/aient/core/utils.py +715 -0
  13. beswarm/aient/src/aient/models/__init__.py +9 -0
  14. beswarm/aient/src/aient/models/audio.py +63 -0
  15. beswarm/aient/src/aient/models/base.py +251 -0
  16. beswarm/aient/src/aient/models/chatgpt.py +938 -0
  17. beswarm/aient/src/aient/models/claude.py +640 -0
  18. beswarm/aient/src/aient/models/duckduckgo.py +241 -0
  19. beswarm/aient/src/aient/models/gemini.py +357 -0
  20. beswarm/aient/src/aient/models/groq.py +268 -0
  21. beswarm/aient/src/aient/models/vertex.py +420 -0
  22. beswarm/aient/src/aient/plugins/__init__.py +33 -0
  23. beswarm/aient/src/aient/plugins/arXiv.py +48 -0
  24. beswarm/aient/src/aient/plugins/config.py +172 -0
  25. beswarm/aient/src/aient/plugins/excute_command.py +35 -0
  26. beswarm/aient/src/aient/plugins/get_time.py +19 -0
  27. beswarm/aient/src/aient/plugins/image.py +72 -0
  28. beswarm/aient/src/aient/plugins/list_directory.py +50 -0
  29. beswarm/aient/src/aient/plugins/read_file.py +79 -0
  30. beswarm/aient/src/aient/plugins/registry.py +116 -0
  31. beswarm/aient/src/aient/plugins/run_python.py +156 -0
  32. beswarm/aient/src/aient/plugins/websearch.py +394 -0
  33. beswarm/aient/src/aient/plugins/write_file.py +51 -0
  34. beswarm/aient/src/aient/prompt/__init__.py +1 -0
  35. beswarm/aient/src/aient/prompt/agent.py +280 -0
  36. beswarm/aient/src/aient/utils/__init__.py +0 -0
  37. beswarm/aient/src/aient/utils/prompt.py +143 -0
  38. beswarm/aient/src/aient/utils/scripts.py +721 -0
  39. beswarm/aient/test/chatgpt.py +161 -0
  40. beswarm/aient/test/claude.py +32 -0
  41. beswarm/aient/test/test.py +2 -0
  42. beswarm/aient/test/test_API.py +6 -0
  43. beswarm/aient/test/test_Deepbricks.py +20 -0
  44. beswarm/aient/test/test_Web_crawler.py +262 -0
  45. beswarm/aient/test/test_aiwaves.py +25 -0
  46. beswarm/aient/test/test_aiwaves_arxiv.py +19 -0
  47. beswarm/aient/test/test_ask_gemini.py +8 -0
  48. beswarm/aient/test/test_class.py +17 -0
  49. beswarm/aient/test/test_claude.py +23 -0
  50. beswarm/aient/test/test_claude_zh_char.py +26 -0
  51. beswarm/aient/test/test_ddg_search.py +50 -0
  52. beswarm/aient/test/test_download_pdf.py +56 -0
  53. beswarm/aient/test/test_gemini.py +97 -0
  54. beswarm/aient/test/test_get_token_dict.py +21 -0
  55. beswarm/aient/test/test_google_search.py +35 -0
  56. beswarm/aient/test/test_jieba.py +32 -0
  57. beswarm/aient/test/test_json.py +65 -0
  58. beswarm/aient/test/test_langchain_search_old.py +235 -0
  59. beswarm/aient/test/test_logging.py +32 -0
  60. beswarm/aient/test/test_ollama.py +55 -0
  61. beswarm/aient/test/test_plugin.py +16 -0
  62. beswarm/aient/test/test_py_run.py +26 -0
  63. beswarm/aient/test/test_requests.py +162 -0
  64. beswarm/aient/test/test_search.py +18 -0
  65. beswarm/aient/test/test_tikitoken.py +19 -0
  66. beswarm/aient/test/test_token.py +94 -0
  67. beswarm/aient/test/test_url.py +33 -0
  68. beswarm/aient/test/test_whisper.py +14 -0
  69. beswarm/aient/test/test_wildcard.py +20 -0
  70. beswarm/aient/test/test_yjh.py +21 -0
  71. {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/METADATA +1 -1
  72. beswarm-0.1.13.dist-info/RECORD +131 -0
  73. beswarm-0.1.12.dist-info/RECORD +0 -61
  74. {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/WHEEL +0 -0
  75. {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,938 @@
1
+ import os
2
+ import re
3
+ import json
4
+ import copy
5
+ import httpx
6
+ import asyncio
7
+ import requests
8
+ from typing import Set
9
+ from typing import Union, Optional, Callable, List, Dict, Any
10
+ from pathlib import Path
11
+
12
+
13
+ from .base import BaseLLM
14
+ from ..plugins import PLUGINS, get_tools_result_async, function_call_list, update_tools_config
15
+ from ..utils.scripts import safe_get, async_generator_to_sync, parse_function_xml, parse_continuous_json, convert_functions_to_xml
16
+ from ..core.request import prepare_request_payload
17
+ from ..core.response import fetch_response_stream
18
+
19
+ def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]:
20
+ """
21
+ Get filtered list of object variable names.
22
+ :param keys: List of keys to include. If the first key is "not", the remaining keys will be removed from the class keys.
23
+ :return: List of class keys.
24
+ """
25
+ class_keys = obj.__dict__.keys()
26
+ if not keys:
27
+ return set(class_keys)
28
+
29
+ # Remove the passed keys from the class keys.
30
+ if keys[0] == "not":
31
+ return {key for key in class_keys if key not in keys[1:]}
32
+ # Check if all passed keys are valid
33
+ if invalid_keys := set(keys) - class_keys:
34
+ raise ValueError(
35
+ f"Invalid keys: {invalid_keys}",
36
+ )
37
+ # Only return specified keys that are in class_keys
38
+ return {key for key in keys if key in class_keys}
39
+
40
+ class chatgpt(BaseLLM):
41
+ """
42
+ Official ChatGPT API
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ api_key: str = None,
48
+ engine: str = os.environ.get("GPT_ENGINE") or "gpt-4o",
49
+ api_url: str = (os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions"),
50
+ system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
51
+ proxy: str = None,
52
+ timeout: float = 600,
53
+ max_tokens: int = None,
54
+ temperature: float = 0.5,
55
+ top_p: float = 1.0,
56
+ presence_penalty: float = 0.0,
57
+ frequency_penalty: float = 0.0,
58
+ reply_count: int = 1,
59
+ truncate_limit: int = None,
60
+ use_plugins: bool = True,
61
+ print_log: bool = False,
62
+ tools: Optional[Union[list, str, Callable]] = [],
63
+ function_call_max_loop: int = 3,
64
+ cut_history_by_function_name: str = "",
65
+ ) -> None:
66
+ """
67
+ Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
68
+ """
69
+ super().__init__(api_key, engine, api_url, system_prompt, proxy, timeout, max_tokens, temperature, top_p, presence_penalty, frequency_penalty, reply_count, truncate_limit, use_plugins=use_plugins, print_log=print_log)
70
+ self.conversation: dict[str, list[dict]] = {
71
+ "default": [
72
+ {
73
+ "role": "system",
74
+ "content": self.system_prompt,
75
+ },
76
+ ],
77
+ }
78
+ self.function_calls_counter = {}
79
+ self.function_call_max_loop = function_call_max_loop
80
+ self.cut_history_by_function_name = cut_history_by_function_name
81
+
82
+
83
+ # 注册和处理传入的工具
84
+ self._register_tools(tools)
85
+
86
+
87
+ def _register_tools(self, tools):
88
+ """动态注册工具函数并更新配置"""
89
+
90
+ self.plugins = copy.deepcopy(PLUGINS)
91
+ self.function_call_list = copy.deepcopy(function_call_list)
92
+ # 如果有新工具,需要注册到registry并更新配置
93
+ self.plugins, self.function_call_list, _ = update_tools_config()
94
+
95
+ if isinstance(tools, list):
96
+ self.tools = tools if tools else []
97
+ else:
98
+ self.tools = [tools] if tools else []
99
+
100
+ for tool in self.tools:
101
+ tool_name = tool.__name__ if callable(tool) else str(tool)
102
+ if tool_name in self.plugins:
103
+ self.plugins[tool_name] = True
104
+ else:
105
+ raise ValueError(f"Tool {tool_name} not found in plugins")
106
+
107
+ def add_to_conversation(
108
+ self,
109
+ message: Union[str, list],
110
+ role: str,
111
+ convo_id: str = "default",
112
+ function_name: str = "",
113
+ total_tokens: int = 0,
114
+ function_arguments: str = "",
115
+ pass_history: int = 9999,
116
+ function_call_id: str = "",
117
+ ) -> None:
118
+ """
119
+ Add a message to the conversation
120
+ """
121
+ # print("role", role, "function_name", function_name, "message", message)
122
+ if convo_id not in self.conversation:
123
+ self.reset(convo_id=convo_id)
124
+ if function_name == "" and message and message != None:
125
+ self.conversation[convo_id].append({"role": role, "content": message})
126
+ elif function_name != "" and message and message != None:
127
+ # 删除从 cut_history_by_function_name 以后的所有历史记录
128
+ if function_name == self.cut_history_by_function_name:
129
+ matching_message = next(filter(lambda x: safe_get(x, "tool_calls", 0, "function", "name", default="") == 'get_next_pdf', self.conversation[convo_id]), None)
130
+ if matching_message is not None:
131
+ self.conversation[convo_id] = self.conversation[convo_id][:self.conversation[convo_id].index(matching_message)]
132
+
133
+ if not (all(value == False for value in self.plugins.values()) or self.use_plugins == False):
134
+ self.conversation[convo_id].append({
135
+ "role": "assistant",
136
+ "tool_calls": [
137
+ {
138
+ "id": function_call_id,
139
+ "type": "function",
140
+ "function": {
141
+ "name": function_name,
142
+ "arguments": function_arguments,
143
+ },
144
+ }
145
+ ],
146
+ })
147
+ self.conversation[convo_id].append({"role": role, "tool_call_id": function_call_id, "content": message})
148
+ else:
149
+ self.conversation[convo_id].append({"role": "assistant", "content": convert_functions_to_xml(function_arguments)})
150
+ self.conversation[convo_id].append({"role": "user", "content": message})
151
+
152
+ else:
153
+ print('\033[31m')
154
+ print("error: add_to_conversation message is None or empty")
155
+ print("role", role, "function_name", function_name, "message", message)
156
+ print('\033[0m')
157
+
158
+ conversation_len = len(self.conversation[convo_id]) - 1
159
+ message_index = 0
160
+ # if self.print_log:
161
+ # replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(self.conversation[convo_id])))
162
+ # print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
163
+ while message_index < conversation_len:
164
+ if self.conversation[convo_id][message_index]["role"] == self.conversation[convo_id][message_index + 1]["role"]:
165
+ if self.conversation[convo_id][message_index].get("content") and self.conversation[convo_id][message_index + 1].get("content"):
166
+ if type(self.conversation[convo_id][message_index + 1]["content"]) == str \
167
+ and type(self.conversation[convo_id][message_index]["content"]) == list:
168
+ self.conversation[convo_id][message_index + 1]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index + 1]["content"]}]
169
+ if type(self.conversation[convo_id][message_index]["content"]) == str \
170
+ and type(self.conversation[convo_id][message_index + 1]["content"]) == list:
171
+ self.conversation[convo_id][message_index]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index]["content"]}]
172
+ if type(self.conversation[convo_id][message_index]["content"]) == dict \
173
+ and type(self.conversation[convo_id][message_index + 1]["content"]) == str:
174
+ self.conversation[convo_id][message_index]["content"] = [self.conversation[convo_id][message_index]["content"]]
175
+ self.conversation[convo_id][message_index + 1]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index + 1]["content"]}]
176
+ if type(self.conversation[convo_id][message_index]["content"]) == dict \
177
+ and type(self.conversation[convo_id][message_index + 1]["content"]) == list:
178
+ self.conversation[convo_id][message_index]["content"] = [self.conversation[convo_id][message_index]["content"]]
179
+ if type(self.conversation[convo_id][message_index]["content"]) == dict \
180
+ and type(self.conversation[convo_id][message_index + 1]["content"]) == dict:
181
+ self.conversation[convo_id][message_index]["content"] = [self.conversation[convo_id][message_index]["content"]]
182
+ self.conversation[convo_id][message_index + 1]["content"] = [self.conversation[convo_id][message_index + 1]["content"]]
183
+ if type(self.conversation[convo_id][message_index]["content"]) == list \
184
+ and type(self.conversation[convo_id][message_index + 1]["content"]) == dict:
185
+ self.conversation[convo_id][message_index + 1]["content"] = [self.conversation[convo_id][message_index + 1]["content"]]
186
+ self.conversation[convo_id][message_index]["content"] += self.conversation[convo_id][message_index + 1]["content"]
187
+ self.conversation[convo_id].pop(message_index + 1)
188
+ conversation_len = conversation_len - 1
189
+ else:
190
+ message_index = message_index + 1
191
+
192
+ history_len = len(self.conversation[convo_id])
193
+
194
+ history = pass_history
195
+ if pass_history < 2:
196
+ history = 2
197
+ while history_len > history:
198
+ mess_body = self.conversation[convo_id].pop(1)
199
+ history_len = history_len - 1
200
+ if mess_body.get("role") == "user":
201
+ assistant_body = self.conversation[convo_id].pop(1)
202
+ history_len = history_len - 1
203
+ if assistant_body.get("tool_calls"):
204
+ self.conversation[convo_id].pop(1)
205
+ history_len = history_len - 1
206
+
207
+ if total_tokens:
208
+ self.current_tokens[convo_id] = total_tokens
209
+ self.tokens_usage[convo_id] += total_tokens
210
+
211
+ def truncate_conversation(self, convo_id: str = "default") -> None:
212
+ """
213
+ Truncate the conversation
214
+ """
215
+ while True:
216
+ if (
217
+ self.current_tokens[convo_id] > self.truncate_limit
218
+ and len(self.conversation[convo_id]) > 1
219
+ ):
220
+ # Don't remove the first message
221
+ mess = self.conversation[convo_id].pop(1)
222
+ string_mess = json.dumps(mess, ensure_ascii=False)
223
+ self.current_tokens[convo_id] -= len(string_mess) / 4
224
+ print("Truncate message:", mess)
225
+ else:
226
+ break
227
+
228
+ async def get_post_body(
229
+ self,
230
+ prompt: str,
231
+ role: str = "user",
232
+ convo_id: str = "default",
233
+ model: str = "",
234
+ pass_history: int = 9999,
235
+ **kwargs,
236
+ ):
237
+ self.conversation[convo_id][0] = {"role": "system","content": self.system_prompt}
238
+
239
+ # 构造 provider 信息
240
+ provider = {
241
+ "provider": "openai",
242
+ "base_url": kwargs.get('api_url', self.api_url.chat_url),
243
+ "api": kwargs.get('api_key', self.api_key),
244
+ "model": [model or self.engine],
245
+ "tools": True if self.use_plugins else False,
246
+ "image": True
247
+ }
248
+
249
+ # 构造请求数据
250
+ request_data = {
251
+ "model": model or self.engine,
252
+ "messages": copy.deepcopy(self.conversation[convo_id]) if pass_history else [
253
+ {"role": "system","content": self.system_prompt},
254
+ {"role": role, "content": prompt}
255
+ ],
256
+ "stream": True,
257
+ "stream_options": {
258
+ "include_usage": True
259
+ },
260
+ "temperature": kwargs.get("temperature", self.temperature)
261
+ }
262
+
263
+ if kwargs.get("max_tokens", self.max_tokens):
264
+ request_data["max_tokens"] = kwargs.get("max_tokens", self.max_tokens)
265
+
266
+ # 添加工具相关信息
267
+ if kwargs.get("plugins", None):
268
+ self.plugins = kwargs.get("plugins")
269
+
270
+ plugins_status = kwargs.get("plugins", self.plugins)
271
+ if not (all(value == False for value in plugins_status.values()) or self.use_plugins == False):
272
+ tools_request_body = []
273
+ for item in plugins_status.keys():
274
+ try:
275
+ if plugins_status[item]:
276
+ tools_request_body.append({"type": "function", "function": self.function_call_list[item]})
277
+ except:
278
+ pass
279
+ if tools_request_body:
280
+ request_data["tools"] = tools_request_body
281
+ request_data["tool_choice"] = "auto"
282
+
283
+ # print("request_data", json.dumps(request_data, indent=4, ensure_ascii=False))
284
+
285
+ # 调用核心模块的 prepare_request_payload 函数
286
+ url, headers, json_post_body, engine_type = await prepare_request_payload(provider, request_data)
287
+
288
+ return url, headers, json_post_body, engine_type
289
+
290
+ async def _process_stream_response(
291
+ self,
292
+ response_gen,
293
+ convo_id="default",
294
+ function_name="",
295
+ total_tokens=0,
296
+ function_arguments="",
297
+ function_call_id="",
298
+ model="",
299
+ language="English",
300
+ system_prompt=None,
301
+ pass_history=9999,
302
+ is_async=False,
303
+ **kwargs
304
+ ):
305
+ """
306
+ 处理流式响应的共用逻辑
307
+
308
+ :param response_gen: 响应生成器(同步或异步)
309
+ :param is_async: 是否使用异步模式
310
+ """
311
+ response_role = None
312
+ full_response = ""
313
+ function_full_response = ""
314
+ function_call_name = ""
315
+ need_function_call = False
316
+
317
+ # 处理单行数据的公共逻辑
318
+ def process_line(line):
319
+ nonlocal response_role, full_response, function_full_response, function_call_name, need_function_call, total_tokens, function_call_id
320
+
321
+ if not line or (isinstance(line, str) and line.startswith(':')):
322
+ return None
323
+
324
+ if isinstance(line, str) and line.startswith('data:'):
325
+ line = line.lstrip("data: ")
326
+ if line == "[DONE]":
327
+ return "DONE"
328
+ elif isinstance(line, (dict, list)):
329
+ if isinstance(line, dict) and safe_get(line, "choices", 0, "message", "content"):
330
+ full_response = line["choices"][0]["message"]["content"]
331
+ return full_response
332
+ else:
333
+ return str(line)
334
+ else:
335
+ try:
336
+ if isinstance(line, str):
337
+ line = json.loads(line)
338
+ if safe_get(line, "choices", 0, "message", "content"):
339
+ full_response = line["choices"][0]["message"]["content"]
340
+ return full_response
341
+ else:
342
+ return str(line)
343
+ except:
344
+ print("json.loads error:", repr(line))
345
+ return None
346
+
347
+ resp = json.loads(line) if isinstance(line, str) else line
348
+ if "error" in resp:
349
+ raise Exception(f"{resp}")
350
+
351
+ total_tokens = total_tokens or safe_get(resp, "usage", "total_tokens", default=0)
352
+ delta = safe_get(resp, "choices", 0, "delta")
353
+ if not delta:
354
+ return None
355
+
356
+ response_role = response_role or safe_get(delta, "role")
357
+ if safe_get(delta, "content"):
358
+ need_function_call = False
359
+ content = delta["content"]
360
+ full_response += content
361
+ return content
362
+
363
+ if safe_get(delta, "tool_calls"):
364
+ need_function_call = True
365
+ function_call_name = function_call_name or safe_get(delta, "tool_calls", 0, "function", "name")
366
+ function_full_response += safe_get(delta, "tool_calls", 0, "function", "arguments", default="")
367
+ function_call_id = function_call_id or safe_get(delta, "tool_calls", 0, "id")
368
+ return None
369
+
370
+ # 处理流式响应
371
+ async def process_async():
372
+ nonlocal response_role, full_response, function_full_response, function_call_name, need_function_call, total_tokens, function_call_id
373
+
374
+ async for line in response_gen:
375
+ line = line.strip() if isinstance(line, str) else line
376
+ result = process_line(line)
377
+ if result == "DONE":
378
+ break
379
+ elif result:
380
+ yield result
381
+
382
+ def process_sync():
383
+ nonlocal response_role, full_response, function_full_response, function_call_name, need_function_call, total_tokens, function_call_id
384
+
385
+ for line in response_gen:
386
+ line = line.decode("utf-8") if hasattr(line, "decode") else line
387
+ result = process_line(line)
388
+ if result == "DONE":
389
+ break
390
+ elif result:
391
+ yield result
392
+
393
+ # 使用同步或异步处理器处理响应
394
+ if is_async:
395
+ async for chunk in process_async():
396
+ yield chunk
397
+ else:
398
+ for chunk in process_sync():
399
+ yield chunk
400
+
401
+ if self.print_log:
402
+ print("\n\rtotal_tokens", total_tokens)
403
+
404
+ if response_role is None:
405
+ response_role = "assistant"
406
+
407
+ function_parameter = parse_function_xml(full_response)
408
+ if function_parameter:
409
+ need_function_call = True
410
+
411
+ # 处理函数调用
412
+ if need_function_call and self.use_plugins != False:
413
+ if self.print_log:
414
+ print("function_parameter", function_parameter)
415
+ print("function_full_response", function_full_response)
416
+
417
+ function_response = ""
418
+ # 定义处理单个工具调用的辅助函数
419
+ async def process_single_tool_call(tool_name, tool_args, tool_id):
420
+ nonlocal function_response
421
+
422
+ if not self.function_calls_counter.get(tool_name):
423
+ self.function_calls_counter[tool_name] = 1
424
+ else:
425
+ self.function_calls_counter[tool_name] += 1
426
+
427
+ tool_response = ""
428
+ has_args = safe_get(self.function_call_list, tool_name, "parameters", "required", default=False)
429
+ if self.function_calls_counter[tool_name] <= self.function_call_max_loop and (tool_args != "{}" or not has_args):
430
+ function_call_max_tokens = self.truncate_limit - 1000
431
+ if function_call_max_tokens <= 0:
432
+ function_call_max_tokens = int(self.truncate_limit / 2)
433
+ if self.print_log:
434
+ print(f"\033[32m function_call {tool_name}, max token: {function_call_max_tokens} \033[0m")
435
+
436
+ # 处理函数调用结果
437
+ if is_async:
438
+ async for chunk in get_tools_result_async(
439
+ tool_name, tool_args, function_call_max_tokens,
440
+ model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
441
+ kwargs.get('api_url', self.api_url.chat_url), use_plugins=False, model=model or self.engine,
442
+ add_message=self.add_to_conversation, convo_id=convo_id, language=language
443
+ ):
444
+ yield chunk
445
+ else:
446
+ async def run_async():
447
+ async for chunk in get_tools_result_async(
448
+ tool_name, tool_args, function_call_max_tokens,
449
+ model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
450
+ kwargs.get('api_url', self.api_url.chat_url), use_plugins=False, model=model or self.engine,
451
+ add_message=self.add_to_conversation, convo_id=convo_id, language=language
452
+ ):
453
+ yield chunk
454
+
455
+ for chunk in async_generator_to_sync(run_async()):
456
+ yield chunk
457
+ else:
458
+ tool_response = f"无法找到相关信息,停止使用工具 {tool_name}"
459
+
460
+ yield tool_response
461
+
462
+ # 使用统一的JSON解析逻辑
463
+ try:
464
+ if function_full_response:
465
+ function_parameter = parse_continuous_json(function_full_response, function_call_name)
466
+ except Exception as e:
467
+ print(f"解析JSON失败: {e}")
468
+ # 保持原始工具调用
469
+ tool_calls = [{
470
+ 'function_name': function_call_name,
471
+ 'parameter': function_full_response,
472
+ 'function_call_id': function_call_id
473
+ }]
474
+
475
+ # 统一处理逻辑,将所有情况转换为列表处理
476
+ if isinstance(function_parameter, list) and function_parameter:
477
+ # 多个工具调用
478
+ tool_calls = function_parameter
479
+
480
+ # 处理所有工具调用
481
+ all_responses = []
482
+
483
+ for tool_info in tool_calls:
484
+ tool_name = tool_info['function_name']
485
+ tool_args = json.dumps(tool_info['parameter']) if not isinstance(tool_info['parameter'], str) else tool_info['parameter']
486
+ tool_id = tool_info.get('function_call_id', tool_name + "_tool_call")
487
+
488
+ tool_response = ""
489
+ if is_async:
490
+ async for chunk in process_single_tool_call(tool_name, tool_args, tool_id):
491
+ if isinstance(chunk, str) and "function_response:" in chunk:
492
+ tool_response = chunk.replace("function_response:", "")
493
+ else:
494
+ yield chunk
495
+ else:
496
+ for chunk in async_generator_to_sync(process_single_tool_call(tool_name, tool_args, tool_id)):
497
+ if isinstance(chunk, str) and "function_response:" in chunk:
498
+ tool_response = chunk.replace("function_response:", "")
499
+ else:
500
+ yield chunk
501
+ all_responses.append(f"[{tool_name}({tool_args}) Result]:\n\n{tool_response}")
502
+
503
+ # 合并所有工具响应
504
+ function_response = "\n\n".join(all_responses)
505
+
506
+ # 使用第一个工具的名称和参数作为历史记录
507
+ function_call_name = tool_calls[0]['function_name']
508
+ function_full_response = function_full_response or json.dumps(tool_calls) if not isinstance(tool_calls[0]['parameter'], str) else tool_calls
509
+ function_call_id = tool_calls[0].get('function_call_id', function_call_name + "_tool_call")
510
+
511
+ response_role = "tool"
512
+
513
+ # 递归处理函数调用响应
514
+ if is_async:
515
+ async for chunk in self.ask_stream_async(
516
+ function_response, response_role, convo_id=convo_id,
517
+ function_name=function_call_name, total_tokens=total_tokens,
518
+ model=model or self.engine, function_arguments=function_full_response,
519
+ function_call_id=function_call_id, api_key=kwargs.get('api_key', self.api_key),
520
+ api_url=kwargs.get('api_url', self.api_url.chat_url),
521
+ plugins=kwargs.get("plugins", self.plugins), system_prompt=system_prompt
522
+ ):
523
+ yield chunk
524
+ else:
525
+ for chunk in self.ask_stream(
526
+ function_response, response_role, convo_id=convo_id,
527
+ function_name=function_call_name, total_tokens=total_tokens,
528
+ model=model or self.engine, function_arguments=function_full_response,
529
+ function_call_id=function_call_id, api_key=kwargs.get('api_key', self.api_key),
530
+ api_url=kwargs.get('api_url', self.api_url.chat_url),
531
+ plugins=kwargs.get("plugins", self.plugins), system_prompt=system_prompt
532
+ ):
533
+ yield chunk
534
+ else:
535
+ # 添加响应到对话历史
536
+ self.add_to_conversation(full_response, response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)
537
+ self.function_calls_counter = {}
538
+
539
+ # 清理翻译引擎相关的历史记录
540
+ if pass_history <= 2 and len(self.conversation[convo_id]) >= 2 \
541
+ and (
542
+ "You are a translation engine" in self.conversation[convo_id][-2]["content"] \
543
+ or "You are a translation engine" in safe_get(self.conversation, convo_id, -2, "content", 0, "text", default="") \
544
+ or "你是一位精通简体中文的专业翻译" in self.conversation[convo_id][-2]["content"] \
545
+ or "你是一位精通简体中文的专业翻译" in safe_get(self.conversation, convo_id, -2, "content", 0, "text", default="")
546
+ ):
547
+ self.conversation[convo_id].pop(-1)
548
+ self.conversation[convo_id].pop(-1)
549
+
550
+ def ask_stream(
551
+ self,
552
+ prompt: list,
553
+ role: str = "user",
554
+ convo_id: str = "default",
555
+ model: str = "",
556
+ pass_history: int = 9999,
557
+ function_name: str = "",
558
+ total_tokens: int = 0,
559
+ function_arguments: str = "",
560
+ function_call_id: str = "",
561
+ language: str = "English",
562
+ system_prompt: str = None,
563
+ **kwargs,
564
+ ):
565
+ """
566
+ Ask a question (同步流式响应)
567
+ """
568
+ # 准备会话
569
+ self.system_prompt = system_prompt or self.system_prompt
570
+ if convo_id not in self.conversation or pass_history <= 2:
571
+ self.reset(convo_id=convo_id, system_prompt=system_prompt)
572
+ self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name, total_tokens=total_tokens, function_arguments=function_arguments, function_call_id=function_call_id, pass_history=pass_history)
573
+
574
+ # 获取请求体
575
+ json_post = None
576
+ async def get_post_body_async():
577
+ nonlocal json_post
578
+ url, headers, json_post, engine_type = await self.get_post_body(prompt, role, convo_id, model, pass_history, **kwargs)
579
+ return url, headers, json_post, engine_type
580
+
581
+ # 替换原来的获取请求体的代码
582
+ # json_post = next(async_generator_to_sync(get_post_body_async()))
583
+ try:
584
+ url, headers, json_post, engine_type = asyncio.run(get_post_body_async())
585
+ except RuntimeError:
586
+ # 如果已经在事件循环中,则使用不同的方法
587
+ loop = asyncio.get_event_loop()
588
+ url, headers, json_post, engine_type = loop.run_until_complete(get_post_body_async())
589
+
590
+ self.truncate_conversation(convo_id=convo_id)
591
+
592
+ # 打印日志
593
+ if self.print_log:
594
+ print("api_url", kwargs.get('api_url', self.api_url.chat_url), url)
595
+ print("api_key", kwargs.get('api_key', self.api_key))
596
+
597
+ # 发送请求并处理响应
598
+ for _ in range(3):
599
+ if self.print_log:
600
+ replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
601
+ print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
602
+
603
+ try:
604
+ # 改进处理方式,创建一个内部异步函数来处理异步调用
605
+ async def process_async():
606
+ # 异步调用 fetch_response_stream
607
+ async_generator = fetch_response_stream(
608
+ self.aclient,
609
+ url,
610
+ headers,
611
+ json_post,
612
+ engine_type,
613
+ model or self.engine,
614
+ )
615
+ # 异步处理响应流
616
+ async for chunk in self._process_stream_response(
617
+ async_generator,
618
+ convo_id=convo_id,
619
+ function_name=function_name,
620
+ total_tokens=total_tokens,
621
+ function_arguments=function_arguments,
622
+ function_call_id=function_call_id,
623
+ model=model,
624
+ language=language,
625
+ system_prompt=system_prompt,
626
+ pass_history=pass_history,
627
+ is_async=True,
628
+ **kwargs
629
+ ):
630
+ yield chunk
631
+
632
+ # 将异步函数转换为同步生成器
633
+ return async_generator_to_sync(process_async())
634
+ except ConnectionError:
635
+ print("连接错误,请检查服务器状态或网络连接。")
636
+ return
637
+ except requests.exceptions.ReadTimeout:
638
+ print("请求超时,请检查网络连接或增加超时时间。")
639
+ return
640
+ except httpx.RemoteProtocolError:
641
+ continue
642
+ except Exception as e:
643
+ print(f"发生了未预料的错误:{e}")
644
+ if "Invalid URL" in str(e):
645
+ e = "您输入了无效的API URL,请使用正确的URL并使用`/start`命令重新设置API URL。具体错误如下:\n\n" + str(e)
646
+ raise Exception(f"{e}")
647
+ # 最后一次重试失败,向上抛出异常
648
+ if _ == 2:
649
+ raise Exception(f"{e}")
650
+
651
+ async def ask_stream_async(
652
+ self,
653
+ prompt: list,
654
+ role: str = "user",
655
+ convo_id: str = "default",
656
+ model: str = "",
657
+ pass_history: int = 9999,
658
+ function_name: str = "",
659
+ total_tokens: int = 0,
660
+ function_arguments: str = "",
661
+ function_call_id: str = "",
662
+ language: str = "English",
663
+ system_prompt: str = None,
664
+ **kwargs,
665
+ ):
666
+ """
667
+ Ask a question (异步流式响应)
668
+ """
669
+ # 准备会话
670
+ self.system_prompt = system_prompt or self.system_prompt
671
+ if convo_id not in self.conversation or pass_history <= 2:
672
+ self.reset(convo_id=convo_id, system_prompt=system_prompt)
673
+ self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name, total_tokens=total_tokens, function_arguments=function_arguments, pass_history=pass_history, function_call_id=function_call_id)
674
+
675
+ # 获取请求体
676
+ url, headers, json_post, engine_type = await self.get_post_body(prompt, role, convo_id, model, pass_history, **kwargs)
677
+ self.truncate_conversation(convo_id=convo_id)
678
+
679
+ # 打印日志
680
+ if self.print_log:
681
+ print("api_url", kwargs.get('api_url', self.api_url.chat_url) == url)
682
+ print("api_url", kwargs.get('api_url', self.api_url.chat_url))
683
+ print("api_url", url)
684
+ # print("headers", headers)
685
+ print("api_key", kwargs.get('api_key', self.api_key))
686
+
687
+ # 发送请求并处理响应
688
+ for _ in range(3):
689
+ if self.print_log:
690
+ replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
691
+ print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
692
+
693
+ try:
694
+ # 使用fetch_response_stream处理响应
695
+ generator = fetch_response_stream(
696
+ self.aclient,
697
+ url,
698
+ headers,
699
+ json_post,
700
+ engine_type,
701
+ model or self.engine,
702
+ )
703
+ # if isinstance(chunk, dict) and "error" in chunk:
704
+ # # 处理错误响应
705
+ # if chunk["status_code"] in (400, 422, 503):
706
+ # json_post, should_retry = await self._handle_response_error(
707
+ # type('Response', (), {'status_code': chunk["status_code"], 'text': json.dumps(chunk["details"]), 'aread': lambda: asyncio.sleep(0)}),
708
+ # json_post
709
+ # )
710
+ # if should_retry:
711
+ # break # 跳出内部循环,继续外部循环重试
712
+ # raise Exception(f"{chunk['status_code']} {chunk['error']} {chunk['details']}")
713
+
714
+ # 处理正常响应
715
+ async for processed_chunk in self._process_stream_response(
716
+ generator,
717
+ convo_id=convo_id,
718
+ function_name=function_name,
719
+ total_tokens=total_tokens,
720
+ function_arguments=function_arguments,
721
+ function_call_id=function_call_id,
722
+ model=model,
723
+ language=language,
724
+ system_prompt=system_prompt,
725
+ pass_history=pass_history,
726
+ is_async=True,
727
+ **kwargs
728
+ ):
729
+ yield processed_chunk
730
+
731
+ # 成功处理,跳出重试循环
732
+ break
733
+ except httpx.RemoteProtocolError:
734
+ continue
735
+ except Exception as e:
736
+ print(f"发生了未预料的错误:{e}")
737
+ import traceback
738
+ traceback.print_exc()
739
+ if "Invalid URL" in str(e):
740
+ e = "您输入了无效的API URL,请使用正确的URL并使用`/start`命令重新设置API URL。具体错误如下:\n\n" + str(e)
741
+ raise Exception(f"{e}")
742
+ # 最后一次重试失败,向上抛出异常
743
+ if _ == 2:
744
+ raise Exception(f"{e}")
745
+
746
+ async def ask_async(
747
+ self,
748
+ prompt: str,
749
+ role: str = "user",
750
+ convo_id: str = "default",
751
+ model: str = "",
752
+ pass_history: int = 9999,
753
+ **kwargs,
754
+ ) -> str:
755
+ """
756
+ Non-streaming ask
757
+ """
758
+ response = self.ask_stream_async(
759
+ prompt=prompt,
760
+ role=role,
761
+ convo_id=convo_id,
762
+ pass_history=pass_history,
763
+ model=model or self.engine,
764
+ **kwargs,
765
+ )
766
+ full_response: str = "".join([r async for r in response])
767
+ return full_response
768
+
769
+ def rollback(self, n: int = 1, convo_id: str = "default") -> None:
770
+ """
771
+ Rollback the conversation
772
+ """
773
+ for _ in range(n):
774
+ self.conversation[convo_id].pop()
775
+
776
+ def reset(self, convo_id: str = "default", system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally") -> None:
777
+ """
778
+ Reset the conversation
779
+ """
780
+ self.system_prompt = system_prompt or self.system_prompt
781
+ self.conversation[convo_id] = [
782
+ {"role": "system", "content": self.system_prompt},
783
+ ]
784
+ self.tokens_usage[convo_id] = 0
785
+ self.current_tokens[convo_id] = 0
786
+
787
+ def save(self, file: str, *keys: str) -> None:
788
+ """
789
+ Save the Chatbot configuration to a JSON file
790
+ """
791
+ with open(file, "w", encoding="utf-8") as f:
792
+ data = {
793
+ key: self.__dict__[key]
794
+ for key in get_filtered_keys_from_object(self, *keys)
795
+ }
796
+ # saves session.proxies dict as session
797
+ # leave this here for compatibility
798
+ data["session"] = data["proxy"]
799
+ del data["aclient"]
800
+ json.dump(
801
+ data,
802
+ f,
803
+ indent=2,
804
+ )
805
+
806
+ def load(self, file: Path, *keys_: str) -> None:
807
+ """
808
+ Load the Chatbot configuration from a JSON file
809
+ """
810
+ with open(file, encoding="utf-8") as f:
811
+ # load json, if session is in keys, load proxies
812
+ loaded_config = json.load(f)
813
+ keys = get_filtered_keys_from_object(self, *keys_)
814
+
815
+ if (
816
+ "session" in keys
817
+ and loaded_config["session"]
818
+ or "proxy" in keys
819
+ and loaded_config["proxy"]
820
+ ):
821
+ self.proxy = loaded_config.get("session", loaded_config["proxy"])
822
+ self.session = httpx.Client(
823
+ follow_redirects=True,
824
+ proxies=self.proxy,
825
+ timeout=self.timeout,
826
+ cookies=self.session.cookies,
827
+ headers=self.session.headers,
828
+ )
829
+ self.aclient = httpx.AsyncClient(
830
+ follow_redirects=True,
831
+ proxies=self.proxy,
832
+ timeout=self.timeout,
833
+ cookies=self.session.cookies,
834
+ headers=self.session.headers,
835
+ )
836
+ if "session" in keys:
837
+ keys.remove("session")
838
+ if "aclient" in keys:
839
+ keys.remove("aclient")
840
+ self.__dict__.update({key: loaded_config[key] for key in keys})
841
+
842
+ def _handle_response_error_common(self, response_text, json_post):
843
+ """通用的响应错误处理逻辑,适用于同步和异步场景"""
844
+ try:
845
+ # 检查内容审核失败
846
+ if "Content did not pass the moral check" in response_text:
847
+ return json_post, False, f"内容未通过道德检查:{response_text[:400]}"
848
+
849
+ # 处理函数调用相关错误
850
+ if "function calling" in response_text:
851
+ if "tools" in json_post:
852
+ del json_post["tools"]
853
+ if "tool_choice" in json_post:
854
+ del json_post["tool_choice"]
855
+ return json_post, True, None
856
+
857
+ # 处理请求格式错误
858
+ elif "invalid_request_error" in response_text:
859
+ for index, mess in enumerate(json_post["messages"]):
860
+ if type(mess["content"]) == list and "text" in mess["content"][0]:
861
+ json_post["messages"][index] = {
862
+ "role": mess["role"],
863
+ "content": mess["content"][0]["text"]
864
+ }
865
+ return json_post, True, None
866
+
867
+ # 处理角色不允许错误
868
+ elif "'function' is not an allowed role" in response_text:
869
+ if json_post["messages"][-1]["role"] == "tool":
870
+ mess = json_post["messages"][-1]
871
+ json_post["messages"][-1] = {
872
+ "role": "assistant",
873
+ "name": mess["name"],
874
+ "content": mess["content"]
875
+ }
876
+ return json_post, True, None
877
+
878
+ # 处理服务器繁忙错误
879
+ elif "Sorry, server is busy" in response_text:
880
+ for index, mess in enumerate(json_post["messages"]):
881
+ if type(mess["content"]) == list and "text" in mess["content"][0]:
882
+ json_post["messages"][index] = {
883
+ "role": mess["role"],
884
+ "content": mess["content"][0]["text"]
885
+ }
886
+ return json_post, True, None
887
+
888
+ # 处理token超限错误
889
+ elif "is not possible because the prompts occupy" in response_text:
890
+ max_tokens = re.findall(r"only\s(\d+)\stokens", response_text)
891
+ if max_tokens:
892
+ json_post["max_tokens"] = int(max_tokens[0])
893
+ return json_post, True, None
894
+
895
+ # 默认移除工具相关设置
896
+ else:
897
+ if "tools" in json_post:
898
+ del json_post["tools"]
899
+ if "tool_choice" in json_post:
900
+ del json_post["tool_choice"]
901
+ return json_post, True, None
902
+
903
+ except Exception as e:
904
+ print(f"处理响应错误时出现异常: {e}")
905
+ return json_post, False, str(e)
906
+
907
+ def _handle_response_error_sync(self, response, json_post):
908
+ """处理API响应错误并相应地修改请求体(同步版本)"""
909
+ response_text = response.text
910
+
911
+ # 处理空响应
912
+ if response.status_code == 200 and response_text == "":
913
+ for index, mess in enumerate(json_post["messages"]):
914
+ if type(mess["content"]) == list and "text" in mess["content"][0]:
915
+ json_post["messages"][index] = {
916
+ "role": mess["role"],
917
+ "content": mess["content"][0]["text"]
918
+ }
919
+ return json_post, True
920
+
921
+ json_post, should_retry, error_msg = self._handle_response_error_common(response_text, json_post)
922
+
923
+ if error_msg:
924
+ raise Exception(f"{response.status_code} {response.reason} {error_msg}")
925
+
926
+ return json_post, should_retry
927
+
928
+ async def _handle_response_error(self, response, json_post):
929
+ """处理API响应错误并相应地修改请求体(异步版本)"""
930
+ await response.aread()
931
+ response_text = response.text
932
+
933
+ json_post, should_retry, error_msg = self._handle_response_error_common(response_text, json_post)
934
+
935
+ if error_msg:
936
+ raise Exception(f"{response.status_code} {response.reason_phrase} {error_msg}")
937
+
938
+ return json_post, should_retry