iflow-mcp_galaxyxieyu_api-auto-test 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atf/mcp_server.py ADDED
@@ -0,0 +1,169 @@
1
+ """
2
+ API Auto Test MCP Server
3
+ MCP 服务器主入口
4
+
5
+ 架构说明:
6
+ - mcp/models.py: 所有 Pydantic 数据模型
7
+ - mcp/utils.py: 工具函数
8
+ - mcp/executor.py: 测试执行逻辑
9
+ - mcp/tools/*.py: 各功能工具实现
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import inspect
16
+ import json
17
+ import os
18
+ import sys
19
+
20
+ from mcp.server import FastMCP
21
+
22
+ from atf.core.log_manager import log
23
+ # 导入各工具注册函数
24
+ from atf.mcp.tools.health_tool import register_health_tool
25
+ from atf.mcp.tools.metrics_tools import register_metrics_tools
26
+ from atf.mcp.tools.testcase_tools import register_testcase_tools
27
+ from atf.mcp.tools.unittest_tools import register_unittest_tools
28
+ from atf.mcp.tools.runner_tools import register_runner_tools
29
+
30
+
31
+ # 创建 MCP 服务器实例
32
+ mcp = FastMCP(name="api-auto-test-mcp")
33
+
34
+
35
+ def register_all_tools() -> None:
36
+ """注册所有 MCP 工具"""
37
+ register_health_tool(mcp)
38
+ register_metrics_tools(mcp)
39
+ register_testcase_tools(mcp)
40
+ register_unittest_tools(mcp)
41
+ register_runner_tools(mcp)
42
+
43
+
44
+ def _build_parser() -> argparse.ArgumentParser:
45
+ parser = argparse.ArgumentParser(description="API Auto Test MCP Server")
46
+ parser.add_argument(
47
+ "--transport",
48
+ choices=["stdio", "sse"],
49
+ default=os.getenv("MCP_TRANSPORT", "stdio"),
50
+ help="传输方式,默认 stdio",
51
+ )
52
+ parser.add_argument(
53
+ "--host",
54
+ default=os.getenv("MCP_HOST", "127.0.0.1"),
55
+ help="SSE 监听地址(仅 sse 生效)",
56
+ )
57
+ parser.add_argument(
58
+ "--port",
59
+ type=int,
60
+ default=int(os.getenv("MCP_PORT", "8000")),
61
+ help="SSE 监听端口(仅 sse 生效)",
62
+ )
63
+ parser.add_argument(
64
+ "--sse-path",
65
+ default=os.getenv("MCP_SSE_PATH", "/mcp"),
66
+ help="SSE 路由路径(仅 sse 生效)",
67
+ )
68
+ parser.add_argument(
69
+ "--auth-token",
70
+ default=os.getenv("MCP_AUTH_TOKEN"),
71
+ help="SSE 鉴权 Token(仅 sse 生效,建议使用网关/反代统一鉴权)",
72
+ )
73
+ return parser
74
+
75
+
76
+ def _filter_run_kwargs(kwargs: dict[str, object]) -> dict[str, object]:
77
+ try:
78
+ signature = inspect.signature(mcp.run)
79
+ except (TypeError, ValueError):
80
+ return kwargs
81
+
82
+ if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
83
+ return kwargs
84
+
85
+ allowed = {name for name in signature.parameters.keys()}
86
+ filtered = {key: value for key, value in kwargs.items() if key in allowed}
87
+ unsupported = set(kwargs) - set(filtered)
88
+ if unsupported:
89
+ log.warning(f"SSE 参数未被当前 mcp 版本支持,将忽略: {sorted(unsupported)}")
90
+ return filtered
91
+
92
+
93
+ def main() -> None:
94
+ """MCP 服务器入口函数,支持 uv run mcp install"""
95
+ # 检查是否有 install 子命令
96
+ if len(sys.argv) > 1 and sys.argv[1] == "install":
97
+ install_mcp_config()
98
+ return
99
+
100
+ # 注册所有工具
101
+ register_all_tools()
102
+
103
+ parser = _build_parser()
104
+ args = parser.parse_args()
105
+
106
+ if args.transport == "stdio":
107
+ mcp.run("stdio")
108
+ return
109
+
110
+ if not args.auth_token:
111
+ log.warning("SSE 模式未设置 auth token,建议在网关或反代层做鉴权")
112
+
113
+ run_kwargs = {
114
+ "host": args.host,
115
+ "port": args.port,
116
+ "path": args.sse_path,
117
+ }
118
+ if args.auth_token:
119
+ run_kwargs["auth_token"] = args.auth_token
120
+
121
+ try:
122
+ mcp.run("sse", **_filter_run_kwargs(run_kwargs))
123
+ except TypeError as exc:
124
+ log.error(f"SSE 启动失败,请检查 mcp 版本与参数兼容性: {exc}")
125
+ raise
126
+
127
+
128
+ def install_mcp_config() -> None:
129
+ """安装 MCP 配置到 Claude Code"""
130
+ mcp_config = {
131
+ "command": "api-auto-test-mcp",
132
+ "args": ["--workspace", "${workspace}"]
133
+ }
134
+
135
+ # 尝试找到 Claude Code 的 MCP 配置文件
136
+ config_path = None
137
+ for path in [
138
+ os.path.expanduser("~/.claude/.mcp.json"),
139
+ os.path.expanduser("~/.config/claude/mcp_settings.json"),
140
+ ]:
141
+ if os.path.exists(path):
142
+ config_path = path
143
+ break
144
+
145
+ if config_path:
146
+ with open(config_path, 'r') as f:
147
+ try:
148
+ config = json.load(f)
149
+ except:
150
+ config = {}
151
+ else:
152
+ config_path = os.path.expanduser("~/.claude/.mcp.json")
153
+ os.makedirs(os.path.dirname(config_path), exist_ok=True)
154
+ config = {}
155
+
156
+ mcp_servers = config.get("mcpServers", {})
157
+ mcp_servers["api-auto-test-mcp"] = mcp_config
158
+ config["mcpServers"] = mcp_servers
159
+
160
+ with open(config_path, 'w') as f:
161
+ json.dump(config, f, indent=2)
162
+
163
+ print(f"已配置 MCP 服务器到: {config_path}")
164
+ print(f"配置内容: {json.dumps(mcp_config, indent=2)}")
165
+ print("\n请重启 Claude Code 以加载新的 MCP 服务器")
166
+
167
+
168
+ if __name__ == "__main__":
169
+ main()
atf/runner.py ADDED
@@ -0,0 +1,134 @@
1
+ # @time: 2024-08-16
2
+ # @author: xiaoqq
3
+
4
+ import os
5
+ import pytest
6
+ from atf.core.config_manager import ConfigManager
7
+ from atf.core.log_manager import logger
8
+ from atf.core.globals import Globals
9
+ from atf.core.log_manager import log
10
+ from atf.handlers.notification_handler import NotificationHandler
11
+ from atf.handlers.report_generator import ReportGenerator
12
+ from atf.core.login_handler import LoginHandler
13
+
14
+ def execute_test_cases(testcases, env, report_type):
15
+ """
16
+ 根据传入的 testcases 参数执行指定的测试用例
17
+ :param testcases: 测试用例路径列表(目录或文件)
18
+ :param env: 环境名称
19
+ """
20
+ config_manager = ConfigManager()
21
+ login_handler = LoginHandler()
22
+ projects_config = {}
23
+
24
+ # 当前执行环境存入全局变量
25
+ Globals.set('env', env)
26
+
27
+ # 遍历所有测试用例路径
28
+ for testcase_path in testcases:
29
+ if os.path.isdir(testcase_path):
30
+ project_name = get_project_name_from_path(testcase_path)
31
+ project_env_config = config_manager.get_project_env_config(project_name, env)
32
+ if project_env_config is not None:
33
+ projects_config.update(project_env_config)
34
+ else:
35
+ log.warning(f"{project_name} 的配置未找到,跳过该项目。")
36
+ elif os.path.isfile(testcase_path):
37
+ project_name = get_project_name_from_path(os.path.dirname(testcase_path))
38
+ project_env_config = config_manager.get_project_env_config(project_name, env)
39
+ if project_env_config is not None:
40
+ projects_config.update(project_env_config)
41
+ else:
42
+ log.warning(f"{project_name} 的配置未找到,跳过该项目。")
43
+
44
+ # 判断项目是否需要登录获取token,需要则登录获取token存入全局变量
45
+ for project_name, project_config in projects_config.items():
46
+ project_env_config = config_manager.get_project_env_config(project_name, env)
47
+ token = login_handler.login_if_needed(project_name, project_env_config.get(project_name), env)
48
+ if token:
49
+ # 将 token 存入全局变量
50
+ Globals.update(project_name, "token", token)
51
+
52
+ log.info(f"工程配置信息更新:{Globals.get_data()}")
53
+
54
+ # 测试报告处理,根据 report 类型准备报告目录和文件
55
+ report_generator = ReportGenerator(report_type, env)
56
+ if report_type == 'allure':
57
+ report_generator.prepare_allure_report()
58
+ pytest_args = [f'--alluredir={report_generator.allure_results_dir}']
59
+ elif report_type == 'pytest-html':
60
+ report_generator.prepare_pytest_html_report()
61
+ pytest_args = [f'--html={report_generator.html_report_path}', '--self-contained-html']
62
+ else:
63
+ pytest_args = []
64
+
65
+ for testcase_path in testcases:
66
+ pytest_args.append(testcase_path)
67
+
68
+ try:
69
+ pytest.main(pytest_args)
70
+ logger.info(f"所有测试用例执行完成。")
71
+ except Exception as e:
72
+ logger.error(f"执行所有测试用例时出错: {str(e)}")
73
+
74
+ if report_type == 'allure':
75
+ report_generator.generate_allure_report()
76
+
77
+ def get_project_name_from_path(path):
78
+ """
79
+ 根据路径获取项目名称
80
+ :param path: 路径字符串
81
+ :return: 项目名称
82
+ """
83
+ parts = os.path.normpath(path).split(os.sep)
84
+ # 确保有足够的部分返回项目名称
85
+ if len(parts) > 1:
86
+ return parts[1]
87
+ elif len(parts) == 1:
88
+ return parts[0]
89
+ return None
90
+
91
+ def run_tests(testcases=None, env=None, report_type=None):
92
+ """
93
+ 主运行函数,执行指定的测试用例
94
+ :param testcases: 测试用例路径列表(目录或文件),默认为 None 则执行所有
95
+ :param env: 执行环境,可选参数,默认为 None 则执行 test 环境
96
+ :param report: 报告类型(allure 或 pytest-html),默认为 pytest-html
97
+ """
98
+ if testcases is None:
99
+ # 获取所有嵌套的测试用例路径
100
+ testcases = [os.path.join(root, file)
101
+ for root, dirs, files in os.walk('test_cases')
102
+ for file in files if file.endswith('.py')]
103
+ elif testcases == ['test_cases/']:
104
+ # 当指定为 test_cases/ 时,获取该目录下所有的子文件夹
105
+ testcases = [os.path.join('test_cases', d)
106
+ for d in os.listdir('test_cases')
107
+ if os.path.isdir(os.path.join('test_cases', d))]
108
+
109
+ if env is None:
110
+ env = 'pre'
111
+
112
+ if report_type is None:
113
+ report_type = 'pytest-html'
114
+
115
+ execute_test_cases(testcases, env, report_type)
116
+
117
+ results = Globals.get('test_results')
118
+ webhook = Globals.get('dingtalk').get('webhook')
119
+ secret = Globals.get('dingtalk').get('secret')
120
+ NotificationHandler(webhook, secret).send_markdown_msg(
121
+ conclusion=results['conclusion'],
122
+ total=results['total'],
123
+ passed=results['passed'],
124
+ failed=results['failed'],
125
+ error=results['error'],
126
+ skipped=results['skipped'],
127
+ start_time=results['start_time'],
128
+ duration=results['duration']
129
+ )
130
+
131
+
132
+ if __name__ == '__main__':
133
+ # report_type 可以为allure、pytest-html,默认pytest-html
134
+ run_tests(testcases=['test_cases/'], env='pre', report_type='allure')
@@ -0,0 +1,337 @@
1
+ # @time: 2024-12-25
2
+ # @author: auto-generated
3
+
4
+ import os
5
+ from pathlib import Path
6
+
7
+ import yaml
8
+
9
+ from atf.core.log_manager import log
10
+
11
+
12
+ class UnitCaseGenerator:
13
+ """单元测试用例文件生成器"""
14
+
15
+ def generate_unit_tests(self, yaml_file: str, output_dir: str | None = None, overwrite: bool = False) -> str | None:
16
+ """
17
+ 根据 YAML 文件生成单元测试用例
18
+ :param yaml_file: YAML 文件路径
19
+ :param output_dir: 输出目录
20
+ :param overwrite: 是否覆盖已存在的文件
21
+ :return: 生成的文件路径
22
+ """
23
+ test_data = self._load_yaml(yaml_file)
24
+ if not test_data:
25
+ return None
26
+
27
+ if not self._validate_unittest_data(test_data):
28
+ log.warning(f"{yaml_file} 数据校验不通过,跳过生成。")
29
+ return None
30
+
31
+ unittest_data = test_data["unittest"]
32
+ file_path = self._get_output_path(yaml_file, unittest_data["name"], output_dir)
33
+
34
+ if os.path.exists(file_path):
35
+ if overwrite:
36
+ log.info(f"覆盖已存在的测试用例文件: {file_path}")
37
+ else:
38
+ log.info(f"测试用例文件已存在,跳过生成: {file_path}")
39
+ return None
40
+
41
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
42
+ self._generate_file(file_path, yaml_file, unittest_data)
43
+ log.info(f"已生成单元测试文件: {file_path}")
44
+ return file_path
45
+
46
+ def _load_yaml(self, yaml_file: str) -> dict | None:
47
+ """加载 YAML 文件"""
48
+ try:
49
+ with open(yaml_file, "r", encoding="utf-8") as f:
50
+ return yaml.safe_load(f)
51
+ except FileNotFoundError:
52
+ log.error(f"未找到文件: {yaml_file}")
53
+ except yaml.YAMLError as e:
54
+ log.error(f"YAML 解析错误: {e}")
55
+ return None
56
+
57
+ def _validate_unittest_data(self, data: dict) -> bool:
58
+ """校验单元测试数据"""
59
+ if not data.get("unittest"):
60
+ log.error("数据必须包含 'unittest' 键")
61
+ return False
62
+
63
+ ut = data["unittest"]
64
+ if not ut.get("name"):
65
+ log.error("unittest.name 不能为空")
66
+ return False
67
+
68
+ if not ut.get("target") or not ut["target"].get("module"):
69
+ log.error("unittest.target.module 不能为空")
70
+ return False
71
+
72
+ if not ut.get("cases"):
73
+ log.error("unittest.cases 不能为空")
74
+ return False
75
+
76
+ for case in ut["cases"]:
77
+ if not case.get("id"):
78
+ log.error("unittest.cases.id 不能为空")
79
+ return False
80
+
81
+ return True
82
+
83
+ def _get_output_path(self, yaml_file: str, name: str, output_dir: str | None) -> str:
84
+ """获取输出文件路径"""
85
+ relative_path = os.path.relpath(yaml_file, "tests")
86
+ path_parts = relative_path.split(os.sep)
87
+ if path_parts:
88
+ path_parts.pop() # 移除文件名
89
+ dir_path = os.path.join(*path_parts) if path_parts else ""
90
+
91
+ file_name = f"test_{name}.py"
92
+ base_dir = output_dir or "test_cases"
93
+ return os.path.join(base_dir, dir_path, file_name)
94
+
95
+ def _generate_file(self, file_path: str, yaml_file: str, ut: dict) -> None:
96
+ """生成测试文件"""
97
+ with open(file_path, "w", encoding="utf-8") as f:
98
+ self._write_imports(f, ut)
99
+ self._write_class_header(f, ut)
100
+ self._write_test_methods(f, ut, yaml_file)
101
+
102
+ def _write_imports(self, f, ut: dict) -> None:
103
+ """写入导入语句"""
104
+ f.write("# Auto-generated unit test module\n\n")
105
+
106
+ # 写入运行方式注释(根据 env_type)
107
+ env_type = ut.get("env_type", "venv")
108
+ if env_type == "venv":
109
+ f.write("# 运行方式: source .venv/bin/activate && pytest test_cases/ -v\n\n")
110
+ elif env_type == "conda":
111
+ f.write("# 运行方式: conda activate <env_name> && pytest test_cases/ -v\n\n")
112
+ elif env_type == "uv":
113
+ f.write("# 运行方式: uv run pytest test_cases/ -v\n\n")
114
+
115
+ f.write("import pytest\n")
116
+ f.write("from unittest.mock import patch, MagicMock, call\n")
117
+ f.write("import allure\n")
118
+ f.write("import yaml\n\n")
119
+
120
+ # 导入被测模块
121
+ target = ut["target"]
122
+ module = target["module"]
123
+ class_name = target.get("class")
124
+ func_name = target.get("function")
125
+
126
+ if class_name:
127
+ f.write(f"from {module} import {class_name}\n\n")
128
+ elif func_name:
129
+ f.write(f"from {module} import {func_name}\n\n")
130
+ else:
131
+ f.write(f"import {module}\n\n")
132
+
133
+ def _write_class_header(self, f, ut: dict) -> None:
134
+ """写入类定义和装饰器"""
135
+ name = ut["name"]
136
+ class_name = "".join(s.capitalize() for s in name.split("_"))
137
+ allure_cfg = ut.get("allure", {})
138
+
139
+ epic = allure_cfg.get("epic", "单元测试")
140
+ feature = allure_cfg.get("feature")
141
+
142
+ f.write(f"@allure.epic('{epic}')\n")
143
+ if feature:
144
+ f.write(f"@allure.feature('{feature}')\n")
145
+ f.write(f"class Test{class_name}:\n")
146
+ f.write(f' """单元测试: {ut.get("description", name)}"""\n\n')
147
+
148
+ def _write_test_methods(self, f, ut: dict, yaml_file: str) -> None:
149
+ """写入测试方法"""
150
+ target = ut["target"]
151
+ allure_cfg = ut.get("allure", {})
152
+ story = allure_cfg.get("story", ut["name"])
153
+
154
+ for case in ut["cases"]:
155
+ self._write_single_test(f, case, target, story)
156
+
157
+ def _write_single_test(self, f, case: dict, target: dict, story: str) -> None:
158
+ """写入单个测试方法"""
159
+ case_id = case["id"]
160
+ desc = case.get("description", case_id)
161
+ mocks = case.get("mocks", [])
162
+
163
+ # 写入装饰器
164
+ f.write(f" @allure.story('{story}')\n")
165
+ self._write_mock_decorators(f, mocks)
166
+
167
+ # 写入方法签名
168
+ mock_params = self._get_mock_params(mocks)
169
+ f.write(f" def {case_id}(self{mock_params}):\n")
170
+ f.write(f' """{desc}"""\n')
171
+
172
+ # 写入 mock 配置
173
+ self._write_mock_setup(f, mocks)
174
+
175
+ # 写入执行代码
176
+ self._write_execution(f, case, target)
177
+
178
+ # 写入断言
179
+ self._write_assertions(f, case, mocks)
180
+
181
+ f.write("\n")
182
+
183
+ def _write_mock_decorators(self, f, mocks: list) -> None:
184
+ """写入 mock 装饰器"""
185
+ for mock in reversed(mocks):
186
+ target = mock["target"]
187
+ method = mock.get("method")
188
+ if method:
189
+ f.write(f" @patch('{target}.{method}')\n")
190
+ else:
191
+ f.write(f" @patch('{target}')\n")
192
+
193
+ def _get_mock_params(self, mocks: list) -> str:
194
+ """获取 mock 参数列表"""
195
+ if not mocks:
196
+ return ""
197
+ params = []
198
+ for mock in mocks:
199
+ name = self._get_mock_var_name(mock)
200
+ params.append(name)
201
+ return ", " + ", ".join(params)
202
+
203
+ def _get_mock_var_name(self, mock: dict) -> str:
204
+ """获取 mock 变量名"""
205
+ target = mock["target"]
206
+ method = mock.get("method")
207
+ if method:
208
+ return f"mock_{method.lower()}"
209
+ parts = target.split(".")
210
+ return f"mock_{parts[-1].lower()}"
211
+
212
+ def _write_mock_setup(self, f, mocks: list) -> None:
213
+ """写入 mock 配置"""
214
+ for mock in mocks:
215
+ var_name = self._get_mock_var_name(mock)
216
+ ret_val = mock.get("return_value")
217
+ side_effect = mock.get("side_effect")
218
+
219
+ if ret_val is not None:
220
+ f.write(f" {var_name}.return_value = {repr(ret_val)}\n")
221
+ if side_effect is not None:
222
+ f.write(f" {var_name}.side_effect = {repr(side_effect)}\n")
223
+
224
+ def _write_execution(self, f, case: dict, target: dict) -> None:
225
+ """写入执行代码"""
226
+ inputs = case.get("inputs", {})
227
+ args = inputs.get("args", [])
228
+ kwargs = inputs.get("kwargs", {})
229
+
230
+ class_name = target.get("class")
231
+ func_name = target.get("function")
232
+
233
+ f.write("\n # 执行\n")
234
+
235
+ # 构建参数字符串
236
+ args_str = ", ".join(repr(a) for a in args)
237
+ kwargs_str = ", ".join(f"{k}={repr(v)}" for k, v in kwargs.items())
238
+ all_args = ", ".join(filter(None, [args_str, kwargs_str]))
239
+
240
+ if class_name and func_name:
241
+ f.write(f" instance = {class_name}()\n")
242
+ f.write(f" result = instance.{func_name}({all_args})\n")
243
+ elif class_name:
244
+ f.write(f" result = {class_name}({all_args})\n")
245
+ elif func_name:
246
+ f.write(f" result = {func_name}({all_args})\n")
247
+ else:
248
+ f.write(f" result = None # TODO: 补充执行逻辑\n")
249
+
250
+ def _write_assertions(self, f, case: dict, mocks: list) -> None:
251
+ """写入断言"""
252
+ asserts = case.get("assert", [])
253
+ if not asserts:
254
+ return
255
+
256
+ f.write("\n # 断言\n")
257
+ mock_var_map = {self._get_mock_key(m): self._get_mock_var_name(m) for m in mocks}
258
+
259
+ for assertion in asserts:
260
+ self._write_single_assertion(f, assertion, mock_var_map)
261
+
262
+ def _get_mock_key(self, mock: dict) -> str:
263
+ """获取 mock 键名"""
264
+ target = mock["target"]
265
+ method = mock.get("method")
266
+ parts = target.split(".")
267
+ if method:
268
+ return f"{parts[-1]}.{method}"
269
+ return parts[-1]
270
+
271
+ def _write_single_assertion(self, f, assertion: dict, mock_var_map: dict) -> None:
272
+ """写入单个断言"""
273
+ assert_type = assertion["type"]
274
+ field = assertion.get("field")
275
+ expected = assertion.get("expected")
276
+ mock_name = assertion.get("mock")
277
+ args = assertion.get("args", [])
278
+ kwargs = assertion.get("kwargs", {})
279
+ exception = assertion.get("exception")
280
+ message = assertion.get("message")
281
+
282
+ if assert_type == "equals":
283
+ if field and field != "result":
284
+ f.write(f" assert result{self._parse_field(field)} == {repr(expected)}\n")
285
+ else:
286
+ f.write(f" assert result == {repr(expected)}\n")
287
+
288
+ elif assert_type == "not_equals":
289
+ if field and field != "result":
290
+ f.write(f" assert result{self._parse_field(field)} != {repr(expected)}\n")
291
+ else:
292
+ f.write(f" assert result != {repr(expected)}\n")
293
+
294
+ elif assert_type == "contains":
295
+ f.write(f" assert {repr(expected)} in result\n")
296
+
297
+ elif assert_type == "is_none":
298
+ f.write(f" assert result is None\n")
299
+
300
+ elif assert_type == "is_not_none":
301
+ f.write(f" assert result is not None\n")
302
+
303
+ elif assert_type == "called_once" and mock_name:
304
+ var = mock_var_map.get(mock_name, f"mock_{mock_name.lower()}")
305
+ f.write(f" {var}.assert_called_once()\n")
306
+
307
+ elif assert_type == "called_with" and mock_name:
308
+ var = mock_var_map.get(mock_name, f"mock_{mock_name.lower()}")
309
+ args_str = ", ".join(repr(a) for a in args)
310
+ kwargs_str = ", ".join(f"{k}={repr(v)}" for k, v in kwargs.items())
311
+ all_args = ", ".join(filter(None, [args_str, kwargs_str]))
312
+ f.write(f" {var}.assert_called_with({all_args})\n")
313
+
314
+ elif assert_type == "not_called" and mock_name:
315
+ var = mock_var_map.get(mock_name, f"mock_{mock_name.lower()}")
316
+ f.write(f" {var}.assert_not_called()\n")
317
+
318
+ elif assert_type == "raises" and exception:
319
+ match_str = f", match={repr(message)}" if message else ""
320
+ f.write(f" # 注意: raises 断言需要包裹执行代码\n")
321
+ f.write(f" # with pytest.raises({exception}{match_str}):\n")
322
+ f.write(f" # 执行代码\n")
323
+
324
+ def _parse_field(self, field: str) -> str:
325
+ """解析字段路径为 Python 访问语法"""
326
+ if not field:
327
+ return ""
328
+ if field.startswith("$."):
329
+ field = field[2:]
330
+ parts = field.split(".")
331
+ result = ""
332
+ for part in parts:
333
+ if part.isdigit():
334
+ result += f"[{part}]"
335
+ else:
336
+ result += f"['{part}']"
337
+ return result
atf/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ # 工具函数模块
2
+ # 可以在这里导入自定义工具函数