mofox-plugin-dev-toolkit 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mofox_plugin_dev_toolkit-0.3.3.dist-info/METADATA +730 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/RECORD +46 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/WHEEL +5 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/entry_points.txt +2 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/licenses/LICENSE +674 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/top_level.txt +1 -0
- mpdt/__init__.py +15 -0
- mpdt/__main__.py +8 -0
- mpdt/cli.py +316 -0
- mpdt/commands/__init__.py +9 -0
- mpdt/commands/check.py +498 -0
- mpdt/commands/dev.py +318 -0
- mpdt/commands/generate.py +448 -0
- mpdt/commands/init.py +686 -0
- mpdt/dev/bridge_plugin/__init__.py +17 -0
- mpdt/dev/bridge_plugin/cleanup_handler.py +65 -0
- mpdt/dev/bridge_plugin/dev_config.py +24 -0
- mpdt/dev/bridge_plugin/file_watcher.py +169 -0
- mpdt/dev/bridge_plugin/plugin.py +219 -0
- mpdt/templates/__init__.py +165 -0
- mpdt/templates/action_template.py +102 -0
- mpdt/templates/adapter_template.py +129 -0
- mpdt/templates/chatter_template.py +103 -0
- mpdt/templates/event_template.py +116 -0
- mpdt/templates/plus_command_template.py +150 -0
- mpdt/templates/prompt_template.py +92 -0
- mpdt/templates/router_template.py +175 -0
- mpdt/templates/tool_template.py +98 -0
- mpdt/utils/__init__.py +10 -0
- mpdt/utils/code_parser.py +401 -0
- mpdt/utils/color_printer.py +99 -0
- mpdt/utils/config_loader.py +171 -0
- mpdt/utils/config_manager.py +297 -0
- mpdt/utils/file_ops.py +207 -0
- mpdt/utils/license_generator.py +980 -0
- mpdt/utils/plugin_parser.py +195 -0
- mpdt/utils/template_engine.py +112 -0
- mpdt/validators/__init__.py +26 -0
- mpdt/validators/auto_fix_validator.py +990 -0
- mpdt/validators/base.py +129 -0
- mpdt/validators/component_validator.py +842 -0
- mpdt/validators/config_validator.py +119 -0
- mpdt/validators/metadata_validator.py +107 -0
- mpdt/validators/structure_validator.py +72 -0
- mpdt/validators/style_validator.py +117 -0
- mpdt/validators/type_validator.py +206 -0
|
@@ -0,0 +1,990 @@
|
|
|
1
|
+
"""自动修复验证器
|
|
2
|
+
|
|
3
|
+
提供自动修复常见问题的功能,可以接收其他验证器的报错并尝试自动修复
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import ast
|
|
7
|
+
import re
|
|
8
|
+
import subprocess
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
import libcst as cst
|
|
12
|
+
|
|
13
|
+
from .base import BaseValidator, ValidationIssue, ValidationLevel, ValidationResult
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AutoFixValidator(BaseValidator):
|
|
17
|
+
"""自动修复验证器
|
|
18
|
+
|
|
19
|
+
接收其他验证器的错误并尝试自动修复
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, plugin_path: Path):
|
|
23
|
+
super().__init__(plugin_path)
|
|
24
|
+
self.fixes_applied = []
|
|
25
|
+
self.fixes_failed = []
|
|
26
|
+
self.fixed_issues = [] # 记录已修复的原始问题
|
|
27
|
+
|
|
28
|
+
def validate(self) -> ValidationResult:
|
|
29
|
+
"""执行自动修复(实际上是 fix 而非 validate)
|
|
30
|
+
|
|
31
|
+
这是一个兼容方法,建议使用 fix_issues 方法
|
|
32
|
+
"""
|
|
33
|
+
result = ValidationResult(validator_name="AutoFixValidator", success=True)
|
|
34
|
+
|
|
35
|
+
plugin_name = self._get_plugin_name()
|
|
36
|
+
if not plugin_name:
|
|
37
|
+
result.add_error("无法确定插件名称")
|
|
38
|
+
return result
|
|
39
|
+
|
|
40
|
+
# 修复导入顺序
|
|
41
|
+
|
|
42
|
+
# 汇总修复结果
|
|
43
|
+
if self.fixes_applied:
|
|
44
|
+
result.add_info(f"应用了 {len(self.fixes_applied)} 个自动修复")
|
|
45
|
+
for fix in self.fixes_applied:
|
|
46
|
+
result.add_info(fix)
|
|
47
|
+
else:
|
|
48
|
+
result.add_info("未发现可自动修复的问题")
|
|
49
|
+
|
|
50
|
+
return result
|
|
51
|
+
|
|
52
|
+
def fix_issues(self, validation_results: list[ValidationResult]) -> ValidationResult:
|
|
53
|
+
"""自动修复验证问题
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
validation_results: 其他验证器返回的验证结果列表
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
修复结果
|
|
60
|
+
"""
|
|
61
|
+
result = ValidationResult(validator_name="AutoFixValidator", success=True)
|
|
62
|
+
|
|
63
|
+
plugin_name = self._get_plugin_name()
|
|
64
|
+
if not plugin_name:
|
|
65
|
+
return result
|
|
66
|
+
|
|
67
|
+
# 收集所有需要修复的问题
|
|
68
|
+
all_issues = []
|
|
69
|
+
for validation_result in validation_results:
|
|
70
|
+
for issue in validation_result.issues:
|
|
71
|
+
if issue.level == ValidationLevel.ERROR or issue.level == ValidationLevel.WARNING:
|
|
72
|
+
all_issues.append(issue)
|
|
73
|
+
|
|
74
|
+
if not all_issues:
|
|
75
|
+
return result
|
|
76
|
+
|
|
77
|
+
# 按问题类型分类并修复
|
|
78
|
+
self._fix_missing_plugin_meta(all_issues, result)
|
|
79
|
+
self._fix_missing_metadata_issues(all_issues, result)
|
|
80
|
+
self._fix_missing_component_fields(all_issues, result)
|
|
81
|
+
self._fix_missing_methods(all_issues, result)
|
|
82
|
+
self._fix_method_signatures(all_issues, result)
|
|
83
|
+
self._fix_style_issues(all_issues, result)
|
|
84
|
+
|
|
85
|
+
return result
|
|
86
|
+
|
|
87
|
+
def _fix_missing_plugin_meta(self, issues: list[ValidationIssue], result: ValidationResult) -> None:
|
|
88
|
+
"""修复缺失的 __plugin_meta__ 变量"""
|
|
89
|
+
for issue in issues:
|
|
90
|
+
if "未找到 __plugin_meta__ 变量" in issue.message or "未找到 __plugin_meta__" in issue.message:
|
|
91
|
+
try:
|
|
92
|
+
# 查找 __init__.py 文件
|
|
93
|
+
init_file = self.plugin_path / "__init__.py"
|
|
94
|
+
if not init_file.exists():
|
|
95
|
+
self.fixes_failed.append("未找到 __init__.py 文件")
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
before_count = len(self.fixes_applied)
|
|
99
|
+
self._add_plugin_meta_variable(init_file, issue)
|
|
100
|
+
if len(self.fixes_applied) > before_count:
|
|
101
|
+
self.fixed_issues.append(issue)
|
|
102
|
+
except Exception as e:
|
|
103
|
+
self.fixes_failed.append(f"修复 __plugin_meta__ 变量失败: {e}")
|
|
104
|
+
|
|
105
|
+
def _fix_missing_metadata_issues(self, issues: list[ValidationIssue], result: ValidationResult) -> None:
|
|
106
|
+
"""修复缺失的元数据问题"""
|
|
107
|
+
for issue in issues:
|
|
108
|
+
# 修复 PluginMetadata 缺失参数
|
|
109
|
+
if "PluginMetadata 缺少必需字段" in issue.message:
|
|
110
|
+
try:
|
|
111
|
+
# 从消息中提取字段名
|
|
112
|
+
match = re.search(r"缺少必需字段[::]\s*(\w+)", issue.message)
|
|
113
|
+
if match:
|
|
114
|
+
field_name = match.group(1)
|
|
115
|
+
init_file = self.plugin_path / "__init__.py"
|
|
116
|
+
if init_file.exists():
|
|
117
|
+
before_count = len(self.fixes_applied)
|
|
118
|
+
self._add_plugin_meta_argument(init_file, field_name, issue)
|
|
119
|
+
if len(self.fixes_applied) > before_count:
|
|
120
|
+
self.fixed_issues.append(issue)
|
|
121
|
+
except Exception as e:
|
|
122
|
+
self.fixes_failed.append(f"修复 PluginMetadata 参数失败: {issue.message} - {e}")
|
|
123
|
+
# 匹配 "缺少必需的类属性" 相关错误
|
|
124
|
+
elif "缺少必需的类属性" in issue.message or "缺少必需元数据字段" in issue.message:
|
|
125
|
+
try:
|
|
126
|
+
# 从消息中提取字段名
|
|
127
|
+
match = re.search(r"[::]\s*(\w+)", issue.message)
|
|
128
|
+
if match:
|
|
129
|
+
field_name = match.group(1)
|
|
130
|
+
file_path = self._resolve_file_path(issue.file_path)
|
|
131
|
+
if file_path and file_path.exists():
|
|
132
|
+
before_count = len(self.fixes_applied)
|
|
133
|
+
self._add_class_attribute(file_path, field_name, issue)
|
|
134
|
+
# 如果修复成功,记录原始问题
|
|
135
|
+
if len(self.fixes_applied) > before_count:
|
|
136
|
+
self.fixed_issues.append(issue)
|
|
137
|
+
except Exception as e:
|
|
138
|
+
self.fixes_failed.append(f"修复元数据字段失败: {issue.message} - {e}")
|
|
139
|
+
|
|
140
|
+
def _fix_missing_component_fields(self, issues: list[ValidationIssue], result: ValidationResult) -> None:
|
|
141
|
+
"""修复组件缺失的字段"""
|
|
142
|
+
for issue in issues:
|
|
143
|
+
if "缺少必需的类属性" in issue.message:
|
|
144
|
+
try:
|
|
145
|
+
# 提取类名和字段名
|
|
146
|
+
# 格式: "组件 MyAction 缺少必需的类属性: action_name"
|
|
147
|
+
match = re.search(r"组件\s+(\w+)\s+缺少必需的类属性[::]\s*(\w+)", issue.message)
|
|
148
|
+
if match:
|
|
149
|
+
class_name = match.group(1)
|
|
150
|
+
field_name = match.group(2)
|
|
151
|
+
file_path = self._resolve_file_path(issue.file_path)
|
|
152
|
+
|
|
153
|
+
if file_path and file_path.exists():
|
|
154
|
+
before_count = len(self.fixes_applied)
|
|
155
|
+
self._add_class_attribute(file_path, field_name, issue, class_name=class_name)
|
|
156
|
+
if len(self.fixes_applied) > before_count:
|
|
157
|
+
self.fixed_issues.append(issue)
|
|
158
|
+
except Exception as e:
|
|
159
|
+
self.fixes_failed.append(f"修复组件字段失败: {issue.message} - {e}")
|
|
160
|
+
|
|
161
|
+
def _fix_missing_methods(self, issues: list[ValidationIssue], result: ValidationResult) -> None:
|
|
162
|
+
"""修复缺失的方法"""
|
|
163
|
+
for issue in issues:
|
|
164
|
+
if "缺少必需的方法" in issue.message:
|
|
165
|
+
try:
|
|
166
|
+
# 提取类名和方法名
|
|
167
|
+
# 格式: "组件 MyAction 缺少必需的方法: execute"
|
|
168
|
+
match = re.search(r"组件\s+(\w+)\s+缺少必需的方法[::]\s*(\w+)", issue.message)
|
|
169
|
+
if match:
|
|
170
|
+
class_name = match.group(1)
|
|
171
|
+
method_name = match.group(2)
|
|
172
|
+
file_path = self._resolve_file_path(issue.file_path)
|
|
173
|
+
|
|
174
|
+
if file_path and file_path.exists():
|
|
175
|
+
before_count = len(self.fixes_applied)
|
|
176
|
+
self._add_method_to_class(file_path, class_name, method_name, issue)
|
|
177
|
+
if len(self.fixes_applied) > before_count:
|
|
178
|
+
self.fixed_issues.append(issue)
|
|
179
|
+
except Exception as e:
|
|
180
|
+
self.fixes_failed.append(f"修复缺失方法失败: {issue.message} - {e}")
|
|
181
|
+
|
|
182
|
+
def _fix_method_signatures(self, issues: list[ValidationIssue], result: ValidationResult) -> None:
|
|
183
|
+
"""修复方法签名问题"""
|
|
184
|
+
for issue in issues:
|
|
185
|
+
# 修复异步方法问题
|
|
186
|
+
if "应该是异步方法" in issue.message or "不应该是异步方法" in issue.message:
|
|
187
|
+
try:
|
|
188
|
+
# 提取类名和方法名
|
|
189
|
+
match = re.search(r"组件\s+(\w+)\s+的方法\s+(\w+)", issue.message)
|
|
190
|
+
if match:
|
|
191
|
+
class_name = match.group(1)
|
|
192
|
+
method_name = match.group(2)
|
|
193
|
+
file_path = self._resolve_file_path(issue.file_path)
|
|
194
|
+
should_be_async = "应该是异步方法" in issue.message
|
|
195
|
+
|
|
196
|
+
if file_path and file_path.exists():
|
|
197
|
+
before_count = len(self.fixes_applied)
|
|
198
|
+
self._fix_method_async(file_path, class_name, method_name, should_be_async, issue)
|
|
199
|
+
if len(self.fixes_applied) > before_count:
|
|
200
|
+
self.fixed_issues.append(issue)
|
|
201
|
+
except Exception as e:
|
|
202
|
+
self.fixes_failed.append(f"修复方法签名失败: {issue.message} - {e}")
|
|
203
|
+
|
|
204
|
+
# 修复参数问题
|
|
205
|
+
elif "缺少必需参数" in issue.message or "参数过多" in issue.message:
|
|
206
|
+
try:
|
|
207
|
+
match = re.search(r"组件\s+(\w+)\s+的方法\s+(\w+)", issue.message)
|
|
208
|
+
if match:
|
|
209
|
+
class_name = match.group(1)
|
|
210
|
+
method_name = match.group(2)
|
|
211
|
+
file_path = self._resolve_file_path(issue.file_path)
|
|
212
|
+
|
|
213
|
+
if file_path and file_path.exists() and issue.suggestion:
|
|
214
|
+
before_count = len(self.fixes_applied)
|
|
215
|
+
self._fix_method_parameters(file_path, class_name, method_name, issue)
|
|
216
|
+
if len(self.fixes_applied) > before_count:
|
|
217
|
+
self.fixed_issues.append(issue)
|
|
218
|
+
except Exception as e:
|
|
219
|
+
self.fixes_failed.append(f"修复方法参数失败: {issue.message} - {e}")
|
|
220
|
+
|
|
221
|
+
def _fix_style_issues(self, issues: list[ValidationIssue], result: ValidationResult) -> None:
|
|
222
|
+
"""修复代码风格问题
|
|
223
|
+
|
|
224
|
+
使用 ruff 自动修复代码风格和格式问题
|
|
225
|
+
"""
|
|
226
|
+
# 检查是否有代码风格问题(ruff 错误格式:字母+数字开头,如 "F401:", "E501:")
|
|
227
|
+
has_style_issues = any(re.match(r"^[A-Z]\d+:", issue.message) for issue in issues)
|
|
228
|
+
|
|
229
|
+
if not has_style_issues:
|
|
230
|
+
return
|
|
231
|
+
|
|
232
|
+
# 检查 ruff 是否安装
|
|
233
|
+
if not self._is_ruff_installed():
|
|
234
|
+
self.fixes_failed.append("未安装 ruff,无法自动修复代码风格问题")
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
# 运行 ruff check --fix
|
|
239
|
+
cmd = ["ruff", "check", "--fix", str(self.plugin_path)]
|
|
240
|
+
subprocess.run(cmd, capture_output=True, text=True, encoding="utf-8", errors="ignore")
|
|
241
|
+
|
|
242
|
+
# 运行 ruff format
|
|
243
|
+
cmd_format = ["ruff", "format", str(self.plugin_path)]
|
|
244
|
+
subprocess.run(cmd_format, capture_output=True, text=True, encoding="utf-8", errors="ignore")
|
|
245
|
+
|
|
246
|
+
self.fixes_applied.append("使用 ruff 自动修复了代码风格问题")
|
|
247
|
+
|
|
248
|
+
except Exception as e:
|
|
249
|
+
self.fixes_failed.append(f"运行 ruff 自动修复失败: {e}")
|
|
250
|
+
|
|
251
|
+
def _is_ruff_installed(self) -> bool:
|
|
252
|
+
"""检查 ruff 是否安装"""
|
|
253
|
+
try:
|
|
254
|
+
subprocess.run(["ruff", "--version"], capture_output=True, check=True, encoding="utf-8", errors="ignore")
|
|
255
|
+
return True
|
|
256
|
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
257
|
+
return False
|
|
258
|
+
|
|
259
|
+
def _add_class_attribute(
|
|
260
|
+
self, file_path: Path, field_name: str, issue: ValidationIssue, class_name: str | None = None
|
|
261
|
+
) -> None:
|
|
262
|
+
"""添加类属性
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
file_path: 文件路径
|
|
266
|
+
field_name: 字段名
|
|
267
|
+
issue: 验证问题
|
|
268
|
+
class_name: 类名(可选)
|
|
269
|
+
"""
|
|
270
|
+
try:
|
|
271
|
+
source = file_path.read_text(encoding="utf-8")
|
|
272
|
+
tree = ast.parse(source)
|
|
273
|
+
|
|
274
|
+
# 查找目标类
|
|
275
|
+
target_class = None
|
|
276
|
+
for node in ast.walk(tree):
|
|
277
|
+
if isinstance(node, ast.ClassDef):
|
|
278
|
+
if class_name is None or node.name == class_name:
|
|
279
|
+
target_class = node
|
|
280
|
+
break
|
|
281
|
+
|
|
282
|
+
if not target_class:
|
|
283
|
+
self.fixes_failed.append(f"未找到类定义: {class_name or '任意类'}")
|
|
284
|
+
return
|
|
285
|
+
|
|
286
|
+
# 使用 libcst 添加属性
|
|
287
|
+
module = cst.parse_module(source)
|
|
288
|
+
transformer = AddClassAttributeTransformer(
|
|
289
|
+
target_class.name, field_name, self._get_default_value_for_field(field_name)
|
|
290
|
+
)
|
|
291
|
+
modified = module.visit(transformer)
|
|
292
|
+
|
|
293
|
+
if transformer.modified:
|
|
294
|
+
file_path.write_text(modified.code, encoding="utf-8")
|
|
295
|
+
self.fixes_applied.append(f"在 {file_path.name} 的类 {target_class.name} 中添加属性 {field_name}")
|
|
296
|
+
else:
|
|
297
|
+
self.fixes_failed.append(f"未能修改类 {target_class.name}")
|
|
298
|
+
|
|
299
|
+
except Exception as e:
|
|
300
|
+
self.fixes_failed.append(f"添加类属性 {field_name} 失败: {e}")
|
|
301
|
+
|
|
302
|
+
def _add_method_to_class(self, file_path: Path, class_name: str, method_name: str, issue: ValidationIssue) -> None:
|
|
303
|
+
"""添加方法到类
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
file_path: 文件路径
|
|
307
|
+
class_name: 类名
|
|
308
|
+
method_name: 方法名
|
|
309
|
+
issue: 验证问题
|
|
310
|
+
"""
|
|
311
|
+
try:
|
|
312
|
+
source = file_path.read_text(encoding="utf-8")
|
|
313
|
+
module = cst.parse_module(source)
|
|
314
|
+
|
|
315
|
+
# 从建议中提取方法模板
|
|
316
|
+
method_template = self._generate_method_template(method_name, issue.suggestion)
|
|
317
|
+
|
|
318
|
+
transformer = AddMethodTransformer(class_name, method_name, method_template)
|
|
319
|
+
modified = module.visit(transformer)
|
|
320
|
+
|
|
321
|
+
if transformer.modified:
|
|
322
|
+
file_path.write_text(modified.code, encoding="utf-8")
|
|
323
|
+
self.fixes_applied.append(f"在 {file_path.name} 的类 {class_name} 中添加方法 {method_name}")
|
|
324
|
+
else:
|
|
325
|
+
self.fixes_failed.append(f"未能在类 {class_name} 中添加方法 {method_name}")
|
|
326
|
+
|
|
327
|
+
except Exception as e:
|
|
328
|
+
self.fixes_failed.append(f"添加方法 {method_name} 失败: {e}")
|
|
329
|
+
|
|
330
|
+
def _fix_method_async(
|
|
331
|
+
self, file_path: Path, class_name: str, method_name: str, should_be_async: bool, issue: ValidationIssue
|
|
332
|
+
) -> None:
|
|
333
|
+
"""修复方法的异步性
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
file_path: 文件路径
|
|
337
|
+
class_name: 类名
|
|
338
|
+
method_name: 方法名
|
|
339
|
+
should_be_async: 是否应该是异步方法
|
|
340
|
+
issue: 验证问题
|
|
341
|
+
"""
|
|
342
|
+
try:
|
|
343
|
+
source = file_path.read_text(encoding="utf-8")
|
|
344
|
+
module = cst.parse_module(source)
|
|
345
|
+
|
|
346
|
+
transformer = FixMethodAsyncTransformer(class_name, method_name, should_be_async)
|
|
347
|
+
modified = module.visit(transformer)
|
|
348
|
+
|
|
349
|
+
if transformer.modified:
|
|
350
|
+
file_path.write_text(modified.code, encoding="utf-8")
|
|
351
|
+
async_str = "异步" if should_be_async else "同步"
|
|
352
|
+
self.fixes_applied.append(f"修复 {file_path.name} 中 {class_name}.{method_name} 为{async_str}方法")
|
|
353
|
+
else:
|
|
354
|
+
self.fixes_failed.append(f"未能修复方法 {class_name}.{method_name}")
|
|
355
|
+
|
|
356
|
+
except Exception as e:
|
|
357
|
+
self.fixes_failed.append(f"修复方法异步性失败: {e}")
|
|
358
|
+
|
|
359
|
+
def _fix_method_parameters(
|
|
360
|
+
self, file_path: Path, class_name: str, method_name: str, issue: ValidationIssue
|
|
361
|
+
) -> None:
|
|
362
|
+
"""修复方法参数
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
file_path: 文件路径
|
|
366
|
+
class_name: 类名
|
|
367
|
+
method_name: 方法名
|
|
368
|
+
issue: 验证问题
|
|
369
|
+
"""
|
|
370
|
+
try:
|
|
371
|
+
# 从建议中提取参数列表
|
|
372
|
+
if not issue.suggestion:
|
|
373
|
+
return
|
|
374
|
+
|
|
375
|
+
# 解析建议中的方法签名
|
|
376
|
+
match = re.search(r"def\s+\w+\(self,\s*([^)]+)\)", issue.suggestion)
|
|
377
|
+
if not match:
|
|
378
|
+
match = re.search(r"应包含[::]\s*([^。\n]+)", issue.suggestion)
|
|
379
|
+
|
|
380
|
+
if not match:
|
|
381
|
+
return
|
|
382
|
+
|
|
383
|
+
params_str = match.group(1).strip()
|
|
384
|
+
|
|
385
|
+
source = file_path.read_text(encoding="utf-8")
|
|
386
|
+
module = cst.parse_module(source)
|
|
387
|
+
|
|
388
|
+
transformer = FixMethodParametersTransformer(class_name, method_name, params_str)
|
|
389
|
+
modified = module.visit(transformer)
|
|
390
|
+
|
|
391
|
+
if transformer.modified:
|
|
392
|
+
file_path.write_text(modified.code, encoding="utf-8")
|
|
393
|
+
self.fixes_applied.append(f"修复 {file_path.name} 中 {class_name}.{method_name} 的参数")
|
|
394
|
+
else:
|
|
395
|
+
self.fixes_failed.append(f"未能修复方法 {class_name}.{method_name} 的参数")
|
|
396
|
+
|
|
397
|
+
except Exception as e:
|
|
398
|
+
self.fixes_failed.append(f"修复方法参数失败: {e}")
|
|
399
|
+
|
|
400
|
+
def _add_plugin_meta_variable(self, file_path: Path, issue: ValidationIssue) -> None:
|
|
401
|
+
"""在 __init__.py 中添加 __plugin_meta__ 变量
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
file_path: __init__.py 文件路径
|
|
405
|
+
issue: 验证问题
|
|
406
|
+
"""
|
|
407
|
+
try:
|
|
408
|
+
source = file_path.read_text(encoding="utf-8")
|
|
409
|
+
|
|
410
|
+
# 检查是否已存在
|
|
411
|
+
if "__plugin_meta__" in source:
|
|
412
|
+
return
|
|
413
|
+
|
|
414
|
+
# 获取插件名称
|
|
415
|
+
plugin_name = self.plugin_path.name
|
|
416
|
+
|
|
417
|
+
# 构建 __plugin_meta__ 定义
|
|
418
|
+
meta_code = """from src.plugin_system.base.plugin_metadata import PluginMetadata
|
|
419
|
+
|
|
420
|
+
__plugin_meta__ = PluginMetadata(
|
|
421
|
+
usage = "unknown",
|
|
422
|
+
name="hello_world_plugin - 副本",
|
|
423
|
+
version="0.1.0",
|
|
424
|
+
author="",
|
|
425
|
+
description="",
|
|
426
|
+
)
|
|
427
|
+
"""
|
|
428
|
+
|
|
429
|
+
# 检查是否已有 PluginMetadata 导入
|
|
430
|
+
has_import = (
|
|
431
|
+
"from src.plugin_system.base.plugin_metadata import PluginMetadata" in source
|
|
432
|
+
or "import src.plugin_system.base.plugin_metadata" in source
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
if has_import:
|
|
436
|
+
# 如果已有导入,只添加变量定义
|
|
437
|
+
meta_code = f'''\n__plugin_meta__ = PluginMetadata(
|
|
438
|
+
name="{plugin_name}",
|
|
439
|
+
version="0.1.0",
|
|
440
|
+
author="",
|
|
441
|
+
description="",
|
|
442
|
+
)
|
|
443
|
+
'''
|
|
444
|
+
|
|
445
|
+
# 在文件开头添加(在 docstring 之后)
|
|
446
|
+
lines = source.split("\n")
|
|
447
|
+
insert_pos = 0
|
|
448
|
+
|
|
449
|
+
# 跳过开头的 docstring
|
|
450
|
+
in_docstring = False
|
|
451
|
+
for i, line in enumerate(lines):
|
|
452
|
+
stripped = line.strip()
|
|
453
|
+
if i == 0 and (stripped.startswith('"""') or stripped.startswith("'''")):
|
|
454
|
+
in_docstring = True
|
|
455
|
+
if stripped.count('"""') >= 2 or stripped.count("'''") >= 2:
|
|
456
|
+
insert_pos = i + 1
|
|
457
|
+
break
|
|
458
|
+
elif in_docstring and ('"""' in line or "'''" in line):
|
|
459
|
+
insert_pos = i + 1
|
|
460
|
+
break
|
|
461
|
+
elif not in_docstring:
|
|
462
|
+
insert_pos = i
|
|
463
|
+
break
|
|
464
|
+
|
|
465
|
+
# 插入代码
|
|
466
|
+
lines.insert(insert_pos, meta_code)
|
|
467
|
+
new_source = "\n".join(lines)
|
|
468
|
+
|
|
469
|
+
file_path.write_text(new_source, encoding="utf-8")
|
|
470
|
+
self.fixes_applied.append(f"在 {file_path.name} 中添加 __plugin_meta__ 变量")
|
|
471
|
+
|
|
472
|
+
except Exception as e:
|
|
473
|
+
self.fixes_failed.append(f"添加 __plugin_meta__ 变量失败: {e}")
|
|
474
|
+
|
|
475
|
+
def _add_plugin_meta_argument(self, file_path: Path, arg_name: str, issue: ValidationIssue) -> None:
|
|
476
|
+
"""在 PluginMetadata 调用中添加缺失的参数
|
|
477
|
+
|
|
478
|
+
Args:
|
|
479
|
+
file_path: __init__.py 文件路径
|
|
480
|
+
arg_name: 参数名
|
|
481
|
+
issue: 验证问题
|
|
482
|
+
"""
|
|
483
|
+
try:
|
|
484
|
+
source = file_path.read_text(encoding="utf-8")
|
|
485
|
+
module = cst.parse_module(source)
|
|
486
|
+
|
|
487
|
+
# 获取参数的默认值
|
|
488
|
+
arg_value = self._get_default_value_for_metadata_field(arg_name)
|
|
489
|
+
|
|
490
|
+
transformer = AddCallArgumentTransformer(
|
|
491
|
+
variable_name="__plugin_meta__", function_name="PluginMetadata", arg_name=arg_name, arg_value=arg_value
|
|
492
|
+
)
|
|
493
|
+
modified = module.visit(transformer)
|
|
494
|
+
|
|
495
|
+
if transformer.modified:
|
|
496
|
+
file_path.write_text(modified.code, encoding="utf-8")
|
|
497
|
+
self.fixes_applied.append(f"在 PluginMetadata 中添加参数 {arg_name}={arg_value}")
|
|
498
|
+
else:
|
|
499
|
+
self.fixes_failed.append(f"未能在 PluginMetadata 中添加参数 {arg_name}")
|
|
500
|
+
|
|
501
|
+
except Exception as e:
|
|
502
|
+
self.fixes_failed.append(f"添加 PluginMetadata 参数 {arg_name} 失败: {e}")
|
|
503
|
+
|
|
504
|
+
def _fix_method_return_type(
|
|
505
|
+
self, file_path: Path, class_name: str, method_name: str, expected_type: str, issue: ValidationIssue
|
|
506
|
+
) -> None:
|
|
507
|
+
"""修复方法的返回类型注解
|
|
508
|
+
|
|
509
|
+
Args:
|
|
510
|
+
file_path: 文件路径
|
|
511
|
+
class_name: 类名
|
|
512
|
+
method_name: 方法名
|
|
513
|
+
expected_type: 预期的返回类型
|
|
514
|
+
issue: 验证问题
|
|
515
|
+
"""
|
|
516
|
+
try:
|
|
517
|
+
source = file_path.read_text(encoding="utf-8")
|
|
518
|
+
module = cst.parse_module(source)
|
|
519
|
+
|
|
520
|
+
transformer = FixReturnTypeTransformer(class_name, method_name, expected_type)
|
|
521
|
+
modified = module.visit(transformer)
|
|
522
|
+
|
|
523
|
+
if transformer.modified:
|
|
524
|
+
file_path.write_text(modified.code, encoding="utf-8")
|
|
525
|
+
self.fixes_applied.append(
|
|
526
|
+
f"修复 {file_path.name} 中 {class_name}.{method_name} 的返回类型注解为 {expected_type}"
|
|
527
|
+
)
|
|
528
|
+
else:
|
|
529
|
+
self.fixes_failed.append(f"未能修复方法 {class_name}.{method_name} 的返回类型")
|
|
530
|
+
|
|
531
|
+
except Exception as e:
|
|
532
|
+
self.fixes_failed.append(f"修复返回类型注解失败: {e}")
|
|
533
|
+
|
|
534
|
+
def _resolve_file_path(self, relative_path: str | None) -> Path | None:
|
|
535
|
+
"""解析相对文件路径为绝对路径
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
relative_path: 相对路径
|
|
539
|
+
|
|
540
|
+
Returns:
|
|
541
|
+
绝对路径
|
|
542
|
+
"""
|
|
543
|
+
if not relative_path:
|
|
544
|
+
return None
|
|
545
|
+
|
|
546
|
+
# 移除插件名前缀
|
|
547
|
+
parts = relative_path.split("/")
|
|
548
|
+
if len(parts) > 1:
|
|
549
|
+
relative_path = "/".join(parts[1:])
|
|
550
|
+
|
|
551
|
+
return self.plugin_path / relative_path
|
|
552
|
+
|
|
553
|
+
def _get_default_value_for_field(self, field_name: str) -> str:
|
|
554
|
+
"""获取字段的默认值
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
field_name: 字段名
|
|
558
|
+
|
|
559
|
+
Returns:
|
|
560
|
+
默认值字符串
|
|
561
|
+
"""
|
|
562
|
+
# 根据字段名推断默认值
|
|
563
|
+
name_fields = [
|
|
564
|
+
"name",
|
|
565
|
+
"action_name",
|
|
566
|
+
"command_name",
|
|
567
|
+
"handler_name",
|
|
568
|
+
"adapter_name",
|
|
569
|
+
"prompt_name",
|
|
570
|
+
"chatter_name",
|
|
571
|
+
"component_name",
|
|
572
|
+
]
|
|
573
|
+
desc_fields = [
|
|
574
|
+
"description",
|
|
575
|
+
"action_description",
|
|
576
|
+
"command_description",
|
|
577
|
+
"handler_description",
|
|
578
|
+
"adapter_description",
|
|
579
|
+
"chatter_description",
|
|
580
|
+
"component_description",
|
|
581
|
+
]
|
|
582
|
+
|
|
583
|
+
if field_name in name_fields:
|
|
584
|
+
return f'"{field_name.replace("_", " ").title()}"'
|
|
585
|
+
elif field_name in desc_fields:
|
|
586
|
+
return '"待完善的描述"'
|
|
587
|
+
elif "version" in field_name.lower():
|
|
588
|
+
return '"0.1.0"'
|
|
589
|
+
elif "author" in field_name.lower():
|
|
590
|
+
return '""'
|
|
591
|
+
else:
|
|
592
|
+
return '""'
|
|
593
|
+
|
|
594
|
+
def _get_default_value_for_metadata_field(self, field_name: str) -> str:
|
|
595
|
+
"""获取 PluginMetadata 字段的默认值
|
|
596
|
+
|
|
597
|
+
Args:
|
|
598
|
+
field_name: 字段名
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
默认值字符串
|
|
602
|
+
"""
|
|
603
|
+
# 获取插件名称
|
|
604
|
+
plugin_name = self.plugin_path.name
|
|
605
|
+
|
|
606
|
+
# 根据字段名返回默认值
|
|
607
|
+
if field_name == "name":
|
|
608
|
+
return f'"{plugin_name}"'
|
|
609
|
+
elif field_name == "description":
|
|
610
|
+
return f'"{plugin_name} 插件"'
|
|
611
|
+
elif field_name == "usage":
|
|
612
|
+
return '"待完善"'
|
|
613
|
+
elif field_name == "version":
|
|
614
|
+
return '"0.1.0"'
|
|
615
|
+
elif field_name == "author":
|
|
616
|
+
return '""'
|
|
617
|
+
elif field_name == "license":
|
|
618
|
+
return '"MIT"'
|
|
619
|
+
else:
|
|
620
|
+
return '""'
|
|
621
|
+
|
|
622
|
+
def _generate_method_template(self, method_name: str, suggestion: str | None) -> str:
|
|
623
|
+
"""生成方法模板
|
|
624
|
+
|
|
625
|
+
Args:
|
|
626
|
+
method_name: 方法名
|
|
627
|
+
suggestion: 建议信息
|
|
628
|
+
|
|
629
|
+
Returns:
|
|
630
|
+
方法代码模板
|
|
631
|
+
"""
|
|
632
|
+
# 从建议中提取方法签名
|
|
633
|
+
if suggestion and "def " in suggestion:
|
|
634
|
+
lines = suggestion.split("\n")
|
|
635
|
+
for line in lines:
|
|
636
|
+
if "def " in line:
|
|
637
|
+
return line.strip()
|
|
638
|
+
|
|
639
|
+
# 默认模板
|
|
640
|
+
common_async_methods = ["execute", "go_activate", "from_platform_message"]
|
|
641
|
+
is_async = method_name in common_async_methods
|
|
642
|
+
|
|
643
|
+
async_prefix = "async " if is_async else ""
|
|
644
|
+
|
|
645
|
+
# 根据方法名推断参数
|
|
646
|
+
if method_name == "execute":
|
|
647
|
+
return f'{async_prefix}def execute(self):\n """执行方法"""\n raise NotImplementedError'
|
|
648
|
+
elif method_name == "go_activate":
|
|
649
|
+
return f'{async_prefix}def go_activate(self, llm_judge_model=None):\n """激活判断"""\n return True'
|
|
650
|
+
elif method_name == "from_platform_message":
|
|
651
|
+
return f'{async_prefix}def from_platform_message(self, raw):\n """转换平台消息"""\n raise NotImplementedError'
|
|
652
|
+
elif method_name == "register_endpoints":
|
|
653
|
+
return 'def register_endpoints(self):\n """注册端点"""\n pass'
|
|
654
|
+
else:
|
|
655
|
+
return f'{async_prefix}def {method_name}(self):\n """TODO: 添加方法说明"""\n raise NotImplementedError'
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
# ============== libcst Transformers ==============
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
class AddCallArgumentTransformer(cst.CSTTransformer):
|
|
662
|
+
"""在函数调用中添加参数的转换器"""
|
|
663
|
+
|
|
664
|
+
def __init__(self, variable_name: str, function_name: str, arg_name: str, arg_value: str):
|
|
665
|
+
self.variable_name = variable_name
|
|
666
|
+
self.function_name = function_name
|
|
667
|
+
self.arg_name = arg_name
|
|
668
|
+
self.arg_value = arg_value
|
|
669
|
+
self.modified = False
|
|
670
|
+
|
|
671
|
+
def leave_SimpleStatementLine( # noqa: N802
|
|
672
|
+
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
|
|
673
|
+
) -> cst.SimpleStatementLine:
|
|
674
|
+
"""修改赋值语句中的函数调用"""
|
|
675
|
+
new_body = []
|
|
676
|
+
|
|
677
|
+
for statement in updated_node.body:
|
|
678
|
+
# 处理普通赋值
|
|
679
|
+
if isinstance(statement, cst.Assign):
|
|
680
|
+
for target in statement.targets:
|
|
681
|
+
if isinstance(target.target, cst.Name) and target.target.value == self.variable_name:
|
|
682
|
+
# 找到目标变量,修改其值
|
|
683
|
+
new_value = self._add_argument_to_call(statement.value)
|
|
684
|
+
if new_value is not None:
|
|
685
|
+
statement = statement.with_changes(value=new_value)
|
|
686
|
+
self.modified = True
|
|
687
|
+
|
|
688
|
+
# 处理带类型注解的赋值
|
|
689
|
+
elif isinstance(statement, cst.AnnAssign):
|
|
690
|
+
if isinstance(statement.target, cst.Name) and statement.target.value == self.variable_name:
|
|
691
|
+
if statement.value:
|
|
692
|
+
new_value = self._add_argument_to_call(statement.value)
|
|
693
|
+
if new_value is not None:
|
|
694
|
+
statement = statement.with_changes(value=new_value)
|
|
695
|
+
self.modified = True
|
|
696
|
+
|
|
697
|
+
new_body.append(statement)
|
|
698
|
+
|
|
699
|
+
return updated_node.with_changes(body=new_body)
|
|
700
|
+
|
|
701
|
+
def _add_argument_to_call(self, node: cst.BaseExpression) -> cst.BaseExpression | None:
|
|
702
|
+
"""在函数调用中添加参数"""
|
|
703
|
+
if not isinstance(node, cst.Call):
|
|
704
|
+
return None
|
|
705
|
+
|
|
706
|
+
# 检查函数名
|
|
707
|
+
func_name = None
|
|
708
|
+
if isinstance(node.func, cst.Name):
|
|
709
|
+
func_name = node.func.value
|
|
710
|
+
elif isinstance(node.func, cst.Attribute):
|
|
711
|
+
func_name = node.func.attr.value
|
|
712
|
+
|
|
713
|
+
if func_name != self.function_name:
|
|
714
|
+
return None
|
|
715
|
+
|
|
716
|
+
# 检查参数是否已存在
|
|
717
|
+
for arg in node.args:
|
|
718
|
+
if arg.keyword and arg.keyword.value == self.arg_name:
|
|
719
|
+
return None # 参数已存在
|
|
720
|
+
|
|
721
|
+
# 创建新参数
|
|
722
|
+
new_arg = cst.Arg(
|
|
723
|
+
keyword=cst.Name(self.arg_name),
|
|
724
|
+
value=cst.parse_expression(self.arg_value),
|
|
725
|
+
equal=cst.AssignEqual(
|
|
726
|
+
whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace("")
|
|
727
|
+
),
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
# 添加参数到列表
|
|
731
|
+
new_args = list(node.args) + [new_arg]
|
|
732
|
+
|
|
733
|
+
return node.with_changes(args=new_args)
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
class AddClassAttributeTransformer(cst.CSTTransformer):
|
|
737
|
+
"""添加类属性的转换器"""
|
|
738
|
+
|
|
739
|
+
def __init__(self, class_name: str, attr_name: str, attr_value: str):
|
|
740
|
+
self.class_name = class_name
|
|
741
|
+
self.attr_name = attr_name
|
|
742
|
+
self.attr_value = attr_value
|
|
743
|
+
self.modified = False
|
|
744
|
+
|
|
745
|
+
def leave_ClassDef( # noqa: N802
|
|
746
|
+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
|
747
|
+
) -> cst.ClassDef:
|
|
748
|
+
"""在类定义中添加属性"""
|
|
749
|
+
if updated_node.name.value != self.class_name:
|
|
750
|
+
return updated_node
|
|
751
|
+
|
|
752
|
+
# 检查属性是否已存在
|
|
753
|
+
for stmt in updated_node.body.body:
|
|
754
|
+
if isinstance(stmt, cst.SimpleStatementLine):
|
|
755
|
+
for s in stmt.body:
|
|
756
|
+
if isinstance(s, cst.Assign | cst.AnnAssign):
|
|
757
|
+
target = s.targets[0].target if isinstance(s, cst.Assign) else s.target
|
|
758
|
+
if isinstance(target, cst.Name) and target.value == self.attr_name:
|
|
759
|
+
return updated_node # 属性已存在
|
|
760
|
+
|
|
761
|
+
# 创建新的赋值语句
|
|
762
|
+
new_assignment = cst.SimpleStatementLine(
|
|
763
|
+
body=[
|
|
764
|
+
cst.Assign(
|
|
765
|
+
targets=[cst.AssignTarget(target=cst.Name(self.attr_name))],
|
|
766
|
+
value=cst.parse_expression(self.attr_value),
|
|
767
|
+
)
|
|
768
|
+
]
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
# 插入到类体开头(在 docstring 之后)
|
|
772
|
+
body_list = list(updated_node.body.body)
|
|
773
|
+
insert_pos = 0
|
|
774
|
+
|
|
775
|
+
# 跳过 docstring
|
|
776
|
+
if body_list and isinstance(body_list[0], cst.SimpleStatementLine):
|
|
777
|
+
first_stmt = body_list[0].body[0]
|
|
778
|
+
if isinstance(first_stmt, cst.Expr) and isinstance(
|
|
779
|
+
first_stmt.value, cst.SimpleString | cst.ConcatenatedString
|
|
780
|
+
):
|
|
781
|
+
insert_pos = 1
|
|
782
|
+
|
|
783
|
+
body_list.insert(insert_pos, new_assignment)
|
|
784
|
+
|
|
785
|
+
self.modified = True
|
|
786
|
+
return updated_node.with_changes(body=updated_node.body.with_changes(body=body_list))
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
class AddMethodTransformer(cst.CSTTransformer):
|
|
790
|
+
"""添加方法的转换器"""
|
|
791
|
+
|
|
792
|
+
def __init__(self, class_name: str, method_name: str, method_template: str):
|
|
793
|
+
self.class_name = class_name
|
|
794
|
+
self.method_name = method_name
|
|
795
|
+
self.method_template = method_template
|
|
796
|
+
self.modified = False
|
|
797
|
+
|
|
798
|
+
def leave_ClassDef( # noqa: N802
|
|
799
|
+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
|
800
|
+
) -> cst.ClassDef:
|
|
801
|
+
"""在类中添加方法"""
|
|
802
|
+
if updated_node.name.value != self.class_name:
|
|
803
|
+
return updated_node
|
|
804
|
+
|
|
805
|
+
# 检查方法是否已存在
|
|
806
|
+
for stmt in updated_node.body.body:
|
|
807
|
+
if isinstance(stmt, cst.FunctionDef):
|
|
808
|
+
if stmt.name.value == self.method_name:
|
|
809
|
+
return updated_node # 方法已存在
|
|
810
|
+
|
|
811
|
+
# 解析方法模板
|
|
812
|
+
try:
|
|
813
|
+
# 将模板包装成完整的类来解析
|
|
814
|
+
full_code = f"class Temp:\n {self.method_template}"
|
|
815
|
+
temp_module = cst.parse_module(full_code)
|
|
816
|
+
temp_class = temp_module.body[0]
|
|
817
|
+
if isinstance(temp_class, cst.ClassDef):
|
|
818
|
+
new_method = temp_class.body.body[0]
|
|
819
|
+
else:
|
|
820
|
+
return updated_node
|
|
821
|
+
|
|
822
|
+
# 添加到类体末尾
|
|
823
|
+
body_list = list(updated_node.body.body)
|
|
824
|
+
body_list.append(new_method)
|
|
825
|
+
|
|
826
|
+
self.modified = True
|
|
827
|
+
return updated_node.with_changes(body=updated_node.body.with_changes(body=body_list))
|
|
828
|
+
except Exception:
|
|
829
|
+
return updated_node
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
class FixMethodAsyncTransformer(cst.CSTTransformer):
|
|
833
|
+
"""修复方法异步性的转换器"""
|
|
834
|
+
|
|
835
|
+
def __init__(self, class_name: str, method_name: str, should_be_async: bool):
|
|
836
|
+
self.class_name = class_name
|
|
837
|
+
self.method_name = method_name
|
|
838
|
+
self.should_be_async = should_be_async
|
|
839
|
+
self.modified = False
|
|
840
|
+
self.in_target_class = False
|
|
841
|
+
|
|
842
|
+
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: N802
|
|
843
|
+
if node.name.value == self.class_name:
|
|
844
|
+
self.in_target_class = True
|
|
845
|
+
|
|
846
|
+
def leave_ClassDef( # noqa: N802
|
|
847
|
+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
|
848
|
+
) -> cst.ClassDef:
|
|
849
|
+
if original_node.name.value == self.class_name:
|
|
850
|
+
self.in_target_class = False
|
|
851
|
+
return updated_node
|
|
852
|
+
|
|
853
|
+
def leave_FunctionDef( # noqa: N802
|
|
854
|
+
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
|
|
855
|
+
) -> cst.FunctionDef:
|
|
856
|
+
"""修改函数定义"""
|
|
857
|
+
if not self.in_target_class or updated_node.name.value != self.method_name:
|
|
858
|
+
return updated_node
|
|
859
|
+
|
|
860
|
+
if self.should_be_async:
|
|
861
|
+
# 转换为异步函数
|
|
862
|
+
self.modified = True
|
|
863
|
+
return cst.FunctionDef(
|
|
864
|
+
name=updated_node.name,
|
|
865
|
+
params=updated_node.params,
|
|
866
|
+
body=updated_node.body,
|
|
867
|
+
decorators=updated_node.decorators,
|
|
868
|
+
returns=updated_node.returns,
|
|
869
|
+
asynchronous=cst.Asynchronous(whitespace_after=cst.SimpleWhitespace(" ")),
|
|
870
|
+
)
|
|
871
|
+
else:
|
|
872
|
+
# 转换为同步函数(移除 async)
|
|
873
|
+
if isinstance(updated_node, cst.FunctionDef) and updated_node.asynchronous:
|
|
874
|
+
self.modified = True
|
|
875
|
+
return updated_node.with_changes(asynchronous=None)
|
|
876
|
+
|
|
877
|
+
return updated_node
|
|
878
|
+
|
|
879
|
+
|
|
880
|
+
class FixReturnTypeTransformer(cst.CSTTransformer):
|
|
881
|
+
"""修复方法返回类型的转换器"""
|
|
882
|
+
|
|
883
|
+
def __init__(self, class_name: str, method_name: str, return_type: str):
|
|
884
|
+
self.class_name = class_name
|
|
885
|
+
self.method_name = method_name
|
|
886
|
+
self.return_type = return_type
|
|
887
|
+
self.modified = False
|
|
888
|
+
self.in_target_class = False
|
|
889
|
+
|
|
890
|
+
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: N802
|
|
891
|
+
if node.name.value == self.class_name:
|
|
892
|
+
self.in_target_class = True
|
|
893
|
+
|
|
894
|
+
def leave_ClassDef( # noqa: N802
|
|
895
|
+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
|
896
|
+
) -> cst.ClassDef:
|
|
897
|
+
if original_node.name.value == self.class_name:
|
|
898
|
+
self.in_target_class = False
|
|
899
|
+
return updated_node
|
|
900
|
+
|
|
901
|
+
def leave_FunctionDef( # noqa: N802
|
|
902
|
+
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
|
|
903
|
+
) -> cst.FunctionDef:
|
|
904
|
+
"""修改函数返回类型"""
|
|
905
|
+
if not self.in_target_class or updated_node.name.value != self.method_name:
|
|
906
|
+
return updated_node
|
|
907
|
+
|
|
908
|
+
try:
|
|
909
|
+
# 创建新的返回类型注解
|
|
910
|
+
new_annotation = cst.Annotation(annotation=cst.parse_expression(self.return_type))
|
|
911
|
+
|
|
912
|
+
self.modified = True
|
|
913
|
+
return updated_node.with_changes(returns=new_annotation)
|
|
914
|
+
except Exception:
|
|
915
|
+
return updated_node
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
class FixMethodParametersTransformer(cst.CSTTransformer):
|
|
919
|
+
"""修复方法参数的转换器"""
|
|
920
|
+
|
|
921
|
+
def __init__(self, class_name: str, method_name: str, params_str: str):
|
|
922
|
+
self.class_name = class_name
|
|
923
|
+
self.method_name = method_name
|
|
924
|
+
self.params_str = params_str
|
|
925
|
+
self.modified = False
|
|
926
|
+
self.in_target_class = False
|
|
927
|
+
|
|
928
|
+
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: N802
|
|
929
|
+
if node.name.value == self.class_name:
|
|
930
|
+
self.in_target_class = True
|
|
931
|
+
|
|
932
|
+
def leave_ClassDef( # noqa: N802
|
|
933
|
+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
|
934
|
+
) -> cst.ClassDef:
|
|
935
|
+
if original_node.name.value == self.class_name:
|
|
936
|
+
self.in_target_class = False
|
|
937
|
+
return updated_node
|
|
938
|
+
|
|
939
|
+
def leave_FunctionDef( # noqa: N802
|
|
940
|
+
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
|
|
941
|
+
) -> cst.FunctionDef:
|
|
942
|
+
"""修改函数参数"""
|
|
943
|
+
if not self.in_target_class or updated_node.name.value != self.method_name:
|
|
944
|
+
return updated_node
|
|
945
|
+
|
|
946
|
+
try:
|
|
947
|
+
# 解析参数
|
|
948
|
+
param_list = [p.strip() for p in self.params_str.split(",")]
|
|
949
|
+
new_params = [cst.Param(name=cst.Name("self"))]
|
|
950
|
+
|
|
951
|
+
for param in param_list:
|
|
952
|
+
if not param:
|
|
953
|
+
continue
|
|
954
|
+
|
|
955
|
+
# 解析参数(可能包含类型注解和默认值)
|
|
956
|
+
if ":" in param:
|
|
957
|
+
parts = param.split(":")
|
|
958
|
+
param_name = parts[0].strip()
|
|
959
|
+
type_and_default = parts[1].strip()
|
|
960
|
+
|
|
961
|
+
if "=" in type_and_default:
|
|
962
|
+
type_part, default_part = type_and_default.split("=", 1)
|
|
963
|
+
new_params.append(
|
|
964
|
+
cst.Param(
|
|
965
|
+
name=cst.Name(param_name),
|
|
966
|
+
annotation=cst.Annotation(annotation=cst.parse_expression(type_part.strip())),
|
|
967
|
+
default=cst.parse_expression(default_part.strip()),
|
|
968
|
+
)
|
|
969
|
+
)
|
|
970
|
+
else:
|
|
971
|
+
new_params.append(
|
|
972
|
+
cst.Param(
|
|
973
|
+
name=cst.Name(param_name),
|
|
974
|
+
annotation=cst.Annotation(annotation=cst.parse_expression(type_and_default)),
|
|
975
|
+
)
|
|
976
|
+
)
|
|
977
|
+
else:
|
|
978
|
+
param_name = param.split("=")[0].strip()
|
|
979
|
+
if "=" in param:
|
|
980
|
+
default_val = param.split("=")[1].strip()
|
|
981
|
+
new_params.append(
|
|
982
|
+
cst.Param(name=cst.Name(param_name), default=cst.parse_expression(default_val))
|
|
983
|
+
)
|
|
984
|
+
else:
|
|
985
|
+
new_params.append(cst.Param(name=cst.Name(param_name)))
|
|
986
|
+
|
|
987
|
+
self.modified = True
|
|
988
|
+
return updated_node.with_changes(params=updated_node.params.with_changes(params=new_params))
|
|
989
|
+
except Exception:
|
|
990
|
+
return updated_node
|