aient 1.0.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/__init__.py +1 -0
- aient/core/.git +1 -0
- aient/core/__init__.py +1 -0
- aient/core/log_config.py +6 -0
- aient/core/models.py +227 -0
- aient/core/request.py +1361 -0
- aient/core/response.py +531 -0
- aient/core/test/test_base_api.py +17 -0
- aient/core/test/test_image.py +15 -0
- aient/core/test/test_payload.py +92 -0
- aient/core/utils.py +655 -0
- aient/models/__init__.py +9 -0
- aient/models/audio.py +63 -0
- aient/models/base.py +270 -0
- aient/models/chatgpt.py +856 -0
- aient/models/claude.py +640 -0
- aient/models/duckduckgo.py +241 -0
- aient/models/gemini.py +357 -0
- aient/models/groq.py +268 -0
- aient/models/vertex.py +420 -0
- aient/plugins/__init__.py +32 -0
- aient/plugins/arXiv.py +48 -0
- aient/plugins/config.py +178 -0
- aient/plugins/image.py +72 -0
- aient/plugins/registry.py +116 -0
- aient/plugins/run_python.py +156 -0
- aient/plugins/today.py +19 -0
- aient/plugins/websearch.py +393 -0
- aient/utils/__init__.py +0 -0
- aient/utils/prompt.py +143 -0
- aient/utils/scripts.py +235 -0
- aient-1.0.29.dist-info/METADATA +119 -0
- aient-1.0.29.dist-info/RECORD +36 -0
- aient-1.0.29.dist-info/WHEEL +5 -0
- aient-1.0.29.dist-info/licenses/LICENSE +7 -0
- aient-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,32 @@
|
|
1
|
+
import os
|
2
|
+
import pkgutil
|
3
|
+
import importlib
|
4
|
+
|
5
|
+
# 首先导入registry,因为其他模块中的装饰器依赖它
|
6
|
+
from .registry import registry, register_tool, register_agent
|
7
|
+
|
8
|
+
# 自动导入当前目录下所有的插件模块
|
9
|
+
excluded_modules = ['config', 'registry', '__init__']
|
10
|
+
current_dir = os.path.dirname(__file__)
|
11
|
+
|
12
|
+
# 先导入所有模块,确保装饰器被执行
|
13
|
+
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
|
14
|
+
if module_name not in excluded_modules:
|
15
|
+
importlib.import_module(f'.{module_name}', package=__name__)
|
16
|
+
|
17
|
+
# 然后从config导入必要的定义
|
18
|
+
from .config import *
|
19
|
+
|
20
|
+
# 确保将所有工具函数添加到全局名称空间
|
21
|
+
for tool_name, tool_func in registry.tools.items():
|
22
|
+
globals()[tool_name] = tool_func
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
'PLUGINS',
|
26
|
+
'function_call_list',
|
27
|
+
'get_tools_result_async',
|
28
|
+
'registry',
|
29
|
+
'register_tool',
|
30
|
+
'register_agent',
|
31
|
+
'update_tools_config',
|
32
|
+
] + list(registry.tools.keys())
|
aient/plugins/arXiv.py
ADDED
@@ -0,0 +1,48 @@
|
|
1
|
+
import requests
|
2
|
+
|
3
|
+
from ..utils.scripts import Document_extract
|
4
|
+
from .registry import register_tool
|
5
|
+
|
6
|
+
@register_tool()
|
7
|
+
async def download_read_arxiv_pdf(arxiv_id: str) -> str:
|
8
|
+
"""
|
9
|
+
下载指定arXiv ID的论文PDF并提取其内容。
|
10
|
+
|
11
|
+
此函数会下载arXiv上的论文PDF文件,保存到指定路径,
|
12
|
+
然后使用文档提取工具读取其内容。
|
13
|
+
|
14
|
+
Args:
|
15
|
+
arxiv_id: arXiv论文的ID,例如'2305.12345'
|
16
|
+
|
17
|
+
Returns:
|
18
|
+
提取的论文内容文本或失败消息
|
19
|
+
"""
|
20
|
+
# 构造下载PDF的URL
|
21
|
+
url = f'https://arxiv.org/pdf/{arxiv_id}.pdf'
|
22
|
+
|
23
|
+
# 发送HTTP GET请求
|
24
|
+
response = requests.get(url)
|
25
|
+
|
26
|
+
# 检查是否成功获取内容
|
27
|
+
if response.status_code == 200:
|
28
|
+
# 将PDF内容写入文件
|
29
|
+
save_path = "paper.pdf"
|
30
|
+
with open(save_path, 'wb') as file:
|
31
|
+
file.write(response.content)
|
32
|
+
print(f'PDF下载成功,保存路径: {save_path}')
|
33
|
+
return await Document_extract(None, save_path)
|
34
|
+
else:
|
35
|
+
print(f'下载失败,状态码: {response.status_code}')
|
36
|
+
return "文件下载失败"
|
37
|
+
|
38
|
+
if __name__ == '__main__':
|
39
|
+
# 示例使用
|
40
|
+
arxiv_id = '2305.12345' # 替换为实际的arXiv ID
|
41
|
+
|
42
|
+
# 测试下载功能
|
43
|
+
# print(download_read_arxiv_pdf(arxiv_id))
|
44
|
+
|
45
|
+
# 测试函数转换为JSON
|
46
|
+
# json_result = function_to_json(download_read_arxiv_pdf)
|
47
|
+
# import json
|
48
|
+
# print(json.dumps(json_result, indent=2, ensure_ascii=False))
|
aient/plugins/config.py
ADDED
@@ -0,0 +1,178 @@
|
|
1
|
+
import os
|
2
|
+
import json
|
3
|
+
import inspect
|
4
|
+
|
5
|
+
from .registry import registry
|
6
|
+
from ..utils.scripts import cut_message, safe_get
|
7
|
+
from ..utils.prompt import search_key_word_prompt, arxiv_doc_user_prompt
|
8
|
+
|
9
|
+
async def get_tools_result_async(function_call_name, function_full_response, function_call_max_tokens, engine, robot, api_key, api_url, use_plugins, model, add_message, convo_id, language):
|
10
|
+
function_response = ""
|
11
|
+
if function_call_name in registry.tools:
|
12
|
+
function_to_call = registry.tools[function_call_name]
|
13
|
+
function_args = registry.tools_info[function_call_name].args
|
14
|
+
# required_args = registry.tools_info[function_call_name].required
|
15
|
+
if function_args:
|
16
|
+
arg = function_args[0]
|
17
|
+
else:
|
18
|
+
arg = None
|
19
|
+
if function_call_name == "get_search_results":
|
20
|
+
prompt = json.loads(function_full_response)["query"]
|
21
|
+
yield "message_search_stage_1"
|
22
|
+
llm = robot(api_key=api_key, api_url=api_url.source_api_url, engine=engine, use_plugins=use_plugins)
|
23
|
+
keywords = (await llm.ask_async(search_key_word_prompt.format(source=prompt), model=model)).split("\n")
|
24
|
+
print("keywords", keywords)
|
25
|
+
keywords = [item.replace("三行关键词是:", "") for item in keywords if "\\x" not in item if item != ""]
|
26
|
+
keywords = [prompt] + keywords
|
27
|
+
keywords = keywords[:3]
|
28
|
+
print("select keywords", keywords)
|
29
|
+
async for chunk in function_to_call(keywords):
|
30
|
+
if type(chunk) == str:
|
31
|
+
yield chunk
|
32
|
+
else:
|
33
|
+
function_response = "\n\n".join(chunk)
|
34
|
+
# function_response = yield chunk
|
35
|
+
# function_response = yield from eval(function_call_name)(prompt, keywords)
|
36
|
+
function_call_max_tokens = 32000
|
37
|
+
function_response, text_len = cut_message(function_response, function_call_max_tokens, engine)
|
38
|
+
if function_response:
|
39
|
+
function_response = (
|
40
|
+
f"You need to response the following question: {prompt}. Search results is provided inside <Search_results></Search_results> XML tags. Your task is to think about the question step by step and then answer the above question in {language} based on the Search results provided. Please response in {language} and adopt a style that is logical, in-depth, and detailed. Note: In order to make the answer appear highly professional, you should be an expert in textual analysis, aiming to make the answer precise and comprehensive. Directly response markdown format, without using markdown code blocks. For each sentence quoting search results, a markdown ordered superscript number url link must be used to indicate the source, e.g., [¹](https://www.example.com)"
|
41
|
+
"Here is the Search results, inside <Search_results></Search_results> XML tags:"
|
42
|
+
"<Search_results>"
|
43
|
+
"{}"
|
44
|
+
"</Search_results>"
|
45
|
+
).format(function_response)
|
46
|
+
else:
|
47
|
+
function_response = "无法找到相关信息,停止使用 tools"
|
48
|
+
# user_prompt = f"You need to response the following question: {prompt}. Search results is provided inside <Search_results></Search_results> XML tags. Your task is to think about the question step by step and then answer the above question in {config.language} based on the Search results provided. Please response in {config.language} and adopt a style that is logical, in-depth, and detailed. Note: In order to make the answer appear highly professional, you should be an expert in textual analysis, aiming to make the answer precise and comprehensive. Directly response markdown format, without using markdown code blocks"
|
49
|
+
# self.add_to_conversation(user_prompt, "user", convo_id=convo_id)
|
50
|
+
elif arg: # generate_image get_url_content run_python_script
|
51
|
+
prompt = safe_get(json.loads(function_full_response), arg, default=".")
|
52
|
+
if inspect.iscoroutinefunction(function_to_call):
|
53
|
+
function_response = await function_to_call(prompt)
|
54
|
+
else:
|
55
|
+
function_response = function_to_call(prompt)
|
56
|
+
function_response, text_len = cut_message(function_response, function_call_max_tokens, engine)
|
57
|
+
else: # get_date_time_weekday
|
58
|
+
if inspect.iscoroutinefunction(function_to_call):
|
59
|
+
function_response = await function_to_call()
|
60
|
+
else:
|
61
|
+
function_response = function_to_call()
|
62
|
+
function_response, text_len = cut_message(function_response, function_call_max_tokens, engine)
|
63
|
+
|
64
|
+
if function_call_name == "download_read_arxiv_pdf":
|
65
|
+
add_message(arxiv_doc_user_prompt, "user", convo_id=convo_id)
|
66
|
+
|
67
|
+
function_response = (
|
68
|
+
f"function_response:{function_response}"
|
69
|
+
)
|
70
|
+
yield function_response
|
71
|
+
# return function_response
|
72
|
+
|
73
|
+
def function_to_json(func) -> dict:
|
74
|
+
"""
|
75
|
+
将Python函数转换为JSON可序列化的字典,描述函数的签名,包括名称、描述和参数。
|
76
|
+
|
77
|
+
Args:
|
78
|
+
func: 要转换的函数
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
表示函数签名的JSON格式字典
|
82
|
+
"""
|
83
|
+
type_map = {
|
84
|
+
str: "string",
|
85
|
+
int: "integer",
|
86
|
+
float: "number",
|
87
|
+
bool: "boolean",
|
88
|
+
type(None): "null",
|
89
|
+
}
|
90
|
+
|
91
|
+
try:
|
92
|
+
signature = inspect.signature(func)
|
93
|
+
except ValueError as e:
|
94
|
+
raise ValueError(f"获取函数{func.__name__}签名失败: {str(e)}")
|
95
|
+
|
96
|
+
parameters = {}
|
97
|
+
for param in signature.parameters.values():
|
98
|
+
try:
|
99
|
+
if param.annotation == inspect._empty:
|
100
|
+
parameters[param.name] = {"type": "string"}
|
101
|
+
else:
|
102
|
+
parameters[param.name] = {"type": type_map.get(param.annotation, "string")}
|
103
|
+
except KeyError as e:
|
104
|
+
raise KeyError(f"未知类型注解 {param.annotation} 用于参数 {param.name}: {str(e)}")
|
105
|
+
|
106
|
+
required = [
|
107
|
+
param.name
|
108
|
+
for param in signature.parameters.values()
|
109
|
+
if param.default == inspect._empty
|
110
|
+
]
|
111
|
+
|
112
|
+
return {
|
113
|
+
"name": func.__name__,
|
114
|
+
"description": func.__doc__ or "",
|
115
|
+
"parameters": {
|
116
|
+
"type": "object",
|
117
|
+
"properties": parameters,
|
118
|
+
"required": required,
|
119
|
+
},
|
120
|
+
}
|
121
|
+
|
122
|
+
def gpt2claude_tools_json(json_dict):
|
123
|
+
import copy
|
124
|
+
json_dict = copy.deepcopy(json_dict)
|
125
|
+
keys_to_change = {
|
126
|
+
"parameters": "input_schema",
|
127
|
+
}
|
128
|
+
for old_key, new_key in keys_to_change.items():
|
129
|
+
if old_key in json_dict:
|
130
|
+
if new_key:
|
131
|
+
json_dict[new_key] = json_dict.pop(old_key)
|
132
|
+
else:
|
133
|
+
json_dict.pop(old_key)
|
134
|
+
else:
|
135
|
+
if new_key and "description" in json_dict.keys():
|
136
|
+
json_dict[new_key] = {
|
137
|
+
"type": "object",
|
138
|
+
"properties": {}
|
139
|
+
}
|
140
|
+
if "tools" in json_dict.keys():
|
141
|
+
json_dict["tool_choice"] = {
|
142
|
+
"type": "auto"
|
143
|
+
}
|
144
|
+
return json_dict
|
145
|
+
|
146
|
+
# print("registry.tools", json.dumps(registry.tools_info.get('get_time', {}), indent=4, ensure_ascii=False))
|
147
|
+
# print("registry.tools", json.dumps(registry.tools_info['run_python_script'].to_dict(), indent=4, ensure_ascii=False))
|
148
|
+
|
149
|
+
# 修改PLUGINS定义,使用registry中的工具
|
150
|
+
def get_plugins():
|
151
|
+
return {
|
152
|
+
tool_name: (os.environ.get(tool_name, "False") == "False") == False
|
153
|
+
for tool_name in registry.tools.keys()
|
154
|
+
}
|
155
|
+
|
156
|
+
# 修改function_call_list定义,使用registry中的工具
|
157
|
+
def get_function_call_list():
|
158
|
+
function_list = {}
|
159
|
+
for tool_name, tool_func in registry.tools.items():
|
160
|
+
function_list[tool_name] = function_to_json(tool_func)
|
161
|
+
return function_list
|
162
|
+
|
163
|
+
def get_claude_tools_list():
|
164
|
+
function_list = get_function_call_list()
|
165
|
+
return {f"{key}": gpt2claude_tools_json(function_list[key]) for key in function_list.keys()}
|
166
|
+
|
167
|
+
# 初始化默认配置
|
168
|
+
PLUGINS = get_plugins()
|
169
|
+
function_call_list = get_function_call_list()
|
170
|
+
claude_tools_list = get_claude_tools_list()
|
171
|
+
|
172
|
+
# 动态更新工具函数配置
|
173
|
+
def update_tools_config():
|
174
|
+
global PLUGINS, function_call_list, claude_tools_list
|
175
|
+
PLUGINS = get_plugins()
|
176
|
+
function_call_list = get_function_call_list()
|
177
|
+
claude_tools_list = get_claude_tools_list()
|
178
|
+
return PLUGINS, function_call_list, claude_tools_list
|
aient/plugins/image.py
ADDED
@@ -0,0 +1,72 @@
|
|
1
|
+
import os
|
2
|
+
import requests
|
3
|
+
import json
|
4
|
+
from ..models.base import BaseLLM
|
5
|
+
from .registry import register_tool
|
6
|
+
|
7
|
+
API = os.environ.get('API', None)
|
8
|
+
API_URL = os.environ.get('API_URL', None)
|
9
|
+
|
10
|
+
class dalle3(BaseLLM):
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
api_key: str,
|
14
|
+
api_url: str = (os.environ.get("API_URL") or "https://api.openai.com/v1/images/generations"),
|
15
|
+
timeout: float = 20,
|
16
|
+
):
|
17
|
+
super().__init__(api_key, api_url=api_url, timeout=timeout)
|
18
|
+
self.engine: str = "dall-e-3"
|
19
|
+
|
20
|
+
def generate(
|
21
|
+
self,
|
22
|
+
prompt: str,
|
23
|
+
model: str = "",
|
24
|
+
**kwargs,
|
25
|
+
):
|
26
|
+
url = self.api_url.image_url
|
27
|
+
headers = {"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"}
|
28
|
+
|
29
|
+
json_post = {
|
30
|
+
"model": os.environ.get("IMAGE_MODEL_NAME") or model or self.engine,
|
31
|
+
"prompt": prompt,
|
32
|
+
"n": 1,
|
33
|
+
"size": "1024x1024",
|
34
|
+
}
|
35
|
+
try:
|
36
|
+
response = self.session.post(
|
37
|
+
url,
|
38
|
+
headers=headers,
|
39
|
+
json=json_post,
|
40
|
+
timeout=kwargs.get("timeout", self.timeout),
|
41
|
+
stream=True,
|
42
|
+
)
|
43
|
+
except ConnectionError:
|
44
|
+
print("连接错误,请检查服务器状态或网络连接。")
|
45
|
+
return
|
46
|
+
except requests.exceptions.ReadTimeout:
|
47
|
+
print("请求超时,请检查网络连接或增加超时时间。{e}")
|
48
|
+
return
|
49
|
+
except Exception as e:
|
50
|
+
print(f"发生了未预料的错误: {e}")
|
51
|
+
return
|
52
|
+
|
53
|
+
if response.status_code != 200:
|
54
|
+
raise Exception(f"{response.status_code} {response.reason} {response.text}")
|
55
|
+
json_data = json.loads(response.text)
|
56
|
+
url = json_data["data"][0]["url"]
|
57
|
+
yield url
|
58
|
+
|
59
|
+
@register_tool()
|
60
|
+
def generate_image(text):
|
61
|
+
"""
|
62
|
+
生成图像
|
63
|
+
|
64
|
+
参数:
|
65
|
+
text: 描述图像的文本
|
66
|
+
|
67
|
+
返回:
|
68
|
+
图像的URL
|
69
|
+
"""
|
70
|
+
dallbot = dalle3(api_key=f"{API}")
|
71
|
+
for data in dallbot.generate(text):
|
72
|
+
return data
|
@@ -0,0 +1,116 @@
|
|
1
|
+
from typing import Callable, Dict, Literal, List, Optional
|
2
|
+
from dataclasses import dataclass, asdict
|
3
|
+
import inspect
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class FunctionInfo:
|
7
|
+
name: str
|
8
|
+
func: Callable
|
9
|
+
args: List[str]
|
10
|
+
docstring: Optional[str]
|
11
|
+
body: str
|
12
|
+
return_type: Optional[str]
|
13
|
+
def to_dict(self) -> dict:
|
14
|
+
# using asdict, but exclude func field because it cannot be serialized
|
15
|
+
d = asdict(self)
|
16
|
+
d.pop('func') # remove func field
|
17
|
+
return d
|
18
|
+
|
19
|
+
@classmethod
|
20
|
+
def from_dict(cls, data: dict) -> 'FunctionInfo':
|
21
|
+
# if you need to create an object from a dictionary
|
22
|
+
if 'func' not in data:
|
23
|
+
data['func'] = None # or other default value
|
24
|
+
return cls(**data)
|
25
|
+
|
26
|
+
class Registry:
|
27
|
+
_instance = None
|
28
|
+
_registry: Dict[str, Dict[str, Callable]] = {
|
29
|
+
"tools": {},
|
30
|
+
"agents": {}
|
31
|
+
}
|
32
|
+
_registry_info: Dict[str, Dict[str, FunctionInfo]] = {
|
33
|
+
"tools": {},
|
34
|
+
"agents": {}
|
35
|
+
}
|
36
|
+
|
37
|
+
def __new__(cls):
|
38
|
+
if cls._instance is None:
|
39
|
+
cls._instance = super().__new__(cls)
|
40
|
+
return cls._instance
|
41
|
+
|
42
|
+
def register(self,
|
43
|
+
type: Literal["tool", "agent"],
|
44
|
+
name: str = None):
|
45
|
+
"""
|
46
|
+
统一的注册装饰器
|
47
|
+
Args:
|
48
|
+
type: 注册类型,"tool" 或 "agent"
|
49
|
+
name: 可选的注册名称
|
50
|
+
"""
|
51
|
+
def decorator(func: Callable):
|
52
|
+
nonlocal name
|
53
|
+
if name is None:
|
54
|
+
name = func.__name__
|
55
|
+
# if type == "agent" and name.startswith('get_'):
|
56
|
+
# name = name[4:] # 对 agent 移除 'get_' 前缀
|
57
|
+
|
58
|
+
# 获取函数信息
|
59
|
+
signature = inspect.signature(func)
|
60
|
+
args = list(signature.parameters.keys())
|
61
|
+
docstring = inspect.getdoc(func)
|
62
|
+
|
63
|
+
# 获取函数体
|
64
|
+
source_lines = inspect.getsource(func)
|
65
|
+
# 移除装饰器和函数定义行
|
66
|
+
body_lines = source_lines.split('\n')[1:] # 跳过装饰器行
|
67
|
+
while body_lines and (body_lines[0].strip().startswith('@') or 'def ' in body_lines[0]):
|
68
|
+
body_lines = body_lines[1:]
|
69
|
+
body = '\n'.join(body_lines)
|
70
|
+
|
71
|
+
# 获取返回类型提示
|
72
|
+
return_type = None
|
73
|
+
if signature.return_annotation != inspect.Signature.empty:
|
74
|
+
return_type = str(signature.return_annotation)
|
75
|
+
|
76
|
+
# 创建函数信息对象
|
77
|
+
func_info = FunctionInfo(
|
78
|
+
name=name,
|
79
|
+
func=func,
|
80
|
+
args=args,
|
81
|
+
docstring=docstring,
|
82
|
+
body=body,
|
83
|
+
return_type=return_type
|
84
|
+
)
|
85
|
+
|
86
|
+
registry_type = f"{type}s"
|
87
|
+
self._registry[registry_type][name] = func
|
88
|
+
self._registry_info[registry_type][name] = func_info
|
89
|
+
return func
|
90
|
+
return decorator
|
91
|
+
|
92
|
+
@property
|
93
|
+
def tools(self) -> Dict[str, Callable]:
|
94
|
+
return self._registry["tools"]
|
95
|
+
|
96
|
+
@property
|
97
|
+
def agents(self) -> Dict[str, Callable]:
|
98
|
+
return self._registry["agents"]
|
99
|
+
|
100
|
+
@property
|
101
|
+
def tools_info(self) -> Dict[str, FunctionInfo]:
|
102
|
+
return self._registry_info["tools"]
|
103
|
+
|
104
|
+
@property
|
105
|
+
def agents_info(self) -> Dict[str, FunctionInfo]:
|
106
|
+
return self._registry_info["agents"]
|
107
|
+
|
108
|
+
# 创建全局实例
|
109
|
+
registry = Registry()
|
110
|
+
|
111
|
+
# 便捷的注册函数
|
112
|
+
def register_tool(name: str = None):
|
113
|
+
return registry.register(type="tool", name=name)
|
114
|
+
|
115
|
+
def register_agent(name: str = None):
|
116
|
+
return registry.register(type="agent", name=name)
|
@@ -0,0 +1,156 @@
|
|
1
|
+
import os
|
2
|
+
import ast
|
3
|
+
import asyncio
|
4
|
+
import logging
|
5
|
+
import tempfile
|
6
|
+
from .registry import register_tool
|
7
|
+
|
8
|
+
def get_dangerous_attributes(node):
|
9
|
+
# 简单的代码审查,检查是否包含某些危险关键词
|
10
|
+
dangerous_keywords = ['os', 'subprocess', 'sys', 'import', 'eval', 'exec', 'open']
|
11
|
+
if isinstance(node, ast.Name):
|
12
|
+
return node.id in dangerous_keywords
|
13
|
+
elif isinstance(node, ast.Attribute):
|
14
|
+
return node.attr in dangerous_keywords
|
15
|
+
return False
|
16
|
+
|
17
|
+
def check_code_safety(code):
|
18
|
+
try:
|
19
|
+
# 解析代码为 AST
|
20
|
+
tree = ast.parse(code)
|
21
|
+
|
22
|
+
# 检查所有节点
|
23
|
+
for node in ast.walk(tree):
|
24
|
+
# 检查危险属性访问
|
25
|
+
if get_dangerous_attributes(node):
|
26
|
+
return False
|
27
|
+
|
28
|
+
# 检查危险的调用
|
29
|
+
if isinstance(node, ast.Call):
|
30
|
+
if isinstance(node.func, (ast.Name, ast.Attribute)):
|
31
|
+
if get_dangerous_attributes(node.func):
|
32
|
+
return False
|
33
|
+
|
34
|
+
# 检查字符串编码/解码操作
|
35
|
+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
|
36
|
+
if node.func.attr in ('encode', 'decode'):
|
37
|
+
return False
|
38
|
+
|
39
|
+
return True
|
40
|
+
except SyntaxError:
|
41
|
+
return False
|
42
|
+
|
43
|
+
@register_tool()
|
44
|
+
async def run_python_script(code):
|
45
|
+
"""
|
46
|
+
执行 Python 代码
|
47
|
+
|
48
|
+
参数:
|
49
|
+
code: 要执行的 Python 代码字符串
|
50
|
+
|
51
|
+
返回:
|
52
|
+
执行结果字符串
|
53
|
+
"""
|
54
|
+
|
55
|
+
timeout = 10
|
56
|
+
# 检查代码安全性
|
57
|
+
if not check_code_safety(code):
|
58
|
+
return "Code contains potentially dangerous operations.\n\n"
|
59
|
+
|
60
|
+
# 添加一段捕获代码,确保最后表达式的值会被输出
|
61
|
+
# 这种方式比 ast 解析更可靠
|
62
|
+
wrapper_code = """
|
63
|
+
import sys
|
64
|
+
_result = None
|
65
|
+
|
66
|
+
def _capture_last_result(code_to_run):
|
67
|
+
global _result
|
68
|
+
namespace = {{}}
|
69
|
+
exec(code_to_run, namespace)
|
70
|
+
if "_last_expr" in namespace:
|
71
|
+
_result = namespace["_last_expr"]
|
72
|
+
|
73
|
+
# 用户代码
|
74
|
+
_user_code = '''
|
75
|
+
{}
|
76
|
+
'''
|
77
|
+
|
78
|
+
# 处理用户代码,尝试提取最后一个表达式
|
79
|
+
lines = _user_code.strip().split('\\n')
|
80
|
+
if lines:
|
81
|
+
# 检查最后一行是否是表达式
|
82
|
+
last_line = lines[-1].strip()
|
83
|
+
if last_line and not last_line.startswith(('def ', 'class ', 'if ', 'for ', 'while ', 'try:', 'with ')):
|
84
|
+
if not any(last_line.startswith(kw) for kw in ['return', 'print', 'raise', 'assert', 'import', 'from ']):
|
85
|
+
if not last_line.endswith(':') and not last_line.endswith('='):
|
86
|
+
# 可能是表达式,修改它
|
87
|
+
lines[-1] = "_last_expr = " + last_line
|
88
|
+
_user_code = '\\n'.join(lines)
|
89
|
+
|
90
|
+
_capture_last_result(_user_code)
|
91
|
+
|
92
|
+
# 输出结果
|
93
|
+
if _result is not None:
|
94
|
+
print("\\nResult:", repr(_result))
|
95
|
+
""".format(code)
|
96
|
+
|
97
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as temp_file:
|
98
|
+
temp_file.write(wrapper_code)
|
99
|
+
temp_file_name = temp_file.name
|
100
|
+
|
101
|
+
try:
|
102
|
+
process = await asyncio.create_subprocess_exec(
|
103
|
+
'python', temp_file_name,
|
104
|
+
stdout=asyncio.subprocess.PIPE,
|
105
|
+
stderr=asyncio.subprocess.PIPE
|
106
|
+
)
|
107
|
+
|
108
|
+
try:
|
109
|
+
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
110
|
+
stdout = stdout.decode()
|
111
|
+
stderr = stderr.decode()
|
112
|
+
return_code = process.returncode
|
113
|
+
except asyncio.TimeoutError:
|
114
|
+
# 使用 SIGTERM 信号终止进程
|
115
|
+
process.terminate()
|
116
|
+
await asyncio.sleep(0.1) # 给进程一点时间来终止
|
117
|
+
if process.returncode is None:
|
118
|
+
# 如果进程还没有终止,使用 SIGKILL
|
119
|
+
process.kill()
|
120
|
+
return "Process execution timed out."
|
121
|
+
|
122
|
+
mess = (
|
123
|
+
f"Execution result:\n{stdout}\n",
|
124
|
+
f"Stderr:\n{stderr}\n" if stderr else "",
|
125
|
+
f"Return Code: {return_code}\n" if return_code else "",
|
126
|
+
)
|
127
|
+
mess = "".join(mess)
|
128
|
+
return mess
|
129
|
+
|
130
|
+
except Exception as e:
|
131
|
+
logging.error(f"Error executing code: {str(e)}")
|
132
|
+
return f"Error: {str(e)}"
|
133
|
+
|
134
|
+
finally:
|
135
|
+
try:
|
136
|
+
os.unlink(temp_file_name)
|
137
|
+
except Exception as e:
|
138
|
+
logging.error(f"Error deleting temporary file: {str(e)}")
|
139
|
+
|
140
|
+
# 使用示例
|
141
|
+
async def main():
|
142
|
+
code = """
|
143
|
+
print("Hello, World!")
|
144
|
+
"""
|
145
|
+
code = """
|
146
|
+
def add(a, b):
|
147
|
+
return a + b
|
148
|
+
|
149
|
+
result = add(5, 3)
|
150
|
+
print(result)
|
151
|
+
"""
|
152
|
+
result = await run_python_script(code)
|
153
|
+
print(result)
|
154
|
+
|
155
|
+
if __name__ == "__main__":
|
156
|
+
asyncio.run(main())
|
aient/plugins/today.py
ADDED
@@ -0,0 +1,19 @@
|
|
1
|
+
import pytz
|
2
|
+
import datetime
|
3
|
+
|
4
|
+
from .registry import register_tool
|
5
|
+
|
6
|
+
# Plugins 获取日期时间
|
7
|
+
@register_tool()
|
8
|
+
def get_date_time_weekday():
|
9
|
+
"""
|
10
|
+
获取当前日期时间及星期几
|
11
|
+
|
12
|
+
返回:
|
13
|
+
包含当前日期时间及星期几的字符串
|
14
|
+
"""
|
15
|
+
tz = pytz.timezone('Asia/Shanghai') # 为东八区设置时区
|
16
|
+
now = datetime.datetime.now(tz) # 获取东八区当前时间
|
17
|
+
weekday = now.weekday()
|
18
|
+
weekday_str = ['星期一', '星期二', '星期三', '星期四', '星期五', '星期六', '星期日'][weekday]
|
19
|
+
return "今天是:" + str(now.date()) + ",现在的时间是:" + str(now.time())[:-7] + "," + weekday_str
|