beswarm 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. beswarm/aient/main.py +50 -0
  2. beswarm/aient/setup.py +15 -0
  3. beswarm/aient/src/aient/__init__.py +1 -0
  4. beswarm/aient/src/aient/core/__init__.py +1 -0
  5. beswarm/aient/src/aient/core/log_config.py +6 -0
  6. beswarm/aient/src/aient/core/models.py +232 -0
  7. beswarm/aient/src/aient/core/request.py +1665 -0
  8. beswarm/aient/src/aient/core/response.py +617 -0
  9. beswarm/aient/src/aient/core/test/test_base_api.py +18 -0
  10. beswarm/aient/src/aient/core/test/test_image.py +15 -0
  11. beswarm/aient/src/aient/core/test/test_payload.py +92 -0
  12. beswarm/aient/src/aient/core/utils.py +715 -0
  13. beswarm/aient/src/aient/models/__init__.py +9 -0
  14. beswarm/aient/src/aient/models/audio.py +63 -0
  15. beswarm/aient/src/aient/models/base.py +251 -0
  16. beswarm/aient/src/aient/models/chatgpt.py +941 -0
  17. beswarm/aient/src/aient/models/claude.py +640 -0
  18. beswarm/aient/src/aient/models/duckduckgo.py +241 -0
  19. beswarm/aient/src/aient/models/gemini.py +357 -0
  20. beswarm/aient/src/aient/models/groq.py +268 -0
  21. beswarm/aient/src/aient/models/vertex.py +420 -0
  22. beswarm/aient/src/aient/plugins/__init__.py +33 -0
  23. beswarm/aient/src/aient/plugins/arXiv.py +48 -0
  24. beswarm/aient/src/aient/plugins/config.py +172 -0
  25. beswarm/aient/src/aient/plugins/excute_command.py +35 -0
  26. beswarm/aient/src/aient/plugins/get_time.py +19 -0
  27. beswarm/aient/src/aient/plugins/image.py +72 -0
  28. beswarm/aient/src/aient/plugins/list_directory.py +50 -0
  29. beswarm/aient/src/aient/plugins/read_file.py +79 -0
  30. beswarm/aient/src/aient/plugins/registry.py +116 -0
  31. beswarm/aient/src/aient/plugins/run_python.py +156 -0
  32. beswarm/aient/src/aient/plugins/websearch.py +394 -0
  33. beswarm/aient/src/aient/plugins/write_file.py +51 -0
  34. beswarm/aient/src/aient/prompt/__init__.py +1 -0
  35. beswarm/aient/src/aient/prompt/agent.py +280 -0
  36. beswarm/aient/src/aient/utils/__init__.py +0 -0
  37. beswarm/aient/src/aient/utils/prompt.py +143 -0
  38. beswarm/aient/src/aient/utils/scripts.py +721 -0
  39. beswarm/aient/test/chatgpt.py +161 -0
  40. beswarm/aient/test/claude.py +32 -0
  41. beswarm/aient/test/test.py +2 -0
  42. beswarm/aient/test/test_API.py +6 -0
  43. beswarm/aient/test/test_Deepbricks.py +20 -0
  44. beswarm/aient/test/test_Web_crawler.py +262 -0
  45. beswarm/aient/test/test_aiwaves.py +25 -0
  46. beswarm/aient/test/test_aiwaves_arxiv.py +19 -0
  47. beswarm/aient/test/test_ask_gemini.py +8 -0
  48. beswarm/aient/test/test_class.py +17 -0
  49. beswarm/aient/test/test_claude.py +23 -0
  50. beswarm/aient/test/test_claude_zh_char.py +26 -0
  51. beswarm/aient/test/test_ddg_search.py +50 -0
  52. beswarm/aient/test/test_download_pdf.py +56 -0
  53. beswarm/aient/test/test_gemini.py +97 -0
  54. beswarm/aient/test/test_get_token_dict.py +21 -0
  55. beswarm/aient/test/test_google_search.py +35 -0
  56. beswarm/aient/test/test_jieba.py +32 -0
  57. beswarm/aient/test/test_json.py +65 -0
  58. beswarm/aient/test/test_langchain_search_old.py +235 -0
  59. beswarm/aient/test/test_logging.py +32 -0
  60. beswarm/aient/test/test_ollama.py +55 -0
  61. beswarm/aient/test/test_plugin.py +16 -0
  62. beswarm/aient/test/test_py_run.py +26 -0
  63. beswarm/aient/test/test_requests.py +162 -0
  64. beswarm/aient/test/test_search.py +18 -0
  65. beswarm/aient/test/test_tikitoken.py +19 -0
  66. beswarm/aient/test/test_token.py +94 -0
  67. beswarm/aient/test/test_url.py +33 -0
  68. beswarm/aient/test/test_whisper.py +14 -0
  69. beswarm/aient/test/test_wildcard.py +20 -0
  70. beswarm/aient/test/test_yjh.py +21 -0
  71. beswarm/tools/worker.py +3 -1
  72. {beswarm-0.1.12.dist-info → beswarm-0.1.14.dist-info}/METADATA +1 -1
  73. beswarm-0.1.14.dist-info/RECORD +131 -0
  74. beswarm-0.1.12.dist-info/RECORD +0 -61
  75. {beswarm-0.1.12.dist-info → beswarm-0.1.14.dist-info}/WHEEL +0 -0
  76. {beswarm-0.1.12.dist-info → beswarm-0.1.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,941 @@
1
+ import os
2
+ import re
3
+ import json
4
+ import copy
5
+ import httpx
6
+ import asyncio
7
+ import requests
8
+ from typing import Set
9
+ from typing import Union, Optional, Callable, List, Dict, Any
10
+ from pathlib import Path
11
+
12
+
13
+ from .base import BaseLLM
14
+ from ..plugins import PLUGINS, get_tools_result_async, function_call_list, update_tools_config
15
+ from ..utils.scripts import safe_get, async_generator_to_sync, parse_function_xml, parse_continuous_json, convert_functions_to_xml
16
+ from ..core.request import prepare_request_payload
17
+ from ..core.response import fetch_response_stream
18
+
19
+ def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]:
20
+ """
21
+ Get filtered list of object variable names.
22
+ :param keys: List of keys to include. If the first key is "not", the remaining keys will be removed from the class keys.
23
+ :return: List of class keys.
24
+ """
25
+ class_keys = obj.__dict__.keys()
26
+ if not keys:
27
+ return set(class_keys)
28
+
29
+ # Remove the passed keys from the class keys.
30
+ if keys[0] == "not":
31
+ return {key for key in class_keys if key not in keys[1:]}
32
+ # Check if all passed keys are valid
33
+ if invalid_keys := set(keys) - class_keys:
34
+ raise ValueError(
35
+ f"Invalid keys: {invalid_keys}",
36
+ )
37
+ # Only return specified keys that are in class_keys
38
+ return {key for key in keys if key in class_keys}
39
+
40
+ class chatgpt(BaseLLM):
41
+ """
42
+ Official ChatGPT API
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ api_key: str = None,
48
+ engine: str = os.environ.get("GPT_ENGINE") or "gpt-4o",
49
+ api_url: str = (os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions"),
50
+ system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
51
+ proxy: str = None,
52
+ timeout: float = 600,
53
+ max_tokens: int = None,
54
+ temperature: float = 0.5,
55
+ top_p: float = 1.0,
56
+ presence_penalty: float = 0.0,
57
+ frequency_penalty: float = 0.0,
58
+ reply_count: int = 1,
59
+ truncate_limit: int = None,
60
+ use_plugins: bool = True,
61
+ print_log: bool = False,
62
+ tools: Optional[Union[list, str, Callable]] = [],
63
+ function_call_max_loop: int = 3,
64
+ cut_history_by_function_name: str = "",
65
+ cache_messages: list = None,
66
+ ) -> None:
67
+ """
68
+ Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
69
+ """
70
+ super().__init__(api_key, engine, api_url, system_prompt, proxy, timeout, max_tokens, temperature, top_p, presence_penalty, frequency_penalty, reply_count, truncate_limit, use_plugins=use_plugins, print_log=print_log)
71
+ self.conversation: dict[str, list[dict]] = {
72
+ "default": [
73
+ {
74
+ "role": "system",
75
+ "content": self.system_prompt,
76
+ },
77
+ ],
78
+ }
79
+ if cache_messages:
80
+ self.conversation["default"] = cache_messages
81
+ self.function_calls_counter = {}
82
+ self.function_call_max_loop = function_call_max_loop
83
+ self.cut_history_by_function_name = cut_history_by_function_name
84
+
85
+
86
+ # 注册和处理传入的工具
87
+ self._register_tools(tools)
88
+
89
+
90
+ def _register_tools(self, tools):
91
+ """动态注册工具函数并更新配置"""
92
+
93
+ self.plugins = copy.deepcopy(PLUGINS)
94
+ self.function_call_list = copy.deepcopy(function_call_list)
95
+ # 如果有新工具,需要注册到registry并更新配置
96
+ self.plugins, self.function_call_list, _ = update_tools_config()
97
+
98
+ if isinstance(tools, list):
99
+ self.tools = tools if tools else []
100
+ else:
101
+ self.tools = [tools] if tools else []
102
+
103
+ for tool in self.tools:
104
+ tool_name = tool.__name__ if callable(tool) else str(tool)
105
+ if tool_name in self.plugins:
106
+ self.plugins[tool_name] = True
107
+ else:
108
+ raise ValueError(f"Tool {tool_name} not found in plugins")
109
+
110
+ def add_to_conversation(
111
+ self,
112
+ message: Union[str, list],
113
+ role: str,
114
+ convo_id: str = "default",
115
+ function_name: str = "",
116
+ total_tokens: int = 0,
117
+ function_arguments: str = "",
118
+ pass_history: int = 9999,
119
+ function_call_id: str = "",
120
+ ) -> None:
121
+ """
122
+ Add a message to the conversation
123
+ """
124
+ # print("role", role, "function_name", function_name, "message", message)
125
+ if convo_id not in self.conversation:
126
+ self.reset(convo_id=convo_id)
127
+ if function_name == "" and message and message != None:
128
+ self.conversation[convo_id].append({"role": role, "content": message})
129
+ elif function_name != "" and message and message != None:
130
+ # 删除从 cut_history_by_function_name 以后的所有历史记录
131
+ if function_name == self.cut_history_by_function_name:
132
+ matching_message = next(filter(lambda x: safe_get(x, "tool_calls", 0, "function", "name", default="") == 'get_next_pdf', self.conversation[convo_id]), None)
133
+ if matching_message is not None:
134
+ self.conversation[convo_id] = self.conversation[convo_id][:self.conversation[convo_id].index(matching_message)]
135
+
136
+ if not (all(value == False for value in self.plugins.values()) or self.use_plugins == False):
137
+ self.conversation[convo_id].append({
138
+ "role": "assistant",
139
+ "tool_calls": [
140
+ {
141
+ "id": function_call_id,
142
+ "type": "function",
143
+ "function": {
144
+ "name": function_name,
145
+ "arguments": function_arguments,
146
+ },
147
+ }
148
+ ],
149
+ })
150
+ self.conversation[convo_id].append({"role": role, "tool_call_id": function_call_id, "content": message})
151
+ else:
152
+ self.conversation[convo_id].append({"role": "assistant", "content": convert_functions_to_xml(function_arguments)})
153
+ self.conversation[convo_id].append({"role": "user", "content": message})
154
+
155
+ else:
156
+ print('\033[31m')
157
+ print("error: add_to_conversation message is None or empty")
158
+ print("role", role, "function_name", function_name, "message", message)
159
+ print('\033[0m')
160
+
161
+ conversation_len = len(self.conversation[convo_id]) - 1
162
+ message_index = 0
163
+ # if self.print_log:
164
+ # replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(self.conversation[convo_id])))
165
+ # print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
166
+ while message_index < conversation_len:
167
+ if self.conversation[convo_id][message_index]["role"] == self.conversation[convo_id][message_index + 1]["role"]:
168
+ if self.conversation[convo_id][message_index].get("content") and self.conversation[convo_id][message_index + 1].get("content"):
169
+ if type(self.conversation[convo_id][message_index + 1]["content"]) == str \
170
+ and type(self.conversation[convo_id][message_index]["content"]) == list:
171
+ self.conversation[convo_id][message_index + 1]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index + 1]["content"]}]
172
+ if type(self.conversation[convo_id][message_index]["content"]) == str \
173
+ and type(self.conversation[convo_id][message_index + 1]["content"]) == list:
174
+ self.conversation[convo_id][message_index]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index]["content"]}]
175
+ if type(self.conversation[convo_id][message_index]["content"]) == dict \
176
+ and type(self.conversation[convo_id][message_index + 1]["content"]) == str:
177
+ self.conversation[convo_id][message_index]["content"] = [self.conversation[convo_id][message_index]["content"]]
178
+ self.conversation[convo_id][message_index + 1]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index + 1]["content"]}]
179
+ if type(self.conversation[convo_id][message_index]["content"]) == dict \
180
+ and type(self.conversation[convo_id][message_index + 1]["content"]) == list:
181
+ self.conversation[convo_id][message_index]["content"] = [self.conversation[convo_id][message_index]["content"]]
182
+ if type(self.conversation[convo_id][message_index]["content"]) == dict \
183
+ and type(self.conversation[convo_id][message_index + 1]["content"]) == dict:
184
+ self.conversation[convo_id][message_index]["content"] = [self.conversation[convo_id][message_index]["content"]]
185
+ self.conversation[convo_id][message_index + 1]["content"] = [self.conversation[convo_id][message_index + 1]["content"]]
186
+ if type(self.conversation[convo_id][message_index]["content"]) == list \
187
+ and type(self.conversation[convo_id][message_index + 1]["content"]) == dict:
188
+ self.conversation[convo_id][message_index + 1]["content"] = [self.conversation[convo_id][message_index + 1]["content"]]
189
+ self.conversation[convo_id][message_index]["content"] += self.conversation[convo_id][message_index + 1]["content"]
190
+ self.conversation[convo_id].pop(message_index + 1)
191
+ conversation_len = conversation_len - 1
192
+ else:
193
+ message_index = message_index + 1
194
+
195
+ history_len = len(self.conversation[convo_id])
196
+
197
+ history = pass_history
198
+ if pass_history < 2:
199
+ history = 2
200
+ while history_len > history:
201
+ mess_body = self.conversation[convo_id].pop(1)
202
+ history_len = history_len - 1
203
+ if mess_body.get("role") == "user":
204
+ assistant_body = self.conversation[convo_id].pop(1)
205
+ history_len = history_len - 1
206
+ if assistant_body.get("tool_calls"):
207
+ self.conversation[convo_id].pop(1)
208
+ history_len = history_len - 1
209
+
210
+ if total_tokens:
211
+ self.current_tokens[convo_id] = total_tokens
212
+ self.tokens_usage[convo_id] += total_tokens
213
+
214
+ def truncate_conversation(self, convo_id: str = "default") -> None:
215
+ """
216
+ Truncate the conversation
217
+ """
218
+ while True:
219
+ if (
220
+ self.current_tokens[convo_id] > self.truncate_limit
221
+ and len(self.conversation[convo_id]) > 1
222
+ ):
223
+ # Don't remove the first message
224
+ mess = self.conversation[convo_id].pop(1)
225
+ string_mess = json.dumps(mess, ensure_ascii=False)
226
+ self.current_tokens[convo_id] -= len(string_mess) / 4
227
+ print("Truncate message:", mess)
228
+ else:
229
+ break
230
+
231
+ async def get_post_body(
232
+ self,
233
+ prompt: str,
234
+ role: str = "user",
235
+ convo_id: str = "default",
236
+ model: str = "",
237
+ pass_history: int = 9999,
238
+ **kwargs,
239
+ ):
240
+ self.conversation[convo_id][0] = {"role": "system","content": self.system_prompt}
241
+
242
+ # 构造 provider 信息
243
+ provider = {
244
+ "provider": "openai",
245
+ "base_url": kwargs.get('api_url', self.api_url.chat_url),
246
+ "api": kwargs.get('api_key', self.api_key),
247
+ "model": [model or self.engine],
248
+ "tools": True if self.use_plugins else False,
249
+ "image": True
250
+ }
251
+
252
+ # 构造请求数据
253
+ request_data = {
254
+ "model": model or self.engine,
255
+ "messages": copy.deepcopy(self.conversation[convo_id]) if pass_history else [
256
+ {"role": "system","content": self.system_prompt},
257
+ {"role": role, "content": prompt}
258
+ ],
259
+ "stream": True,
260
+ "stream_options": {
261
+ "include_usage": True
262
+ },
263
+ "temperature": kwargs.get("temperature", self.temperature)
264
+ }
265
+
266
+ if kwargs.get("max_tokens", self.max_tokens):
267
+ request_data["max_tokens"] = kwargs.get("max_tokens", self.max_tokens)
268
+
269
+ # 添加工具相关信息
270
+ if kwargs.get("plugins", None):
271
+ self.plugins = kwargs.get("plugins")
272
+
273
+ plugins_status = kwargs.get("plugins", self.plugins)
274
+ if not (all(value == False for value in plugins_status.values()) or self.use_plugins == False):
275
+ tools_request_body = []
276
+ for item in plugins_status.keys():
277
+ try:
278
+ if plugins_status[item]:
279
+ tools_request_body.append({"type": "function", "function": self.function_call_list[item]})
280
+ except:
281
+ pass
282
+ if tools_request_body:
283
+ request_data["tools"] = tools_request_body
284
+ request_data["tool_choice"] = "auto"
285
+
286
+ # print("request_data", json.dumps(request_data, indent=4, ensure_ascii=False))
287
+
288
+ # 调用核心模块的 prepare_request_payload 函数
289
+ url, headers, json_post_body, engine_type = await prepare_request_payload(provider, request_data)
290
+
291
+ return url, headers, json_post_body, engine_type
292
+
293
+ async def _process_stream_response(
294
+ self,
295
+ response_gen,
296
+ convo_id="default",
297
+ function_name="",
298
+ total_tokens=0,
299
+ function_arguments="",
300
+ function_call_id="",
301
+ model="",
302
+ language="English",
303
+ system_prompt=None,
304
+ pass_history=9999,
305
+ is_async=False,
306
+ **kwargs
307
+ ):
308
+ """
309
+ 处理流式响应的共用逻辑
310
+
311
+ :param response_gen: 响应生成器(同步或异步)
312
+ :param is_async: 是否使用异步模式
313
+ """
314
+ response_role = None
315
+ full_response = ""
316
+ function_full_response = ""
317
+ function_call_name = ""
318
+ need_function_call = False
319
+
320
+ # 处理单行数据的公共逻辑
321
+ def process_line(line):
322
+ nonlocal response_role, full_response, function_full_response, function_call_name, need_function_call, total_tokens, function_call_id
323
+
324
+ if not line or (isinstance(line, str) and line.startswith(':')):
325
+ return None
326
+
327
+ if isinstance(line, str) and line.startswith('data:'):
328
+ line = line.lstrip("data: ")
329
+ if line == "[DONE]":
330
+ return "DONE"
331
+ elif isinstance(line, (dict, list)):
332
+ if isinstance(line, dict) and safe_get(line, "choices", 0, "message", "content"):
333
+ full_response = line["choices"][0]["message"]["content"]
334
+ return full_response
335
+ else:
336
+ return str(line)
337
+ else:
338
+ try:
339
+ if isinstance(line, str):
340
+ line = json.loads(line)
341
+ if safe_get(line, "choices", 0, "message", "content"):
342
+ full_response = line["choices"][0]["message"]["content"]
343
+ return full_response
344
+ else:
345
+ return str(line)
346
+ except:
347
+ print("json.loads error:", repr(line))
348
+ return None
349
+
350
+ resp = json.loads(line) if isinstance(line, str) else line
351
+ if "error" in resp:
352
+ raise Exception(f"{resp}")
353
+
354
+ total_tokens = total_tokens or safe_get(resp, "usage", "total_tokens", default=0)
355
+ delta = safe_get(resp, "choices", 0, "delta")
356
+ if not delta:
357
+ return None
358
+
359
+ response_role = response_role or safe_get(delta, "role")
360
+ if safe_get(delta, "content"):
361
+ need_function_call = False
362
+ content = delta["content"]
363
+ full_response += content
364
+ return content
365
+
366
+ if safe_get(delta, "tool_calls"):
367
+ need_function_call = True
368
+ function_call_name = function_call_name or safe_get(delta, "tool_calls", 0, "function", "name")
369
+ function_full_response += safe_get(delta, "tool_calls", 0, "function", "arguments", default="")
370
+ function_call_id = function_call_id or safe_get(delta, "tool_calls", 0, "id")
371
+ return None
372
+
373
+ # 处理流式响应
374
+ async def process_async():
375
+ nonlocal response_role, full_response, function_full_response, function_call_name, need_function_call, total_tokens, function_call_id
376
+
377
+ async for line in response_gen:
378
+ line = line.strip() if isinstance(line, str) else line
379
+ result = process_line(line)
380
+ if result == "DONE":
381
+ break
382
+ elif result:
383
+ yield result
384
+
385
+ def process_sync():
386
+ nonlocal response_role, full_response, function_full_response, function_call_name, need_function_call, total_tokens, function_call_id
387
+
388
+ for line in response_gen:
389
+ line = line.decode("utf-8") if hasattr(line, "decode") else line
390
+ result = process_line(line)
391
+ if result == "DONE":
392
+ break
393
+ elif result:
394
+ yield result
395
+
396
+ # 使用同步或异步处理器处理响应
397
+ if is_async:
398
+ async for chunk in process_async():
399
+ yield chunk
400
+ else:
401
+ for chunk in process_sync():
402
+ yield chunk
403
+
404
+ if self.print_log:
405
+ print("\n\rtotal_tokens", total_tokens)
406
+
407
+ if response_role is None:
408
+ response_role = "assistant"
409
+
410
+ function_parameter = parse_function_xml(full_response)
411
+ if function_parameter:
412
+ need_function_call = True
413
+
414
+ # 处理函数调用
415
+ if need_function_call and self.use_plugins != False:
416
+ if self.print_log:
417
+ print("function_parameter", function_parameter)
418
+ print("function_full_response", function_full_response)
419
+
420
+ function_response = ""
421
+ # 定义处理单个工具调用的辅助函数
422
+ async def process_single_tool_call(tool_name, tool_args, tool_id):
423
+ nonlocal function_response
424
+
425
+ if not self.function_calls_counter.get(tool_name):
426
+ self.function_calls_counter[tool_name] = 1
427
+ else:
428
+ self.function_calls_counter[tool_name] += 1
429
+
430
+ tool_response = ""
431
+ has_args = safe_get(self.function_call_list, tool_name, "parameters", "required", default=False)
432
+ if self.function_calls_counter[tool_name] <= self.function_call_max_loop and (tool_args != "{}" or not has_args):
433
+ function_call_max_tokens = self.truncate_limit - 1000
434
+ if function_call_max_tokens <= 0:
435
+ function_call_max_tokens = int(self.truncate_limit / 2)
436
+ if self.print_log:
437
+ print(f"\033[32m function_call {tool_name}, max token: {function_call_max_tokens} \033[0m")
438
+
439
+ # 处理函数调用结果
440
+ if is_async:
441
+ async for chunk in get_tools_result_async(
442
+ tool_name, tool_args, function_call_max_tokens,
443
+ model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
444
+ kwargs.get('api_url', self.api_url.chat_url), use_plugins=False, model=model or self.engine,
445
+ add_message=self.add_to_conversation, convo_id=convo_id, language=language
446
+ ):
447
+ yield chunk
448
+ else:
449
+ async def run_async():
450
+ async for chunk in get_tools_result_async(
451
+ tool_name, tool_args, function_call_max_tokens,
452
+ model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
453
+ kwargs.get('api_url', self.api_url.chat_url), use_plugins=False, model=model or self.engine,
454
+ add_message=self.add_to_conversation, convo_id=convo_id, language=language
455
+ ):
456
+ yield chunk
457
+
458
+ for chunk in async_generator_to_sync(run_async()):
459
+ yield chunk
460
+ else:
461
+ tool_response = f"无法找到相关信息,停止使用工具 {tool_name}"
462
+
463
+ yield tool_response
464
+
465
+ # 使用统一的JSON解析逻辑
466
+ try:
467
+ if function_full_response:
468
+ function_parameter = parse_continuous_json(function_full_response, function_call_name)
469
+ except Exception as e:
470
+ print(f"解析JSON失败: {e}")
471
+ # 保持原始工具调用
472
+ tool_calls = [{
473
+ 'function_name': function_call_name,
474
+ 'parameter': function_full_response,
475
+ 'function_call_id': function_call_id
476
+ }]
477
+
478
+ # 统一处理逻辑,将所有情况转换为列表处理
479
+ if isinstance(function_parameter, list) and function_parameter:
480
+ # 多个工具调用
481
+ tool_calls = function_parameter
482
+
483
+ # 处理所有工具调用
484
+ all_responses = []
485
+
486
+ for tool_info in tool_calls:
487
+ tool_name = tool_info['function_name']
488
+ tool_args = json.dumps(tool_info['parameter']) if not isinstance(tool_info['parameter'], str) else tool_info['parameter']
489
+ tool_id = tool_info.get('function_call_id', tool_name + "_tool_call")
490
+
491
+ tool_response = ""
492
+ if is_async:
493
+ async for chunk in process_single_tool_call(tool_name, tool_args, tool_id):
494
+ if isinstance(chunk, str) and "function_response:" in chunk:
495
+ tool_response = chunk.replace("function_response:", "")
496
+ else:
497
+ yield chunk
498
+ else:
499
+ for chunk in async_generator_to_sync(process_single_tool_call(tool_name, tool_args, tool_id)):
500
+ if isinstance(chunk, str) and "function_response:" in chunk:
501
+ tool_response = chunk.replace("function_response:", "")
502
+ else:
503
+ yield chunk
504
+ all_responses.append(f"[{tool_name}({tool_args}) Result]:\n\n{tool_response}")
505
+
506
+ # 合并所有工具响应
507
+ function_response = "\n\n".join(all_responses)
508
+
509
+ # 使用第一个工具的名称和参数作为历史记录
510
+ function_call_name = tool_calls[0]['function_name']
511
+ function_full_response = function_full_response or json.dumps(tool_calls) if not isinstance(tool_calls[0]['parameter'], str) else tool_calls
512
+ function_call_id = tool_calls[0].get('function_call_id', function_call_name + "_tool_call")
513
+
514
+ response_role = "tool"
515
+
516
+ # 递归处理函数调用响应
517
+ if is_async:
518
+ async for chunk in self.ask_stream_async(
519
+ function_response, response_role, convo_id=convo_id,
520
+ function_name=function_call_name, total_tokens=total_tokens,
521
+ model=model or self.engine, function_arguments=function_full_response,
522
+ function_call_id=function_call_id, api_key=kwargs.get('api_key', self.api_key),
523
+ api_url=kwargs.get('api_url', self.api_url.chat_url),
524
+ plugins=kwargs.get("plugins", self.plugins), system_prompt=system_prompt
525
+ ):
526
+ yield chunk
527
+ else:
528
+ for chunk in self.ask_stream(
529
+ function_response, response_role, convo_id=convo_id,
530
+ function_name=function_call_name, total_tokens=total_tokens,
531
+ model=model or self.engine, function_arguments=function_full_response,
532
+ function_call_id=function_call_id, api_key=kwargs.get('api_key', self.api_key),
533
+ api_url=kwargs.get('api_url', self.api_url.chat_url),
534
+ plugins=kwargs.get("plugins", self.plugins), system_prompt=system_prompt
535
+ ):
536
+ yield chunk
537
+ else:
538
+ # 添加响应到对话历史
539
+ self.add_to_conversation(full_response, response_role, convo_id=convo_id, total_tokens=total_tokens, pass_history=pass_history)
540
+ self.function_calls_counter = {}
541
+
542
+ # 清理翻译引擎相关的历史记录
543
+ if pass_history <= 2 and len(self.conversation[convo_id]) >= 2 \
544
+ and (
545
+ "You are a translation engine" in self.conversation[convo_id][-2]["content"] \
546
+ or "You are a translation engine" in safe_get(self.conversation, convo_id, -2, "content", 0, "text", default="") \
547
+ or "你是一位精通简体中文的专业翻译" in self.conversation[convo_id][-2]["content"] \
548
+ or "你是一位精通简体中文的专业翻译" in safe_get(self.conversation, convo_id, -2, "content", 0, "text", default="")
549
+ ):
550
+ self.conversation[convo_id].pop(-1)
551
+ self.conversation[convo_id].pop(-1)
552
+
553
+ def ask_stream(
554
+ self,
555
+ prompt: list,
556
+ role: str = "user",
557
+ convo_id: str = "default",
558
+ model: str = "",
559
+ pass_history: int = 9999,
560
+ function_name: str = "",
561
+ total_tokens: int = 0,
562
+ function_arguments: str = "",
563
+ function_call_id: str = "",
564
+ language: str = "English",
565
+ system_prompt: str = None,
566
+ **kwargs,
567
+ ):
568
+ """
569
+ Ask a question (同步流式响应)
570
+ """
571
+ # 准备会话
572
+ self.system_prompt = system_prompt or self.system_prompt
573
+ if convo_id not in self.conversation or pass_history <= 2:
574
+ self.reset(convo_id=convo_id, system_prompt=system_prompt)
575
+ self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name, total_tokens=total_tokens, function_arguments=function_arguments, function_call_id=function_call_id, pass_history=pass_history)
576
+
577
+ # 获取请求体
578
+ json_post = None
579
+ async def get_post_body_async():
580
+ nonlocal json_post
581
+ url, headers, json_post, engine_type = await self.get_post_body(prompt, role, convo_id, model, pass_history, **kwargs)
582
+ return url, headers, json_post, engine_type
583
+
584
+ # 替换原来的获取请求体的代码
585
+ # json_post = next(async_generator_to_sync(get_post_body_async()))
586
+ try:
587
+ url, headers, json_post, engine_type = asyncio.run(get_post_body_async())
588
+ except RuntimeError:
589
+ # 如果已经在事件循环中,则使用不同的方法
590
+ loop = asyncio.get_event_loop()
591
+ url, headers, json_post, engine_type = loop.run_until_complete(get_post_body_async())
592
+
593
+ self.truncate_conversation(convo_id=convo_id)
594
+
595
+ # 打印日志
596
+ if self.print_log:
597
+ print("api_url", kwargs.get('api_url', self.api_url.chat_url), url)
598
+ print("api_key", kwargs.get('api_key', self.api_key))
599
+
600
+ # 发送请求并处理响应
601
+ for _ in range(3):
602
+ if self.print_log:
603
+ replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
604
+ print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
605
+
606
+ try:
607
+ # 改进处理方式,创建一个内部异步函数来处理异步调用
608
+ async def process_async():
609
+ # 异步调用 fetch_response_stream
610
+ async_generator = fetch_response_stream(
611
+ self.aclient,
612
+ url,
613
+ headers,
614
+ json_post,
615
+ engine_type,
616
+ model or self.engine,
617
+ )
618
+ # 异步处理响应流
619
+ async for chunk in self._process_stream_response(
620
+ async_generator,
621
+ convo_id=convo_id,
622
+ function_name=function_name,
623
+ total_tokens=total_tokens,
624
+ function_arguments=function_arguments,
625
+ function_call_id=function_call_id,
626
+ model=model,
627
+ language=language,
628
+ system_prompt=system_prompt,
629
+ pass_history=pass_history,
630
+ is_async=True,
631
+ **kwargs
632
+ ):
633
+ yield chunk
634
+
635
+ # 将异步函数转换为同步生成器
636
+ return async_generator_to_sync(process_async())
637
+ except ConnectionError:
638
+ print("连接错误,请检查服务器状态或网络连接。")
639
+ return
640
+ except requests.exceptions.ReadTimeout:
641
+ print("请求超时,请检查网络连接或增加超时时间。")
642
+ return
643
+ except httpx.RemoteProtocolError:
644
+ continue
645
+ except Exception as e:
646
+ print(f"发生了未预料的错误:{e}")
647
+ if "Invalid URL" in str(e):
648
+ e = "您输入了无效的API URL,请使用正确的URL并使用`/start`命令重新设置API URL。具体错误如下:\n\n" + str(e)
649
+ raise Exception(f"{e}")
650
+ # 最后一次重试失败,向上抛出异常
651
+ if _ == 2:
652
+ raise Exception(f"{e}")
653
+
654
+ async def ask_stream_async(
655
+ self,
656
+ prompt: list,
657
+ role: str = "user",
658
+ convo_id: str = "default",
659
+ model: str = "",
660
+ pass_history: int = 9999,
661
+ function_name: str = "",
662
+ total_tokens: int = 0,
663
+ function_arguments: str = "",
664
+ function_call_id: str = "",
665
+ language: str = "English",
666
+ system_prompt: str = None,
667
+ **kwargs,
668
+ ):
669
+ """
670
+ Ask a question (异步流式响应)
671
+ """
672
+ # 准备会话
673
+ self.system_prompt = system_prompt or self.system_prompt
674
+ if convo_id not in self.conversation or pass_history <= 2:
675
+ self.reset(convo_id=convo_id, system_prompt=system_prompt)
676
+ self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name, total_tokens=total_tokens, function_arguments=function_arguments, pass_history=pass_history, function_call_id=function_call_id)
677
+
678
+ # 获取请求体
679
+ url, headers, json_post, engine_type = await self.get_post_body(prompt, role, convo_id, model, pass_history, **kwargs)
680
+ self.truncate_conversation(convo_id=convo_id)
681
+
682
+ # 打印日志
683
+ if self.print_log:
684
+ print("api_url", kwargs.get('api_url', self.api_url.chat_url) == url)
685
+ print("api_url", kwargs.get('api_url', self.api_url.chat_url))
686
+ print("api_url", url)
687
+ # print("headers", headers)
688
+ print("api_key", kwargs.get('api_key', self.api_key))
689
+
690
+ # 发送请求并处理响应
691
+ for _ in range(3):
692
+ if self.print_log:
693
+ replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(json_post)))
694
+ print(json.dumps(replaced_text, indent=4, ensure_ascii=False))
695
+
696
+ try:
697
+ # 使用fetch_response_stream处理响应
698
+ generator = fetch_response_stream(
699
+ self.aclient,
700
+ url,
701
+ headers,
702
+ json_post,
703
+ engine_type,
704
+ model or self.engine,
705
+ )
706
+ # if isinstance(chunk, dict) and "error" in chunk:
707
+ # # 处理错误响应
708
+ # if chunk["status_code"] in (400, 422, 503):
709
+ # json_post, should_retry = await self._handle_response_error(
710
+ # type('Response', (), {'status_code': chunk["status_code"], 'text': json.dumps(chunk["details"]), 'aread': lambda: asyncio.sleep(0)}),
711
+ # json_post
712
+ # )
713
+ # if should_retry:
714
+ # break # 跳出内部循环,继续外部循环重试
715
+ # raise Exception(f"{chunk['status_code']} {chunk['error']} {chunk['details']}")
716
+
717
+ # 处理正常响应
718
+ async for processed_chunk in self._process_stream_response(
719
+ generator,
720
+ convo_id=convo_id,
721
+ function_name=function_name,
722
+ total_tokens=total_tokens,
723
+ function_arguments=function_arguments,
724
+ function_call_id=function_call_id,
725
+ model=model,
726
+ language=language,
727
+ system_prompt=system_prompt,
728
+ pass_history=pass_history,
729
+ is_async=True,
730
+ **kwargs
731
+ ):
732
+ yield processed_chunk
733
+
734
+ # 成功处理,跳出重试循环
735
+ break
736
+ except httpx.RemoteProtocolError:
737
+ continue
738
+ except Exception as e:
739
+ print(f"发生了未预料的错误:{e}")
740
+ import traceback
741
+ traceback.print_exc()
742
+ if "Invalid URL" in str(e):
743
+ e = "您输入了无效的API URL,请使用正确的URL并使用`/start`命令重新设置API URL。具体错误如下:\n\n" + str(e)
744
+ raise Exception(f"{e}")
745
+ # 最后一次重试失败,向上抛出异常
746
+ if _ == 2:
747
+ raise Exception(f"{e}")
748
+
749
+ async def ask_async(
750
+ self,
751
+ prompt: str,
752
+ role: str = "user",
753
+ convo_id: str = "default",
754
+ model: str = "",
755
+ pass_history: int = 9999,
756
+ **kwargs,
757
+ ) -> str:
758
+ """
759
+ Non-streaming ask
760
+ """
761
+ response = self.ask_stream_async(
762
+ prompt=prompt,
763
+ role=role,
764
+ convo_id=convo_id,
765
+ pass_history=pass_history,
766
+ model=model or self.engine,
767
+ **kwargs,
768
+ )
769
+ full_response: str = "".join([r async for r in response])
770
+ return full_response
771
+
772
+ def rollback(self, n: int = 1, convo_id: str = "default") -> None:
773
+ """
774
+ Rollback the conversation
775
+ """
776
+ for _ in range(n):
777
+ self.conversation[convo_id].pop()
778
+
779
+ def reset(self, convo_id: str = "default", system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally") -> None:
780
+ """
781
+ Reset the conversation
782
+ """
783
+ self.system_prompt = system_prompt or self.system_prompt
784
+ self.conversation[convo_id] = [
785
+ {"role": "system", "content": self.system_prompt},
786
+ ]
787
+ self.tokens_usage[convo_id] = 0
788
+ self.current_tokens[convo_id] = 0
789
+
790
+ def save(self, file: str, *keys: str) -> None:
791
+ """
792
+ Save the Chatbot configuration to a JSON file
793
+ """
794
+ with open(file, "w", encoding="utf-8") as f:
795
+ data = {
796
+ key: self.__dict__[key]
797
+ for key in get_filtered_keys_from_object(self, *keys)
798
+ }
799
+ # saves session.proxies dict as session
800
+ # leave this here for compatibility
801
+ data["session"] = data["proxy"]
802
+ del data["aclient"]
803
+ json.dump(
804
+ data,
805
+ f,
806
+ indent=2,
807
+ )
808
+
809
+ def load(self, file: Path, *keys_: str) -> None:
810
+ """
811
+ Load the Chatbot configuration from a JSON file
812
+ """
813
+ with open(file, encoding="utf-8") as f:
814
+ # load json, if session is in keys, load proxies
815
+ loaded_config = json.load(f)
816
+ keys = get_filtered_keys_from_object(self, *keys_)
817
+
818
+ if (
819
+ "session" in keys
820
+ and loaded_config["session"]
821
+ or "proxy" in keys
822
+ and loaded_config["proxy"]
823
+ ):
824
+ self.proxy = loaded_config.get("session", loaded_config["proxy"])
825
+ self.session = httpx.Client(
826
+ follow_redirects=True,
827
+ proxies=self.proxy,
828
+ timeout=self.timeout,
829
+ cookies=self.session.cookies,
830
+ headers=self.session.headers,
831
+ )
832
+ self.aclient = httpx.AsyncClient(
833
+ follow_redirects=True,
834
+ proxies=self.proxy,
835
+ timeout=self.timeout,
836
+ cookies=self.session.cookies,
837
+ headers=self.session.headers,
838
+ )
839
+ if "session" in keys:
840
+ keys.remove("session")
841
+ if "aclient" in keys:
842
+ keys.remove("aclient")
843
+ self.__dict__.update({key: loaded_config[key] for key in keys})
844
+
845
+ def _handle_response_error_common(self, response_text, json_post):
846
+ """通用的响应错误处理逻辑,适用于同步和异步场景"""
847
+ try:
848
+ # 检查内容审核失败
849
+ if "Content did not pass the moral check" in response_text:
850
+ return json_post, False, f"内容未通过道德检查:{response_text[:400]}"
851
+
852
+ # 处理函数调用相关错误
853
+ if "function calling" in response_text:
854
+ if "tools" in json_post:
855
+ del json_post["tools"]
856
+ if "tool_choice" in json_post:
857
+ del json_post["tool_choice"]
858
+ return json_post, True, None
859
+
860
+ # 处理请求格式错误
861
+ elif "invalid_request_error" in response_text:
862
+ for index, mess in enumerate(json_post["messages"]):
863
+ if type(mess["content"]) == list and "text" in mess["content"][0]:
864
+ json_post["messages"][index] = {
865
+ "role": mess["role"],
866
+ "content": mess["content"][0]["text"]
867
+ }
868
+ return json_post, True, None
869
+
870
+ # 处理角色不允许错误
871
+ elif "'function' is not an allowed role" in response_text:
872
+ if json_post["messages"][-1]["role"] == "tool":
873
+ mess = json_post["messages"][-1]
874
+ json_post["messages"][-1] = {
875
+ "role": "assistant",
876
+ "name": mess["name"],
877
+ "content": mess["content"]
878
+ }
879
+ return json_post, True, None
880
+
881
+ # 处理服务器繁忙错误
882
+ elif "Sorry, server is busy" in response_text:
883
+ for index, mess in enumerate(json_post["messages"]):
884
+ if type(mess["content"]) == list and "text" in mess["content"][0]:
885
+ json_post["messages"][index] = {
886
+ "role": mess["role"],
887
+ "content": mess["content"][0]["text"]
888
+ }
889
+ return json_post, True, None
890
+
891
+ # 处理token超限错误
892
+ elif "is not possible because the prompts occupy" in response_text:
893
+ max_tokens = re.findall(r"only\s(\d+)\stokens", response_text)
894
+ if max_tokens:
895
+ json_post["max_tokens"] = int(max_tokens[0])
896
+ return json_post, True, None
897
+
898
+ # 默认移除工具相关设置
899
+ else:
900
+ if "tools" in json_post:
901
+ del json_post["tools"]
902
+ if "tool_choice" in json_post:
903
+ del json_post["tool_choice"]
904
+ return json_post, True, None
905
+
906
+ except Exception as e:
907
+ print(f"处理响应错误时出现异常: {e}")
908
+ return json_post, False, str(e)
909
+
910
+ def _handle_response_error_sync(self, response, json_post):
911
+ """处理API响应错误并相应地修改请求体(同步版本)"""
912
+ response_text = response.text
913
+
914
+ # 处理空响应
915
+ if response.status_code == 200 and response_text == "":
916
+ for index, mess in enumerate(json_post["messages"]):
917
+ if type(mess["content"]) == list and "text" in mess["content"][0]:
918
+ json_post["messages"][index] = {
919
+ "role": mess["role"],
920
+ "content": mess["content"][0]["text"]
921
+ }
922
+ return json_post, True
923
+
924
+ json_post, should_retry, error_msg = self._handle_response_error_common(response_text, json_post)
925
+
926
+ if error_msg:
927
+ raise Exception(f"{response.status_code} {response.reason} {error_msg}")
928
+
929
+ return json_post, should_retry
930
+
931
+ async def _handle_response_error(self, response, json_post):
932
+ """处理API响应错误并相应地修改请求体(异步版本)"""
933
+ await response.aread()
934
+ response_text = response.text
935
+
936
+ json_post, should_retry, error_msg = self._handle_response_error_common(response_text, json_post)
937
+
938
+ if error_msg:
939
+ raise Exception(f"{response.status_code} {response.reason_phrase} {error_msg}")
940
+
941
+ return json_post, should_retry