mofox-plugin-dev-toolkit 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mofox_plugin_dev_toolkit-0.3.3.dist-info/METADATA +730 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/RECORD +46 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/WHEEL +5 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/entry_points.txt +2 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/licenses/LICENSE +674 -0
- mofox_plugin_dev_toolkit-0.3.3.dist-info/top_level.txt +1 -0
- mpdt/__init__.py +15 -0
- mpdt/__main__.py +8 -0
- mpdt/cli.py +316 -0
- mpdt/commands/__init__.py +9 -0
- mpdt/commands/check.py +498 -0
- mpdt/commands/dev.py +318 -0
- mpdt/commands/generate.py +448 -0
- mpdt/commands/init.py +686 -0
- mpdt/dev/bridge_plugin/__init__.py +17 -0
- mpdt/dev/bridge_plugin/cleanup_handler.py +65 -0
- mpdt/dev/bridge_plugin/dev_config.py +24 -0
- mpdt/dev/bridge_plugin/file_watcher.py +169 -0
- mpdt/dev/bridge_plugin/plugin.py +219 -0
- mpdt/templates/__init__.py +165 -0
- mpdt/templates/action_template.py +102 -0
- mpdt/templates/adapter_template.py +129 -0
- mpdt/templates/chatter_template.py +103 -0
- mpdt/templates/event_template.py +116 -0
- mpdt/templates/plus_command_template.py +150 -0
- mpdt/templates/prompt_template.py +92 -0
- mpdt/templates/router_template.py +175 -0
- mpdt/templates/tool_template.py +98 -0
- mpdt/utils/__init__.py +10 -0
- mpdt/utils/code_parser.py +401 -0
- mpdt/utils/color_printer.py +99 -0
- mpdt/utils/config_loader.py +171 -0
- mpdt/utils/config_manager.py +297 -0
- mpdt/utils/file_ops.py +207 -0
- mpdt/utils/license_generator.py +980 -0
- mpdt/utils/plugin_parser.py +195 -0
- mpdt/utils/template_engine.py +112 -0
- mpdt/validators/__init__.py +26 -0
- mpdt/validators/auto_fix_validator.py +990 -0
- mpdt/validators/base.py +129 -0
- mpdt/validators/component_validator.py +842 -0
- mpdt/validators/config_validator.py +119 -0
- mpdt/validators/metadata_validator.py +107 -0
- mpdt/validators/structure_validator.py +72 -0
- mpdt/validators/style_validator.py +117 -0
- mpdt/validators/type_validator.py +206 -0
|
@@ -0,0 +1,448 @@
|
|
|
1
|
+
"""
|
|
2
|
+
代码生成命令实现
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import libcst as cst
|
|
9
|
+
import questionary
|
|
10
|
+
|
|
11
|
+
from mpdt.templates import prepare_component_context
|
|
12
|
+
from mpdt.utils.code_parser import CodeParser
|
|
13
|
+
from mpdt.utils.color_printer import (
|
|
14
|
+
console,
|
|
15
|
+
print_error,
|
|
16
|
+
print_step,
|
|
17
|
+
print_success,
|
|
18
|
+
print_warning,
|
|
19
|
+
)
|
|
20
|
+
from mpdt.utils.file_ops import (
|
|
21
|
+
ensure_dir,
|
|
22
|
+
get_git_user_info,
|
|
23
|
+
safe_write_file,
|
|
24
|
+
to_snake_case,
|
|
25
|
+
validate_component_name,
|
|
26
|
+
)
|
|
27
|
+
from mpdt.utils.plugin_parser import extract_plugin_name
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def generate_component(
|
|
31
|
+
component_type: str | None = None,
|
|
32
|
+
component_name: str | None = None,
|
|
33
|
+
description: str | None = None,
|
|
34
|
+
output_dir: str | None = None,
|
|
35
|
+
force: bool = False,
|
|
36
|
+
verbose: bool = False,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""
|
|
39
|
+
生成插件组件(始终生成异步方法)
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
component_type: 组件类型 (None 表示交互式询问)
|
|
43
|
+
component_name: 组件名称 (None 表示交互式询问)
|
|
44
|
+
description: 组件描述
|
|
45
|
+
output_dir: 输出目录
|
|
46
|
+
force: 是否覆盖
|
|
47
|
+
verbose: 详细输出
|
|
48
|
+
"""
|
|
49
|
+
# 确定工作目录
|
|
50
|
+
if output_dir:
|
|
51
|
+
work_dir = Path(output_dir)
|
|
52
|
+
else:
|
|
53
|
+
work_dir = Path.cwd()
|
|
54
|
+
|
|
55
|
+
# 先检查是否在插件目录中,避免用户填完信息后才报错
|
|
56
|
+
plugin_name = _detect_plugin_name(work_dir)
|
|
57
|
+
if not plugin_name:
|
|
58
|
+
print_error("未检测到插件目录!请在插件根目录下运行此命令")
|
|
59
|
+
print_warning("提示: 插件目录应包含 plugin.py 文件")
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
if verbose:
|
|
63
|
+
console.print(f"[dim]检测到插件: {plugin_name}[/dim]")
|
|
64
|
+
|
|
65
|
+
# 交互式获取组件信息
|
|
66
|
+
use_components_folder = True # 默认使用 components 文件夹
|
|
67
|
+
if not component_type or not component_name:
|
|
68
|
+
component_info = _interactive_generate()
|
|
69
|
+
component_type = component_info["component_type"]
|
|
70
|
+
component_name = component_info["component_name"]
|
|
71
|
+
description = component_info.get("description") or description
|
|
72
|
+
use_components_folder = component_info.get("use_components_folder", True)
|
|
73
|
+
force = component_info.get("force", force)
|
|
74
|
+
|
|
75
|
+
# 此时 component_type 和 component_name 必定不为 None
|
|
76
|
+
assert component_type is not None
|
|
77
|
+
assert component_name is not None
|
|
78
|
+
|
|
79
|
+
print_step(f"生成 {component_type.upper()} 组件: {component_name}")
|
|
80
|
+
|
|
81
|
+
# 验证组件名称
|
|
82
|
+
if not validate_component_name(component_name):
|
|
83
|
+
print_error("组件名称无效!必须使用小写字母、数字和下划线,以字母开头")
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
# 确保组件名称为 snake_case
|
|
87
|
+
component_name = to_snake_case(component_name)
|
|
88
|
+
|
|
89
|
+
# 标准化组件类型(命令行参数 plus-command -> plus_command)
|
|
90
|
+
normalized_type = component_type.replace("-", "_")
|
|
91
|
+
|
|
92
|
+
# 准备上下文
|
|
93
|
+
git_info = get_git_user_info()
|
|
94
|
+
context = prepare_component_context(
|
|
95
|
+
component_type=normalized_type,
|
|
96
|
+
component_name=component_name,
|
|
97
|
+
plugin_name=plugin_name,
|
|
98
|
+
author=git_info.get("name", ""),
|
|
99
|
+
description=description or f"{component_name} 组件",
|
|
100
|
+
is_async=True, # 始终生成异步方法
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# 生成组件文件
|
|
104
|
+
component_file = _generate_component_file(
|
|
105
|
+
work_dir=work_dir,
|
|
106
|
+
component_type=normalized_type, # 使用标准化的类型
|
|
107
|
+
component_name=component_name,
|
|
108
|
+
context=context,
|
|
109
|
+
force=force,
|
|
110
|
+
verbose=verbose,
|
|
111
|
+
use_components_folder=use_components_folder,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if not component_file:
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
# 更新插件注册
|
|
118
|
+
if not _update_plugin_registration(
|
|
119
|
+
work_dir=work_dir,
|
|
120
|
+
component_type=normalized_type, # 使用标准化的类型
|
|
121
|
+
component_name=component_name,
|
|
122
|
+
context=context,
|
|
123
|
+
verbose=verbose,
|
|
124
|
+
use_components_folder=use_components_folder,
|
|
125
|
+
):
|
|
126
|
+
print_warning("⚠️ 自动更新插件注册失败,请手动添加到 plugin.py")
|
|
127
|
+
|
|
128
|
+
# 打印成功信息
|
|
129
|
+
print_success(f"✨ {context['class_name']} 生成成功!")
|
|
130
|
+
console.print("\n[bold cyan]生成的文件:[/bold cyan]")
|
|
131
|
+
console.print(f" 📄 {component_file.relative_to(work_dir)}")
|
|
132
|
+
|
|
133
|
+
console.print("\n[bold cyan]下一步:[/bold cyan]")
|
|
134
|
+
console.print(f" 1. 编辑 {component_file.name} 实现具体逻辑")
|
|
135
|
+
console.print(" 2. 运行 mpdt check 检查代码")
|
|
136
|
+
console.print(" 3. 运行 mpdt test 测试功能")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _interactive_generate() -> dict[str, Any]:
|
|
140
|
+
"""交互式生成组件"""
|
|
141
|
+
console.print("\n[bold cyan]🔧 组件生成向导[/bold cyan]\n")
|
|
142
|
+
|
|
143
|
+
answers = questionary.form(
|
|
144
|
+
component_type=questionary.select(
|
|
145
|
+
"选择组件类型:",
|
|
146
|
+
choices=[
|
|
147
|
+
questionary.Choice("Action 组件", value="action"),
|
|
148
|
+
questionary.Choice("Tool 组件", value="tool"),
|
|
149
|
+
questionary.Choice("Event 事件", value="event"),
|
|
150
|
+
questionary.Choice("Adapter 适配器", value="adapter"),
|
|
151
|
+
questionary.Choice("Prompt 提示词", value="prompt"),
|
|
152
|
+
questionary.Choice("Plus Command 命令", value="plus-command"),
|
|
153
|
+
questionary.Choice("Chatter 聊天组件", value="chatter"),
|
|
154
|
+
questionary.Choice("Router 路由组件", value="router"),
|
|
155
|
+
],
|
|
156
|
+
),
|
|
157
|
+
component_name=questionary.text(
|
|
158
|
+
"组件名称 (使用下划线命名):",
|
|
159
|
+
validate=lambda x: validate_component_name(x) or "组件名称格式无效",
|
|
160
|
+
),
|
|
161
|
+
description=questionary.text(
|
|
162
|
+
"组件描述 (可选):",
|
|
163
|
+
default="",
|
|
164
|
+
),
|
|
165
|
+
use_components_folder=questionary.select(
|
|
166
|
+
"组件文件存放位置:",
|
|
167
|
+
choices=[
|
|
168
|
+
questionary.Choice("components/ 文件夹 (推荐)", value=True),
|
|
169
|
+
questionary.Choice("插件根目录", value=False),
|
|
170
|
+
],
|
|
171
|
+
),
|
|
172
|
+
force=questionary.confirm(
|
|
173
|
+
"如果文件存在,是否覆盖?",
|
|
174
|
+
default=False,
|
|
175
|
+
),
|
|
176
|
+
).ask()
|
|
177
|
+
|
|
178
|
+
return answers
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _detect_plugin_name(work_dir: Path) -> str | None:
|
|
182
|
+
"""
|
|
183
|
+
检测插件名称
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
work_dir: 工作目录
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
插件名称,未检测到则返回 None
|
|
190
|
+
"""
|
|
191
|
+
# 检查 plugin.py 文件
|
|
192
|
+
plugin_file = work_dir / "plugin.py"
|
|
193
|
+
if not plugin_file.exists():
|
|
194
|
+
# 尝试在父目录查找
|
|
195
|
+
plugin_file = work_dir.parent / "plugin.py"
|
|
196
|
+
if not plugin_file.exists():
|
|
197
|
+
return None
|
|
198
|
+
work_dir = work_dir.parent
|
|
199
|
+
|
|
200
|
+
# 从目录名推断插件名
|
|
201
|
+
return work_dir.name
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _generate_component_file(
|
|
205
|
+
work_dir: Path,
|
|
206
|
+
component_type: str,
|
|
207
|
+
component_name: str,
|
|
208
|
+
context: dict,
|
|
209
|
+
force: bool,
|
|
210
|
+
verbose: bool,
|
|
211
|
+
use_components_folder: bool = True,
|
|
212
|
+
) -> Path | None:
|
|
213
|
+
"""
|
|
214
|
+
生成组件文件
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
work_dir: 工作目录
|
|
218
|
+
component_type: 组件类型
|
|
219
|
+
component_name: 组件名称
|
|
220
|
+
context: 模板上下文
|
|
221
|
+
force: 是否覆盖
|
|
222
|
+
verbose: 详细输出
|
|
223
|
+
use_components_folder: 是否使用 components 文件夹,False 则在根目录生成
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
生成的文件路径,失败返回 None
|
|
227
|
+
"""
|
|
228
|
+
# 确定组件目录
|
|
229
|
+
if use_components_folder:
|
|
230
|
+
component_dir = work_dir / "components" / f"{component_type}s"
|
|
231
|
+
ensure_dir(component_dir)
|
|
232
|
+
|
|
233
|
+
# 确保 __init__.py 存在
|
|
234
|
+
init_file = component_dir / "__init__.py"
|
|
235
|
+
if not init_file.exists():
|
|
236
|
+
safe_write_file(init_file, f'"""\n{component_type.title()}s 组件\n"""\n')
|
|
237
|
+
else:
|
|
238
|
+
# 在插件根目录生成
|
|
239
|
+
component_dir = work_dir
|
|
240
|
+
|
|
241
|
+
# 生成组件文件
|
|
242
|
+
component_file = component_dir / f"{component_name}.py"
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
# 组件类型到模板 key 的映射(此时 component_type 已经是标准化的下划线格式)
|
|
246
|
+
type_map = {
|
|
247
|
+
"action": "action",
|
|
248
|
+
"tool": "tool",
|
|
249
|
+
"event": "event",
|
|
250
|
+
"adapter": "adapter",
|
|
251
|
+
"prompt": "prompt",
|
|
252
|
+
"plus_command": "plus_command",
|
|
253
|
+
"chatter":"chatter",
|
|
254
|
+
"router":"router"
|
|
255
|
+
}
|
|
256
|
+
template_key = type_map.get(component_type)
|
|
257
|
+
if not template_key:
|
|
258
|
+
print_error(f"不支持的组件类型: {component_type}")
|
|
259
|
+
return None
|
|
260
|
+
|
|
261
|
+
from mpdt.templates import get_component_template
|
|
262
|
+
template = get_component_template(template_key)
|
|
263
|
+
content = template.format(**context)
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
safe_write_file(component_file, content, force=force)
|
|
267
|
+
if verbose:
|
|
268
|
+
console.print(f"[dim]✓ 生成文件: {component_file}[/dim]")
|
|
269
|
+
return component_file
|
|
270
|
+
except FileExistsError:
|
|
271
|
+
print_error(f"文件已存在: {component_file}")
|
|
272
|
+
print_warning("使用 --force 选项覆盖已存在的文件")
|
|
273
|
+
return None
|
|
274
|
+
except Exception as e:
|
|
275
|
+
print_error(f"生成文件失败: {e}")
|
|
276
|
+
return None
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _update_plugin_registration(
|
|
280
|
+
work_dir: Path,
|
|
281
|
+
component_type: str,
|
|
282
|
+
component_name: str,
|
|
283
|
+
context: dict,
|
|
284
|
+
verbose: bool,
|
|
285
|
+
use_components_folder: bool = True,
|
|
286
|
+
) -> bool:
|
|
287
|
+
"""
|
|
288
|
+
更新插件注册代码 (使用 CodeParser)
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
work_dir: 工作目录
|
|
292
|
+
component_type: 组件类型
|
|
293
|
+
component_name: 组件名称
|
|
294
|
+
context: 模板上下文
|
|
295
|
+
verbose: 详细输出
|
|
296
|
+
use_components_folder: 是否使用 components 文件夹
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
是否更新成功
|
|
300
|
+
"""
|
|
301
|
+
plugin_file = work_dir / "plugin.py"
|
|
302
|
+
if not plugin_file.exists():
|
|
303
|
+
return False
|
|
304
|
+
|
|
305
|
+
try:
|
|
306
|
+
# 使用 plugin_parser 验证插件名称
|
|
307
|
+
parsed_plugin_name = extract_plugin_name(work_dir)
|
|
308
|
+
if not parsed_plugin_name:
|
|
309
|
+
# 如果无法从类属性中解析,使用目录名作为后备方案
|
|
310
|
+
parsed_plugin_name = work_dir.name
|
|
311
|
+
|
|
312
|
+
# 使用 CodeParser 读取和解析源代码
|
|
313
|
+
parser = CodeParser.from_file(plugin_file)
|
|
314
|
+
|
|
315
|
+
# 创建转换器
|
|
316
|
+
transformer = PluginRegistrationTransformer(
|
|
317
|
+
plugin_name=parsed_plugin_name,
|
|
318
|
+
component_type=component_type,
|
|
319
|
+
component_name=component_name,
|
|
320
|
+
class_name=context["class_name"],
|
|
321
|
+
use_components_folder=use_components_folder,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# 应用转换
|
|
325
|
+
modified_tree = parser.module.visit(transformer)
|
|
326
|
+
|
|
327
|
+
# 写回文件
|
|
328
|
+
plugin_file.write_text(modified_tree.code, encoding="utf-8")
|
|
329
|
+
|
|
330
|
+
return transformer.import_added or transformer.registration_added
|
|
331
|
+
|
|
332
|
+
except Exception:
|
|
333
|
+
return False
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class PluginRegistrationTransformer(cst.CSTTransformer):
|
|
337
|
+
"""用于添加组件导入和注册的 CST 转换器"""
|
|
338
|
+
|
|
339
|
+
def __init__(
|
|
340
|
+
self,
|
|
341
|
+
plugin_name: str,
|
|
342
|
+
component_type: str,
|
|
343
|
+
component_name: str,
|
|
344
|
+
class_name: str,
|
|
345
|
+
use_components_folder: bool = True,
|
|
346
|
+
):
|
|
347
|
+
self.plugin_name = plugin_name
|
|
348
|
+
self.component_type = component_type
|
|
349
|
+
self.component_name = component_name
|
|
350
|
+
self.class_name = class_name
|
|
351
|
+
self.use_components_folder = use_components_folder
|
|
352
|
+
self.import_added = False
|
|
353
|
+
self.registration_added = False
|
|
354
|
+
|
|
355
|
+
def leave_Module( # noqa: N802
|
|
356
|
+
self, original_node: cst.Module, updated_node: cst.Module
|
|
357
|
+
) -> cst.Module:
|
|
358
|
+
"""在模块级别添加导入语句"""
|
|
359
|
+
if self.import_added:
|
|
360
|
+
return updated_node
|
|
361
|
+
|
|
362
|
+
# 根据存放位置构建导入语句
|
|
363
|
+
if self.use_components_folder:
|
|
364
|
+
import_path = f"{self.plugin_name}.components.{self.component_type}s.{self.component_name}"
|
|
365
|
+
else:
|
|
366
|
+
import_path = f"{self.plugin_name}.{self.component_name}"
|
|
367
|
+
|
|
368
|
+
import_statement = cst.parse_statement(
|
|
369
|
+
f"from {import_path} import {self.class_name}"
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# 检查是否已存在相同的导入
|
|
373
|
+
for stmt in updated_node.body:
|
|
374
|
+
if isinstance(stmt, cst.SimpleStatementLine):
|
|
375
|
+
for s in stmt.body:
|
|
376
|
+
if isinstance(s, cst.ImportFrom) and s.module:
|
|
377
|
+
module_str = cst.Module([]).code_for_node(s.module)
|
|
378
|
+
if module_str == import_path:
|
|
379
|
+
self.import_added = True
|
|
380
|
+
return updated_node
|
|
381
|
+
|
|
382
|
+
# 找到最后一个导入语句的位置
|
|
383
|
+
last_import_idx = -1
|
|
384
|
+
for idx, stmt in enumerate(updated_node.body):
|
|
385
|
+
if isinstance(stmt, cst.SimpleStatementLine):
|
|
386
|
+
for s in stmt.body:
|
|
387
|
+
if isinstance(s, cst.Import | cst.ImportFrom):
|
|
388
|
+
last_import_idx = idx
|
|
389
|
+
|
|
390
|
+
# 在最后一个导入后添加新导入
|
|
391
|
+
if last_import_idx >= 0:
|
|
392
|
+
new_body = list(updated_node.body)
|
|
393
|
+
new_body.insert(last_import_idx + 1, import_statement)
|
|
394
|
+
self.import_added = True
|
|
395
|
+
return updated_node.with_changes(body=new_body)
|
|
396
|
+
|
|
397
|
+
return updated_node
|
|
398
|
+
|
|
399
|
+
def leave_FunctionDef( # noqa: N802
|
|
400
|
+
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
|
|
401
|
+
) -> cst.FunctionDef:
|
|
402
|
+
"""在 get_plugin_components 函数中添加注册代码"""
|
|
403
|
+
if updated_node.name.value != "get_plugin_components":
|
|
404
|
+
return updated_node
|
|
405
|
+
|
|
406
|
+
if self.registration_added:
|
|
407
|
+
return updated_node
|
|
408
|
+
|
|
409
|
+
# 根据组件类型生成对应的 get_xxx_info() 方法调用
|
|
410
|
+
info_method_map = {
|
|
411
|
+
"action": "get_action_info",
|
|
412
|
+
"tool": "get_tool_info",
|
|
413
|
+
"event": "get_event_handler_info",
|
|
414
|
+
"adapter": "get_adapter_info",
|
|
415
|
+
"prompt": "get_prompt_info",
|
|
416
|
+
"plus_command": "get_command_info",
|
|
417
|
+
"chatter": "get_chatter_info",
|
|
418
|
+
"router": "get_router_info",
|
|
419
|
+
}
|
|
420
|
+
info_method = info_method_map.get(self.component_type, "get_component_info")
|
|
421
|
+
|
|
422
|
+
# 构建注册代码(带注释的语句)
|
|
423
|
+
registration_stmt = f"components.append(({self.class_name}.{info_method}(), {self.class_name})) # 注册 {self.class_name}"
|
|
424
|
+
|
|
425
|
+
# 检查是否已存在注册代码
|
|
426
|
+
function_code = cst.Module([]).code_for_node(updated_node)
|
|
427
|
+
if self.class_name in function_code and info_method in function_code:
|
|
428
|
+
self.registration_added = True
|
|
429
|
+
return updated_node
|
|
430
|
+
|
|
431
|
+
# 找到 return 语句并在其前面插入注册代码
|
|
432
|
+
new_body = []
|
|
433
|
+
for stmt in updated_node.body.body:
|
|
434
|
+
# 如果是 return 语句,在前面插入注册代码
|
|
435
|
+
if isinstance(stmt, cst.SimpleStatementLine):
|
|
436
|
+
for s in stmt.body:
|
|
437
|
+
if isinstance(s, cst.Return):
|
|
438
|
+
# 插入注册代码
|
|
439
|
+
new_body.append(cst.parse_statement(registration_stmt))
|
|
440
|
+
self.registration_added = True
|
|
441
|
+
|
|
442
|
+
new_body.append(stmt)
|
|
443
|
+
|
|
444
|
+
if self.registration_added:
|
|
445
|
+
new_function_body = updated_node.body.with_changes(body=new_body)
|
|
446
|
+
return updated_node.with_changes(body=new_function_body)
|
|
447
|
+
|
|
448
|
+
return updated_node
|