aient 1.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,32 @@
1
+ import os
2
+ import pkgutil
3
+ import importlib
4
+
5
+ # 首先导入registry,因为其他模块中的装饰器依赖它
6
+ from .registry import registry, register_tool, register_agent
7
+
8
+ # 自动导入当前目录下所有的插件模块
9
+ excluded_modules = ['config', 'registry', '__init__']
10
+ current_dir = os.path.dirname(__file__)
11
+
12
+ # 先导入所有模块,确保装饰器被执行
13
+ for _, module_name, _ in pkgutil.iter_modules([current_dir]):
14
+ if module_name not in excluded_modules:
15
+ importlib.import_module(f'.{module_name}', package=__name__)
16
+
17
+ # 然后从config导入必要的定义
18
+ from .config import *
19
+
20
+ # 确保将所有工具函数添加到全局名称空间
21
+ for tool_name, tool_func in registry.tools.items():
22
+ globals()[tool_name] = tool_func
23
+
24
+ __all__ = [
25
+ 'PLUGINS',
26
+ 'function_call_list',
27
+ 'get_tools_result_async',
28
+ 'registry',
29
+ 'register_tool',
30
+ 'register_agent',
31
+ 'update_tools_config',
32
+ ] + list(registry.tools.keys())
aient/plugins/arXiv.py ADDED
@@ -0,0 +1,48 @@
1
+ import requests
2
+
3
+ from ..utils.scripts import Document_extract
4
+ from .registry import register_tool
5
+
6
+ @register_tool()
7
+ async def download_read_arxiv_pdf(arxiv_id: str) -> str:
8
+ """
9
+ 下载指定arXiv ID的论文PDF并提取其内容。
10
+
11
+ 此函数会下载arXiv上的论文PDF文件,保存到指定路径,
12
+ 然后使用文档提取工具读取其内容。
13
+
14
+ Args:
15
+ arxiv_id: arXiv论文的ID,例如'2305.12345'
16
+
17
+ Returns:
18
+ 提取的论文内容文本或失败消息
19
+ """
20
+ # 构造下载PDF的URL
21
+ url = f'https://arxiv.org/pdf/{arxiv_id}.pdf'
22
+
23
+ # 发送HTTP GET请求
24
+ response = requests.get(url)
25
+
26
+ # 检查是否成功获取内容
27
+ if response.status_code == 200:
28
+ # 将PDF内容写入文件
29
+ save_path = "paper.pdf"
30
+ with open(save_path, 'wb') as file:
31
+ file.write(response.content)
32
+ print(f'PDF下载成功,保存路径: {save_path}')
33
+ return await Document_extract(None, save_path)
34
+ else:
35
+ print(f'下载失败,状态码: {response.status_code}')
36
+ return "文件下载失败"
37
+
38
+ if __name__ == '__main__':
39
+ # 示例使用
40
+ arxiv_id = '2305.12345' # 替换为实际的arXiv ID
41
+
42
+ # 测试下载功能
43
+ # print(download_read_arxiv_pdf(arxiv_id))
44
+
45
+ # 测试函数转换为JSON
46
+ # json_result = function_to_json(download_read_arxiv_pdf)
47
+ # import json
48
+ # print(json.dumps(json_result, indent=2, ensure_ascii=False))
@@ -0,0 +1,178 @@
1
+ import os
2
+ import json
3
+ import inspect
4
+
5
+ from .registry import registry
6
+ from ..utils.scripts import cut_message, safe_get
7
+ from ..utils.prompt import search_key_word_prompt, arxiv_doc_user_prompt
8
+
9
+ async def get_tools_result_async(function_call_name, function_full_response, function_call_max_tokens, engine, robot, api_key, api_url, use_plugins, model, add_message, convo_id, language):
10
+ function_response = ""
11
+ if function_call_name in registry.tools:
12
+ function_to_call = registry.tools[function_call_name]
13
+ function_args = registry.tools_info[function_call_name].args
14
+ # required_args = registry.tools_info[function_call_name].required
15
+ if function_args:
16
+ arg = function_args[0]
17
+ else:
18
+ arg = None
19
+ if function_call_name == "get_search_results":
20
+ prompt = json.loads(function_full_response)["query"]
21
+ yield "message_search_stage_1"
22
+ llm = robot(api_key=api_key, api_url=api_url.source_api_url, engine=engine, use_plugins=use_plugins)
23
+ keywords = (await llm.ask_async(search_key_word_prompt.format(source=prompt), model=model)).split("\n")
24
+ print("keywords", keywords)
25
+ keywords = [item.replace("三行关键词是:", "") for item in keywords if "\\x" not in item if item != ""]
26
+ keywords = [prompt] + keywords
27
+ keywords = keywords[:3]
28
+ print("select keywords", keywords)
29
+ async for chunk in function_to_call(keywords):
30
+ if type(chunk) == str:
31
+ yield chunk
32
+ else:
33
+ function_response = "\n\n".join(chunk)
34
+ # function_response = yield chunk
35
+ # function_response = yield from eval(function_call_name)(prompt, keywords)
36
+ function_call_max_tokens = 32000
37
+ function_response, text_len = cut_message(function_response, function_call_max_tokens, engine)
38
+ if function_response:
39
+ function_response = (
40
+ f"You need to response the following question: {prompt}. Search results is provided inside <Search_results></Search_results> XML tags. Your task is to think about the question step by step and then answer the above question in {language} based on the Search results provided. Please response in {language} and adopt a style that is logical, in-depth, and detailed. Note: In order to make the answer appear highly professional, you should be an expert in textual analysis, aiming to make the answer precise and comprehensive. Directly response markdown format, without using markdown code blocks. For each sentence quoting search results, a markdown ordered superscript number url link must be used to indicate the source, e.g., [¹](https://www.example.com)"
41
+ "Here is the Search results, inside <Search_results></Search_results> XML tags:"
42
+ "<Search_results>"
43
+ "{}"
44
+ "</Search_results>"
45
+ ).format(function_response)
46
+ else:
47
+ function_response = "无法找到相关信息,停止使用 tools"
48
+ # user_prompt = f"You need to response the following question: {prompt}. Search results is provided inside <Search_results></Search_results> XML tags. Your task is to think about the question step by step and then answer the above question in {config.language} based on the Search results provided. Please response in {config.language} and adopt a style that is logical, in-depth, and detailed. Note: In order to make the answer appear highly professional, you should be an expert in textual analysis, aiming to make the answer precise and comprehensive. Directly response markdown format, without using markdown code blocks"
49
+ # self.add_to_conversation(user_prompt, "user", convo_id=convo_id)
50
+ elif arg: # generate_image get_url_content run_python_script
51
+ prompt = safe_get(json.loads(function_full_response), arg, default=".")
52
+ if inspect.iscoroutinefunction(function_to_call):
53
+ function_response = await function_to_call(prompt)
54
+ else:
55
+ function_response = function_to_call(prompt)
56
+ function_response, text_len = cut_message(function_response, function_call_max_tokens, engine)
57
+ else: # get_date_time_weekday
58
+ if inspect.iscoroutinefunction(function_to_call):
59
+ function_response = await function_to_call()
60
+ else:
61
+ function_response = function_to_call()
62
+ function_response, text_len = cut_message(function_response, function_call_max_tokens, engine)
63
+
64
+ if function_call_name == "download_read_arxiv_pdf":
65
+ add_message(arxiv_doc_user_prompt, "user", convo_id=convo_id)
66
+
67
+ function_response = (
68
+ f"function_response:{function_response}"
69
+ )
70
+ yield function_response
71
+ # return function_response
72
+
73
+ def function_to_json(func) -> dict:
74
+ """
75
+ 将Python函数转换为JSON可序列化的字典,描述函数的签名,包括名称、描述和参数。
76
+
77
+ Args:
78
+ func: 要转换的函数
79
+
80
+ Returns:
81
+ 表示函数签名的JSON格式字典
82
+ """
83
+ type_map = {
84
+ str: "string",
85
+ int: "integer",
86
+ float: "number",
87
+ bool: "boolean",
88
+ type(None): "null",
89
+ }
90
+
91
+ try:
92
+ signature = inspect.signature(func)
93
+ except ValueError as e:
94
+ raise ValueError(f"获取函数{func.__name__}签名失败: {str(e)}")
95
+
96
+ parameters = {}
97
+ for param in signature.parameters.values():
98
+ try:
99
+ if param.annotation == inspect._empty:
100
+ parameters[param.name] = {"type": "string"}
101
+ else:
102
+ parameters[param.name] = {"type": type_map.get(param.annotation, "string")}
103
+ except KeyError as e:
104
+ raise KeyError(f"未知类型注解 {param.annotation} 用于参数 {param.name}: {str(e)}")
105
+
106
+ required = [
107
+ param.name
108
+ for param in signature.parameters.values()
109
+ if param.default == inspect._empty
110
+ ]
111
+
112
+ return {
113
+ "name": func.__name__,
114
+ "description": func.__doc__ or "",
115
+ "parameters": {
116
+ "type": "object",
117
+ "properties": parameters,
118
+ "required": required,
119
+ },
120
+ }
121
+
122
+ def gpt2claude_tools_json(json_dict):
123
+ import copy
124
+ json_dict = copy.deepcopy(json_dict)
125
+ keys_to_change = {
126
+ "parameters": "input_schema",
127
+ }
128
+ for old_key, new_key in keys_to_change.items():
129
+ if old_key in json_dict:
130
+ if new_key:
131
+ json_dict[new_key] = json_dict.pop(old_key)
132
+ else:
133
+ json_dict.pop(old_key)
134
+ else:
135
+ if new_key and "description" in json_dict.keys():
136
+ json_dict[new_key] = {
137
+ "type": "object",
138
+ "properties": {}
139
+ }
140
+ if "tools" in json_dict.keys():
141
+ json_dict["tool_choice"] = {
142
+ "type": "auto"
143
+ }
144
+ return json_dict
145
+
146
+ # print("registry.tools", json.dumps(registry.tools_info.get('get_time', {}), indent=4, ensure_ascii=False))
147
+ # print("registry.tools", json.dumps(registry.tools_info['run_python_script'].to_dict(), indent=4, ensure_ascii=False))
148
+
149
+ # 修改PLUGINS定义,使用registry中的工具
150
+ def get_plugins():
151
+ return {
152
+ tool_name: (os.environ.get(tool_name, "False") == "False") == False
153
+ for tool_name in registry.tools.keys()
154
+ }
155
+
156
+ # 修改function_call_list定义,使用registry中的工具
157
+ def get_function_call_list():
158
+ function_list = {}
159
+ for tool_name, tool_func in registry.tools.items():
160
+ function_list[tool_name] = function_to_json(tool_func)
161
+ return function_list
162
+
163
+ def get_claude_tools_list():
164
+ function_list = get_function_call_list()
165
+ return {f"{key}": gpt2claude_tools_json(function_list[key]) for key in function_list.keys()}
166
+
167
+ # 初始化默认配置
168
+ PLUGINS = get_plugins()
169
+ function_call_list = get_function_call_list()
170
+ claude_tools_list = get_claude_tools_list()
171
+
172
+ # 动态更新工具函数配置
173
+ def update_tools_config():
174
+ global PLUGINS, function_call_list, claude_tools_list
175
+ PLUGINS = get_plugins()
176
+ function_call_list = get_function_call_list()
177
+ claude_tools_list = get_claude_tools_list()
178
+ return PLUGINS, function_call_list, claude_tools_list
aient/plugins/image.py ADDED
@@ -0,0 +1,72 @@
1
+ import os
2
+ import requests
3
+ import json
4
+ from ..models.base import BaseLLM
5
+ from .registry import register_tool
6
+
7
+ API = os.environ.get('API', None)
8
+ API_URL = os.environ.get('API_URL', None)
9
+
10
+ class dalle3(BaseLLM):
11
+ def __init__(
12
+ self,
13
+ api_key: str,
14
+ api_url: str = (os.environ.get("API_URL") or "https://api.openai.com/v1/images/generations"),
15
+ timeout: float = 20,
16
+ ):
17
+ super().__init__(api_key, api_url=api_url, timeout=timeout)
18
+ self.engine: str = "dall-e-3"
19
+
20
+ def generate(
21
+ self,
22
+ prompt: str,
23
+ model: str = "",
24
+ **kwargs,
25
+ ):
26
+ url = self.api_url.image_url
27
+ headers = {"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"}
28
+
29
+ json_post = {
30
+ "model": os.environ.get("IMAGE_MODEL_NAME") or model or self.engine,
31
+ "prompt": prompt,
32
+ "n": 1,
33
+ "size": "1024x1024",
34
+ }
35
+ try:
36
+ response = self.session.post(
37
+ url,
38
+ headers=headers,
39
+ json=json_post,
40
+ timeout=kwargs.get("timeout", self.timeout),
41
+ stream=True,
42
+ )
43
+ except ConnectionError:
44
+ print("连接错误,请检查服务器状态或网络连接。")
45
+ return
46
+ except requests.exceptions.ReadTimeout:
47
+ print("请求超时,请检查网络连接或增加超时时间。{e}")
48
+ return
49
+ except Exception as e:
50
+ print(f"发生了未预料的错误: {e}")
51
+ return
52
+
53
+ if response.status_code != 200:
54
+ raise Exception(f"{response.status_code} {response.reason} {response.text}")
55
+ json_data = json.loads(response.text)
56
+ url = json_data["data"][0]["url"]
57
+ yield url
58
+
59
+ @register_tool()
60
+ def generate_image(text):
61
+ """
62
+ 生成图像
63
+
64
+ 参数:
65
+ text: 描述图像的文本
66
+
67
+ 返回:
68
+ 图像的URL
69
+ """
70
+ dallbot = dalle3(api_key=f"{API}")
71
+ for data in dallbot.generate(text):
72
+ return data
@@ -0,0 +1,116 @@
1
+ from typing import Callable, Dict, Literal, List, Optional
2
+ from dataclasses import dataclass, asdict
3
+ import inspect
4
+
5
+ @dataclass
6
+ class FunctionInfo:
7
+ name: str
8
+ func: Callable
9
+ args: List[str]
10
+ docstring: Optional[str]
11
+ body: str
12
+ return_type: Optional[str]
13
+ def to_dict(self) -> dict:
14
+ # using asdict, but exclude func field because it cannot be serialized
15
+ d = asdict(self)
16
+ d.pop('func') # remove func field
17
+ return d
18
+
19
+ @classmethod
20
+ def from_dict(cls, data: dict) -> 'FunctionInfo':
21
+ # if you need to create an object from a dictionary
22
+ if 'func' not in data:
23
+ data['func'] = None # or other default value
24
+ return cls(**data)
25
+
26
+ class Registry:
27
+ _instance = None
28
+ _registry: Dict[str, Dict[str, Callable]] = {
29
+ "tools": {},
30
+ "agents": {}
31
+ }
32
+ _registry_info: Dict[str, Dict[str, FunctionInfo]] = {
33
+ "tools": {},
34
+ "agents": {}
35
+ }
36
+
37
+ def __new__(cls):
38
+ if cls._instance is None:
39
+ cls._instance = super().__new__(cls)
40
+ return cls._instance
41
+
42
+ def register(self,
43
+ type: Literal["tool", "agent"],
44
+ name: str = None):
45
+ """
46
+ 统一的注册装饰器
47
+ Args:
48
+ type: 注册类型,"tool" 或 "agent"
49
+ name: 可选的注册名称
50
+ """
51
+ def decorator(func: Callable):
52
+ nonlocal name
53
+ if name is None:
54
+ name = func.__name__
55
+ # if type == "agent" and name.startswith('get_'):
56
+ # name = name[4:] # 对 agent 移除 'get_' 前缀
57
+
58
+ # 获取函数信息
59
+ signature = inspect.signature(func)
60
+ args = list(signature.parameters.keys())
61
+ docstring = inspect.getdoc(func)
62
+
63
+ # 获取函数体
64
+ source_lines = inspect.getsource(func)
65
+ # 移除装饰器和函数定义行
66
+ body_lines = source_lines.split('\n')[1:] # 跳过装饰器行
67
+ while body_lines and (body_lines[0].strip().startswith('@') or 'def ' in body_lines[0]):
68
+ body_lines = body_lines[1:]
69
+ body = '\n'.join(body_lines)
70
+
71
+ # 获取返回类型提示
72
+ return_type = None
73
+ if signature.return_annotation != inspect.Signature.empty:
74
+ return_type = str(signature.return_annotation)
75
+
76
+ # 创建函数信息对象
77
+ func_info = FunctionInfo(
78
+ name=name,
79
+ func=func,
80
+ args=args,
81
+ docstring=docstring,
82
+ body=body,
83
+ return_type=return_type
84
+ )
85
+
86
+ registry_type = f"{type}s"
87
+ self._registry[registry_type][name] = func
88
+ self._registry_info[registry_type][name] = func_info
89
+ return func
90
+ return decorator
91
+
92
+ @property
93
+ def tools(self) -> Dict[str, Callable]:
94
+ return self._registry["tools"]
95
+
96
+ @property
97
+ def agents(self) -> Dict[str, Callable]:
98
+ return self._registry["agents"]
99
+
100
+ @property
101
+ def tools_info(self) -> Dict[str, FunctionInfo]:
102
+ return self._registry_info["tools"]
103
+
104
+ @property
105
+ def agents_info(self) -> Dict[str, FunctionInfo]:
106
+ return self._registry_info["agents"]
107
+
108
+ # 创建全局实例
109
+ registry = Registry()
110
+
111
+ # 便捷的注册函数
112
+ def register_tool(name: str = None):
113
+ return registry.register(type="tool", name=name)
114
+
115
+ def register_agent(name: str = None):
116
+ return registry.register(type="agent", name=name)
@@ -0,0 +1,156 @@
1
+ import os
2
+ import ast
3
+ import asyncio
4
+ import logging
5
+ import tempfile
6
+ from .registry import register_tool
7
+
8
+ def get_dangerous_attributes(node):
9
+ # 简单的代码审查,检查是否包含某些危险关键词
10
+ dangerous_keywords = ['os', 'subprocess', 'sys', 'import', 'eval', 'exec', 'open']
11
+ if isinstance(node, ast.Name):
12
+ return node.id in dangerous_keywords
13
+ elif isinstance(node, ast.Attribute):
14
+ return node.attr in dangerous_keywords
15
+ return False
16
+
17
+ def check_code_safety(code):
18
+ try:
19
+ # 解析代码为 AST
20
+ tree = ast.parse(code)
21
+
22
+ # 检查所有节点
23
+ for node in ast.walk(tree):
24
+ # 检查危险属性访问
25
+ if get_dangerous_attributes(node):
26
+ return False
27
+
28
+ # 检查危险的调用
29
+ if isinstance(node, ast.Call):
30
+ if isinstance(node.func, (ast.Name, ast.Attribute)):
31
+ if get_dangerous_attributes(node.func):
32
+ return False
33
+
34
+ # 检查字符串编码/解码操作
35
+ if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
36
+ if node.func.attr in ('encode', 'decode'):
37
+ return False
38
+
39
+ return True
40
+ except SyntaxError:
41
+ return False
42
+
43
+ @register_tool()
44
+ async def run_python_script(code):
45
+ """
46
+ 执行 Python 代码
47
+
48
+ 参数:
49
+ code: 要执行的 Python 代码字符串
50
+
51
+ 返回:
52
+ 执行结果字符串
53
+ """
54
+
55
+ timeout = 10
56
+ # 检查代码安全性
57
+ if not check_code_safety(code):
58
+ return "Code contains potentially dangerous operations.\n\n"
59
+
60
+ # 添加一段捕获代码,确保最后表达式的值会被输出
61
+ # 这种方式比 ast 解析更可靠
62
+ wrapper_code = """
63
+ import sys
64
+ _result = None
65
+
66
+ def _capture_last_result(code_to_run):
67
+ global _result
68
+ namespace = {{}}
69
+ exec(code_to_run, namespace)
70
+ if "_last_expr" in namespace:
71
+ _result = namespace["_last_expr"]
72
+
73
+ # 用户代码
74
+ _user_code = '''
75
+ {}
76
+ '''
77
+
78
+ # 处理用户代码,尝试提取最后一个表达式
79
+ lines = _user_code.strip().split('\\n')
80
+ if lines:
81
+ # 检查最后一行是否是表达式
82
+ last_line = lines[-1].strip()
83
+ if last_line and not last_line.startswith(('def ', 'class ', 'if ', 'for ', 'while ', 'try:', 'with ')):
84
+ if not any(last_line.startswith(kw) for kw in ['return', 'print', 'raise', 'assert', 'import', 'from ']):
85
+ if not last_line.endswith(':') and not last_line.endswith('='):
86
+ # 可能是表达式,修改它
87
+ lines[-1] = "_last_expr = " + last_line
88
+ _user_code = '\\n'.join(lines)
89
+
90
+ _capture_last_result(_user_code)
91
+
92
+ # 输出结果
93
+ if _result is not None:
94
+ print("\\nResult:", repr(_result))
95
+ """.format(code)
96
+
97
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as temp_file:
98
+ temp_file.write(wrapper_code)
99
+ temp_file_name = temp_file.name
100
+
101
+ try:
102
+ process = await asyncio.create_subprocess_exec(
103
+ 'python', temp_file_name,
104
+ stdout=asyncio.subprocess.PIPE,
105
+ stderr=asyncio.subprocess.PIPE
106
+ )
107
+
108
+ try:
109
+ stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
110
+ stdout = stdout.decode()
111
+ stderr = stderr.decode()
112
+ return_code = process.returncode
113
+ except asyncio.TimeoutError:
114
+ # 使用 SIGTERM 信号终止进程
115
+ process.terminate()
116
+ await asyncio.sleep(0.1) # 给进程一点时间来终止
117
+ if process.returncode is None:
118
+ # 如果进程还没有终止,使用 SIGKILL
119
+ process.kill()
120
+ return "Process execution timed out."
121
+
122
+ mess = (
123
+ f"Execution result:\n{stdout}\n",
124
+ f"Stderr:\n{stderr}\n" if stderr else "",
125
+ f"Return Code: {return_code}\n" if return_code else "",
126
+ )
127
+ mess = "".join(mess)
128
+ return mess
129
+
130
+ except Exception as e:
131
+ logging.error(f"Error executing code: {str(e)}")
132
+ return f"Error: {str(e)}"
133
+
134
+ finally:
135
+ try:
136
+ os.unlink(temp_file_name)
137
+ except Exception as e:
138
+ logging.error(f"Error deleting temporary file: {str(e)}")
139
+
140
+ # 使用示例
141
+ async def main():
142
+ code = """
143
+ print("Hello, World!")
144
+ """
145
+ code = """
146
+ def add(a, b):
147
+ return a + b
148
+
149
+ result = add(5, 3)
150
+ print(result)
151
+ """
152
+ result = await run_python_script(code)
153
+ print(result)
154
+
155
+ if __name__ == "__main__":
156
+ asyncio.run(main())
aient/plugins/today.py ADDED
@@ -0,0 +1,19 @@
1
+ import pytz
2
+ import datetime
3
+
4
+ from .registry import register_tool
5
+
6
+ # Plugins 获取日期时间
7
+ @register_tool()
8
+ def get_date_time_weekday():
9
+ """
10
+ 获取当前日期时间及星期几
11
+
12
+ 返回:
13
+ 包含当前日期时间及星期几的字符串
14
+ """
15
+ tz = pytz.timezone('Asia/Shanghai') # 为东八区设置时区
16
+ now = datetime.datetime.now(tz) # 获取东八区当前时间
17
+ weekday = now.weekday()
18
+ weekday_str = ['星期一', '星期二', '星期三', '星期四', '星期五', '星期六', '星期日'][weekday]
19
+ return "今天是:" + str(now.date()) + ",现在的时间是:" + str(now.time())[:-7] + "," + weekday_str