beswarm 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- beswarm/aient/main.py +50 -0
- beswarm/aient/setup.py +15 -0
- beswarm/aient/src/aient/__init__.py +1 -0
- beswarm/aient/src/aient/core/__init__.py +1 -0
- beswarm/aient/src/aient/core/log_config.py +6 -0
- beswarm/aient/src/aient/core/models.py +232 -0
- beswarm/aient/src/aient/core/request.py +1665 -0
- beswarm/aient/src/aient/core/response.py +617 -0
- beswarm/aient/src/aient/core/test/test_base_api.py +18 -0
- beswarm/aient/src/aient/core/test/test_image.py +15 -0
- beswarm/aient/src/aient/core/test/test_payload.py +92 -0
- beswarm/aient/src/aient/core/utils.py +715 -0
- beswarm/aient/src/aient/models/__init__.py +9 -0
- beswarm/aient/src/aient/models/audio.py +63 -0
- beswarm/aient/src/aient/models/base.py +251 -0
- beswarm/aient/src/aient/models/chatgpt.py +938 -0
- beswarm/aient/src/aient/models/claude.py +640 -0
- beswarm/aient/src/aient/models/duckduckgo.py +241 -0
- beswarm/aient/src/aient/models/gemini.py +357 -0
- beswarm/aient/src/aient/models/groq.py +268 -0
- beswarm/aient/src/aient/models/vertex.py +420 -0
- beswarm/aient/src/aient/plugins/__init__.py +33 -0
- beswarm/aient/src/aient/plugins/arXiv.py +48 -0
- beswarm/aient/src/aient/plugins/config.py +172 -0
- beswarm/aient/src/aient/plugins/excute_command.py +35 -0
- beswarm/aient/src/aient/plugins/get_time.py +19 -0
- beswarm/aient/src/aient/plugins/image.py +72 -0
- beswarm/aient/src/aient/plugins/list_directory.py +50 -0
- beswarm/aient/src/aient/plugins/read_file.py +79 -0
- beswarm/aient/src/aient/plugins/registry.py +116 -0
- beswarm/aient/src/aient/plugins/run_python.py +156 -0
- beswarm/aient/src/aient/plugins/websearch.py +394 -0
- beswarm/aient/src/aient/plugins/write_file.py +51 -0
- beswarm/aient/src/aient/prompt/__init__.py +1 -0
- beswarm/aient/src/aient/prompt/agent.py +280 -0
- beswarm/aient/src/aient/utils/__init__.py +0 -0
- beswarm/aient/src/aient/utils/prompt.py +143 -0
- beswarm/aient/src/aient/utils/scripts.py +721 -0
- beswarm/aient/test/chatgpt.py +161 -0
- beswarm/aient/test/claude.py +32 -0
- beswarm/aient/test/test.py +2 -0
- beswarm/aient/test/test_API.py +6 -0
- beswarm/aient/test/test_Deepbricks.py +20 -0
- beswarm/aient/test/test_Web_crawler.py +262 -0
- beswarm/aient/test/test_aiwaves.py +25 -0
- beswarm/aient/test/test_aiwaves_arxiv.py +19 -0
- beswarm/aient/test/test_ask_gemini.py +8 -0
- beswarm/aient/test/test_class.py +17 -0
- beswarm/aient/test/test_claude.py +23 -0
- beswarm/aient/test/test_claude_zh_char.py +26 -0
- beswarm/aient/test/test_ddg_search.py +50 -0
- beswarm/aient/test/test_download_pdf.py +56 -0
- beswarm/aient/test/test_gemini.py +97 -0
- beswarm/aient/test/test_get_token_dict.py +21 -0
- beswarm/aient/test/test_google_search.py +35 -0
- beswarm/aient/test/test_jieba.py +32 -0
- beswarm/aient/test/test_json.py +65 -0
- beswarm/aient/test/test_langchain_search_old.py +235 -0
- beswarm/aient/test/test_logging.py +32 -0
- beswarm/aient/test/test_ollama.py +55 -0
- beswarm/aient/test/test_plugin.py +16 -0
- beswarm/aient/test/test_py_run.py +26 -0
- beswarm/aient/test/test_requests.py +162 -0
- beswarm/aient/test/test_search.py +18 -0
- beswarm/aient/test/test_tikitoken.py +19 -0
- beswarm/aient/test/test_token.py +94 -0
- beswarm/aient/test/test_url.py +33 -0
- beswarm/aient/test/test_whisper.py +14 -0
- beswarm/aient/test/test_wildcard.py +20 -0
- beswarm/aient/test/test_yjh.py +21 -0
- {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/METADATA +1 -1
- beswarm-0.1.13.dist-info/RECORD +131 -0
- beswarm-0.1.12.dist-info/RECORD +0 -61
- {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/WHEEL +0 -0
- {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,721 @@
|
|
1
|
+
import os
|
2
|
+
import json
|
3
|
+
import base64
|
4
|
+
import tiktoken
|
5
|
+
import requests
|
6
|
+
import urllib.parse
|
7
|
+
|
8
|
+
from ..core.utils import get_image_message
|
9
|
+
|
10
|
+
def get_encode_text(text, model_name):
|
11
|
+
tiktoken.get_encoding("cl100k_base")
|
12
|
+
model_name = "gpt-3.5-turbo"
|
13
|
+
encoding = tiktoken.encoding_for_model(model_name)
|
14
|
+
encode_text = encoding.encode(text, disallowed_special=())
|
15
|
+
return encoding, encode_text
|
16
|
+
|
17
|
+
def get_text_token_len(text, model_name):
|
18
|
+
encoding, encode_text = get_encode_text(text, model_name)
|
19
|
+
return len(encode_text)
|
20
|
+
|
21
|
+
def cut_message(message: str, max_tokens: int, model_name: str):
|
22
|
+
if type(message) != str:
|
23
|
+
message = str(message)
|
24
|
+
encoding, encode_text = get_encode_text(message, model_name)
|
25
|
+
if len(encode_text) > max_tokens:
|
26
|
+
encode_text = encode_text[:max_tokens]
|
27
|
+
message = encoding.decode(encode_text)
|
28
|
+
encode_text = encoding.encode(message, disallowed_special=())
|
29
|
+
return message, len(encode_text)
|
30
|
+
|
31
|
+
import imghdr
|
32
|
+
def encode_image(image_path):
|
33
|
+
with open(image_path, "rb") as image_file:
|
34
|
+
file_content = image_file.read()
|
35
|
+
file_type = imghdr.what(None, file_content)
|
36
|
+
base64_encoded = base64.b64encode(file_content).decode('utf-8')
|
37
|
+
|
38
|
+
if file_type == 'png':
|
39
|
+
return f"data:image/png;base64,{base64_encoded}"
|
40
|
+
elif file_type in ['jpeg', 'jpg']:
|
41
|
+
return f"data:image/jpeg;base64,{base64_encoded}"
|
42
|
+
else:
|
43
|
+
raise ValueError(f"不支持的图片格式: {file_type}")
|
44
|
+
|
45
|
+
def get_doc_from_url(url):
|
46
|
+
filename = urllib.parse.unquote(url.split("/")[-1])
|
47
|
+
response = requests.get(url, stream=True)
|
48
|
+
with open(filename, 'wb') as f:
|
49
|
+
for chunk in response.iter_content(chunk_size=1024):
|
50
|
+
f.write(chunk)
|
51
|
+
return filename
|
52
|
+
|
53
|
+
def get_encode_image(image_url):
|
54
|
+
filename = get_doc_from_url(image_url)
|
55
|
+
image_path = os.getcwd() + "/" + filename
|
56
|
+
base64_image = encode_image(image_path)
|
57
|
+
os.remove(image_path)
|
58
|
+
return base64_image
|
59
|
+
|
60
|
+
from io import BytesIO
|
61
|
+
def get_audio_message(file_bytes):
|
62
|
+
try:
|
63
|
+
# 创建一个字节流对象
|
64
|
+
audio_stream = BytesIO(file_bytes)
|
65
|
+
|
66
|
+
# 直接使用字节流对象进行转录
|
67
|
+
import config
|
68
|
+
transcript = config.whisperBot.generate(audio_stream)
|
69
|
+
# print("transcript", transcript)
|
70
|
+
|
71
|
+
return transcript
|
72
|
+
|
73
|
+
except Exception as e:
|
74
|
+
return f"处理音频文件时出错: {str(e)}"
|
75
|
+
|
76
|
+
async def Document_extract(docurl, docpath=None, engine_type = None):
|
77
|
+
filename = docpath
|
78
|
+
text = None
|
79
|
+
prompt = None
|
80
|
+
if docpath and docurl and "paper.pdf" != docpath:
|
81
|
+
filename = get_doc_from_url(docurl)
|
82
|
+
docpath = os.getcwd() + "/" + filename
|
83
|
+
if filename and filename[-3:] == "pdf":
|
84
|
+
from pdfminer.high_level import extract_text
|
85
|
+
text = extract_text(docpath)
|
86
|
+
if filename and (filename[-3:] == "txt" or filename[-3:] == ".md" or filename[-3:] == ".py" or filename[-3:] == "yml"):
|
87
|
+
with open(docpath, 'r') as f:
|
88
|
+
text = f.read()
|
89
|
+
if text:
|
90
|
+
prompt = (
|
91
|
+
"Here is the document, inside <document></document> XML tags:"
|
92
|
+
"<document>"
|
93
|
+
"{}"
|
94
|
+
"</document>"
|
95
|
+
).format(text)
|
96
|
+
if filename and filename[-3:] == "jpg" or filename[-3:] == "png" or filename[-4:] == "jpeg":
|
97
|
+
prompt = await get_image_message(docurl, engine_type)
|
98
|
+
if filename and filename[-3:] == "wav" or filename[-3:] == "mp3":
|
99
|
+
with open(docpath, "rb") as file:
|
100
|
+
file_bytes = file.read()
|
101
|
+
prompt = get_audio_message(file_bytes)
|
102
|
+
prompt = (
|
103
|
+
"Here is the text content after voice-to-text conversion, inside <voice-to-text></voice-to-text> XML tags:"
|
104
|
+
"<voice-to-text>"
|
105
|
+
"{}"
|
106
|
+
"</voice-to-text>"
|
107
|
+
).format(prompt)
|
108
|
+
if os.path.exists(docpath):
|
109
|
+
os.remove(docpath)
|
110
|
+
return prompt
|
111
|
+
|
112
|
+
def split_json_strings(input_string):
|
113
|
+
# 初始化结果列表和当前 JSON 字符串
|
114
|
+
json_strings = []
|
115
|
+
current_json = ""
|
116
|
+
brace_count = 0
|
117
|
+
|
118
|
+
# 遍历输入字符串的每个字符
|
119
|
+
for char in input_string:
|
120
|
+
current_json += char
|
121
|
+
if char == '{':
|
122
|
+
brace_count += 1
|
123
|
+
elif char == '}':
|
124
|
+
brace_count -= 1
|
125
|
+
|
126
|
+
# 如果花括号配对完成,我们找到了一个完整的 JSON 字符串
|
127
|
+
if brace_count == 0:
|
128
|
+
# 尝试解析当前 JSON 字符串
|
129
|
+
try:
|
130
|
+
json.loads(current_json)
|
131
|
+
json_strings.append(current_json)
|
132
|
+
current_json = ""
|
133
|
+
except json.JSONDecodeError:
|
134
|
+
# 如果解析失败,继续添加字符
|
135
|
+
pass
|
136
|
+
if json_strings == []:
|
137
|
+
json_strings.append(input_string)
|
138
|
+
return json_strings
|
139
|
+
|
140
|
+
def check_json(json_data):
|
141
|
+
while True:
|
142
|
+
try:
|
143
|
+
result = split_json_strings(json_data)
|
144
|
+
if len(result) > 0:
|
145
|
+
json_data = result[0]
|
146
|
+
json.loads(json_data)
|
147
|
+
break
|
148
|
+
except json.decoder.JSONDecodeError as e:
|
149
|
+
print("JSON error:", e)
|
150
|
+
print("JSON body", repr(json_data))
|
151
|
+
if "Invalid control character" in str(e):
|
152
|
+
json_data = json_data.replace("\n", "\\n")
|
153
|
+
elif "Unterminated string starting" in str(e):
|
154
|
+
json_data += '"}'
|
155
|
+
elif "Expecting ',' delimiter" in str(e):
|
156
|
+
json_data = {"prompt": json_data}
|
157
|
+
elif "Expecting ':' delimiter" in str(e):
|
158
|
+
json_data = '{"prompt": ' + json.dumps(json_data) + '}'
|
159
|
+
elif "Expecting value: line 1 column 1" in str(e):
|
160
|
+
if json_data.startswith("prompt: "):
|
161
|
+
json_data = json_data.replace("prompt: ", "")
|
162
|
+
json_data = '{"prompt": ' + json.dumps(json_data) + '}'
|
163
|
+
else:
|
164
|
+
json_data = '{"prompt": ' + json.dumps(json_data) + '}'
|
165
|
+
return json_data
|
166
|
+
|
167
|
+
def is_surrounded_by_chinese(text, index):
|
168
|
+
left_char = text[index - 1]
|
169
|
+
if 0 < index < len(text) - 1:
|
170
|
+
right_char = text[index + 1]
|
171
|
+
return '\u4e00' <= left_char <= '\u9fff' or '\u4e00' <= right_char <= '\u9fff'
|
172
|
+
if index == len(text) - 1:
|
173
|
+
return '\u4e00' <= left_char <= '\u9fff'
|
174
|
+
return False
|
175
|
+
|
176
|
+
def replace_char(string, index, new_char):
|
177
|
+
return string[:index] + new_char + string[index+1:]
|
178
|
+
|
179
|
+
def claude_replace(text):
|
180
|
+
Punctuation_mapping = {",": ",", ":": ":", "!": "!", "?": "?", ";": ";"}
|
181
|
+
key_list = list(Punctuation_mapping.keys())
|
182
|
+
for i in range(len(text)):
|
183
|
+
if is_surrounded_by_chinese(text, i) and (text[i] in key_list):
|
184
|
+
text = replace_char(text, i, Punctuation_mapping[text[i]])
|
185
|
+
return text
|
186
|
+
|
187
|
+
def safe_get(data, *keys, default=None):
|
188
|
+
for key in keys:
|
189
|
+
try:
|
190
|
+
data = data[key] if isinstance(data, (dict, list)) else data.get(key)
|
191
|
+
except (KeyError, IndexError, AttributeError, TypeError):
|
192
|
+
return default
|
193
|
+
return data
|
194
|
+
|
195
|
+
import asyncio
|
196
|
+
def async_generator_to_sync(async_gen):
|
197
|
+
"""
|
198
|
+
将异步生成器转换为同步生成器的工具函数
|
199
|
+
|
200
|
+
Args:
|
201
|
+
async_gen: 异步生成器函数
|
202
|
+
|
203
|
+
Yields:
|
204
|
+
异步生成器产生的每个值
|
205
|
+
"""
|
206
|
+
loop = asyncio.new_event_loop()
|
207
|
+
asyncio.set_event_loop(loop)
|
208
|
+
|
209
|
+
try:
|
210
|
+
async def collect_chunks():
|
211
|
+
chunks = []
|
212
|
+
async for chunk in async_gen:
|
213
|
+
chunks.append(chunk)
|
214
|
+
return chunks
|
215
|
+
|
216
|
+
chunks = loop.run_until_complete(collect_chunks())
|
217
|
+
for chunk in chunks:
|
218
|
+
yield chunk
|
219
|
+
|
220
|
+
except Exception as e:
|
221
|
+
print(f"Error during async execution: {e}")
|
222
|
+
raise
|
223
|
+
finally:
|
224
|
+
try:
|
225
|
+
# 清理所有待处理的任务
|
226
|
+
tasks = [t for t in asyncio.all_tasks(loop) if not t.done()]
|
227
|
+
if tasks:
|
228
|
+
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
229
|
+
loop.run_until_complete(loop.shutdown_asyncgens())
|
230
|
+
loop.close()
|
231
|
+
except Exception as e:
|
232
|
+
print(f"Error during cleanup: {e}")
|
233
|
+
|
234
|
+
def parse_tools_from_cursor_prompt(text):
|
235
|
+
import json
|
236
|
+
import re
|
237
|
+
|
238
|
+
# 从 cursor_prompt 中提取 <tools> 标签内的 JSON 字符串
|
239
|
+
tools_match = re.search(r"<tools>\n(.*?)\n</tools>", text, re.DOTALL)
|
240
|
+
if tools_match:
|
241
|
+
tools_json_string = tools_match.group(1).strip()
|
242
|
+
try:
|
243
|
+
tools_list_data = json.loads(tools_json_string, strict=False)
|
244
|
+
return tools_list_data
|
245
|
+
except json.JSONDecodeError as e:
|
246
|
+
print(f"解析 JSON 时出错: {e}")
|
247
|
+
return []
|
248
|
+
|
249
|
+
from dataclasses import dataclass
|
250
|
+
from typing import List, Callable, Optional, TypeVar, Generic, Union, Dict, Any
|
251
|
+
|
252
|
+
# 定义结果类型
|
253
|
+
@dataclass
|
254
|
+
class XmlMatcherResult:
|
255
|
+
matched: bool
|
256
|
+
data: str = ""
|
257
|
+
|
258
|
+
# 泛型类型变量,用于 transform 的返回类型
|
259
|
+
R = TypeVar('R')
|
260
|
+
|
261
|
+
class XmlMatcher(Generic[R]):
|
262
|
+
def __init__(self,
|
263
|
+
tag_name: str,
|
264
|
+
transform: Optional[Callable[[XmlMatcherResult], R]] = None,
|
265
|
+
position: int = 0):
|
266
|
+
self.tag_name: str = tag_name
|
267
|
+
self.transform: Optional[Callable[[XmlMatcherResult], R]] = transform
|
268
|
+
self.position: int = position
|
269
|
+
|
270
|
+
self.index: int = 0
|
271
|
+
self.chunks: List[XmlMatcherResult] = []
|
272
|
+
self.cached: List[str] = []
|
273
|
+
self.matched: bool = False
|
274
|
+
self.state: str = "TEXT" # "TEXT", "TAG_OPEN", "TAG_CLOSE"
|
275
|
+
self.depth: int = 0
|
276
|
+
self.pointer: int = 0
|
277
|
+
|
278
|
+
def _collect(self):
|
279
|
+
"""将缓存的字符收集到 chunks 中"""
|
280
|
+
if not self.cached:
|
281
|
+
return
|
282
|
+
|
283
|
+
data = "".join(self.cached)
|
284
|
+
# 检查是否需要合并到上一个 chunk
|
285
|
+
# 仅当当前缓存的匹配状态与上一个 chunk 相同时才合并
|
286
|
+
last = self.chunks[-1] if self.chunks else None
|
287
|
+
current_matched_state = self.matched if self.state == "TEXT" else (self.depth > 0) # 在标签解析过程中,匹配状态取决于深度
|
288
|
+
|
289
|
+
if last and last.matched == current_matched_state:
|
290
|
+
last.data += data
|
291
|
+
else:
|
292
|
+
# 只有当 data 不为空时才添加新的 chunk
|
293
|
+
if data:
|
294
|
+
self.chunks.append(XmlMatcherResult(data=data, matched=current_matched_state))
|
295
|
+
|
296
|
+
self.cached = []
|
297
|
+
|
298
|
+
def _pop(self) -> List[Union[XmlMatcherResult, R]]:
|
299
|
+
"""返回处理过的 chunks 并清空列表"""
|
300
|
+
chunks_to_return = self.chunks
|
301
|
+
self.chunks = []
|
302
|
+
if not self.transform:
|
303
|
+
# 如果没有 transform 函数,直接返回原始结果列表
|
304
|
+
# 需要显式类型转换,因为泛型 R 默认为 XmlMatcherResult
|
305
|
+
return [chunk for chunk in chunks_to_return] # type: ignore
|
306
|
+
# 应用 transform 函数
|
307
|
+
return [self.transform(chunk) for chunk in chunks_to_return]
|
308
|
+
|
309
|
+
def _update(self, chunk: str):
|
310
|
+
"""处理输入字符串块的核心逻辑"""
|
311
|
+
for char in chunk:
|
312
|
+
current_char_processed = False # 标记当前字符是否已被状态机逻辑处理
|
313
|
+
|
314
|
+
if self.state == "TEXT":
|
315
|
+
if char == "<" and (self.pointer >= self.position or self.matched):
|
316
|
+
self._collect()
|
317
|
+
self.state = "TAG_OPEN"
|
318
|
+
self.cached.append(char)
|
319
|
+
self.index = 0 # 重置 index 以开始匹配标签名或跳过空格
|
320
|
+
current_char_processed = True
|
321
|
+
# else: 保持在 TEXT 状态,字符将在循环末尾添加到 cached
|
322
|
+
|
323
|
+
elif self.state == "TAG_OPEN":
|
324
|
+
self.cached.append(char)
|
325
|
+
current_char_processed = True
|
326
|
+
|
327
|
+
tag_name_len = len(self.tag_name)
|
328
|
+
|
329
|
+
# 状态: 刚进入 < 之后
|
330
|
+
if self.index == 0:
|
331
|
+
if char == "/":
|
332
|
+
self.state = "TAG_CLOSE"
|
333
|
+
# index 保持 0,准备匹配闭合标签名或跳过空格
|
334
|
+
elif char.isspace():
|
335
|
+
# 跳过 < 后的空格
|
336
|
+
pass # index 保持 0
|
337
|
+
elif char == self.tag_name[0]:
|
338
|
+
# 开始匹配标签名
|
339
|
+
self.index = 1
|
340
|
+
else:
|
341
|
+
# 无效标签开头 (不是 /,不是空格,不是 tag_name[0])
|
342
|
+
self.state = "TEXT"
|
343
|
+
current_char_processed = True
|
344
|
+
# 状态: 正在匹配标签名
|
345
|
+
elif self.index < tag_name_len:
|
346
|
+
if self.tag_name[self.index] == char:
|
347
|
+
self.index += 1
|
348
|
+
# 允许在标签名匹配过程中遇到空格,视为属性或无效字符处理
|
349
|
+
elif char.isspace():
|
350
|
+
# 遇到空格,表示标签名已结束,进入属性/结束符处理
|
351
|
+
# 将 index 设置为 tag_name_len 以便后续逻辑处理
|
352
|
+
# 但前提是当前 index 确实匹配到了 tag_name
|
353
|
+
# 如果是 <t hink> 这种情况,这里会失败
|
354
|
+
# 为了简化,我们不允许标签名内部有空格,如果需要,逻辑会更复杂
|
355
|
+
# 因此,如果这里遇到空格但 index < tag_name_len,视为无效
|
356
|
+
self.state = "TEXT"
|
357
|
+
current_char_processed = True
|
358
|
+
else:
|
359
|
+
# 字符不匹配标签名
|
360
|
+
self.state = "TEXT"
|
361
|
+
current_char_processed = True
|
362
|
+
# 状态: 标签名已完全匹配 (self.index == tag_name_len)
|
363
|
+
else: # self.index >= tag_name_len (实际是 ==)
|
364
|
+
if char == ">":
|
365
|
+
# 找到了开始标签的结束符
|
366
|
+
self.state = "TEXT"
|
367
|
+
self.depth += 1
|
368
|
+
self.matched = True
|
369
|
+
self.cached = [] # 清空缓存,丢弃 <tag ...>
|
370
|
+
elif char.isspace():
|
371
|
+
# 忽略标签名后的空格
|
372
|
+
pass # 保持在 TAG_OPEN 状态,等待 > 或属性
|
373
|
+
else:
|
374
|
+
# 字符是属性的一部分,忽略它,继续等待 '>'
|
375
|
+
pass # 保持在 TAG_OPEN 状态
|
376
|
+
|
377
|
+
elif self.state == "TAG_CLOSE":
|
378
|
+
self.cached.append(char)
|
379
|
+
current_char_processed = True # 默认设为 True
|
380
|
+
|
381
|
+
tag_name_len = len(self.tag_name)
|
382
|
+
|
383
|
+
# 状态: 刚进入 </ 之后
|
384
|
+
if self.index == 0:
|
385
|
+
if char.isspace():
|
386
|
+
# 跳过 </ 后的空格
|
387
|
+
pass # index 保持 0
|
388
|
+
elif char == self.tag_name[0]:
|
389
|
+
# 开始匹配标签名
|
390
|
+
self.index = 1
|
391
|
+
else:
|
392
|
+
# 无效闭合标签 (不是空格,不是 tag_name[0])
|
393
|
+
self.state = "TEXT"
|
394
|
+
current_char_processed = True
|
395
|
+
# 状态: 正在匹配标签名
|
396
|
+
elif self.index < tag_name_len:
|
397
|
+
if self.tag_name[self.index] == char:
|
398
|
+
self.index += 1
|
399
|
+
else:
|
400
|
+
# 字符不匹配标签名
|
401
|
+
self.state = "TEXT"
|
402
|
+
current_char_processed = True
|
403
|
+
# 状态: 标签名已完全匹配 (self.index == tag_name_len)
|
404
|
+
else: # self.index == tag_name_len
|
405
|
+
if char == ">":
|
406
|
+
# 找到了 '>'
|
407
|
+
was_inside_tag = self.depth > 0
|
408
|
+
self.state = "TEXT" # 无论如何都回到 TEXT 状态
|
409
|
+
|
410
|
+
if was_inside_tag:
|
411
|
+
# 确实在标签内部,正常处理闭合标签
|
412
|
+
self.depth -= 1
|
413
|
+
self.matched = self.depth > 0
|
414
|
+
self.cached = [] # 清空缓存,丢弃 </tag>
|
415
|
+
# current_char_processed 保持 True
|
416
|
+
else:
|
417
|
+
# 不在标签内部,这是一个无效/意外的闭合标签
|
418
|
+
# 将其视为普通文本,但阻止最后的 > 被添加到缓存
|
419
|
+
# 保留 cached 中已有的 '</tag' 部分,它们将在下次 collect 时作为文本处理
|
420
|
+
current_char_processed = True # 标记 '>' 已处理,防止循环末尾再次添加
|
421
|
+
|
422
|
+
elif char.isspace():
|
423
|
+
# 允许 </tag >, 继续等待 '>'
|
424
|
+
pass # 保持在 TAG_CLOSE 状态, current_char_processed 保持 True
|
425
|
+
else:
|
426
|
+
# 闭合标签名后出现非空格、非 > 的字符
|
427
|
+
self.state = "TEXT"
|
428
|
+
current_char_processed = True
|
429
|
+
|
430
|
+
# 如果当前字符未被状态机逻辑处理(即应视为普通文本)
|
431
|
+
if not current_char_processed:
|
432
|
+
# 确保状态是 TEXT
|
433
|
+
if self.state != "TEXT":
|
434
|
+
# 如果之前在尝试匹配标签但失败了,缓存的内容应视为文本
|
435
|
+
self.state = "TEXT"
|
436
|
+
|
437
|
+
self.cached.append(char)
|
438
|
+
|
439
|
+
self.pointer += 1
|
440
|
+
|
441
|
+
# 在处理完整个 chunk 后,如果状态是 TEXT,收集剩余缓存
|
442
|
+
if self.state == "TEXT":
|
443
|
+
self._collect()
|
444
|
+
|
445
|
+
|
446
|
+
def final(self, chunk: Optional[str] = None) -> List[Union[XmlMatcherResult, R]]:
|
447
|
+
"""处理最后一块数据并返回所有结果"""
|
448
|
+
if chunk:
|
449
|
+
self._update(chunk)
|
450
|
+
# 确保所有剩余缓存都被收集
|
451
|
+
# 即使状态不是 TEXT,也需要收集,以防有未闭合的标签等情况
|
452
|
+
self._collect()
|
453
|
+
return self._pop()
|
454
|
+
|
455
|
+
def update(self, chunk: str) -> List[Union[XmlMatcherResult, R]]:
|
456
|
+
"""处理一块数据并返回当前处理的结果"""
|
457
|
+
self._update(chunk)
|
458
|
+
return self._pop()
|
459
|
+
|
460
|
+
def parse_function_xml(xml_content: str) -> List[Dict[str, Any]]:
|
461
|
+
"""
|
462
|
+
解析XML格式的函数调用信息,转换为字典数组格式
|
463
|
+
只解析倒数两层XML标签,忽略更高层级的XML标签
|
464
|
+
|
465
|
+
参数:
|
466
|
+
xml_content: 包含一个或多个函数调用的XML字符串
|
467
|
+
|
468
|
+
返回:
|
469
|
+
包含所有函数调用信息的字典数组,每个字典包含函数名和参数
|
470
|
+
"""
|
471
|
+
result_functions = []
|
472
|
+
|
473
|
+
# 第一步:识别XML中的顶层标签(可能是函数调用)
|
474
|
+
position = 0
|
475
|
+
while position < len(xml_content):
|
476
|
+
# 寻找下一个开始标签
|
477
|
+
tag_start = xml_content.find("<", position)
|
478
|
+
if tag_start == -1:
|
479
|
+
break # 没有找到更多的标签
|
480
|
+
|
481
|
+
# 检查是否是XML标签的开始(不是闭合标签)
|
482
|
+
if tag_start + 1 < len(xml_content) and xml_content[tag_start + 1] == '/':
|
483
|
+
# 这是一个结束标签,跳过
|
484
|
+
position = tag_start + 1
|
485
|
+
continue
|
486
|
+
|
487
|
+
# 找到标签的结束位置
|
488
|
+
tag_end = xml_content.find(">", tag_start)
|
489
|
+
if tag_end == -1:
|
490
|
+
break # 标签未正确关闭
|
491
|
+
|
492
|
+
# 提取标签名
|
493
|
+
tag_content = xml_content[tag_start+1:tag_end].strip()
|
494
|
+
# 处理可能有属性的情况
|
495
|
+
tag_name = tag_content.split()[0] if " " in tag_content else tag_content
|
496
|
+
|
497
|
+
if not tag_name:
|
498
|
+
position = tag_end + 1
|
499
|
+
continue # 空标签名,跳过
|
500
|
+
|
501
|
+
# 查找整个标签的起止范围
|
502
|
+
full_start_tag = f"<{tag_name}"
|
503
|
+
full_end_tag = f"</{tag_name}>"
|
504
|
+
|
505
|
+
# 从当前位置找到开始标签
|
506
|
+
start_pos = xml_content.find(full_start_tag, position)
|
507
|
+
if start_pos == -1:
|
508
|
+
position = tag_end + 1
|
509
|
+
continue
|
510
|
+
|
511
|
+
# 找到对应的结束标签
|
512
|
+
end_pos = xml_content.find(full_end_tag, start_pos)
|
513
|
+
if end_pos == -1:
|
514
|
+
# 没有找到结束标签,可能是未闭合标签
|
515
|
+
position = tag_end + 1
|
516
|
+
continue
|
517
|
+
|
518
|
+
# 标签的内容(不包括开始和结束标签)
|
519
|
+
tag_inner_content = xml_content[tag_end+1:end_pos]
|
520
|
+
|
521
|
+
# 如果是普通辅助标签(如tool_call),则在其内部寻找函数调用
|
522
|
+
if tag_name in ["tool_call", "function_call", "tool", "function"]:
|
523
|
+
# 递归处理内部内容
|
524
|
+
nested_functions = parse_function_xml(tag_inner_content)
|
525
|
+
result_functions.extend(nested_functions)
|
526
|
+
else:
|
527
|
+
# 将当前标签作为函数名,解析其内部标签作为参数
|
528
|
+
parameters = {}
|
529
|
+
|
530
|
+
# 解析内部标签作为参数
|
531
|
+
param_position = 0
|
532
|
+
while param_position < len(tag_inner_content):
|
533
|
+
param_tag_start = tag_inner_content.find("<", param_position)
|
534
|
+
if param_tag_start == -1:
|
535
|
+
break
|
536
|
+
|
537
|
+
# 跳过闭合标签
|
538
|
+
if param_tag_start + 1 < len(tag_inner_content) and tag_inner_content[param_tag_start + 1] == '/':
|
539
|
+
param_position = param_tag_start + 1
|
540
|
+
continue
|
541
|
+
|
542
|
+
param_tag_end = tag_inner_content.find(">", param_tag_start)
|
543
|
+
if param_tag_end == -1:
|
544
|
+
break
|
545
|
+
|
546
|
+
# 提取参数名
|
547
|
+
param_name = tag_inner_content[param_tag_start+1:param_tag_end].strip()
|
548
|
+
if " " in param_name: # 处理有属性的情况
|
549
|
+
param_name = param_name.split()[0]
|
550
|
+
|
551
|
+
if not param_name:
|
552
|
+
param_position = param_tag_end + 1
|
553
|
+
continue
|
554
|
+
|
555
|
+
# 查找参数标签的结束位置
|
556
|
+
param_end_tag = f"</{param_name}>"
|
557
|
+
param_end_pos = tag_inner_content.find(param_end_tag, param_tag_end)
|
558
|
+
|
559
|
+
if param_end_pos == -1:
|
560
|
+
# 参数标签未闭合
|
561
|
+
param_position = param_tag_end + 1
|
562
|
+
continue
|
563
|
+
|
564
|
+
# 提取参数值
|
565
|
+
param_value = tag_inner_content[param_tag_end+1:param_end_pos].strip()
|
566
|
+
parameters[param_name] = param_value
|
567
|
+
|
568
|
+
# 更新位置到当前参数标签之后
|
569
|
+
param_position = param_end_pos + len(param_end_tag)
|
570
|
+
|
571
|
+
# 添加解析结果
|
572
|
+
result_functions.append({
|
573
|
+
'function_name': tag_name,
|
574
|
+
'parameter': parameters
|
575
|
+
})
|
576
|
+
|
577
|
+
# 更新位置到当前标签之后
|
578
|
+
position = end_pos + len(full_end_tag)
|
579
|
+
|
580
|
+
return result_functions
|
581
|
+
|
582
|
+
def parse_continuous_json(json_str: str, function_name: str = "") -> List[Dict[str, Any]]:
|
583
|
+
"""
|
584
|
+
解析JSON字符串,无论是单个JSON对象还是多个连续的JSON对象
|
585
|
+
都能正确解析并转换为结构化的函数调用格式列表
|
586
|
+
|
587
|
+
Args:
|
588
|
+
json_str: JSON字符串,可能是单个JSON对象或多个连续JSON对象
|
589
|
+
function_name: 函数名称,默认为空字符串
|
590
|
+
|
591
|
+
Returns:
|
592
|
+
包含函数调用信息的字典列表
|
593
|
+
"""
|
594
|
+
if not json_str or not json_str.strip():
|
595
|
+
return []
|
596
|
+
|
597
|
+
# 尝试直接解析为单个JSON
|
598
|
+
try:
|
599
|
+
json_obj = json.loads(json_str)
|
600
|
+
tool_id = function_name + "_single" if function_name else "tool_single"
|
601
|
+
return [{
|
602
|
+
'function_name': function_name or "default_function",
|
603
|
+
'parameter': json_obj,
|
604
|
+
'function_call_id': tool_id
|
605
|
+
}]
|
606
|
+
except json.JSONDecodeError:
|
607
|
+
# 如果不是单个JSON,尝试解析为连续JSON
|
608
|
+
pass
|
609
|
+
|
610
|
+
result = []
|
611
|
+
idx = 0
|
612
|
+
length = len(json_str)
|
613
|
+
|
614
|
+
while idx < length:
|
615
|
+
# 找到JSON对象的开始
|
616
|
+
if json_str[idx] != '{':
|
617
|
+
idx += 1
|
618
|
+
continue
|
619
|
+
|
620
|
+
# 跟踪括号的平衡
|
621
|
+
balance = 1
|
622
|
+
start = idx
|
623
|
+
idx += 1
|
624
|
+
|
625
|
+
# 寻找匹配的右括号
|
626
|
+
while idx < length and balance > 0:
|
627
|
+
if json_str[idx] == '{':
|
628
|
+
balance += 1
|
629
|
+
elif json_str[idx] == '}':
|
630
|
+
balance -= 1
|
631
|
+
idx += 1
|
632
|
+
|
633
|
+
if balance == 0:
|
634
|
+
# 提取出一个完整的JSON对象
|
635
|
+
json_obj_str = json_str[start:idx]
|
636
|
+
try:
|
637
|
+
# 解析JSON对象
|
638
|
+
json_obj = json.loads(json_obj_str)
|
639
|
+
# 构造函数调用信息
|
640
|
+
tool_id = function_name + "_" + str(len(result)) if function_name else "tool_" + str(len(result))
|
641
|
+
result.append({
|
642
|
+
'function_name': function_name or "default_function",
|
643
|
+
'parameter': json_obj,
|
644
|
+
'function_call_id': tool_id
|
645
|
+
})
|
646
|
+
except json.JSONDecodeError:
|
647
|
+
# 忽略解析错误
|
648
|
+
pass
|
649
|
+
|
650
|
+
return result
|
651
|
+
|
652
|
+
def convert_functions_to_xml(functions_list):
|
653
|
+
"""
|
654
|
+
将函数调用列表转换为XML格式的字符串
|
655
|
+
|
656
|
+
参数:
|
657
|
+
functions_list: 函数调用列表,每个元素是包含function_name和parameter的字典
|
658
|
+
|
659
|
+
返回:
|
660
|
+
XML格式的字符串
|
661
|
+
"""
|
662
|
+
xml_result = ""
|
663
|
+
|
664
|
+
if isinstance(functions_list, str):
|
665
|
+
try:
|
666
|
+
# 提取并解析JSON字符串
|
667
|
+
functions_list = json.loads(functions_list)
|
668
|
+
# 确保解析结果是列表
|
669
|
+
if not isinstance(functions_list, list):
|
670
|
+
print(f"提取的工具调用不是列表格式: {functions_list}")
|
671
|
+
except json.JSONDecodeError as e:
|
672
|
+
print(f"从文本中提取的工具调用JSON解析失败: {e}")
|
673
|
+
|
674
|
+
for func in functions_list:
|
675
|
+
# 获取函数名和参数
|
676
|
+
function_name = func.get('function_name', '')
|
677
|
+
parameters = func.get('parameter', {})
|
678
|
+
|
679
|
+
# 开始函数标签
|
680
|
+
xml_result += f"<{function_name}>\n"
|
681
|
+
|
682
|
+
# 添加所有参数
|
683
|
+
for param_name, param_value in parameters.items():
|
684
|
+
xml_result += f"<{param_name}>{param_value}</{param_name}>\n"
|
685
|
+
|
686
|
+
# 结束函数标签
|
687
|
+
xml_result += f"</{function_name}>\n"
|
688
|
+
|
689
|
+
return xml_result
|
690
|
+
|
691
|
+
if __name__ == "__main__":
|
692
|
+
|
693
|
+
# 运行本文件:python -m aient.utils.scripts
|
694
|
+
os.system("clear")
|
695
|
+
test_xml = """
|
696
|
+
✅ 好的,我现在读取 `README.md` 文件。
|
697
|
+
<tool_call>
|
698
|
+
<read_file>
|
699
|
+
<file_path>/Users/yanyuming/Downloads/GitHub/llama3_interpretability_sae/README.md</file_path>
|
700
|
+
</read_file>
|
701
|
+
</tool_call>好的,我现在读取 `README.md` 文件。
|
702
|
+
"""
|
703
|
+
test_xml = """
|
704
|
+
✅ 好的,我现在读取 `README.md` 文件。
|
705
|
+
<read_file>
|
706
|
+
<file_path>README.md</file_path>
|
707
|
+
</read_file>
|
708
|
+
<read_file>
|
709
|
+
<file_path>README.md</file_path>
|
710
|
+
</read_file>
|
711
|
+
|
712
|
+
<tool_call>
|
713
|
+
<read_file>
|
714
|
+
<file_path>README.md</file_path>
|
715
|
+
</read_file>
|
716
|
+
</tool_call>
|
717
|
+
好的,我现在读取 `README.md` 文件。
|
718
|
+
"""
|
719
|
+
|
720
|
+
test_xml = """首先使用read_file工具读取论文内容,然后使用excute_command工具克隆代码仓库到本地。\n```xml\n<read_file>\n<file_path>/Users/yanyuming/Downloads/GitHub/OceanSynthesis/papers/2412.06410v1.pdf</file_path>\n</read_file>\n\n<excute_command>\n<command>git clone https://github.com/bartbussmann/BatchTopK.git</command>\n</excute_command>\n```"""
|
721
|
+
print(parse_function_xml(test_xml))
|