aient 1.0.38__tar.gz → 1.0.39__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {aient-1.0.38/src/aient.egg-info → aient-1.0.39}/PKG-INFO +1 -1
  2. {aient-1.0.38 → aient-1.0.39}/setup.py +1 -1
  3. {aient-1.0.38 → aient-1.0.39}/src/aient/models/chatgpt.py +28 -15
  4. aient-1.0.39/src/aient/utils/scripts.py +551 -0
  5. {aient-1.0.38 → aient-1.0.39/src/aient.egg-info}/PKG-INFO +1 -1
  6. aient-1.0.38/src/aient/utils/scripts.py +0 -235
  7. {aient-1.0.38 → aient-1.0.39}/LICENSE +0 -0
  8. {aient-1.0.38 → aient-1.0.39}/MANIFEST.in +0 -0
  9. {aient-1.0.38 → aient-1.0.39}/README.md +0 -0
  10. {aient-1.0.38 → aient-1.0.39}/setup.cfg +0 -0
  11. {aient-1.0.38 → aient-1.0.39}/src/aient/__init__.py +0 -0
  12. {aient-1.0.38 → aient-1.0.39}/src/aient/core/.git +0 -0
  13. {aient-1.0.38 → aient-1.0.39}/src/aient/core/__init__.py +0 -0
  14. {aient-1.0.38 → aient-1.0.39}/src/aient/core/log_config.py +0 -0
  15. {aient-1.0.38 → aient-1.0.39}/src/aient/core/models.py +0 -0
  16. {aient-1.0.38 → aient-1.0.39}/src/aient/core/request.py +0 -0
  17. {aient-1.0.38 → aient-1.0.39}/src/aient/core/response.py +0 -0
  18. {aient-1.0.38 → aient-1.0.39}/src/aient/core/test/test_base_api.py +0 -0
  19. {aient-1.0.38 → aient-1.0.39}/src/aient/core/test/test_image.py +0 -0
  20. {aient-1.0.38 → aient-1.0.39}/src/aient/core/test/test_payload.py +0 -0
  21. {aient-1.0.38 → aient-1.0.39}/src/aient/core/utils.py +0 -0
  22. {aient-1.0.38 → aient-1.0.39}/src/aient/models/__init__.py +0 -0
  23. {aient-1.0.38 → aient-1.0.39}/src/aient/models/audio.py +0 -0
  24. {aient-1.0.38 → aient-1.0.39}/src/aient/models/base.py +0 -0
  25. {aient-1.0.38 → aient-1.0.39}/src/aient/models/claude.py +0 -0
  26. {aient-1.0.38 → aient-1.0.39}/src/aient/models/duckduckgo.py +0 -0
  27. {aient-1.0.38 → aient-1.0.39}/src/aient/models/gemini.py +0 -0
  28. {aient-1.0.38 → aient-1.0.39}/src/aient/models/groq.py +0 -0
  29. {aient-1.0.38 → aient-1.0.39}/src/aient/models/vertex.py +0 -0
  30. {aient-1.0.38 → aient-1.0.39}/src/aient/plugins/__init__.py +0 -0
  31. {aient-1.0.38 → aient-1.0.39}/src/aient/plugins/arXiv.py +0 -0
  32. {aient-1.0.38 → aient-1.0.39}/src/aient/plugins/config.py +0 -0
  33. {aient-1.0.38 → aient-1.0.39}/src/aient/plugins/image.py +0 -0
  34. {aient-1.0.38 → aient-1.0.39}/src/aient/plugins/registry.py +0 -0
  35. {aient-1.0.38 → aient-1.0.39}/src/aient/plugins/run_python.py +0 -0
  36. {aient-1.0.38 → aient-1.0.39}/src/aient/plugins/today.py +0 -0
  37. {aient-1.0.38 → aient-1.0.39}/src/aient/plugins/websearch.py +0 -0
  38. {aient-1.0.38 → aient-1.0.39}/src/aient/utils/__init__.py +0 -0
  39. {aient-1.0.38 → aient-1.0.39}/src/aient/utils/prompt.py +0 -0
  40. {aient-1.0.38 → aient-1.0.39}/src/aient.egg-info/SOURCES.txt +0 -0
  41. {aient-1.0.38 → aient-1.0.39}/src/aient.egg-info/dependency_links.txt +0 -0
  42. {aient-1.0.38 → aient-1.0.39}/src/aient.egg-info/requires.txt +0 -0
  43. {aient-1.0.38 → aient-1.0.39}/src/aient.egg-info/top_level.txt +0 -0
  44. {aient-1.0.38 → aient-1.0.39}/test/test.py +0 -0
  45. {aient-1.0.38 → aient-1.0.39}/test/test_API.py +0 -0
  46. {aient-1.0.38 → aient-1.0.39}/test/test_Deepbricks.py +0 -0
  47. {aient-1.0.38 → aient-1.0.39}/test/test_Web_crawler.py +0 -0
  48. {aient-1.0.38 → aient-1.0.39}/test/test_aiwaves.py +0 -0
  49. {aient-1.0.38 → aient-1.0.39}/test/test_aiwaves_arxiv.py +0 -0
  50. {aient-1.0.38 → aient-1.0.39}/test/test_ask_gemini.py +0 -0
  51. {aient-1.0.38 → aient-1.0.39}/test/test_class.py +0 -0
  52. {aient-1.0.38 → aient-1.0.39}/test/test_claude.py +0 -0
  53. {aient-1.0.38 → aient-1.0.39}/test/test_claude_zh_char.py +0 -0
  54. {aient-1.0.38 → aient-1.0.39}/test/test_ddg_search.py +0 -0
  55. {aient-1.0.38 → aient-1.0.39}/test/test_download_pdf.py +0 -0
  56. {aient-1.0.38 → aient-1.0.39}/test/test_gemini.py +0 -0
  57. {aient-1.0.38 → aient-1.0.39}/test/test_get_token_dict.py +0 -0
  58. {aient-1.0.38 → aient-1.0.39}/test/test_google_search.py +0 -0
  59. {aient-1.0.38 → aient-1.0.39}/test/test_jieba.py +0 -0
  60. {aient-1.0.38 → aient-1.0.39}/test/test_json.py +0 -0
  61. {aient-1.0.38 → aient-1.0.39}/test/test_langchain_search_old.py +0 -0
  62. {aient-1.0.38 → aient-1.0.39}/test/test_logging.py +0 -0
  63. {aient-1.0.38 → aient-1.0.39}/test/test_ollama.py +0 -0
  64. {aient-1.0.38 → aient-1.0.39}/test/test_plugin.py +0 -0
  65. {aient-1.0.38 → aient-1.0.39}/test/test_py_run.py +0 -0
  66. {aient-1.0.38 → aient-1.0.39}/test/test_requests.py +0 -0
  67. {aient-1.0.38 → aient-1.0.39}/test/test_search.py +0 -0
  68. {aient-1.0.38 → aient-1.0.39}/test/test_tikitoken.py +0 -0
  69. {aient-1.0.38 → aient-1.0.39}/test/test_token.py +0 -0
  70. {aient-1.0.38 → aient-1.0.39}/test/test_url.py +0 -0
  71. {aient-1.0.38 → aient-1.0.39}/test/test_whisper.py +0 -0
  72. {aient-1.0.38 → aient-1.0.39}/test/test_wildcard.py +0 -0
  73. {aient-1.0.38 → aient-1.0.39}/test/test_yjh.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aient
3
- Version: 1.0.38
3
+ Version: 1.0.39
4
4
  Summary: Aient: The Awakening of Agent.
5
5
  Description-Content-Type: text/markdown
6
6
  License-File: LICENSE
@@ -4,7 +4,7 @@ from setuptools import setup, find_packages
4
4
 
5
5
  setup(
6
6
  name="aient",
7
- version="1.0.38",
7
+ version="1.0.39",
8
8
  description="Aient: The Awakening of Agent.",
9
9
  long_description=Path.open(Path("README.md"), encoding="utf-8").read(),
10
10
  long_description_content_type="text/markdown",
@@ -12,7 +12,7 @@ from pathlib import Path
12
12
 
13
13
  from .base import BaseLLM
14
14
  from ..plugins import PLUGINS, get_tools_result_async, function_call_list, update_tools_config
15
- from ..utils.scripts import check_json, safe_get, async_generator_to_sync
15
+ from ..utils.scripts import check_json, safe_get, async_generator_to_sync, parse_function_xml
16
16
  from ..core.request import prepare_request_payload
17
17
  from ..core.response import fetch_response_stream
18
18
 
@@ -131,20 +131,26 @@ class chatgpt(BaseLLM):
131
131
  matching_message = next(filter(lambda x: safe_get(x, "tool_calls", 0, "function", "name", default="") == 'get_next_pdf', self.conversation[convo_id]), None)
132
132
  if matching_message is not None:
133
133
  self.conversation[convo_id] = self.conversation[convo_id][:self.conversation[convo_id].index(matching_message)]
134
- self.conversation[convo_id].append({
135
- "role": "assistant",
136
- "tool_calls": [
137
- {
138
- "id": function_call_id,
139
- "type": "function",
140
- "function": {
141
- "name": function_name,
142
- "arguments": function_arguments,
143
- },
144
- }
145
- ],
146
- })
147
- self.conversation[convo_id].append({"role": role, "tool_call_id": function_call_id, "content": message})
134
+
135
+ if not (all(value == False for value in self.plugins.values()) or self.use_plugins == False):
136
+ self.conversation[convo_id].append({
137
+ "role": "assistant",
138
+ "tool_calls": [
139
+ {
140
+ "id": function_call_id,
141
+ "type": "function",
142
+ "function": {
143
+ "name": function_name,
144
+ "arguments": function_arguments,
145
+ },
146
+ }
147
+ ],
148
+ })
149
+ self.conversation[convo_id].append({"role": role, "tool_call_id": function_call_id, "content": message})
150
+ else:
151
+ self.conversation[convo_id].append({"role": "assistant", "content": "I will use tool: " + function_name + ". tool arguments:" + function_arguments + ". I will get the tool call result in the next user response."})
152
+ self.conversation[convo_id].append({"role": "user", "content": f"[{function_name} Result]\n\n" + message})
153
+
148
154
  else:
149
155
  print('\033[31m')
150
156
  print("error: add_to_conversation message is None or empty")
@@ -390,6 +396,13 @@ class chatgpt(BaseLLM):
390
396
  if response_role is None:
391
397
  response_role = "assistant"
392
398
 
399
+ function_parameter = parse_function_xml(full_response)
400
+ if function_parameter['function_name']:
401
+ need_function_call = True
402
+ function_call_name = function_parameter['function_name']
403
+ function_full_response = json.dumps(function_parameter['parameter'])
404
+ function_call_id = function_parameter['function_name'] + "_tool_call"
405
+
393
406
  # 处理函数调用
394
407
  if need_function_call:
395
408
  if self.print_log:
@@ -0,0 +1,551 @@
1
+ import os
2
+ import json
3
+ import base64
4
+ import tiktoken
5
+ import requests
6
+ import urllib.parse
7
+
8
+ from ..core.utils import get_image_message
9
+
10
+ def get_encode_text(text, model_name):
11
+ tiktoken.get_encoding("cl100k_base")
12
+ model_name = "gpt-3.5-turbo"
13
+ encoding = tiktoken.encoding_for_model(model_name)
14
+ encode_text = encoding.encode(text, disallowed_special=())
15
+ return encoding, encode_text
16
+
17
+ def get_text_token_len(text, model_name):
18
+ encoding, encode_text = get_encode_text(text, model_name)
19
+ return len(encode_text)
20
+
21
+ def cut_message(message: str, max_tokens: int, model_name: str):
22
+ if type(message) != str:
23
+ message = str(message)
24
+ encoding, encode_text = get_encode_text(message, model_name)
25
+ if len(encode_text) > max_tokens:
26
+ encode_text = encode_text[:max_tokens]
27
+ message = encoding.decode(encode_text)
28
+ encode_text = encoding.encode(message, disallowed_special=())
29
+ return message, len(encode_text)
30
+
31
+ import imghdr
32
+ def encode_image(image_path):
33
+ with open(image_path, "rb") as image_file:
34
+ file_content = image_file.read()
35
+ file_type = imghdr.what(None, file_content)
36
+ base64_encoded = base64.b64encode(file_content).decode('utf-8')
37
+
38
+ if file_type == 'png':
39
+ return f"data:image/png;base64,{base64_encoded}"
40
+ elif file_type in ['jpeg', 'jpg']:
41
+ return f"data:image/jpeg;base64,{base64_encoded}"
42
+ else:
43
+ raise ValueError(f"不支持的图片格式: {file_type}")
44
+
45
+ def get_doc_from_url(url):
46
+ filename = urllib.parse.unquote(url.split("/")[-1])
47
+ response = requests.get(url, stream=True)
48
+ with open(filename, 'wb') as f:
49
+ for chunk in response.iter_content(chunk_size=1024):
50
+ f.write(chunk)
51
+ return filename
52
+
53
+ def get_encode_image(image_url):
54
+ filename = get_doc_from_url(image_url)
55
+ image_path = os.getcwd() + "/" + filename
56
+ base64_image = encode_image(image_path)
57
+ os.remove(image_path)
58
+ return base64_image
59
+
60
+ from io import BytesIO
61
+ def get_audio_message(file_bytes):
62
+ try:
63
+ # 创建一个字节流对象
64
+ audio_stream = BytesIO(file_bytes)
65
+
66
+ # 直接使用字节流对象进行转录
67
+ import config
68
+ transcript = config.whisperBot.generate(audio_stream)
69
+ # print("transcript", transcript)
70
+
71
+ return transcript
72
+
73
+ except Exception as e:
74
+ return f"处理音频文件时出错: {str(e)}"
75
+
76
+ async def Document_extract(docurl, docpath=None, engine_type = None):
77
+ filename = docpath
78
+ text = None
79
+ prompt = None
80
+ if docpath and docurl and "paper.pdf" != docpath:
81
+ filename = get_doc_from_url(docurl)
82
+ docpath = os.getcwd() + "/" + filename
83
+ if filename and filename[-3:] == "pdf":
84
+ from pdfminer.high_level import extract_text
85
+ text = extract_text(docpath)
86
+ if filename and (filename[-3:] == "txt" or filename[-3:] == ".md" or filename[-3:] == ".py" or filename[-3:] == "yml"):
87
+ with open(docpath, 'r') as f:
88
+ text = f.read()
89
+ if text:
90
+ prompt = (
91
+ "Here is the document, inside <document></document> XML tags:"
92
+ "<document>"
93
+ "{}"
94
+ "</document>"
95
+ ).format(text)
96
+ if filename and filename[-3:] == "jpg" or filename[-3:] == "png" or filename[-4:] == "jpeg":
97
+ prompt = await get_image_message(docurl, engine_type)
98
+ if filename and filename[-3:] == "wav" or filename[-3:] == "mp3":
99
+ with open(docpath, "rb") as file:
100
+ file_bytes = file.read()
101
+ prompt = get_audio_message(file_bytes)
102
+ prompt = (
103
+ "Here is the text content after voice-to-text conversion, inside <voice-to-text></voice-to-text> XML tags:"
104
+ "<voice-to-text>"
105
+ "{}"
106
+ "</voice-to-text>"
107
+ ).format(prompt)
108
+ if os.path.exists(docpath):
109
+ os.remove(docpath)
110
+ return prompt
111
+
112
+ def split_json_strings(input_string):
113
+ # 初始化结果列表和当前 JSON 字符串
114
+ json_strings = []
115
+ current_json = ""
116
+ brace_count = 0
117
+
118
+ # 遍历输入字符串的每个字符
119
+ for char in input_string:
120
+ current_json += char
121
+ if char == '{':
122
+ brace_count += 1
123
+ elif char == '}':
124
+ brace_count -= 1
125
+
126
+ # 如果花括号配对完成,我们找到了一个完整的 JSON 字符串
127
+ if brace_count == 0:
128
+ # 尝试解析当前 JSON 字符串
129
+ try:
130
+ json.loads(current_json)
131
+ json_strings.append(current_json)
132
+ current_json = ""
133
+ except json.JSONDecodeError:
134
+ # 如果解析失败,继续添加字符
135
+ pass
136
+ if json_strings == []:
137
+ json_strings.append(input_string)
138
+ return json_strings
139
+
140
+ def check_json(json_data):
141
+ while True:
142
+ try:
143
+ result = split_json_strings(json_data)
144
+ if len(result) > 0:
145
+ json_data = result[0]
146
+ json.loads(json_data)
147
+ break
148
+ except json.decoder.JSONDecodeError as e:
149
+ print("JSON error:", e)
150
+ print("JSON body", repr(json_data))
151
+ if "Invalid control character" in str(e):
152
+ json_data = json_data.replace("\n", "\\n")
153
+ elif "Unterminated string starting" in str(e):
154
+ json_data += '"}'
155
+ elif "Expecting ',' delimiter" in str(e):
156
+ json_data = {"prompt": json_data}
157
+ elif "Expecting ':' delimiter" in str(e):
158
+ json_data = '{"prompt": ' + json.dumps(json_data) + '}'
159
+ elif "Expecting value: line 1 column 1" in str(e):
160
+ if json_data.startswith("prompt: "):
161
+ json_data = json_data.replace("prompt: ", "")
162
+ json_data = '{"prompt": ' + json.dumps(json_data) + '}'
163
+ else:
164
+ json_data = '{"prompt": ' + json.dumps(json_data) + '}'
165
+ return json_data
166
+
167
+ def is_surrounded_by_chinese(text, index):
168
+ left_char = text[index - 1]
169
+ if 0 < index < len(text) - 1:
170
+ right_char = text[index + 1]
171
+ return '\u4e00' <= left_char <= '\u9fff' or '\u4e00' <= right_char <= '\u9fff'
172
+ if index == len(text) - 1:
173
+ return '\u4e00' <= left_char <= '\u9fff'
174
+ return False
175
+
176
+ def replace_char(string, index, new_char):
177
+ return string[:index] + new_char + string[index+1:]
178
+
179
+ def claude_replace(text):
180
+ Punctuation_mapping = {",": ",", ":": ":", "!": "!", "?": "?", ";": ";"}
181
+ key_list = list(Punctuation_mapping.keys())
182
+ for i in range(len(text)):
183
+ if is_surrounded_by_chinese(text, i) and (text[i] in key_list):
184
+ text = replace_char(text, i, Punctuation_mapping[text[i]])
185
+ return text
186
+
187
+ def safe_get(data, *keys, default=None):
188
+ for key in keys:
189
+ try:
190
+ data = data[key] if isinstance(data, (dict, list)) else data.get(key)
191
+ except (KeyError, IndexError, AttributeError, TypeError):
192
+ return default
193
+ return data
194
+
195
+ import asyncio
196
+ def async_generator_to_sync(async_gen):
197
+ """
198
+ 将异步生成器转换为同步生成器的工具函数
199
+
200
+ Args:
201
+ async_gen: 异步生成器函数
202
+
203
+ Yields:
204
+ 异步生成器产生的每个值
205
+ """
206
+ loop = asyncio.new_event_loop()
207
+ asyncio.set_event_loop(loop)
208
+
209
+ try:
210
+ async def collect_chunks():
211
+ chunks = []
212
+ async for chunk in async_gen:
213
+ chunks.append(chunk)
214
+ return chunks
215
+
216
+ chunks = loop.run_until_complete(collect_chunks())
217
+ for chunk in chunks:
218
+ yield chunk
219
+
220
+ except Exception as e:
221
+ print(f"Error during async execution: {e}")
222
+ raise
223
+ finally:
224
+ try:
225
+ # 清理所有待处理的任务
226
+ tasks = [t for t in asyncio.all_tasks(loop) if not t.done()]
227
+ if tasks:
228
+ loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
229
+ loop.run_until_complete(loop.shutdown_asyncgens())
230
+ loop.close()
231
+ except Exception as e:
232
+ print(f"Error during cleanup: {e}")
233
+
234
+ def parse_tools_from_cursor_prompt(text):
235
+ import json
236
+ import re
237
+
238
+ # 从 cursor_prompt 中提取 <tools> 标签内的 JSON 字符串
239
+ tools_match = re.search(r"<tools>\n(.*?)\n</tools>", text, re.DOTALL)
240
+ if tools_match:
241
+ tools_json_string = tools_match.group(1).strip()
242
+ try:
243
+ tools_list_data = json.loads(tools_json_string, strict=False)
244
+ return tools_list_data
245
+ except json.JSONDecodeError as e:
246
+ print(f"解析 JSON 时出错: {e}")
247
+ return []
248
+
249
+ from dataclasses import dataclass
250
+ from typing import List, Callable, Optional, TypeVar, Generic, Union, Dict, Any
251
+
252
+ # 定义结果类型
253
+ @dataclass
254
+ class XmlMatcherResult:
255
+ matched: bool
256
+ data: str = ""
257
+
258
+ # 泛型类型变量,用于 transform 的返回类型
259
+ R = TypeVar('R')
260
+
261
+ class XmlMatcher(Generic[R]):
262
+ def __init__(self,
263
+ tag_name: str,
264
+ transform: Optional[Callable[[XmlMatcherResult], R]] = None,
265
+ position: int = 0):
266
+ self.tag_name: str = tag_name
267
+ self.transform: Optional[Callable[[XmlMatcherResult], R]] = transform
268
+ self.position: int = position
269
+
270
+ self.index: int = 0
271
+ self.chunks: List[XmlMatcherResult] = []
272
+ self.cached: List[str] = []
273
+ self.matched: bool = False
274
+ self.state: str = "TEXT" # "TEXT", "TAG_OPEN", "TAG_CLOSE"
275
+ self.depth: int = 0
276
+ self.pointer: int = 0
277
+
278
+ def _collect(self):
279
+ """将缓存的字符收集到 chunks 中"""
280
+ if not self.cached:
281
+ return
282
+
283
+ data = "".join(self.cached)
284
+ # 检查是否需要合并到上一个 chunk
285
+ # 仅当当前缓存的匹配状态与上一个 chunk 相同时才合并
286
+ last = self.chunks[-1] if self.chunks else None
287
+ current_matched_state = self.matched if self.state == "TEXT" else (self.depth > 0) # 在标签解析过程中,匹配状态取决于深度
288
+
289
+ if last and last.matched == current_matched_state:
290
+ last.data += data
291
+ else:
292
+ # 只有当 data 不为空时才添加新的 chunk
293
+ if data:
294
+ self.chunks.append(XmlMatcherResult(data=data, matched=current_matched_state))
295
+
296
+ self.cached = []
297
+
298
+ def _pop(self) -> List[Union[XmlMatcherResult, R]]:
299
+ """返回处理过的 chunks 并清空列表"""
300
+ chunks_to_return = self.chunks
301
+ self.chunks = []
302
+ if not self.transform:
303
+ # 如果没有 transform 函数,直接返回原始结果列表
304
+ # 需要显式类型转换,因为泛型 R 默认为 XmlMatcherResult
305
+ return [chunk for chunk in chunks_to_return] # type: ignore
306
+ # 应用 transform 函数
307
+ return [self.transform(chunk) for chunk in chunks_to_return]
308
+
309
+ def _update(self, chunk: str):
310
+ """处理输入字符串块的核心逻辑"""
311
+ for char in chunk:
312
+ current_char_processed = False # 标记当前字符是否已被状态机逻辑处理
313
+
314
+ if self.state == "TEXT":
315
+ if char == "<" and (self.pointer >= self.position or self.matched):
316
+ self._collect()
317
+ self.state = "TAG_OPEN"
318
+ self.cached.append(char)
319
+ self.index = 0 # 重置 index 以开始匹配标签名或跳过空格
320
+ current_char_processed = True
321
+ # else: 保持在 TEXT 状态,字符将在循环末尾添加到 cached
322
+
323
+ elif self.state == "TAG_OPEN":
324
+ self.cached.append(char)
325
+ current_char_processed = True
326
+
327
+ tag_name_len = len(self.tag_name)
328
+
329
+ # 状态: 刚进入 < 之后
330
+ if self.index == 0:
331
+ if char == "/":
332
+ self.state = "TAG_CLOSE"
333
+ # index 保持 0,准备匹配闭合标签名或跳过空格
334
+ elif char.isspace():
335
+ # 跳过 < 后的空格
336
+ pass # index 保持 0
337
+ elif char == self.tag_name[0]:
338
+ # 开始匹配标签名
339
+ self.index = 1
340
+ else:
341
+ # 无效标签开头 (不是 /,不是空格,不是 tag_name[0])
342
+ self.state = "TEXT"
343
+ current_char_processed = True
344
+ # 状态: 正在匹配标签名
345
+ elif self.index < tag_name_len:
346
+ if self.tag_name[self.index] == char:
347
+ self.index += 1
348
+ # 允许在标签名匹配过程中遇到空格,视为属性或无效字符处理
349
+ elif char.isspace():
350
+ # 遇到空格,表示标签名已结束,进入属性/结束符处理
351
+ # 将 index 设置为 tag_name_len 以便后续逻辑处理
352
+ # 但前提是当前 index 确实匹配到了 tag_name
353
+ # 如果是 <t hink> 这种情况,这里会失败
354
+ # 为了简化,我们不允许标签名内部有空格,如果需要,逻辑会更复杂
355
+ # 因此,如果这里遇到空格但 index < tag_name_len,视为无效
356
+ self.state = "TEXT"
357
+ current_char_processed = True
358
+ else:
359
+ # 字符不匹配标签名
360
+ self.state = "TEXT"
361
+ current_char_processed = True
362
+ # 状态: 标签名已完全匹配 (self.index == tag_name_len)
363
+ else: # self.index >= tag_name_len (实际是 ==)
364
+ if char == ">":
365
+ # 找到了开始标签的结束符
366
+ self.state = "TEXT"
367
+ self.depth += 1
368
+ self.matched = True
369
+ self.cached = [] # 清空缓存,丢弃 <tag ...>
370
+ elif char.isspace():
371
+ # 忽略标签名后的空格
372
+ pass # 保持在 TAG_OPEN 状态,等待 > 或属性
373
+ else:
374
+ # 字符是属性的一部分,忽略它,继续等待 '>'
375
+ pass # 保持在 TAG_OPEN 状态
376
+
377
+ elif self.state == "TAG_CLOSE":
378
+ self.cached.append(char)
379
+ current_char_processed = True # 默认设为 True
380
+
381
+ tag_name_len = len(self.tag_name)
382
+
383
+ # 状态: 刚进入 </ 之后
384
+ if self.index == 0:
385
+ if char.isspace():
386
+ # 跳过 </ 后的空格
387
+ pass # index 保持 0
388
+ elif char == self.tag_name[0]:
389
+ # 开始匹配标签名
390
+ self.index = 1
391
+ else:
392
+ # 无效闭合标签 (不是空格,不是 tag_name[0])
393
+ self.state = "TEXT"
394
+ current_char_processed = True
395
+ # 状态: 正在匹配标签名
396
+ elif self.index < tag_name_len:
397
+ if self.tag_name[self.index] == char:
398
+ self.index += 1
399
+ else:
400
+ # 字符不匹配标签名
401
+ self.state = "TEXT"
402
+ current_char_processed = True
403
+ # 状态: 标签名已完全匹配 (self.index == tag_name_len)
404
+ else: # self.index == tag_name_len
405
+ if char == ">":
406
+ # 找到了 '>'
407
+ was_inside_tag = self.depth > 0
408
+ self.state = "TEXT" # 无论如何都回到 TEXT 状态
409
+
410
+ if was_inside_tag:
411
+ # 确实在标签内部,正常处理闭合标签
412
+ self.depth -= 1
413
+ self.matched = self.depth > 0
414
+ self.cached = [] # 清空缓存,丢弃 </tag>
415
+ # current_char_processed 保持 True
416
+ else:
417
+ # 不在标签内部,这是一个无效/意外的闭合标签
418
+ # 将其视为普通文本,但阻止最后的 > 被添加到缓存
419
+ # 保留 cached 中已有的 '</tag' 部分,它们将在下次 collect 时作为文本处理
420
+ current_char_processed = True # 标记 '>' 已处理,防止循环末尾再次添加
421
+
422
+ elif char.isspace():
423
+ # 允许 </tag >, 继续等待 '>'
424
+ pass # 保持在 TAG_CLOSE 状态, current_char_processed 保持 True
425
+ else:
426
+ # 闭合标签名后出现非空格、非 > 的字符
427
+ self.state = "TEXT"
428
+ current_char_processed = True
429
+
430
+ # 如果当前字符未被状态机逻辑处理(即应视为普通文本)
431
+ if not current_char_processed:
432
+ # 确保状态是 TEXT
433
+ if self.state != "TEXT":
434
+ # 如果之前在尝试匹配标签但失败了,缓存的内容应视为文本
435
+ self.state = "TEXT"
436
+
437
+ self.cached.append(char)
438
+
439
+ self.pointer += 1
440
+
441
+ # 在处理完整个 chunk 后,如果状态是 TEXT,收集剩余缓存
442
+ if self.state == "TEXT":
443
+ self._collect()
444
+
445
+
446
+ def final(self, chunk: Optional[str] = None) -> List[Union[XmlMatcherResult, R]]:
447
+ """处理最后一块数据并返回所有结果"""
448
+ if chunk:
449
+ self._update(chunk)
450
+ # 确保所有剩余缓存都被收集
451
+ # 即使状态不是 TEXT,也需要收集,以防有未闭合的标签等情况
452
+ self._collect()
453
+ return self._pop()
454
+
455
+ def update(self, chunk: str) -> List[Union[XmlMatcherResult, R]]:
456
+ """处理一块数据并返回当前处理的结果"""
457
+ self._update(chunk)
458
+ return self._pop()
459
+
460
+ def parse_function_xml(xml_content: str) -> Dict[str, Any]:
461
+ """
462
+ 解析XML格式的函数调用信息,转换为字典格式
463
+
464
+ 参数:
465
+ xml_content: 包含函数调用的XML字符串
466
+
467
+ 返回:
468
+ 包含函数名和参数的字典
469
+ """
470
+ # 首先,找出根标签(函数名)
471
+ # function_matcher = XmlMatcher[XmlMatcherResult]("", position=0)
472
+ # results = function_matcher.final(xml_content)
473
+
474
+ # 找到第一个匹配的标签
475
+ function_name = ""
476
+ function_content = ""
477
+
478
+ # 寻找第一个开始标签
479
+ tag_start = xml_content.find("<")
480
+ if tag_start != -1:
481
+ tag_end = xml_content.find(">", tag_start)
482
+ if tag_end != -1:
483
+ # 提取标签名(函数名)
484
+ tag_content = xml_content[tag_start+1:tag_end].strip()
485
+ # 处理可能有属性的情况
486
+ function_name = tag_content.split()[0] if " " in tag_content else tag_content
487
+
488
+ # 使用XmlMatcher提取该函数标签内的内容
489
+ content_matcher = XmlMatcher[XmlMatcherResult](function_name)
490
+ match_results = content_matcher.final(xml_content)
491
+
492
+ for result in match_results:
493
+ if result.matched:
494
+ function_content = result.data
495
+ break
496
+
497
+ # 如果没有找到函数名或内容,返回空结果
498
+ if not function_name or not function_content:
499
+ return {'function_name': '', 'parameter': {}}
500
+
501
+ # 解析参数
502
+ parameters = {}
503
+ lines = function_content.strip().split('\n')
504
+ current_param = None
505
+ current_value = []
506
+
507
+ for line in lines:
508
+ line = line.strip()
509
+ if line.startswith('<') and '>' in line:
510
+ # 新参数开始
511
+ if current_param and current_value:
512
+ # 保存之前的参数
513
+ parameters[current_param] = '\n'.join(current_value).strip()
514
+ current_value = []
515
+
516
+ # 提取参数名
517
+ param_start = line.find('<') + 1
518
+ param_end = line.find('>', param_start)
519
+ if param_end != -1:
520
+ param = line[param_start:param_end]
521
+ # 检查是否是闭合标签
522
+ if param.startswith('/'):
523
+ if param[1:] == current_param:
524
+ current_param = None
525
+ else:
526
+ current_param = param
527
+ # 检查是否在同一行有值
528
+ rest = line[param_end+1:]
529
+ if rest and not rest.startswith('</'):
530
+ current_value.append(rest)
531
+ elif current_param:
532
+ # 继续收集当前参数的值
533
+ current_value.append(line)
534
+
535
+ # 处理最后一个参数
536
+ if current_param and current_value:
537
+ parameters[current_param] = '\n'.join(current_value).strip()
538
+
539
+ # 清理参数值中可能的结束标签
540
+ for param, value in parameters.items():
541
+ end_tag = f'</{param}>'
542
+ if value.endswith(end_tag):
543
+ parameters[param] = value[:-len(end_tag)].strip()
544
+
545
+ return {
546
+ 'function_name': function_name,
547
+ 'parameter': parameters
548
+ }
549
+
550
+ if __name__ == "__main__":
551
+ os.system("clear")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aient
3
- Version: 1.0.38
3
+ Version: 1.0.39
4
4
  Summary: Aient: The Awakening of Agent.
5
5
  Description-Content-Type: text/markdown
6
6
  License-File: LICENSE
@@ -1,235 +0,0 @@
1
- import os
2
- import json
3
- import base64
4
- import tiktoken
5
- import requests
6
- import urllib.parse
7
-
8
- from ..core.utils import get_image_message
9
-
10
- def get_encode_text(text, model_name):
11
- tiktoken.get_encoding("cl100k_base")
12
- model_name = "gpt-3.5-turbo"
13
- encoding = tiktoken.encoding_for_model(model_name)
14
- encode_text = encoding.encode(text, disallowed_special=())
15
- return encoding, encode_text
16
-
17
- def get_text_token_len(text, model_name):
18
- encoding, encode_text = get_encode_text(text, model_name)
19
- return len(encode_text)
20
-
21
- def cut_message(message: str, max_tokens: int, model_name: str):
22
- if type(message) != str:
23
- message = str(message)
24
- encoding, encode_text = get_encode_text(message, model_name)
25
- if len(encode_text) > max_tokens:
26
- encode_text = encode_text[:max_tokens]
27
- message = encoding.decode(encode_text)
28
- encode_text = encoding.encode(message, disallowed_special=())
29
- return message, len(encode_text)
30
-
31
- import imghdr
32
- def encode_image(image_path):
33
- with open(image_path, "rb") as image_file:
34
- file_content = image_file.read()
35
- file_type = imghdr.what(None, file_content)
36
- base64_encoded = base64.b64encode(file_content).decode('utf-8')
37
-
38
- if file_type == 'png':
39
- return f"data:image/png;base64,{base64_encoded}"
40
- elif file_type in ['jpeg', 'jpg']:
41
- return f"data:image/jpeg;base64,{base64_encoded}"
42
- else:
43
- raise ValueError(f"不支持的图片格式: {file_type}")
44
-
45
- def get_doc_from_url(url):
46
- filename = urllib.parse.unquote(url.split("/")[-1])
47
- response = requests.get(url, stream=True)
48
- with open(filename, 'wb') as f:
49
- for chunk in response.iter_content(chunk_size=1024):
50
- f.write(chunk)
51
- return filename
52
-
53
- def get_encode_image(image_url):
54
- filename = get_doc_from_url(image_url)
55
- image_path = os.getcwd() + "/" + filename
56
- base64_image = encode_image(image_path)
57
- os.remove(image_path)
58
- return base64_image
59
-
60
- from io import BytesIO
61
- def get_audio_message(file_bytes):
62
- try:
63
- # 创建一个字节流对象
64
- audio_stream = BytesIO(file_bytes)
65
-
66
- # 直接使用字节流对象进行转录
67
- import config
68
- transcript = config.whisperBot.generate(audio_stream)
69
- # print("transcript", transcript)
70
-
71
- return transcript
72
-
73
- except Exception as e:
74
- return f"处理音频文件时出错: {str(e)}"
75
-
76
- async def Document_extract(docurl, docpath=None, engine_type = None):
77
- filename = docpath
78
- text = None
79
- prompt = None
80
- if docpath and docurl and "paper.pdf" != docpath:
81
- filename = get_doc_from_url(docurl)
82
- docpath = os.getcwd() + "/" + filename
83
- if filename and filename[-3:] == "pdf":
84
- from pdfminer.high_level import extract_text
85
- text = extract_text(docpath)
86
- if filename and (filename[-3:] == "txt" or filename[-3:] == ".md" or filename[-3:] == ".py" or filename[-3:] == "yml"):
87
- with open(docpath, 'r') as f:
88
- text = f.read()
89
- if text:
90
- prompt = (
91
- "Here is the document, inside <document></document> XML tags:"
92
- "<document>"
93
- "{}"
94
- "</document>"
95
- ).format(text)
96
- if filename and filename[-3:] == "jpg" or filename[-3:] == "png" or filename[-4:] == "jpeg":
97
- prompt = await get_image_message(docurl, engine_type)
98
- if filename and filename[-3:] == "wav" or filename[-3:] == "mp3":
99
- with open(docpath, "rb") as file:
100
- file_bytes = file.read()
101
- prompt = get_audio_message(file_bytes)
102
- prompt = (
103
- "Here is the text content after voice-to-text conversion, inside <voice-to-text></voice-to-text> XML tags:"
104
- "<voice-to-text>"
105
- "{}"
106
- "</voice-to-text>"
107
- ).format(prompt)
108
- if os.path.exists(docpath):
109
- os.remove(docpath)
110
- return prompt
111
-
112
- def split_json_strings(input_string):
113
- # 初始化结果列表和当前 JSON 字符串
114
- json_strings = []
115
- current_json = ""
116
- brace_count = 0
117
-
118
- # 遍历输入字符串的每个字符
119
- for char in input_string:
120
- current_json += char
121
- if char == '{':
122
- brace_count += 1
123
- elif char == '}':
124
- brace_count -= 1
125
-
126
- # 如果花括号配对完成,我们找到了一个完整的 JSON 字符串
127
- if brace_count == 0:
128
- # 尝试解析当前 JSON 字符串
129
- try:
130
- json.loads(current_json)
131
- json_strings.append(current_json)
132
- current_json = ""
133
- except json.JSONDecodeError:
134
- # 如果解析失败,继续添加字符
135
- pass
136
- if json_strings == []:
137
- json_strings.append(input_string)
138
- return json_strings
139
-
140
- def check_json(json_data):
141
- while True:
142
- try:
143
- result = split_json_strings(json_data)
144
- if len(result) > 0:
145
- json_data = result[0]
146
- json.loads(json_data)
147
- break
148
- except json.decoder.JSONDecodeError as e:
149
- print("JSON error:", e)
150
- print("JSON body", repr(json_data))
151
- if "Invalid control character" in str(e):
152
- json_data = json_data.replace("\n", "\\n")
153
- elif "Unterminated string starting" in str(e):
154
- json_data += '"}'
155
- elif "Expecting ',' delimiter" in str(e):
156
- json_data = {"prompt": json_data}
157
- elif "Expecting ':' delimiter" in str(e):
158
- json_data = '{"prompt": ' + json.dumps(json_data) + '}'
159
- elif "Expecting value: line 1 column 1" in str(e):
160
- if json_data.startswith("prompt: "):
161
- json_data = json_data.replace("prompt: ", "")
162
- json_data = '{"prompt": ' + json.dumps(json_data) + '}'
163
- else:
164
- json_data = '{"prompt": ' + json.dumps(json_data) + '}'
165
- return json_data
166
-
167
- def is_surrounded_by_chinese(text, index):
168
- left_char = text[index - 1]
169
- if 0 < index < len(text) - 1:
170
- right_char = text[index + 1]
171
- return '\u4e00' <= left_char <= '\u9fff' or '\u4e00' <= right_char <= '\u9fff'
172
- if index == len(text) - 1:
173
- return '\u4e00' <= left_char <= '\u9fff'
174
- return False
175
-
176
- def replace_char(string, index, new_char):
177
- return string[:index] + new_char + string[index+1:]
178
-
179
- def claude_replace(text):
180
- Punctuation_mapping = {",": ",", ":": ":", "!": "!", "?": "?", ";": ";"}
181
- key_list = list(Punctuation_mapping.keys())
182
- for i in range(len(text)):
183
- if is_surrounded_by_chinese(text, i) and (text[i] in key_list):
184
- text = replace_char(text, i, Punctuation_mapping[text[i]])
185
- return text
186
-
187
- def safe_get(data, *keys, default=None):
188
- for key in keys:
189
- try:
190
- data = data[key] if isinstance(data, (dict, list)) else data.get(key)
191
- except (KeyError, IndexError, AttributeError, TypeError):
192
- return default
193
- return data
194
-
195
- import asyncio
196
- def async_generator_to_sync(async_gen):
197
- """
198
- 将异步生成器转换为同步生成器的工具函数
199
-
200
- Args:
201
- async_gen: 异步生成器函数
202
-
203
- Yields:
204
- 异步生成器产生的每个值
205
- """
206
- loop = asyncio.new_event_loop()
207
- asyncio.set_event_loop(loop)
208
-
209
- try:
210
- async def collect_chunks():
211
- chunks = []
212
- async for chunk in async_gen:
213
- chunks.append(chunk)
214
- return chunks
215
-
216
- chunks = loop.run_until_complete(collect_chunks())
217
- for chunk in chunks:
218
- yield chunk
219
-
220
- except Exception as e:
221
- print(f"Error during async execution: {e}")
222
- raise
223
- finally:
224
- try:
225
- # 清理所有待处理的任务
226
- tasks = [t for t in asyncio.all_tasks(loop) if not t.done()]
227
- if tasks:
228
- loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
229
- loop.run_until_complete(loop.shutdown_asyncgens())
230
- loop.close()
231
- except Exception as e:
232
- print(f"Error during cleanup: {e}")
233
-
234
- if __name__ == "__main__":
235
- os.system("clear")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes