pytest-dsl 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,7 +19,7 @@ class AuthProvider(abc.ABC):
19
19
  """认证提供者基类"""
20
20
 
21
21
  @abc.abstractmethod
22
- def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
22
+ def apply_auth(self, base_url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
23
23
  """将认证信息应用到请求参数
24
24
 
25
25
  Args:
@@ -57,6 +57,46 @@ class AuthProvider(abc.ABC):
57
57
  request_kwargs.pop('auth', None)
58
58
 
59
59
  return request_kwargs
60
+
61
+ def process_response(self, response: requests.Response) -> None:
62
+ """处理响应以更新认证状态
63
+
64
+ 此方法允许认证提供者在响应返回后处理响应数据,例如从响应中提取
65
+ CSRF令牌、刷新令牌或其他认证信息,并更新内部状态用于后续请求。
66
+
67
+ Args:
68
+ response: 请求响应对象
69
+ """
70
+ # 默认实现:不做任何处理
71
+ pass
72
+
73
+ def pre_request_hook(self, method: str, url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
74
+ """请求发送前的钩子
75
+
76
+ 此方法在请求被发送前调用,允许执行额外的请求预处理。
77
+
78
+ Args:
79
+ method: HTTP方法
80
+ url: 请求URL
81
+ request_kwargs: 请求参数字典
82
+
83
+ Returns:
84
+ 更新后的请求参数字典
85
+ """
86
+ # 默认实现:不做任何预处理
87
+ return request_kwargs
88
+
89
+ def post_response_hook(self, response: requests.Response, request_kwargs: Dict[str, Any]) -> None:
90
+ """响应接收后的钩子
91
+
92
+ 此方法在响应被接收后调用,允许执行额外的响应后处理。
93
+
94
+ Args:
95
+ response: 响应对象
96
+ request_kwargs: 原始请求参数
97
+ """
98
+ # 调用process_response以保持向后兼容
99
+ self.process_response(response)
60
100
 
61
101
  @property
62
102
  def name(self) -> str:
@@ -77,7 +117,7 @@ class BasicAuthProvider(AuthProvider):
77
117
  self.username = username
78
118
  self.password = password
79
119
 
80
- def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
120
+ def apply_auth(self, base_url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
81
121
  """应用基本认证
82
122
 
83
123
  Args:
@@ -106,7 +146,7 @@ class TokenAuthProvider(AuthProvider):
106
146
  self.scheme = scheme
107
147
  self.header = header
108
148
 
109
- def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
149
+ def apply_auth(self, base_url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
110
150
  """应用令牌认证
111
151
 
112
152
  Args:
@@ -148,7 +188,7 @@ class ApiKeyAuthProvider(AuthProvider):
148
188
  self.in_query = in_query
149
189
  self.query_param_name = query_param_name or key_name
150
190
 
151
- def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
191
+ def apply_auth(self, base_url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
152
192
  """应用API Key认证
153
193
 
154
194
  Args:
@@ -204,7 +244,7 @@ class OAuth2Provider(AuthProvider):
204
244
  self._access_token = None
205
245
  self._token_expires_at = 0
206
246
 
207
- def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
247
+ def apply_auth(self, base_url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
208
248
  """应用OAuth2认证
209
249
 
210
250
  Args:
@@ -272,13 +312,14 @@ class CustomAuthProvider(AuthProvider):
272
312
  """自定义认证提供者基类
273
313
 
274
314
  用户可以通过继承此类并实现apply_auth方法来创建自定义认证提供者。
315
+ 此外,还可以实现process_response方法来处理响应数据,例如提取CSRF令牌。
275
316
  """
276
317
  def __init__(self):
277
318
  """初始化自定义认证提供者"""
278
319
  pass
279
320
 
280
321
  @abc.abstractmethod
281
- def apply_auth(self, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
322
+ def apply_auth(self, base_url: str, request_kwargs: Dict[str, Any]) -> Dict[str, Any]:
282
323
  """应用自定义认证
283
324
 
284
325
  Args:
@@ -294,7 +335,7 @@ class CustomAuthProvider(AuthProvider):
294
335
  auth_provider_registry = {}
295
336
 
296
337
 
297
- def register_auth_provider(name: str, provider_class: Type[AuthProvider], *args, **kwargs) -> None:
338
+ def register_auth_provider(name: str, provider_class: Type[AuthProvider]) -> None:
298
339
  """注册认证提供者
299
340
 
300
341
  Args:
@@ -306,8 +347,7 @@ def register_auth_provider(name: str, provider_class: Type[AuthProvider], *args,
306
347
  if not issubclass(provider_class, AuthProvider):
307
348
  raise ValueError(f"Provider class must be a subclass of AuthProvider, got {provider_class.__name__}")
308
349
 
309
- provider = provider_class(*args, **kwargs)
310
- auth_provider_registry[name] = provider
350
+ auth_provider_registry[name] = provider_class
311
351
  logger.info(f"Registered auth provider '{name}' with class {provider_class.__name__}")
312
352
 
313
353
 
@@ -399,7 +439,7 @@ def create_auth_provider(auth_config: Dict[str, Any]) -> Optional[AuthProvider]:
399
439
  elif auth_type == "custom":
400
440
  provider_name = auth_config.get("provider_name")
401
441
  if provider_name and provider_name in auth_provider_registry:
402
- return auth_provider_registry[provider_name]
442
+ return auth_provider_registry[provider_name](**auth_config)
403
443
  else:
404
444
  logger.error(f"未找到名为'{provider_name}'的自定义认证提供者")
405
445
  return None
@@ -1,13 +1,10 @@
1
- import os
2
1
  import json
3
- import time
4
2
  import logging
5
- from typing import Dict, List, Any, Optional, Union, Tuple
3
+ from typing import Dict, Any
6
4
  import requests
7
- from requests.exceptions import RequestException
8
5
  from urllib.parse import urljoin
9
6
  from pytest_dsl.core.yaml_vars import yaml_vars
10
- from pytest_dsl.core.auth_provider import AuthProvider, create_auth_provider
7
+ from pytest_dsl.core.auth_provider import create_auth_provider
11
8
 
12
9
  logger = logging.getLogger(__name__)
13
10
 
@@ -125,11 +122,15 @@ class HTTPClient:
125
122
 
126
123
  elif self.auth_provider and 'auth' not in request_kwargs:
127
124
  # 应用认证提供者
128
- request_kwargs = self.auth_provider.apply_auth(request_kwargs)
125
+ request_kwargs = self.auth_provider.apply_auth(self.base_url, request_kwargs)
129
126
  # 如果使用会话,更新会话头
130
127
  if self.use_session and 'headers' in request_kwargs:
131
128
  self._session.headers.update(request_kwargs['headers'])
132
129
 
130
+ # 调用认证提供者的请求前钩子
131
+ if self.auth_provider and not disable_auth:
132
+ request_kwargs = self.auth_provider.pre_request_hook(method, url, request_kwargs)
133
+
133
134
  # 记录请求详情
134
135
  logger.debug(f"=== HTTP请求详情 ===")
135
136
  logger.debug(f"方法: {method}")
@@ -186,6 +187,10 @@ class HTTPClient:
186
187
  if not hasattr(response, 'elapsed_ms'):
187
188
  response.elapsed_ms = response.elapsed.total_seconds() * 1000
188
189
 
190
+ # 调用认证提供者的响应处理钩子
191
+ if self.auth_provider and not disable_auth:
192
+ self.auth_provider.post_response_hook(response, request_kwargs)
193
+
189
194
  return response
190
195
  except requests.exceptions.RequestException as e:
191
196
  logger.error(f"HTTP请求异常: {type(e).__name__}: {str(e)}")