pytest-dsl 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytest_dsl/__init__.py +10 -0
- pytest_dsl/cli.py +44 -0
- pytest_dsl/conftest_adapter.py +4 -0
- pytest_dsl/core/__init__.py +0 -0
- pytest_dsl/core/auth_provider.py +409 -0
- pytest_dsl/core/auto_decorator.py +181 -0
- pytest_dsl/core/auto_directory.py +81 -0
- pytest_dsl/core/context.py +23 -0
- pytest_dsl/core/custom_auth_example.py +425 -0
- pytest_dsl/core/dsl_executor.py +329 -0
- pytest_dsl/core/dsl_executor_utils.py +84 -0
- pytest_dsl/core/global_context.py +103 -0
- pytest_dsl/core/http_client.py +411 -0
- pytest_dsl/core/http_request.py +810 -0
- pytest_dsl/core/keyword_manager.py +109 -0
- pytest_dsl/core/lexer.py +139 -0
- pytest_dsl/core/parser.py +197 -0
- pytest_dsl/core/parsetab.py +76 -0
- pytest_dsl/core/plugin_discovery.py +187 -0
- pytest_dsl/core/utils.py +146 -0
- pytest_dsl/core/variable_utils.py +267 -0
- pytest_dsl/core/yaml_loader.py +62 -0
- pytest_dsl/core/yaml_vars.py +75 -0
- pytest_dsl/docs/custom_keywords.md +140 -0
- pytest_dsl/examples/__init__.py +5 -0
- pytest_dsl/examples/assert/assertion_example.auto +44 -0
- pytest_dsl/examples/assert/boolean_test.auto +34 -0
- pytest_dsl/examples/assert/expression_test.auto +49 -0
- pytest_dsl/examples/http/__init__.py +3 -0
- pytest_dsl/examples/http/builtin_auth_test.auto +79 -0
- pytest_dsl/examples/http/csrf_auth_test.auto +64 -0
- pytest_dsl/examples/http/custom_auth_test.auto +76 -0
- pytest_dsl/examples/http/file_reference_test.auto +111 -0
- pytest_dsl/examples/http/http_advanced.auto +91 -0
- pytest_dsl/examples/http/http_example.auto +147 -0
- pytest_dsl/examples/http/http_length_test.auto +55 -0
- pytest_dsl/examples/http/http_retry_assertions.auto +91 -0
- pytest_dsl/examples/http/http_retry_assertions_enhanced.auto +94 -0
- pytest_dsl/examples/http/http_with_yaml.auto +58 -0
- pytest_dsl/examples/http/new_retry_test.auto +22 -0
- pytest_dsl/examples/http/retry_assertions_only.auto +52 -0
- pytest_dsl/examples/http/retry_config_only.auto +49 -0
- pytest_dsl/examples/http/retry_debug.auto +22 -0
- pytest_dsl/examples/http/retry_with_fix.auto +21 -0
- pytest_dsl/examples/http/simple_retry.auto +20 -0
- pytest_dsl/examples/http/vars.yaml +55 -0
- pytest_dsl/examples/http_clients.yaml +48 -0
- pytest_dsl/examples/keyword_example.py +70 -0
- pytest_dsl/examples/test_assert.py +16 -0
- pytest_dsl/examples/test_http.py +168 -0
- pytest_dsl/keywords/__init__.py +10 -0
- pytest_dsl/keywords/assertion_keywords.py +610 -0
- pytest_dsl/keywords/global_keywords.py +51 -0
- pytest_dsl/keywords/http_keywords.py +430 -0
- pytest_dsl/keywords/system_keywords.py +17 -0
- pytest_dsl/main_adapter.py +7 -0
- pytest_dsl/plugin.py +44 -0
- pytest_dsl-0.1.0.dist-info/METADATA +537 -0
- pytest_dsl-0.1.0.dist-info/RECORD +63 -0
- pytest_dsl-0.1.0.dist-info/WHEEL +5 -0
- pytest_dsl-0.1.0.dist-info/entry_points.txt +5 -0
- pytest_dsl-0.1.0.dist-info/licenses/LICENSE +21 -0
- pytest_dsl-0.1.0.dist-info/top_level.txt +1 -0
pytest_dsl/__init__.py
ADDED
pytest_dsl/cli.py
ADDED
@@ -0,0 +1,44 @@
|
|
1
|
+
"""
|
2
|
+
pytest-dsl命令行入口
|
3
|
+
|
4
|
+
提供独立的命令行工具,用于执行DSL文件。
|
5
|
+
"""
|
6
|
+
|
7
|
+
import sys
|
8
|
+
import pytest
|
9
|
+
from pathlib import Path
|
10
|
+
|
11
|
+
from pytest_dsl.core.lexer import get_lexer
|
12
|
+
from pytest_dsl.core.parser import get_parser
|
13
|
+
from pytest_dsl.core.dsl_executor import DSLExecutor
|
14
|
+
|
15
|
+
|
16
|
+
def read_file(filename):
|
17
|
+
"""读取 DSL 文件内容"""
|
18
|
+
with open(filename, 'r', encoding='utf-8') as f:
|
19
|
+
return f.read()
|
20
|
+
|
21
|
+
|
22
|
+
def main():
|
23
|
+
"""命令行入口点"""
|
24
|
+
if len(sys.argv) < 2:
|
25
|
+
print("用法: python -m pytest_dsl.cli <dsl_file>")
|
26
|
+
sys.exit(1)
|
27
|
+
|
28
|
+
filename = sys.argv[1]
|
29
|
+
|
30
|
+
lexer = get_lexer()
|
31
|
+
parser = get_parser()
|
32
|
+
executor = DSLExecutor()
|
33
|
+
|
34
|
+
try:
|
35
|
+
dsl_code = read_file(filename)
|
36
|
+
ast = parser.parse(dsl_code, lexer=lexer)
|
37
|
+
executor.execute(ast)
|
38
|
+
except Exception as e:
|
39
|
+
print(f"执行失败: {e}")
|
40
|
+
sys.exit(1)
|
41
|
+
|
42
|
+
|
43
|
+
if __name__ == '__main__':
|
44
|
+
main()
|
File without changes
|
@@ -0,0 +1,409 @@
|
|
1
|
+
"""认证提供者模块
|
2
|
+
|
3
|
+
该模块提供了用于HTTP请求认证的接口和实现。
|
4
|
+
"""
|
5
|
+
|
6
|
+
import abc
|
7
|
+
import base64
|
8
|
+
import json
|
9
|
+
import logging
|
10
|
+
import time
|
11
|
+
from typing import Dict, Any, Optional, Callable, Union, Tuple, Type
|
12
|
+
import requests
|
13
|
+
from requests.auth import HTTPBasicAuth
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class AuthProvider(abc.ABC):
|
19
|
+
"""认证提供者基类"""
|
20
|
+
|
21
|
+
@abc.abstractmethod
|
22
|
+
def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
23
|
+
"""将认证信息应用到请求参数
|
24
|
+
|
25
|
+
Args:
|
26
|
+
request_kwargs: 请求参数字典
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
更新后的请求参数字典
|
30
|
+
"""
|
31
|
+
pass
|
32
|
+
|
33
|
+
def clean_auth_state(self, request_kwargs: Dict[str, Any] = None) -> Dict[str, Any]:
|
34
|
+
"""清理认证状态
|
35
|
+
|
36
|
+
此方法用于清理认证状态,例如移除认证头、清空会话Cookie等。
|
37
|
+
子类可以覆盖此方法以提供自定义的清理逻辑。
|
38
|
+
|
39
|
+
Args:
|
40
|
+
request_kwargs: 请求参数字典
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
更新后的请求参数字典
|
44
|
+
"""
|
45
|
+
# 默认实现:移除基本的认证头
|
46
|
+
if request_kwargs is None:
|
47
|
+
return {}
|
48
|
+
|
49
|
+
if "headers" in request_kwargs:
|
50
|
+
auth_headers = [
|
51
|
+
'Authorization', 'X-API-Key', 'X-Api-Key', 'api-key', 'Api-Key',
|
52
|
+
]
|
53
|
+
for header in auth_headers:
|
54
|
+
request_kwargs["headers"].pop(header, None)
|
55
|
+
|
56
|
+
# 移除认证参数
|
57
|
+
request_kwargs.pop('auth', None)
|
58
|
+
|
59
|
+
return request_kwargs
|
60
|
+
|
61
|
+
@property
|
62
|
+
def name(self) -> str:
|
63
|
+
"""返回认证提供者名称"""
|
64
|
+
return self.__class__.__name__
|
65
|
+
|
66
|
+
|
67
|
+
class BasicAuthProvider(AuthProvider):
|
68
|
+
"""基本认证提供者"""
|
69
|
+
|
70
|
+
def __init__(self, username: str, password: str):
|
71
|
+
"""初始化基本认证
|
72
|
+
|
73
|
+
Args:
|
74
|
+
username: 用户名
|
75
|
+
password: 密码
|
76
|
+
"""
|
77
|
+
self.username = username
|
78
|
+
self.password = password
|
79
|
+
|
80
|
+
def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
81
|
+
"""应用基本认证
|
82
|
+
|
83
|
+
Args:
|
84
|
+
request_kwargs: 请求参数
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
更新后的请求参数
|
88
|
+
"""
|
89
|
+
# 使用requests的基本认证
|
90
|
+
request_kwargs["auth"] = HTTPBasicAuth(self.username, self.password)
|
91
|
+
return request_kwargs
|
92
|
+
|
93
|
+
|
94
|
+
class TokenAuthProvider(AuthProvider):
|
95
|
+
"""令牌认证提供者"""
|
96
|
+
|
97
|
+
def __init__(self, token: str, scheme: str = "Bearer", header: str = "Authorization"):
|
98
|
+
"""初始化令牌认证
|
99
|
+
|
100
|
+
Args:
|
101
|
+
token: 认证令牌
|
102
|
+
scheme: 认证方案 (例如 "Bearer")
|
103
|
+
header: 认证头名称 (默认为 "Authorization")
|
104
|
+
"""
|
105
|
+
self.token = token
|
106
|
+
self.scheme = scheme
|
107
|
+
self.header = header
|
108
|
+
|
109
|
+
def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
110
|
+
"""应用令牌认证
|
111
|
+
|
112
|
+
Args:
|
113
|
+
request_kwargs: 请求参数
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
更新后的请求参数
|
117
|
+
"""
|
118
|
+
# 确保headers存在
|
119
|
+
if "headers" not in request_kwargs:
|
120
|
+
request_kwargs["headers"] = {}
|
121
|
+
|
122
|
+
# 添加认证头
|
123
|
+
if self.scheme:
|
124
|
+
request_kwargs["headers"][self.header] = f"{self.scheme} {self.token}"
|
125
|
+
else:
|
126
|
+
request_kwargs["headers"][self.header] = self.token
|
127
|
+
|
128
|
+
return request_kwargs
|
129
|
+
|
130
|
+
|
131
|
+
class ApiKeyAuthProvider(AuthProvider):
|
132
|
+
"""API Key认证提供者"""
|
133
|
+
|
134
|
+
def __init__(self, api_key: str, key_name: str = "X-API-Key", in_header: bool = True,
|
135
|
+
in_query: bool = False, query_param_name: str = None):
|
136
|
+
"""初始化API Key认证
|
137
|
+
|
138
|
+
Args:
|
139
|
+
api_key: API密钥
|
140
|
+
key_name: 密钥名称 (默认为 "X-API-Key")
|
141
|
+
in_header: 是否在请求头中添加密钥
|
142
|
+
in_query: 是否在查询参数中添加密钥
|
143
|
+
query_param_name: 查询参数名称 (如果与header名称不同)
|
144
|
+
"""
|
145
|
+
self.api_key = api_key
|
146
|
+
self.key_name = key_name
|
147
|
+
self.in_header = in_header
|
148
|
+
self.in_query = in_query
|
149
|
+
self.query_param_name = query_param_name or key_name
|
150
|
+
|
151
|
+
def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
152
|
+
"""应用API Key认证
|
153
|
+
|
154
|
+
Args:
|
155
|
+
request_kwargs: 请求参数
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
更新后的请求参数
|
159
|
+
"""
|
160
|
+
# 添加到请求头
|
161
|
+
if self.in_header:
|
162
|
+
if "headers" not in request_kwargs:
|
163
|
+
request_kwargs["headers"] = {}
|
164
|
+
request_kwargs["headers"][self.key_name] = self.api_key
|
165
|
+
|
166
|
+
# 添加到查询参数
|
167
|
+
if self.in_query:
|
168
|
+
if "params" not in request_kwargs:
|
169
|
+
request_kwargs["params"] = {}
|
170
|
+
request_kwargs["params"][self.query_param_name] = self.api_key
|
171
|
+
|
172
|
+
return request_kwargs
|
173
|
+
|
174
|
+
|
175
|
+
class OAuth2Provider(AuthProvider):
|
176
|
+
"""OAuth2认证提供者"""
|
177
|
+
|
178
|
+
def __init__(self, token_url: str, client_id: str, client_secret: str,
|
179
|
+
scope: str = None, grant_type: str = "client_credentials",
|
180
|
+
username: str = None, password: str = None,
|
181
|
+
token_refresh_window: int = 60):
|
182
|
+
"""初始化OAuth2认证
|
183
|
+
|
184
|
+
Args:
|
185
|
+
token_url: 获取令牌的URL
|
186
|
+
client_id: 客户端ID
|
187
|
+
client_secret: 客户端密钥
|
188
|
+
scope: 权限范围
|
189
|
+
grant_type: 授权类型 (默认为 "client_credentials")
|
190
|
+
username: 用户名 (如果grant_type为"password")
|
191
|
+
password: 密码 (如果grant_type为"password")
|
192
|
+
token_refresh_window: 令牌刷新窗口 (秒),在令牌过期前多少秒刷新
|
193
|
+
"""
|
194
|
+
self.token_url = token_url
|
195
|
+
self.client_id = client_id
|
196
|
+
self.client_secret = client_secret
|
197
|
+
self.scope = scope
|
198
|
+
self.grant_type = grant_type
|
199
|
+
self.username = username
|
200
|
+
self.password = password
|
201
|
+
self.token_refresh_window = token_refresh_window
|
202
|
+
|
203
|
+
# 令牌缓存
|
204
|
+
self._access_token = None
|
205
|
+
self._token_expires_at = 0
|
206
|
+
|
207
|
+
def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
208
|
+
"""应用OAuth2认证
|
209
|
+
|
210
|
+
Args:
|
211
|
+
request_kwargs: 请求参数
|
212
|
+
|
213
|
+
Returns:
|
214
|
+
更新后的请求参数
|
215
|
+
"""
|
216
|
+
# 确保有有效的令牌
|
217
|
+
self._ensure_valid_token()
|
218
|
+
|
219
|
+
# 确保headers存在
|
220
|
+
if "headers" not in request_kwargs:
|
221
|
+
request_kwargs["headers"] = {}
|
222
|
+
|
223
|
+
# 添加认证头
|
224
|
+
request_kwargs["headers"]["Authorization"] = f"Bearer {self._access_token}"
|
225
|
+
return request_kwargs
|
226
|
+
|
227
|
+
def _ensure_valid_token(self) -> None:
|
228
|
+
"""确保有有效的访问令牌"""
|
229
|
+
current_time = time.time()
|
230
|
+
|
231
|
+
# 如果令牌不存在或即将过期,刷新令牌
|
232
|
+
if not self._access_token or current_time + self.token_refresh_window >= self._token_expires_at:
|
233
|
+
self._refresh_token()
|
234
|
+
|
235
|
+
def _refresh_token(self) -> None:
|
236
|
+
"""刷新OAuth2令牌"""
|
237
|
+
data = {
|
238
|
+
"grant_type": self.grant_type,
|
239
|
+
"client_id": self.client_id,
|
240
|
+
"client_secret": self.client_secret
|
241
|
+
}
|
242
|
+
|
243
|
+
if self.scope:
|
244
|
+
data["scope"] = self.scope
|
245
|
+
|
246
|
+
# 对于密码模式
|
247
|
+
if self.grant_type == "password" and self.username and self.password:
|
248
|
+
data["username"] = self.username
|
249
|
+
data["password"] = self.password
|
250
|
+
|
251
|
+
try:
|
252
|
+
response = requests.post(self.token_url, data=data)
|
253
|
+
response.raise_for_status()
|
254
|
+
|
255
|
+
token_data = response.json()
|
256
|
+
self._access_token = token_data.get("access_token")
|
257
|
+
expires_in = token_data.get("expires_in", 3600) # 默认1小时
|
258
|
+
|
259
|
+
if not self._access_token:
|
260
|
+
raise ValueError("响应中缺少access_token字段")
|
261
|
+
|
262
|
+
# 计算过期时间
|
263
|
+
self._token_expires_at = time.time() + expires_in
|
264
|
+
logger.info(f"成功获取OAuth2令牌,有效期{expires_in}秒")
|
265
|
+
|
266
|
+
except Exception as e:
|
267
|
+
logger.error(f"获取OAuth2令牌失败: {str(e)}")
|
268
|
+
raise
|
269
|
+
|
270
|
+
|
271
|
+
class CustomAuthProvider(AuthProvider):
|
272
|
+
"""自定义认证提供者基类
|
273
|
+
|
274
|
+
用户可以通过继承此类并实现apply_auth方法来创建自定义认证提供者。
|
275
|
+
"""
|
276
|
+
def __init__(self):
|
277
|
+
"""初始化自定义认证提供者"""
|
278
|
+
pass
|
279
|
+
|
280
|
+
@abc.abstractmethod
|
281
|
+
def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
282
|
+
"""应用自定义认证
|
283
|
+
|
284
|
+
Args:
|
285
|
+
request_kwargs: 请求参数
|
286
|
+
|
287
|
+
Returns:
|
288
|
+
更新后的请求参数
|
289
|
+
"""
|
290
|
+
pass
|
291
|
+
|
292
|
+
|
293
|
+
# 认证提供者注册表
|
294
|
+
auth_provider_registry = {}
|
295
|
+
|
296
|
+
|
297
|
+
def register_auth_provider(name: str, provider_class: Type[AuthProvider], *args, **kwargs) -> None:
|
298
|
+
"""注册认证提供者
|
299
|
+
|
300
|
+
Args:
|
301
|
+
name: 提供者名称
|
302
|
+
provider_class: 提供者类,必须是 AuthProvider 的子类
|
303
|
+
*args: 传递给提供者类的初始化参数
|
304
|
+
**kwargs: 传递给提供者类的初始化关键字参数
|
305
|
+
"""
|
306
|
+
if not issubclass(provider_class, AuthProvider):
|
307
|
+
raise ValueError(f"Provider class must be a subclass of AuthProvider, got {provider_class.__name__}")
|
308
|
+
|
309
|
+
provider = provider_class(*args, **kwargs)
|
310
|
+
auth_provider_registry[name] = provider
|
311
|
+
logger.info(f"Registered auth provider '{name}' with class {provider_class.__name__}")
|
312
|
+
|
313
|
+
|
314
|
+
def get_auth_provider(name: str) -> Optional[AuthProvider]:
|
315
|
+
"""获取认证提供者
|
316
|
+
|
317
|
+
Args:
|
318
|
+
name: 提供者名称
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
认证提供者实例
|
322
|
+
"""
|
323
|
+
return auth_provider_registry.get(name)
|
324
|
+
|
325
|
+
|
326
|
+
def create_auth_provider(auth_config: Dict[str, Any]) -> Optional[AuthProvider]:
|
327
|
+
"""根据配置创建认证提供者
|
328
|
+
|
329
|
+
Args:
|
330
|
+
auth_config: 认证配置
|
331
|
+
|
332
|
+
Returns:
|
333
|
+
认证提供者实例
|
334
|
+
"""
|
335
|
+
auth_type = auth_config.get("type", "").lower()
|
336
|
+
|
337
|
+
if not auth_type:
|
338
|
+
return None
|
339
|
+
|
340
|
+
if auth_type == "basic":
|
341
|
+
username = auth_config.get("username")
|
342
|
+
password = auth_config.get("password")
|
343
|
+
|
344
|
+
if not username or not password:
|
345
|
+
logger.error("基本认证配置缺少username或password参数")
|
346
|
+
return None
|
347
|
+
|
348
|
+
return BasicAuthProvider(username, password)
|
349
|
+
|
350
|
+
elif auth_type == "token":
|
351
|
+
token = auth_config.get("token")
|
352
|
+
scheme = auth_config.get("scheme", "Bearer")
|
353
|
+
header = auth_config.get("header", "Authorization")
|
354
|
+
|
355
|
+
if not token:
|
356
|
+
logger.error("令牌认证配置缺少token参数")
|
357
|
+
return None
|
358
|
+
|
359
|
+
return TokenAuthProvider(token, scheme, header)
|
360
|
+
|
361
|
+
elif auth_type == "api_key":
|
362
|
+
api_key = auth_config.get("api_key")
|
363
|
+
key_name = auth_config.get("key_name", "X-API-Key")
|
364
|
+
in_header = auth_config.get("in_header", True)
|
365
|
+
in_query = auth_config.get("in_query", False)
|
366
|
+
query_param_name = auth_config.get("query_param_name")
|
367
|
+
|
368
|
+
if not api_key:
|
369
|
+
logger.error("API Key认证配置缺少api_key参数")
|
370
|
+
return None
|
371
|
+
|
372
|
+
return ApiKeyAuthProvider(
|
373
|
+
api_key=api_key,
|
374
|
+
key_name=key_name,
|
375
|
+
in_header=in_header,
|
376
|
+
in_query=in_query,
|
377
|
+
query_param_name=query_param_name
|
378
|
+
)
|
379
|
+
|
380
|
+
elif auth_type == "oauth2":
|
381
|
+
token_url = auth_config.get("token_url")
|
382
|
+
client_id = auth_config.get("client_id")
|
383
|
+
client_secret = auth_config.get("client_secret")
|
384
|
+
scope = auth_config.get("scope")
|
385
|
+
grant_type = auth_config.get("grant_type", "client_credentials")
|
386
|
+
username = auth_config.get("username")
|
387
|
+
password = auth_config.get("password")
|
388
|
+
token_refresh_window = auth_config.get("token_refresh_window", 60)
|
389
|
+
|
390
|
+
if not token_url or not client_id or not client_secret:
|
391
|
+
logger.error("OAuth2认证配置缺少必要参数")
|
392
|
+
return None
|
393
|
+
|
394
|
+
return OAuth2Provider(
|
395
|
+
token_url, client_id, client_secret, scope, grant_type,
|
396
|
+
username, password, token_refresh_window
|
397
|
+
)
|
398
|
+
|
399
|
+
elif auth_type == "custom":
|
400
|
+
provider_name = auth_config.get("provider_name")
|
401
|
+
if provider_name and provider_name in auth_provider_registry:
|
402
|
+
return auth_provider_registry[provider_name]
|
403
|
+
else:
|
404
|
+
logger.error(f"未找到名为'{provider_name}'的自定义认证提供者")
|
405
|
+
return None
|
406
|
+
|
407
|
+
else:
|
408
|
+
logger.error(f"不支持的认证类型: {auth_type}")
|
409
|
+
return None
|
@@ -0,0 +1,181 @@
|
|
1
|
+
"""自动测试装饰器模块
|
2
|
+
|
3
|
+
该模块提供装饰器功能,用于将指定目录下的.auto文件动态添加为测试方法到被装饰的类中。
|
4
|
+
这种方式更贴合pytest的设计理念,可以充分利用pytest的fixture、参数化等功能。
|
5
|
+
"""
|
6
|
+
|
7
|
+
import os
|
8
|
+
import inspect
|
9
|
+
import functools
|
10
|
+
from pathlib import Path
|
11
|
+
import pytest
|
12
|
+
from typing import Optional, Union, List, Dict, Any, Callable, Type
|
13
|
+
|
14
|
+
from pytest_dsl.core.dsl_executor import DSLExecutor
|
15
|
+
from pytest_dsl.core.dsl_executor_utils import read_file, execute_dsl_file, extract_metadata_from_ast
|
16
|
+
from pytest_dsl.core.lexer import get_lexer
|
17
|
+
from pytest_dsl.core.parser import get_parser
|
18
|
+
from pytest_dsl.core.auto_directory import SETUP_FILE_NAME, TEARDOWN_FILE_NAME, execute_hook_file
|
19
|
+
|
20
|
+
# 获取词法分析器和解析器实例
|
21
|
+
lexer = get_lexer()
|
22
|
+
parser = get_parser()
|
23
|
+
|
24
|
+
|
25
|
+
def auto_dsl(directory: Union[str, Path], is_file: bool = False):
|
26
|
+
"""
|
27
|
+
装饰器函数,用于将指定目录下的.auto文件动态添加为测试方法到被装饰的类中。
|
28
|
+
|
29
|
+
Args:
|
30
|
+
directory: 包含.auto文件的目录路径,可以是相对路径或绝对路径
|
31
|
+
is_file: 是否是文件路径而不是目录路径
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
装饰器函数
|
35
|
+
"""
|
36
|
+
path = Path(directory)
|
37
|
+
if not path.is_absolute():
|
38
|
+
# 如果是相对路径,则相对于调用者的文件位置
|
39
|
+
caller_frame = inspect.currentframe().f_back
|
40
|
+
caller_file = caller_frame.f_globals['__file__']
|
41
|
+
caller_dir = Path(caller_file).parent
|
42
|
+
path = (caller_dir / path).resolve()
|
43
|
+
|
44
|
+
if is_file:
|
45
|
+
# 路径是文件
|
46
|
+
if not path.exists() or not path.is_file():
|
47
|
+
raise ValueError(f"文件不存在或不是有效文件: {path}")
|
48
|
+
file_path = path
|
49
|
+
else:
|
50
|
+
# 路径是目录
|
51
|
+
if not path.exists() or not path.is_dir():
|
52
|
+
raise ValueError(f"目录不存在或不是有效目录: {path}")
|
53
|
+
directory_path = path
|
54
|
+
|
55
|
+
def decorator(cls):
|
56
|
+
if is_file:
|
57
|
+
# 如果是文件路径,只添加这个文件的测试方法
|
58
|
+
_add_test_method(cls, file_path)
|
59
|
+
else:
|
60
|
+
# 检查setup.auto和teardown.auto文件
|
61
|
+
setup_file = directory_path / SETUP_FILE_NAME
|
62
|
+
teardown_file = directory_path / TEARDOWN_FILE_NAME
|
63
|
+
|
64
|
+
# 添加setup和teardown方法
|
65
|
+
if setup_file.exists():
|
66
|
+
@classmethod
|
67
|
+
@pytest.fixture(scope="class", autouse=True)
|
68
|
+
def setup_class(cls, request):
|
69
|
+
execute_hook_file(setup_file, True, str(directory_path))
|
70
|
+
|
71
|
+
setattr(cls, "setup_class", setup_class)
|
72
|
+
|
73
|
+
if teardown_file.exists():
|
74
|
+
@classmethod
|
75
|
+
@pytest.fixture(scope="class", autouse=True)
|
76
|
+
def teardown_class(cls, request):
|
77
|
+
request.addfinalizer(lambda: execute_hook_file(teardown_file, False, str(directory_path)))
|
78
|
+
|
79
|
+
setattr(cls, "teardown_class", teardown_class)
|
80
|
+
|
81
|
+
# 处理目录中的测试文件
|
82
|
+
for auto_file in directory_path.glob("*.auto"):
|
83
|
+
if auto_file.name not in [SETUP_FILE_NAME, TEARDOWN_FILE_NAME]:
|
84
|
+
_add_test_method(cls, auto_file)
|
85
|
+
|
86
|
+
return cls
|
87
|
+
|
88
|
+
return decorator
|
89
|
+
|
90
|
+
|
91
|
+
def _add_test_method(cls: Type, auto_file: Path) -> None:
|
92
|
+
"""
|
93
|
+
为.auto文件创建测试方法并添加到类中
|
94
|
+
|
95
|
+
Args:
|
96
|
+
cls: 要添加测试方法的类
|
97
|
+
auto_file: .auto文件路径
|
98
|
+
"""
|
99
|
+
test_name = f"test_{auto_file.stem}"
|
100
|
+
|
101
|
+
# 读取DSL文件内容并解析
|
102
|
+
dsl_code = read_file(str(auto_file))
|
103
|
+
ast = parser.parse(dsl_code, lexer=lexer)
|
104
|
+
|
105
|
+
# 检查是否有数据驱动标记和测试名称
|
106
|
+
data_source, test_title = extract_metadata_from_ast(ast)
|
107
|
+
|
108
|
+
if data_source:
|
109
|
+
test_method = _create_data_driven_test(auto_file, data_source, test_title)
|
110
|
+
else:
|
111
|
+
test_method = _create_simple_test(auto_file)
|
112
|
+
|
113
|
+
# 将测试方法添加到类
|
114
|
+
setattr(cls, test_name, test_method)
|
115
|
+
|
116
|
+
|
117
|
+
def _create_simple_test(auto_file: Path) -> Callable:
|
118
|
+
"""
|
119
|
+
创建普通的测试方法
|
120
|
+
|
121
|
+
Args:
|
122
|
+
auto_file: .auto文件路径
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
function: 测试方法
|
126
|
+
"""
|
127
|
+
def test_method(self):
|
128
|
+
execute_dsl_file(str(auto_file))
|
129
|
+
|
130
|
+
return test_method
|
131
|
+
|
132
|
+
|
133
|
+
def _create_data_driven_test(auto_file: Path, data_source: Dict, test_title: Optional[str]) -> Callable:
|
134
|
+
"""
|
135
|
+
创建数据驱动的测试方法
|
136
|
+
|
137
|
+
Args:
|
138
|
+
auto_file: .auto文件路径
|
139
|
+
data_source: 数据源
|
140
|
+
test_title: 测试标题
|
141
|
+
|
142
|
+
Returns:
|
143
|
+
function: 装饰后的测试方法
|
144
|
+
"""
|
145
|
+
def test_method(self, test_data):
|
146
|
+
executor = DSLExecutor()
|
147
|
+
executor.set_current_data(test_data)
|
148
|
+
execute_dsl_file(str(auto_file), executor)
|
149
|
+
|
150
|
+
# 加载测试数据
|
151
|
+
executor = DSLExecutor()
|
152
|
+
test_data_list = executor._load_test_data(data_source)
|
153
|
+
|
154
|
+
# 为每个数据集创建一个唯一的ID
|
155
|
+
test_ids = _generate_test_ids(test_data_list, test_title or auto_file.stem)
|
156
|
+
|
157
|
+
# 使用pytest.mark.parametrize装饰测试方法
|
158
|
+
return pytest.mark.parametrize(
|
159
|
+
'test_data',
|
160
|
+
test_data_list,
|
161
|
+
ids=test_ids
|
162
|
+
)(test_method)
|
163
|
+
|
164
|
+
|
165
|
+
def _generate_test_ids(test_data_list: List[Dict[str, Any]], base_name: str) -> List[str]:
|
166
|
+
"""
|
167
|
+
为数据驱动测试生成ID
|
168
|
+
|
169
|
+
Args:
|
170
|
+
test_data_list: 测试数据列表
|
171
|
+
base_name: 基础名称
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
List[str]: 测试ID列表
|
175
|
+
"""
|
176
|
+
test_ids = []
|
177
|
+
for data in test_data_list:
|
178
|
+
# 创建一个可读的测试ID
|
179
|
+
test_id = f"{base_name}-{'-'.join(f'{k}={v}' for k, v in data.items())}"
|
180
|
+
test_ids.append(test_id)
|
181
|
+
return test_ids
|