pytest-dsl 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytest_dsl/core/__init__.py +7 -0
- pytest_dsl/core/auth_provider.py +50 -10
- pytest_dsl/core/custom_keyword_manager.py +213 -0
- pytest_dsl/core/dsl_executor.py +39 -2
- pytest_dsl/core/http_client.py +11 -6
- pytest_dsl/core/http_request.py +517 -119
- pytest_dsl/core/lexer.py +14 -1
- pytest_dsl/core/parser.py +45 -2
- pytest_dsl/core/parsetab.py +50 -38
- pytest_dsl/core/variable_utils.py +1 -1
- pytest_dsl/examples/custom/test_advanced_keywords.auto +31 -0
- pytest_dsl/examples/custom/test_custom_keywords.auto +37 -0
- pytest_dsl/examples/custom/test_default_values.auto +34 -0
- pytest_dsl/examples/http/http_retry_assertions.auto +2 -2
- pytest_dsl/examples/http/http_retry_assertions_enhanced.auto +2 -2
- pytest_dsl/examples/quickstart/api_basics.auto +55 -0
- pytest_dsl/examples/quickstart/assertions.auto +31 -0
- pytest_dsl/examples/quickstart/loops.auto +24 -0
- pytest_dsl/examples/test_custom_keyword.py +9 -0
- pytest_dsl/examples/test_http.py +0 -139
- pytest_dsl/examples/test_quickstart.py +14 -0
- pytest_dsl/keywords/http_keywords.py +290 -102
- pytest_dsl/parsetab.py +69 -0
- pytest_dsl-0.2.0.dist-info/METADATA +504 -0
- {pytest_dsl-0.1.0.dist-info → pytest_dsl-0.2.0.dist-info}/RECORD +29 -24
- {pytest_dsl-0.1.0.dist-info → pytest_dsl-0.2.0.dist-info}/WHEEL +1 -1
- pytest_dsl/core/custom_auth_example.py +0 -425
- pytest_dsl/examples/http/csrf_auth_test.auto +0 -64
- pytest_dsl/examples/http/custom_auth_test.auto +0 -76
- pytest_dsl/examples/http_clients.yaml +0 -48
- pytest_dsl/examples/keyword_example.py +0 -70
- pytest_dsl-0.1.0.dist-info/METADATA +0 -537
- {pytest_dsl-0.1.0.dist-info → pytest_dsl-0.2.0.dist-info}/entry_points.txt +0 -0
- {pytest_dsl-0.1.0.dist-info → pytest_dsl-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {pytest_dsl-0.1.0.dist-info → pytest_dsl-0.2.0.dist-info}/top_level.txt +0 -0
pytest_dsl/core/__init__.py
CHANGED
@@ -0,0 +1,7 @@
|
|
1
|
+
# 导入关键模块确保它们被初始化
|
2
|
+
from pytest_dsl.core.keyword_manager import keyword_manager
|
3
|
+
from pytest_dsl.core.global_context import global_context
|
4
|
+
from pytest_dsl.core.yaml_vars import yaml_vars
|
5
|
+
|
6
|
+
# 导入自定义关键字管理器
|
7
|
+
from pytest_dsl.core.custom_keyword_manager import custom_keyword_manager
|
pytest_dsl/core/auth_provider.py
CHANGED
@@ -19,7 +19,7 @@ class AuthProvider(abc.ABC):
|
|
19
19
|
"""认证提供者基类"""
|
20
20
|
|
21
21
|
@abc.abstractmethod
|
22
|
-
def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
22
|
+
def apply_auth(self, base_url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
23
23
|
"""将认证信息应用到请求参数
|
24
24
|
|
25
25
|
Args:
|
@@ -57,6 +57,46 @@ class AuthProvider(abc.ABC):
|
|
57
57
|
request_kwargs.pop('auth', None)
|
58
58
|
|
59
59
|
return request_kwargs
|
60
|
+
|
61
|
+
def process_response(self, response: requests.Response) -> None:
|
62
|
+
"""处理响应以更新认证状态
|
63
|
+
|
64
|
+
此方法允许认证提供者在响应返回后处理响应数据,例如从响应中提取
|
65
|
+
CSRF令牌、刷新令牌或其他认证信息,并更新内部状态用于后续请求。
|
66
|
+
|
67
|
+
Args:
|
68
|
+
response: 请求响应对象
|
69
|
+
"""
|
70
|
+
# 默认实现:不做任何处理
|
71
|
+
pass
|
72
|
+
|
73
|
+
def pre_request_hook(self, method: str, url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
74
|
+
"""请求发送前的钩子
|
75
|
+
|
76
|
+
此方法在请求被发送前调用,允许执行额外的请求预处理。
|
77
|
+
|
78
|
+
Args:
|
79
|
+
method: HTTP方法
|
80
|
+
url: 请求URL
|
81
|
+
request_kwargs: 请求参数字典
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
更新后的请求参数字典
|
85
|
+
"""
|
86
|
+
# 默认实现:不做任何预处理
|
87
|
+
return request_kwargs
|
88
|
+
|
89
|
+
def post_response_hook(self, response: requests.Response, request_kwargs: Dict[str, Any]) -> None:
|
90
|
+
"""响应接收后的钩子
|
91
|
+
|
92
|
+
此方法在响应被接收后调用,允许执行额外的响应后处理。
|
93
|
+
|
94
|
+
Args:
|
95
|
+
response: 响应对象
|
96
|
+
request_kwargs: 原始请求参数
|
97
|
+
"""
|
98
|
+
# 调用process_response以保持向后兼容
|
99
|
+
self.process_response(response)
|
60
100
|
|
61
101
|
@property
|
62
102
|
def name(self) -> str:
|
@@ -77,7 +117,7 @@ class BasicAuthProvider(AuthProvider):
|
|
77
117
|
self.username = username
|
78
118
|
self.password = password
|
79
119
|
|
80
|
-
def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
120
|
+
def apply_auth(self, base_url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
81
121
|
"""应用基本认证
|
82
122
|
|
83
123
|
Args:
|
@@ -106,7 +146,7 @@ class TokenAuthProvider(AuthProvider):
|
|
106
146
|
self.scheme = scheme
|
107
147
|
self.header = header
|
108
148
|
|
109
|
-
def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
149
|
+
def apply_auth(self, base_url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
110
150
|
"""应用令牌认证
|
111
151
|
|
112
152
|
Args:
|
@@ -148,7 +188,7 @@ class ApiKeyAuthProvider(AuthProvider):
|
|
148
188
|
self.in_query = in_query
|
149
189
|
self.query_param_name = query_param_name or key_name
|
150
190
|
|
151
|
-
def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
191
|
+
def apply_auth(self, base_url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
152
192
|
"""应用API Key认证
|
153
193
|
|
154
194
|
Args:
|
@@ -204,7 +244,7 @@ class OAuth2Provider(AuthProvider):
|
|
204
244
|
self._access_token = None
|
205
245
|
self._token_expires_at = 0
|
206
246
|
|
207
|
-
def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
247
|
+
def apply_auth(self, base_url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
208
248
|
"""应用OAuth2认证
|
209
249
|
|
210
250
|
Args:
|
@@ -272,13 +312,14 @@ class CustomAuthProvider(AuthProvider):
|
|
272
312
|
"""自定义认证提供者基类
|
273
313
|
|
274
314
|
用户可以通过继承此类并实现apply_auth方法来创建自定义认证提供者。
|
315
|
+
此外,还可以实现process_response方法来处理响应数据,例如提取CSRF令牌。
|
275
316
|
"""
|
276
317
|
def __init__(self):
|
277
318
|
"""初始化自定义认证提供者"""
|
278
319
|
pass
|
279
320
|
|
280
321
|
@abc.abstractmethod
|
281
|
-
def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
322
|
+
def apply_auth(self, base_url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
282
323
|
"""应用自定义认证
|
283
324
|
|
284
325
|
Args:
|
@@ -294,7 +335,7 @@ class CustomAuthProvider(AuthProvider):
|
|
294
335
|
auth_provider_registry = {}
|
295
336
|
|
296
337
|
|
297
|
-
def register_auth_provider(name: str, provider_class: Type[AuthProvider]
|
338
|
+
def register_auth_provider(name: str, provider_class: Type[AuthProvider]) -> None:
|
298
339
|
"""注册认证提供者
|
299
340
|
|
300
341
|
Args:
|
@@ -306,8 +347,7 @@ def register_auth_provider(name: str, provider_class: Type[AuthProvider], *args,
|
|
306
347
|
if not issubclass(provider_class, AuthProvider):
|
307
348
|
raise ValueError(f"Provider class must be a subclass of AuthProvider, got {provider_class.__name__}")
|
308
349
|
|
309
|
-
|
310
|
-
auth_provider_registry[name] = provider
|
350
|
+
auth_provider_registry[name] = provider_class
|
311
351
|
logger.info(f"Registered auth provider '{name}' with class {provider_class.__name__}")
|
312
352
|
|
313
353
|
|
@@ -399,7 +439,7 @@ def create_auth_provider(auth_config: Dict[str, Any]) -> Optional[AuthProvider]:
|
|
399
439
|
elif auth_type == "custom":
|
400
440
|
provider_name = auth_config.get("provider_name")
|
401
441
|
if provider_name and provider_name in auth_provider_registry:
|
402
|
-
return auth_provider_registry[provider_name]
|
442
|
+
return auth_provider_registry[provider_name](**auth_config)
|
403
443
|
else:
|
404
444
|
logger.error(f"未找到名为'{provider_name}'的自定义认证提供者")
|
405
445
|
return None
|
@@ -0,0 +1,213 @@
|
|
1
|
+
from typing import Dict, Any, List, Optional
|
2
|
+
import os
|
3
|
+
import pathlib
|
4
|
+
from pytest_dsl.core.lexer import get_lexer
|
5
|
+
from pytest_dsl.core.parser import get_parser, Node
|
6
|
+
from pytest_dsl.core.dsl_executor import DSLExecutor
|
7
|
+
from pytest_dsl.core.keyword_manager import keyword_manager
|
8
|
+
from pytest_dsl.core.context import TestContext
|
9
|
+
|
10
|
+
|
11
|
+
class CustomKeywordManager:
|
12
|
+
"""自定义关键字管理器
|
13
|
+
|
14
|
+
负责加载和注册自定义关键字
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self):
|
18
|
+
"""初始化自定义关键字管理器"""
|
19
|
+
self.resource_cache = {} # 缓存已加载的资源文件
|
20
|
+
self.resource_paths = [] # 资源文件搜索路径
|
21
|
+
|
22
|
+
def add_resource_path(self, path: str) -> None:
|
23
|
+
"""添加资源文件搜索路径
|
24
|
+
|
25
|
+
Args:
|
26
|
+
path: 资源文件路径
|
27
|
+
"""
|
28
|
+
if path not in self.resource_paths:
|
29
|
+
self.resource_paths.append(path)
|
30
|
+
|
31
|
+
def load_resource_file(self, file_path: str) -> None:
|
32
|
+
"""加载资源文件
|
33
|
+
|
34
|
+
Args:
|
35
|
+
file_path: 资源文件路径
|
36
|
+
"""
|
37
|
+
# 规范化路径,解决路径叠加的问题
|
38
|
+
file_path = os.path.normpath(file_path)
|
39
|
+
|
40
|
+
# 如果已经缓存,则跳过
|
41
|
+
absolute_path = os.path.abspath(file_path)
|
42
|
+
if absolute_path in self.resource_cache:
|
43
|
+
return
|
44
|
+
|
45
|
+
# 读取文件内容
|
46
|
+
if not os.path.exists(file_path):
|
47
|
+
# 尝试在资源路径中查找
|
48
|
+
for resource_path in self.resource_paths:
|
49
|
+
full_path = os.path.join(resource_path, file_path)
|
50
|
+
if os.path.exists(full_path):
|
51
|
+
file_path = full_path
|
52
|
+
absolute_path = os.path.abspath(file_path)
|
53
|
+
break
|
54
|
+
else:
|
55
|
+
# 如果文件不存在,尝试在根项目目录中查找
|
56
|
+
# 一般情况下文件路径可能是相对于项目根目录的
|
57
|
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
58
|
+
full_path = os.path.join(project_root, file_path)
|
59
|
+
if os.path.exists(full_path):
|
60
|
+
file_path = full_path
|
61
|
+
absolute_path = os.path.abspath(file_path)
|
62
|
+
else:
|
63
|
+
raise FileNotFoundError(f"资源文件不存在: {file_path}")
|
64
|
+
|
65
|
+
try:
|
66
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
67
|
+
content = f.read()
|
68
|
+
|
69
|
+
# 解析资源文件
|
70
|
+
lexer = get_lexer()
|
71
|
+
parser = get_parser()
|
72
|
+
ast = parser.parse(content, lexer=lexer)
|
73
|
+
|
74
|
+
# 标记为已加载
|
75
|
+
self.resource_cache[absolute_path] = True
|
76
|
+
|
77
|
+
# 处理导入指令
|
78
|
+
self._process_imports(ast, os.path.dirname(file_path))
|
79
|
+
|
80
|
+
# 注册关键字
|
81
|
+
self._register_keywords(ast, file_path)
|
82
|
+
except Exception as e:
|
83
|
+
print(f"资源文件 {file_path} 加载失败: {str(e)}")
|
84
|
+
raise
|
85
|
+
|
86
|
+
def _process_imports(self, ast: Node, base_dir: str) -> None:
|
87
|
+
"""处理资源文件中的导入指令
|
88
|
+
|
89
|
+
Args:
|
90
|
+
ast: 抽象语法树
|
91
|
+
base_dir: 基础目录
|
92
|
+
"""
|
93
|
+
if ast.type != 'Start' or not ast.children:
|
94
|
+
return
|
95
|
+
|
96
|
+
metadata_node = ast.children[0]
|
97
|
+
if metadata_node.type != 'Metadata':
|
98
|
+
return
|
99
|
+
|
100
|
+
for item in metadata_node.children:
|
101
|
+
if item.type == '@import':
|
102
|
+
imported_file = item.value
|
103
|
+
# 处理相对路径
|
104
|
+
if not os.path.isabs(imported_file):
|
105
|
+
imported_file = os.path.join(base_dir, imported_file)
|
106
|
+
|
107
|
+
# 规范化路径,避免路径叠加问题
|
108
|
+
imported_file = os.path.normpath(imported_file)
|
109
|
+
|
110
|
+
# 递归加载导入的资源文件
|
111
|
+
self.load_resource_file(imported_file)
|
112
|
+
|
113
|
+
def _register_keywords(self, ast: Node, file_path: str) -> None:
|
114
|
+
"""从AST中注册关键字
|
115
|
+
|
116
|
+
Args:
|
117
|
+
ast: 抽象语法树
|
118
|
+
file_path: 文件路径
|
119
|
+
"""
|
120
|
+
if ast.type != 'Start' or len(ast.children) < 2:
|
121
|
+
return
|
122
|
+
|
123
|
+
# 遍历语句节点
|
124
|
+
statements_node = ast.children[1]
|
125
|
+
if statements_node.type != 'Statements':
|
126
|
+
return
|
127
|
+
|
128
|
+
for node in statements_node.children:
|
129
|
+
if node.type == 'CustomKeyword':
|
130
|
+
self._register_custom_keyword(node, file_path)
|
131
|
+
|
132
|
+
def _register_custom_keyword(self, node: Node, file_path: str) -> None:
|
133
|
+
"""注册自定义关键字
|
134
|
+
|
135
|
+
Args:
|
136
|
+
node: 关键字节点
|
137
|
+
file_path: 资源文件路径
|
138
|
+
"""
|
139
|
+
# 提取关键字信息
|
140
|
+
keyword_name = node.value
|
141
|
+
params_node = node.children[0]
|
142
|
+
body_node = node.children[1]
|
143
|
+
|
144
|
+
# 构建参数列表
|
145
|
+
parameters = []
|
146
|
+
param_mapping = {}
|
147
|
+
param_defaults = {} # 存储参数默认值
|
148
|
+
|
149
|
+
for param in params_node if params_node else []:
|
150
|
+
param_name = param.value
|
151
|
+
param_default = None
|
152
|
+
|
153
|
+
# 检查是否有默认值
|
154
|
+
if param.children and param.children[0]:
|
155
|
+
param_default = param.children[0].value
|
156
|
+
param_defaults[param_name] = param_default # 保存默认值
|
157
|
+
|
158
|
+
# 添加参数定义
|
159
|
+
parameters.append({
|
160
|
+
'name': param_name,
|
161
|
+
'mapping': param_name, # 中文参数名和内部参数名相同
|
162
|
+
'description': f'自定义关键字参数 {param_name}'
|
163
|
+
})
|
164
|
+
|
165
|
+
param_mapping[param_name] = param_name
|
166
|
+
|
167
|
+
# 注册自定义关键字到关键字管理器
|
168
|
+
@keyword_manager.register(keyword_name, parameters)
|
169
|
+
def custom_keyword_executor(**kwargs):
|
170
|
+
"""自定义关键字执行器"""
|
171
|
+
# 创建一个新的DSL执行器
|
172
|
+
executor = DSLExecutor()
|
173
|
+
|
174
|
+
# 获取传递的上下文
|
175
|
+
context = kwargs.get('context')
|
176
|
+
if context:
|
177
|
+
executor.test_context = context
|
178
|
+
|
179
|
+
# 先应用默认值
|
180
|
+
for param_name, default_value in param_defaults.items():
|
181
|
+
executor.variables[param_name] = default_value
|
182
|
+
executor.test_context.set(param_name, default_value)
|
183
|
+
|
184
|
+
# 然后应用传入的参数值(覆盖默认值)
|
185
|
+
for param_name, param_mapping_name in param_mapping.items():
|
186
|
+
if param_mapping_name in kwargs:
|
187
|
+
# 确保参数值在标准变量和测试上下文中都可用
|
188
|
+
executor.variables[param_name] = kwargs[param_mapping_name]
|
189
|
+
executor.test_context.set(param_name, kwargs[param_mapping_name])
|
190
|
+
|
191
|
+
# 执行关键字体中的语句
|
192
|
+
result = None
|
193
|
+
try:
|
194
|
+
for stmt in body_node.children:
|
195
|
+
# 检查是否是return语句
|
196
|
+
if stmt.type == 'Return':
|
197
|
+
# 对表达式求值
|
198
|
+
result = executor.eval_expression(stmt.children[0])
|
199
|
+
break
|
200
|
+
else:
|
201
|
+
# 执行普通语句
|
202
|
+
executor.execute(stmt)
|
203
|
+
except Exception as e:
|
204
|
+
print(f"执行自定义关键字 {keyword_name} 时发生错误: {str(e)}")
|
205
|
+
raise
|
206
|
+
|
207
|
+
return result
|
208
|
+
|
209
|
+
print(f"已注册自定义关键字: {keyword_name} 来自文件: {file_path}")
|
210
|
+
|
211
|
+
|
212
|
+
# 创建全局自定义关键字管理器实例
|
213
|
+
custom_keyword_manager = CustomKeywordManager()
|
pytest_dsl/core/dsl_executor.py
CHANGED
@@ -26,6 +26,7 @@ class DSLExecutor:
|
|
26
26
|
self.test_context = TestContext()
|
27
27
|
self.test_context.executor = self # 让 test_context 能够访问到 executor
|
28
28
|
self.variable_replacer = VariableReplacer(self.variables, self.test_context)
|
29
|
+
self.imported_files = set() # 跟踪已导入的文件,避免循环导入
|
29
30
|
|
30
31
|
def set_current_data(self, data):
|
31
32
|
"""设置当前测试数据集"""
|
@@ -101,7 +102,7 @@ class DSLExecutor:
|
|
101
102
|
return self.eval_expression(value)
|
102
103
|
elif isinstance(value, str):
|
103
104
|
# 定义变量引用模式
|
104
|
-
pattern = r'\$\{([a-zA-Z_][a-zA-Z0-9_]*)\}'
|
105
|
+
pattern = r'\$\{([a-zA-Z_\u4e00-\u9fa5][a-zA-Z0-9_\u4e00-\u9fa5]*)\}'
|
105
106
|
# 检查整个字符串是否完全匹配单一变量引用模式
|
106
107
|
match = re.fullmatch(pattern, value)
|
107
108
|
if match:
|
@@ -133,6 +134,9 @@ class DSLExecutor:
|
|
133
134
|
if child.type == 'Metadata':
|
134
135
|
for item in child.children:
|
135
136
|
metadata[item.type] = item.value
|
137
|
+
# 处理导入指令
|
138
|
+
if item.type == '@import':
|
139
|
+
self._handle_import(item.value)
|
136
140
|
elif child.type == 'Teardown':
|
137
141
|
teardown_node = child
|
138
142
|
|
@@ -154,6 +158,25 @@ class DSLExecutor:
|
|
154
158
|
# 测试用例执行完成后清空上下文
|
155
159
|
self.test_context.clear()
|
156
160
|
|
161
|
+
def _handle_import(self, file_path):
|
162
|
+
"""处理导入指令
|
163
|
+
|
164
|
+
Args:
|
165
|
+
file_path: 资源文件路径
|
166
|
+
"""
|
167
|
+
# 防止循环导入
|
168
|
+
if file_path in self.imported_files:
|
169
|
+
return
|
170
|
+
|
171
|
+
try:
|
172
|
+
# 导入自定义关键字文件
|
173
|
+
from pytest_dsl.core.custom_keyword_manager import custom_keyword_manager
|
174
|
+
custom_keyword_manager.load_resource_file(file_path)
|
175
|
+
self.imported_files.add(file_path)
|
176
|
+
except Exception as e:
|
177
|
+
print(f"导入资源文件失败: {file_path}, 错误: {str(e)}")
|
178
|
+
raise
|
179
|
+
|
157
180
|
def _execute_test_iteration(self, metadata, node, teardown_node):
|
158
181
|
"""执行测试迭代"""
|
159
182
|
try:
|
@@ -305,6 +328,19 @@ class DSLExecutor:
|
|
305
328
|
"""处理清理操作"""
|
306
329
|
self.execute(node.children[0])
|
307
330
|
|
331
|
+
@allure.step("执行返回语句")
|
332
|
+
def _handle_return(self, node):
|
333
|
+
"""处理return语句
|
334
|
+
|
335
|
+
Args:
|
336
|
+
node: Return节点
|
337
|
+
|
338
|
+
Returns:
|
339
|
+
表达式求值结果
|
340
|
+
"""
|
341
|
+
expr_node = node.children[0]
|
342
|
+
return self.eval_expression(expr_node)
|
343
|
+
|
308
344
|
def execute(self, node):
|
309
345
|
"""执行AST节点"""
|
310
346
|
handlers = {
|
@@ -315,7 +351,8 @@ class DSLExecutor:
|
|
315
351
|
'AssignmentKeywordCall': self._handle_assignment_keyword_call,
|
316
352
|
'ForLoop': self._handle_for_loop,
|
317
353
|
'KeywordCall': self._execute_keyword_call,
|
318
|
-
'Teardown': self._handle_teardown
|
354
|
+
'Teardown': self._handle_teardown,
|
355
|
+
'Return': self._handle_return
|
319
356
|
}
|
320
357
|
|
321
358
|
handler = handlers.get(node.type)
|
pytest_dsl/core/http_client.py
CHANGED
@@ -1,13 +1,10 @@
|
|
1
|
-
import os
|
2
1
|
import json
|
3
|
-
import time
|
4
2
|
import logging
|
5
|
-
from typing import Dict,
|
3
|
+
from typing import Dict, Any
|
6
4
|
import requests
|
7
|
-
from requests.exceptions import RequestException
|
8
5
|
from urllib.parse import urljoin
|
9
6
|
from pytest_dsl.core.yaml_vars import yaml_vars
|
10
|
-
from pytest_dsl.core.auth_provider import
|
7
|
+
from pytest_dsl.core.auth_provider import create_auth_provider
|
11
8
|
|
12
9
|
logger = logging.getLogger(__name__)
|
13
10
|
|
@@ -125,11 +122,15 @@ class HTTPClient:
|
|
125
122
|
|
126
123
|
elif self.auth_provider and 'auth' not in request_kwargs:
|
127
124
|
# 应用认证提供者
|
128
|
-
request_kwargs = self.auth_provider.apply_auth(request_kwargs)
|
125
|
+
request_kwargs = self.auth_provider.apply_auth(self.base_url, request_kwargs)
|
129
126
|
# 如果使用会话,更新会话头
|
130
127
|
if self.use_session and 'headers' in request_kwargs:
|
131
128
|
self._session.headers.update(request_kwargs['headers'])
|
132
129
|
|
130
|
+
# 调用认证提供者的请求前钩子
|
131
|
+
if self.auth_provider and not disable_auth:
|
132
|
+
request_kwargs = self.auth_provider.pre_request_hook(method, url, request_kwargs)
|
133
|
+
|
133
134
|
# 记录请求详情
|
134
135
|
logger.debug(f"=== HTTP请求详情 ===")
|
135
136
|
logger.debug(f"方法: {method}")
|
@@ -186,6 +187,10 @@ class HTTPClient:
|
|
186
187
|
if not hasattr(response, 'elapsed_ms'):
|
187
188
|
response.elapsed_ms = response.elapsed.total_seconds() * 1000
|
188
189
|
|
190
|
+
# 调用认证提供者的响应处理钩子
|
191
|
+
if self.auth_provider and not disable_auth:
|
192
|
+
self.auth_provider.post_response_hook(response, request_kwargs)
|
193
|
+
|
189
194
|
return response
|
190
195
|
except requests.exceptions.RequestException as e:
|
191
196
|
logger.error(f"HTTP请求异常: {type(e).__name__}: {str(e)}")
|