mofox-plugin-dev-toolkit 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mofox_plugin_dev_toolkit-0.3.3.dist-info/METADATA +730 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/RECORD +46 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/WHEEL +5 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/entry_points.txt +2 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/licenses/LICENSE +674 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/top_level.txt +1 -0
- mpdt/__init__.py +15 -0
- mpdt/__main__.py +8 -0
- mpdt/cli.py +316 -0
- mpdt/commands/__init__.py +9 -0
- mpdt/commands/check.py +498 -0
- mpdt/commands/dev.py +318 -0
- mpdt/commands/generate.py +448 -0
- mpdt/commands/init.py +686 -0
- mpdt/dev/bridge_plugin/__init__.py +17 -0
- mpdt/dev/bridge_plugin/cleanup_handler.py +65 -0
- mpdt/dev/bridge_plugin/dev_config.py +24 -0
- mpdt/dev/bridge_plugin/file_watcher.py +169 -0
- mpdt/dev/bridge_plugin/plugin.py +219 -0
- mpdt/templates/__init__.py +165 -0
- mpdt/templates/action_template.py +102 -0
- mpdt/templates/adapter_template.py +129 -0
- mpdt/templates/chatter_template.py +103 -0
- mpdt/templates/event_template.py +116 -0
- mpdt/templates/plus_command_template.py +150 -0
- mpdt/templates/prompt_template.py +92 -0
- mpdt/templates/router_template.py +175 -0
- mpdt/templates/tool_template.py +98 -0
- mpdt/utils/__init__.py +10 -0
- mpdt/utils/code_parser.py +401 -0
- mpdt/utils/color_printer.py +99 -0
- mpdt/utils/config_loader.py +171 -0
- mpdt/utils/config_manager.py +297 -0
- mpdt/utils/file_ops.py +207 -0
- mpdt/utils/license_generator.py +980 -0
- mpdt/utils/plugin_parser.py +195 -0
- mpdt/utils/template_engine.py +112 -0
- mpdt/validators/__init__.py +26 -0
- mpdt/validators/auto_fix_validator.py +990 -0
- mpdt/validators/base.py +129 -0
- mpdt/validators/component_validator.py +842 -0
- mpdt/validators/config_validator.py +119 -0
- mpdt/validators/metadata_validator.py +107 -0
- mpdt/validators/structure_validator.py +72 -0
- mpdt/validators/style_validator.py +117 -0
- mpdt/validators/type_validator.py +206 -0
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Router 组件模板
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
ROUTER_TEMPLATE = '''"""
|
|
6
|
+
{description}
|
|
7
|
+
|
|
8
|
+
Created by: {author}
|
|
9
|
+
Created at: {date}
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from fastapi import APIRouter, HTTPException
|
|
13
|
+
from src.common.logger import get_logger
|
|
14
|
+
from src.plugin_system import BaseRouterComponent
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class {class_name}(BaseRouterComponent):
|
|
20
|
+
"""
|
|
21
|
+
{description}
|
|
22
|
+
|
|
23
|
+
Router 组件用于对外暴露 HTTP 接口。
|
|
24
|
+
|
|
25
|
+
使用场景:
|
|
26
|
+
- 提供 RESTful API
|
|
27
|
+
- Webhook 接收端点
|
|
28
|
+
- 自定义 HTTP 服务
|
|
29
|
+
- 与外部系统集成
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
# Router 元数据
|
|
33
|
+
component_name: str = "{router_name}"
|
|
34
|
+
component_description: str = "{description}"
|
|
35
|
+
component_version: str = "1.0.0"
|
|
36
|
+
|
|
37
|
+
def register_endpoints(self) -> None:
|
|
38
|
+
"""
|
|
39
|
+
注册 HTTP 端点
|
|
40
|
+
|
|
41
|
+
使用 self.router 来添加路由:
|
|
42
|
+
- @self.router.get("/path")
|
|
43
|
+
- @self.router.post("/path")
|
|
44
|
+
- @self.router.put("/path")
|
|
45
|
+
- @self.router.delete("/path")
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
@self.router.get("/hello")
|
|
49
|
+
async def hello():
|
|
50
|
+
"""
|
|
51
|
+
示例 GET 端点
|
|
52
|
+
"""
|
|
53
|
+
return {{"message": "Hello from {{self.component_name}}"}}
|
|
54
|
+
|
|
55
|
+
@self.router.get("/status")
|
|
56
|
+
async def get_status():
|
|
57
|
+
"""
|
|
58
|
+
获取状态
|
|
59
|
+
"""
|
|
60
|
+
try:
|
|
61
|
+
# TODO: 实现状态检查逻辑
|
|
62
|
+
return {{
|
|
63
|
+
"status": "ok",
|
|
64
|
+
"component": self.component_name,
|
|
65
|
+
"version": self.component_version
|
|
66
|
+
}}
|
|
67
|
+
except Exception as e:
|
|
68
|
+
logger.error(f"获取状态失败: {{e}}")
|
|
69
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
70
|
+
|
|
71
|
+
@self.router.post("/webhook")
|
|
72
|
+
async def webhook(data: dict):
|
|
73
|
+
"""
|
|
74
|
+
Webhook 接收端点
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
data: 接收的数据
|
|
78
|
+
"""
|
|
79
|
+
try:
|
|
80
|
+
logger.info(f"收到 webhook 数据: {{data}}")
|
|
81
|
+
|
|
82
|
+
# TODO: 处理 webhook 数据
|
|
83
|
+
result = await self._process_webhook(data)
|
|
84
|
+
|
|
85
|
+
return {{
|
|
86
|
+
"success": True,
|
|
87
|
+
"result": result
|
|
88
|
+
}}
|
|
89
|
+
except Exception as e:
|
|
90
|
+
logger.error(f"处理 webhook 失败: {{e}}")
|
|
91
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
92
|
+
|
|
93
|
+
@self.router.get("/data/{{item_id}}")
|
|
94
|
+
async def get_item(item_id: str):
|
|
95
|
+
"""
|
|
96
|
+
获取指定项目
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
item_id: 项目ID
|
|
100
|
+
"""
|
|
101
|
+
try:
|
|
102
|
+
# TODO: 实现获取逻辑
|
|
103
|
+
item = await self._get_item(item_id)
|
|
104
|
+
if not item:
|
|
105
|
+
raise HTTPException(status_code=404, detail="Item not found")
|
|
106
|
+
return item
|
|
107
|
+
except HTTPException:
|
|
108
|
+
raise
|
|
109
|
+
except Exception as e:
|
|
110
|
+
logger.error(f"获取项目失败: {{e}}")
|
|
111
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
112
|
+
|
|
113
|
+
@self.router.post("/data")
|
|
114
|
+
async def create_item(data: dict):
|
|
115
|
+
"""
|
|
116
|
+
创建新项目
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
data: 项目数据
|
|
120
|
+
"""
|
|
121
|
+
try:
|
|
122
|
+
# TODO: 实现创建逻辑
|
|
123
|
+
result = await self._create_item(data)
|
|
124
|
+
return {{
|
|
125
|
+
"success": True,
|
|
126
|
+
"item_id": result
|
|
127
|
+
}}
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error(f"创建项目失败: {{e}}")
|
|
130
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
131
|
+
|
|
132
|
+
async def _process_webhook(self, data: dict) -> dict:
|
|
133
|
+
"""
|
|
134
|
+
处理 webhook 数据
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
data: webhook 数据
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
处理结果
|
|
141
|
+
"""
|
|
142
|
+
# TODO: 实现 webhook 处理逻辑
|
|
143
|
+
return {{"processed": True}}
|
|
144
|
+
|
|
145
|
+
async def _get_item(self, item_id: str) -> dict | None:
|
|
146
|
+
"""
|
|
147
|
+
获取项目
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
item_id: 项目ID
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
项目数据或 None
|
|
154
|
+
"""
|
|
155
|
+
# TODO: 实现获取逻辑
|
|
156
|
+
return {{"id": item_id, "name": "示例项目"}}
|
|
157
|
+
|
|
158
|
+
async def _create_item(self, data: dict) -> str:
|
|
159
|
+
"""
|
|
160
|
+
创建项目
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
data: 项目数据
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
新项目ID
|
|
167
|
+
"""
|
|
168
|
+
# TODO: 实现创建逻辑
|
|
169
|
+
return "new_item_id"
|
|
170
|
+
'''
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def get_router_template() -> str:
|
|
174
|
+
"""获取 Router 组件模板"""
|
|
175
|
+
return ROUTER_TEMPLATE
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tool 组件模板
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
TOOL_TEMPLATE = '''"""
|
|
6
|
+
{description}
|
|
7
|
+
|
|
8
|
+
Created by: {author}
|
|
9
|
+
Created at: {date}
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from src.common.logger import get_logger
|
|
15
|
+
from src.plugin_system import BaseTool, ToolParamType
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class {class_name}(BaseTool):
|
|
21
|
+
"""
|
|
22
|
+
{description}
|
|
23
|
+
|
|
24
|
+
Tool 组件可以被 LLM 调用来执行特定功能。
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# Tool 元数据
|
|
28
|
+
name: str = "{tool_name}"
|
|
29
|
+
description: str = "{description}"
|
|
30
|
+
available_for_llm: bool = True # 是否可供 LLM 使用
|
|
31
|
+
|
|
32
|
+
# 定义工具参数
|
|
33
|
+
# 格式: [("参数名", 参数类型, "参数描述", 是否必填, 枚举值列表)]
|
|
34
|
+
parameters = [
|
|
35
|
+
("query", ToolParamType.STRING, "查询内容", True, None),
|
|
36
|
+
("limit", ToolParamType.INTEGER, "返回结果数量限制", False, None),
|
|
37
|
+
("format", ToolParamType.STRING, "输出格式", False, ["json", "text", "markdown"]),
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
# 缓存配置(可选)
|
|
41
|
+
enable_cache: bool = False # 是否启用缓存
|
|
42
|
+
cache_ttl: int = 3600 # 缓存过期时间(秒)
|
|
43
|
+
|
|
44
|
+
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
|
45
|
+
"""
|
|
46
|
+
执行工具功能(供 LLM 调用)
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
function_args: LLM 传入的参数,格式符合 parameters 定义
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
执行结果字典
|
|
53
|
+
"""
|
|
54
|
+
try:
|
|
55
|
+
logger.info(f"执行 Tool: {{self.name}}")
|
|
56
|
+
logger.debug(f"参数: {{function_args}}")
|
|
57
|
+
|
|
58
|
+
# 获取参数
|
|
59
|
+
query = function_args.get("query")
|
|
60
|
+
limit = function_args.get("limit", 10)
|
|
61
|
+
output_format = function_args.get("format", "text")
|
|
62
|
+
|
|
63
|
+
# TODO: 实现工具的核心逻辑
|
|
64
|
+
result_data = self._process_query(query, limit)
|
|
65
|
+
|
|
66
|
+
# 格式化返回结果
|
|
67
|
+
return {{
|
|
68
|
+
"status": "success",
|
|
69
|
+
"data": result_data,
|
|
70
|
+
"message": "执行成功"
|
|
71
|
+
}}
|
|
72
|
+
|
|
73
|
+
except Exception as e:
|
|
74
|
+
logger.error(f"Tool 执行失败: {{e}}")
|
|
75
|
+
return {{
|
|
76
|
+
"status": "error",
|
|
77
|
+
"message": str(e)
|
|
78
|
+
}}
|
|
79
|
+
|
|
80
|
+
def _process_query(self, query: str, limit: int) -> Any:
|
|
81
|
+
"""
|
|
82
|
+
处理查询的核心逻辑
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
query: 查询内容
|
|
86
|
+
limit: 结果数量限制
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
处理结果
|
|
90
|
+
"""
|
|
91
|
+
# TODO: 实现具体的处理逻辑
|
|
92
|
+
return {{"query": query, "count": limit}}
|
|
93
|
+
'''
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def get_tool_template() -> str:
|
|
97
|
+
"""获取 Tool 组件模板"""
|
|
98
|
+
return TOOL_TEMPLATE
|
mpdt/utils/__init__.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
1
|
+
"""
|
|
2
|
+
代码解析器 - 使用 libcst 保留注释和格式
|
|
3
|
+
|
|
4
|
+
这是一个统一的代码解析工具,用于替代直接使用 ast 模块。
|
|
5
|
+
与 ast 不同,libcst(Concrete Syntax Tree)会保留所有的注释、空白和格式。
|
|
6
|
+
|
|
7
|
+
主要功能:
|
|
8
|
+
- 解析 Python 代码而不丢失注释
|
|
9
|
+
- 提取类定义、函数定义、赋值语句等
|
|
10
|
+
- 查找特定的类属性或方法
|
|
11
|
+
- 支持代码修改并保留原有格式
|
|
12
|
+
|
|
13
|
+
使用示例:
|
|
14
|
+
>>> from mpdt.utils.code_parser import CodeParser
|
|
15
|
+
>>> parser = CodeParser.from_file("plugin.py")
|
|
16
|
+
>>> plugin_name = parser.find_class_attribute("BasePlugin", "plugin_name")
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
import libcst as cst
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CodeParser:
|
|
26
|
+
"""代码解析器 - 保留注释的 Python 代码解析
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
module: libcst 的模块树
|
|
30
|
+
source: 原始源代码字符串
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, source: str):
|
|
34
|
+
"""初始化代码解析器
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
source: Python 源代码字符串
|
|
38
|
+
"""
|
|
39
|
+
self.source = source
|
|
40
|
+
self.module = cst.parse_module(source)
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def from_file(cls, file_path: Path | str, encoding: str = "utf-8") -> "CodeParser":
|
|
44
|
+
"""从文件创建代码解析器
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
file_path: 文件路径
|
|
48
|
+
encoding: 文件编码,默认 utf-8
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
CodeParser 实例
|
|
52
|
+
"""
|
|
53
|
+
if isinstance(file_path, str):
|
|
54
|
+
file_path = Path(file_path)
|
|
55
|
+
|
|
56
|
+
with open(file_path, encoding=encoding) as f:
|
|
57
|
+
source = f.read()
|
|
58
|
+
|
|
59
|
+
return cls(source)
|
|
60
|
+
|
|
61
|
+
def find_class(self, class_name: str | None = None, base_class: str | None = None) -> list[cst.ClassDef]:
|
|
62
|
+
"""查找类定义
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
class_name: 类名,如果为 None 则匹配所有类
|
|
66
|
+
base_class: 基类名,如果为 None 则不检查基类
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
匹配的类定义列表
|
|
70
|
+
"""
|
|
71
|
+
visitor = ClassFinderVisitor(class_name=class_name, base_class=base_class)
|
|
72
|
+
self.module.visit(visitor)
|
|
73
|
+
return visitor.found_classes
|
|
74
|
+
|
|
75
|
+
def find_class_attribute(
|
|
76
|
+
self,
|
|
77
|
+
base_class: str | None = None,
|
|
78
|
+
attribute_name: str | None = None,
|
|
79
|
+
class_name: str | None = None,
|
|
80
|
+
) -> Any:
|
|
81
|
+
"""在类中查找属性的值
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
base_class: 基类名,用于过滤类
|
|
85
|
+
attribute_name: 属性名
|
|
86
|
+
class_name: 类名,用于精确匹配特定类
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
属性值,如果找到多个返回第一个,未找到返回 None
|
|
90
|
+
"""
|
|
91
|
+
classes = self.find_class(class_name=class_name, base_class=base_class)
|
|
92
|
+
|
|
93
|
+
for cls in classes:
|
|
94
|
+
for statement in cls.body.body:
|
|
95
|
+
# 处理简单赋值: attr = value
|
|
96
|
+
if isinstance(statement, cst.SimpleStatementLine):
|
|
97
|
+
for node in statement.body:
|
|
98
|
+
if isinstance(node, cst.Assign):
|
|
99
|
+
for target in node.targets:
|
|
100
|
+
if isinstance(target.target, cst.Name):
|
|
101
|
+
if attribute_name is None or target.target.value == attribute_name:
|
|
102
|
+
return self._extract_value(node.value)
|
|
103
|
+
|
|
104
|
+
# 处理带类型注解的赋值: attr: Type = value
|
|
105
|
+
elif isinstance(node, cst.AnnAssign):
|
|
106
|
+
if isinstance(node.target, cst.Name):
|
|
107
|
+
if attribute_name is None or node.target.value == attribute_name:
|
|
108
|
+
if node.value:
|
|
109
|
+
return self._extract_value(node.value)
|
|
110
|
+
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
def find_all_class_attributes(
|
|
114
|
+
self,
|
|
115
|
+
base_class: str | None = None,
|
|
116
|
+
class_name: str | None = None,
|
|
117
|
+
) -> dict[str, Any]:
|
|
118
|
+
"""获取类中的所有属性
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
base_class: 基类名
|
|
122
|
+
class_name: 类名
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
属性名到值的字典
|
|
126
|
+
"""
|
|
127
|
+
classes = self.find_class(class_name=class_name, base_class=base_class)
|
|
128
|
+
attributes = {}
|
|
129
|
+
|
|
130
|
+
for cls in classes:
|
|
131
|
+
for statement in cls.body.body:
|
|
132
|
+
if isinstance(statement, cst.SimpleStatementLine):
|
|
133
|
+
for node in statement.body:
|
|
134
|
+
if isinstance(node, cst.Assign):
|
|
135
|
+
for target in node.targets:
|
|
136
|
+
if isinstance(target.target, cst.Name):
|
|
137
|
+
attr_name = target.target.value
|
|
138
|
+
attributes[attr_name] = self._extract_value(node.value)
|
|
139
|
+
|
|
140
|
+
elif isinstance(node, cst.AnnAssign):
|
|
141
|
+
if isinstance(node.target, cst.Name):
|
|
142
|
+
attr_name = node.target.value
|
|
143
|
+
if node.value:
|
|
144
|
+
attributes[attr_name] = self._extract_value(node.value)
|
|
145
|
+
|
|
146
|
+
return attributes
|
|
147
|
+
|
|
148
|
+
def has_class_attribute(
|
|
149
|
+
self,
|
|
150
|
+
attribute_name: str,
|
|
151
|
+
base_class: str | None = None,
|
|
152
|
+
class_name: str | None = None,
|
|
153
|
+
) -> bool:
|
|
154
|
+
"""检查类中是否存在某个属性
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
attribute_name: 属性名
|
|
158
|
+
base_class: 基类名
|
|
159
|
+
class_name: 类名
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
是否存在该属性
|
|
163
|
+
"""
|
|
164
|
+
return self.find_class_attribute(base_class, attribute_name, class_name) is not None
|
|
165
|
+
|
|
166
|
+
def find_assignments(self, variable_name: str) -> list[Any]:
|
|
167
|
+
"""查找模块级别的赋值语句
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
variable_name: 变量名
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
赋值值的列表
|
|
174
|
+
"""
|
|
175
|
+
visitor = AssignmentFinderVisitor(variable_name)
|
|
176
|
+
self.module.visit(visitor)
|
|
177
|
+
return visitor.found_values
|
|
178
|
+
|
|
179
|
+
def find_call_arguments(self, variable_name: str, function_name: str | None = None) -> dict[str, Any] | None:
|
|
180
|
+
"""查找变量赋值中的函数调用参数
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
variable_name: 变量名(如 __plugin_meta__)
|
|
184
|
+
function_name: 函数名(如 PluginMetadata),如果为 None 则匹配任何调用
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
参数字典 {参数名: 参数值},如果未找到返回 None
|
|
188
|
+
"""
|
|
189
|
+
visitor = CallArgumentsFinderVisitor(variable_name, function_name)
|
|
190
|
+
self.module.visit(visitor)
|
|
191
|
+
if visitor.found_arguments:
|
|
192
|
+
# 提取参数值
|
|
193
|
+
result = {}
|
|
194
|
+
for arg_name, arg_value in visitor.found_arguments.items():
|
|
195
|
+
result[arg_name] = self._extract_value(arg_value)
|
|
196
|
+
return result
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
def get_missing_call_arguments(self, variable_name: str, required_args: list[str], function_name: str | None = None) -> list[str]:
|
|
200
|
+
"""获取函数调用中缺失的必需参数
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
variable_name: 变量名
|
|
204
|
+
required_args: 必需参数列表
|
|
205
|
+
function_name: 函数名
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
缺失的参数名列表
|
|
209
|
+
"""
|
|
210
|
+
current_args = self.find_call_arguments(variable_name, function_name)
|
|
211
|
+
if current_args is None:
|
|
212
|
+
return required_args
|
|
213
|
+
|
|
214
|
+
missing = []
|
|
215
|
+
for arg in required_args:
|
|
216
|
+
if arg not in current_args or current_args[arg] is None or current_args[arg] == "":
|
|
217
|
+
missing.append(arg)
|
|
218
|
+
return missing
|
|
219
|
+
|
|
220
|
+
def _extract_value(self, node: cst.BaseExpression) -> Any:
|
|
221
|
+
"""从 CST 节点中提取 Python 值
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
node: CST 表达式节点
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
提取的 Python 值
|
|
228
|
+
"""
|
|
229
|
+
# 处理字符串字面量
|
|
230
|
+
if isinstance(node, (cst.SimpleString, cst.ConcatenatedString)):
|
|
231
|
+
try:
|
|
232
|
+
return node.evaluated_value
|
|
233
|
+
except Exception:
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
# 处理整数
|
|
237
|
+
if isinstance(node, cst.Integer):
|
|
238
|
+
return int(node.value)
|
|
239
|
+
|
|
240
|
+
# 处理浮点数
|
|
241
|
+
if isinstance(node, cst.Float):
|
|
242
|
+
return float(node.value)
|
|
243
|
+
|
|
244
|
+
# 处理布尔值和 None
|
|
245
|
+
if isinstance(node, cst.Name):
|
|
246
|
+
if node.value == "True":
|
|
247
|
+
return True
|
|
248
|
+
elif node.value == "False":
|
|
249
|
+
return False
|
|
250
|
+
elif node.value == "None":
|
|
251
|
+
return None
|
|
252
|
+
|
|
253
|
+
# 处理字典
|
|
254
|
+
if isinstance(node, cst.Dict):
|
|
255
|
+
result = {}
|
|
256
|
+
for element in node.elements:
|
|
257
|
+
if isinstance(element, cst.DictElement):
|
|
258
|
+
key = self._extract_value(element.key)
|
|
259
|
+
value = self._extract_value(element.value)
|
|
260
|
+
if key is not None:
|
|
261
|
+
result[key] = value
|
|
262
|
+
return result
|
|
263
|
+
|
|
264
|
+
# 处理列表
|
|
265
|
+
if isinstance(node, cst.List):
|
|
266
|
+
return [self._extract_value(el.value) for el in node.elements if isinstance(el, cst.Element)]
|
|
267
|
+
|
|
268
|
+
# 处理元组
|
|
269
|
+
if isinstance(node, cst.Tuple):
|
|
270
|
+
return tuple(self._extract_value(el.value) for el in node.elements if isinstance(el, cst.Element))
|
|
271
|
+
|
|
272
|
+
# 处理集合
|
|
273
|
+
if isinstance(node, cst.Set):
|
|
274
|
+
return {self._extract_value(el.value) for el in node.elements if isinstance(el, cst.Element)}
|
|
275
|
+
|
|
276
|
+
# 无法提取的复杂表达式返回 None
|
|
277
|
+
return None
|
|
278
|
+
|
|
279
|
+
def get_code(self) -> str:
|
|
280
|
+
"""获取当前的代码(包含所有修改)
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
代码字符串
|
|
284
|
+
"""
|
|
285
|
+
return self.module.code
|
|
286
|
+
|
|
287
|
+
def save_to_file(self, file_path: Path | str, encoding: str = "utf-8") -> None:
|
|
288
|
+
"""保存代码到文件
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
file_path: 文件路径
|
|
292
|
+
encoding: 文件编码
|
|
293
|
+
"""
|
|
294
|
+
if isinstance(file_path, str):
|
|
295
|
+
file_path = Path(file_path)
|
|
296
|
+
|
|
297
|
+
with open(file_path, "w", encoding=encoding) as f:
|
|
298
|
+
f.write(self.get_code())
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class ClassFinderVisitor(cst.CSTVisitor):
|
|
302
|
+
"""查找类定义的访问器"""
|
|
303
|
+
|
|
304
|
+
def __init__(self, class_name: str | None = None, base_class: str | None = None):
|
|
305
|
+
self.class_name = class_name
|
|
306
|
+
self.base_class = base_class
|
|
307
|
+
self.found_classes: list[cst.ClassDef] = []
|
|
308
|
+
|
|
309
|
+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
|
310
|
+
"""访问类定义"""
|
|
311
|
+
# 检查类名
|
|
312
|
+
if self.class_name is not None and node.name.value != self.class_name:
|
|
313
|
+
return
|
|
314
|
+
|
|
315
|
+
# 检查基类
|
|
316
|
+
if self.base_class is not None:
|
|
317
|
+
has_base = False
|
|
318
|
+
for arg in node.bases:
|
|
319
|
+
if isinstance(arg.value, cst.Name) and arg.value.value == self.base_class:
|
|
320
|
+
has_base = True
|
|
321
|
+
break
|
|
322
|
+
# 处理带模块的基类,如 module.BaseClass
|
|
323
|
+
elif isinstance(arg.value, cst.Attribute) and arg.value.attr.value == self.base_class:
|
|
324
|
+
has_base = True
|
|
325
|
+
break
|
|
326
|
+
|
|
327
|
+
if not has_base:
|
|
328
|
+
return
|
|
329
|
+
|
|
330
|
+
self.found_classes.append(node)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
class AssignmentFinderVisitor(cst.CSTVisitor):
|
|
334
|
+
"""查找赋值语句的访问器"""
|
|
335
|
+
|
|
336
|
+
def __init__(self, variable_name: str):
|
|
337
|
+
self.variable_name = variable_name
|
|
338
|
+
self.found_values: list[Any] = []
|
|
339
|
+
self.parser = None
|
|
340
|
+
|
|
341
|
+
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
|
|
342
|
+
"""访问简单语句"""
|
|
343
|
+
for statement in node.body:
|
|
344
|
+
if isinstance(statement, cst.Assign):
|
|
345
|
+
for target in statement.targets:
|
|
346
|
+
if isinstance(target.target, cst.Name) and target.target.value == self.variable_name:
|
|
347
|
+
# 需要一个 CodeParser 实例来提取值
|
|
348
|
+
# 这里我们暂时保存节点,让调用者来提取
|
|
349
|
+
self.found_values.append(statement.value)
|
|
350
|
+
|
|
351
|
+
elif isinstance(statement, cst.AnnAssign):
|
|
352
|
+
if isinstance(statement.target, cst.Name) and statement.target.value == self.variable_name:
|
|
353
|
+
if statement.value:
|
|
354
|
+
self.found_values.append(statement.value)
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class CallArgumentsFinderVisitor(cst.CSTVisitor):
|
|
358
|
+
"""查找函数调用参数的访问器"""
|
|
359
|
+
|
|
360
|
+
def __init__(self, variable_name: str, function_name: str | None = None):
|
|
361
|
+
self.variable_name = variable_name
|
|
362
|
+
self.function_name = function_name
|
|
363
|
+
self.found_arguments: dict[str, cst.BaseExpression] = {}
|
|
364
|
+
|
|
365
|
+
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
|
|
366
|
+
"""访问简单语句"""
|
|
367
|
+
for statement in node.body:
|
|
368
|
+
# 处理普通赋值: var = FunctionCall(...)
|
|
369
|
+
if isinstance(statement, cst.Assign):
|
|
370
|
+
for target in statement.targets:
|
|
371
|
+
if isinstance(target.target, cst.Name) and target.target.value == self.variable_name:
|
|
372
|
+
self._extract_call_arguments(statement.value)
|
|
373
|
+
|
|
374
|
+
# 处理带类型注解的赋值: var: Type = FunctionCall(...)
|
|
375
|
+
elif isinstance(statement, cst.AnnAssign):
|
|
376
|
+
if isinstance(statement.target, cst.Name) and statement.target.value == self.variable_name:
|
|
377
|
+
if statement.value:
|
|
378
|
+
self._extract_call_arguments(statement.value)
|
|
379
|
+
|
|
380
|
+
def _extract_call_arguments(self, node: cst.BaseExpression) -> None:
|
|
381
|
+
"""从表达式中提取函数调用参数"""
|
|
382
|
+
if not isinstance(node, cst.Call):
|
|
383
|
+
return
|
|
384
|
+
|
|
385
|
+
# 检查函数名是否匹配
|
|
386
|
+
if self.function_name is not None:
|
|
387
|
+
func_name = None
|
|
388
|
+
if isinstance(node.func, cst.Name):
|
|
389
|
+
func_name = node.func.value
|
|
390
|
+
elif isinstance(node.func, cst.Attribute):
|
|
391
|
+
func_name = node.func.attr.value
|
|
392
|
+
|
|
393
|
+
if func_name != self.function_name:
|
|
394
|
+
return
|
|
395
|
+
|
|
396
|
+
# 提取参数
|
|
397
|
+
for arg in node.args:
|
|
398
|
+
if arg.keyword:
|
|
399
|
+
# 关键字参数: name=value
|
|
400
|
+
arg_name = arg.keyword.value
|
|
401
|
+
self.found_arguments[arg_name] = arg.value
|