aiointercept 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiointercept/__init__.py +3 -0
- aiointercept/compat.py +19 -0
- aiointercept/core.py +563 -0
- aiointercept-0.1.0.dist-info/METADATA +70 -0
- aiointercept-0.1.0.dist-info/RECORD +6 -0
- aiointercept-0.1.0.dist-info/WHEEL +4 -0
aiointercept/__init__.py
ADDED
aiointercept/compat.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from multidict import MultiDict
|
|
2
|
+
from yarl import URL
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def normalize_url(url: "URL | str") -> URL:
|
|
6
|
+
"""Normalize url to make comparisons."""
|
|
7
|
+
url = URL(url)
|
|
8
|
+
if url.fragment:
|
|
9
|
+
url = url.with_fragment(None)
|
|
10
|
+
return url.with_query(sorted(url.query.items()))
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def merge_params(url: "URL | str", params: "dict[str, str] | None" = None) -> URL:
|
|
14
|
+
url = URL(url)
|
|
15
|
+
if params:
|
|
16
|
+
query_params = MultiDict(url.query)
|
|
17
|
+
query_params.extend(url.with_query(params).query)
|
|
18
|
+
return url.with_query(query_params)
|
|
19
|
+
return url
|
aiointercept/core.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
1
|
+
import socket
|
|
2
|
+
from functools import wraps
|
|
3
|
+
from re import Pattern
|
|
4
|
+
import typing
|
|
5
|
+
from unittest.mock import patch
|
|
6
|
+
import json as json_module
|
|
7
|
+
import inspect
|
|
8
|
+
import warnings
|
|
9
|
+
import gc
|
|
10
|
+
from urllib.parse import parse_qs
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
import aiohttp
|
|
14
|
+
from aiohttp import ClientRequest, ClientResponse, web, hdrs
|
|
15
|
+
from aiohttp.abc import AbstractResolver, ResolveResult
|
|
16
|
+
from aiohttp.connector import SSLContext, TCPConnector
|
|
17
|
+
from aiohttp.formdata import FormData
|
|
18
|
+
from aiohttp.resolver import ThreadedResolver, AsyncResolver
|
|
19
|
+
from aiohttp.test_utils import TestServer
|
|
20
|
+
from aiohttp.web_request import Request
|
|
21
|
+
from yarl import URL
|
|
22
|
+
from typing import Any, Awaitable, Callable, Type
|
|
23
|
+
from .compat import merge_params, normalize_url
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class CallbackResult:
|
|
27
|
+
"""Result object return by a callback"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
method: str = hdrs.METH_GET,
|
|
32
|
+
status: int = 200,
|
|
33
|
+
body: str | bytes = "",
|
|
34
|
+
content_type: str = "application/json",
|
|
35
|
+
payload: dict[str, str] | None = None,
|
|
36
|
+
headers: dict[str, str] | None = None,
|
|
37
|
+
response_class: Type[ClientResponse] | None = None,
|
|
38
|
+
reason: str | None = None,
|
|
39
|
+
):
|
|
40
|
+
self.method = method
|
|
41
|
+
self.status = status
|
|
42
|
+
self.body = body
|
|
43
|
+
self.content_type = content_type
|
|
44
|
+
self.payload = payload
|
|
45
|
+
self.headers = headers
|
|
46
|
+
self.response_class = response_class
|
|
47
|
+
self.reason = reason
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
handler_type = Callable[[web.Request], Awaitable[web.Response]]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class aiointercept:
|
|
54
|
+
"""
|
|
55
|
+
Mock aiohttp requests by redirecting DNS to a local aiohttp.web test server.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
passthrough: list[str] | None = None,
|
|
61
|
+
passthrough_unmatched: bool = False,
|
|
62
|
+
param: str | None = None,
|
|
63
|
+
**kwargs: dict[str, Any],
|
|
64
|
+
) -> None:
|
|
65
|
+
if kwargs:
|
|
66
|
+
warnings.warn(
|
|
67
|
+
"Passing extra parameters to aiointercept via kwargs is deprecated and will be removed in a future release.",
|
|
68
|
+
DeprecationWarning,
|
|
69
|
+
stacklevel=2,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
self._passthrough_urls = passthrough or []
|
|
73
|
+
self._passthrough_hosts: list[str] = []
|
|
74
|
+
|
|
75
|
+
for p in self._passthrough_urls:
|
|
76
|
+
try:
|
|
77
|
+
host = URL(p).host
|
|
78
|
+
self._passthrough_hosts.append(host if host else p)
|
|
79
|
+
except Exception:
|
|
80
|
+
self._passthrough_hosts.append(p)
|
|
81
|
+
|
|
82
|
+
self.param = param
|
|
83
|
+
self.passthrough_unmatched = passthrough_unmatched
|
|
84
|
+
|
|
85
|
+
self._host_list: list[str] = []
|
|
86
|
+
self._patterns_list: list[Pattern[str]] = []
|
|
87
|
+
|
|
88
|
+
# handler are (path, method) → handler or list of handlers (if repeat != True)
|
|
89
|
+
self.handlers: dict[tuple[str, str], handler_type | list[handler_type]] = {}
|
|
90
|
+
# patterns_handler are (pattern, method) → handler or list of handlers (if repeat != True)
|
|
91
|
+
self.patterns_handler: dict[
|
|
92
|
+
tuple[Pattern[str], str], handler_type | list[handler_type]
|
|
93
|
+
] = {}
|
|
94
|
+
|
|
95
|
+
# recorded requests: {(METHOD, URL): [web.Request, ...]}
|
|
96
|
+
self.requests: dict[tuple[str, URL], list[Request]] = {}
|
|
97
|
+
|
|
98
|
+
self.server: TestServer | None = None
|
|
99
|
+
self._patchers: list[Any] = []
|
|
100
|
+
|
|
101
|
+
async def __aenter__(self) -> "aiointercept":
|
|
102
|
+
app = web.Application()
|
|
103
|
+
# we add every route to the app.
|
|
104
|
+
app.router.add_route("*", "/{tail:.*}", self._dispatch)
|
|
105
|
+
self.server = TestServer(app)
|
|
106
|
+
await self.server.start_server()
|
|
107
|
+
|
|
108
|
+
assert isinstance(self.server.host, str) and isinstance(self.server.port, int) # pyright: ignore[reportUnknownMemberType]
|
|
109
|
+
self.server_host = self.server.host
|
|
110
|
+
self.server_port = self.server.port
|
|
111
|
+
|
|
112
|
+
# Patch resolve() on BOTH resolver classes at the class level.
|
|
113
|
+
# This affects every existing and future instance automatically.
|
|
114
|
+
self._originals_resolver: dict[
|
|
115
|
+
Type[AbstractResolver],
|
|
116
|
+
Callable[[Any, str, int, Any], Awaitable[list[ResolveResult]]],
|
|
117
|
+
] = {}
|
|
118
|
+
|
|
119
|
+
for resolver_cls in (ThreadedResolver, AsyncResolver):
|
|
120
|
+
# Capture the original class method
|
|
121
|
+
original_resolve = resolver_cls.resolve
|
|
122
|
+
self._originals_resolver[resolver_cls] = original_resolve
|
|
123
|
+
|
|
124
|
+
# Use a closure to capture the correct 'self' (aiointercept instance)
|
|
125
|
+
# while receiving 'resolver_self' (the resolver instance).
|
|
126
|
+
async def mock_resolve(
|
|
127
|
+
resolver_self: AbstractResolver,
|
|
128
|
+
host: str,
|
|
129
|
+
port: int = 0,
|
|
130
|
+
family: socket.AddressFamily = socket.AF_INET,
|
|
131
|
+
) -> list[ResolveResult]:
|
|
132
|
+
return await self._fake_resolve(resolver_self, host, port, family)
|
|
133
|
+
|
|
134
|
+
p = patch.object(resolver_cls, "resolve", mock_resolve)
|
|
135
|
+
p.start()
|
|
136
|
+
self._patchers.append(p)
|
|
137
|
+
|
|
138
|
+
# Patch _get_ssl_context so that https:// requests to mocked hosts
|
|
139
|
+
# connect to our plain-HTTP TestServer without TLS.
|
|
140
|
+
original_get_ssl_context = TCPConnector._get_ssl_context # pyright: ignore[reportPrivateUsage]
|
|
141
|
+
self._original_ssl_context = original_get_ssl_context
|
|
142
|
+
|
|
143
|
+
def mock_get_ssl_context(connector_self: TCPConnector, req: ClientRequest):
|
|
144
|
+
return self._fake_ssl_context(connector_self, req)
|
|
145
|
+
|
|
146
|
+
p_ssl = patch.object(TCPConnector, "_get_ssl_context", mock_get_ssl_context)
|
|
147
|
+
p_ssl.start()
|
|
148
|
+
self._patchers.append(p_ssl)
|
|
149
|
+
|
|
150
|
+
# Clear the DNS cache on every open connector so cached entries
|
|
151
|
+
# from before our patch was applied cannot bypass us.
|
|
152
|
+
self._clear_all_connector_caches()
|
|
153
|
+
|
|
154
|
+
return self
|
|
155
|
+
|
|
156
|
+
async def __aexit__(
|
|
157
|
+
self,
|
|
158
|
+
exc_type: Type[BaseException] | None,
|
|
159
|
+
exc_val: BaseException | None,
|
|
160
|
+
exc_tb: Any,
|
|
161
|
+
) -> None:
|
|
162
|
+
for p in self._patchers:
|
|
163
|
+
p.stop()
|
|
164
|
+
self._patchers.clear()
|
|
165
|
+
if self.server:
|
|
166
|
+
await self.server.close()
|
|
167
|
+
self.server = None
|
|
168
|
+
self._host_list.clear()
|
|
169
|
+
self._patterns_list.clear()
|
|
170
|
+
self.handlers.clear()
|
|
171
|
+
|
|
172
|
+
# Decorator support
|
|
173
|
+
def __call__(
|
|
174
|
+
self, f: Callable[..., Awaitable[Any]]
|
|
175
|
+
) -> Callable[..., Awaitable[Any]]:
|
|
176
|
+
@wraps(f)
|
|
177
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
178
|
+
async with self as m:
|
|
179
|
+
if self.param:
|
|
180
|
+
kwargs[self.param] = m
|
|
181
|
+
else:
|
|
182
|
+
if args and hasattr(args[0], f.__name__):
|
|
183
|
+
args = (args[0], m) + args[1:]
|
|
184
|
+
else:
|
|
185
|
+
args = args + (m,)
|
|
186
|
+
return await f(*args, **kwargs)
|
|
187
|
+
|
|
188
|
+
return wrapper
|
|
189
|
+
|
|
190
|
+
def _fake_ssl_context(
|
|
191
|
+
self, connector_self: TCPConnector, req: ClientRequest
|
|
192
|
+
) -> SSLContext | None:
|
|
193
|
+
"""Return None (no TLS) for mocked hosts, real SSL context otherwise."""
|
|
194
|
+
host = req.url.raw_host
|
|
195
|
+
if host in self._host_list or self._match_pattern(str(req.url)):
|
|
196
|
+
# Our TestServer is plain HTTP — disable TLS for mocked hosts.
|
|
197
|
+
return None
|
|
198
|
+
# For unmocked hosts, use the original method to get the correct context.
|
|
199
|
+
original = self._original_ssl_context
|
|
200
|
+
return original(connector_self, req)
|
|
201
|
+
|
|
202
|
+
def _match_pattern(self, url: str) -> bool:
|
|
203
|
+
for pattern in self._patterns_list:
|
|
204
|
+
if pattern.match(url):
|
|
205
|
+
return True
|
|
206
|
+
return False
|
|
207
|
+
|
|
208
|
+
async def _fake_resolve(
|
|
209
|
+
self,
|
|
210
|
+
resolver_self: AbstractResolver,
|
|
211
|
+
host: str,
|
|
212
|
+
port: int = 0,
|
|
213
|
+
family: socket.AddressFamily = socket.AF_INET,
|
|
214
|
+
) -> list[ResolveResult]:
|
|
215
|
+
"""Replacement for resolver.resolve() on both resolver classes."""
|
|
216
|
+
# if there is pattern, we always match, because we dont have full url
|
|
217
|
+
if host in self._host_list or self._patterns_list:
|
|
218
|
+
return [
|
|
219
|
+
ResolveResult(
|
|
220
|
+
hostname=host,
|
|
221
|
+
host=self.server_host,
|
|
222
|
+
port=self.server_port,
|
|
223
|
+
family=family,
|
|
224
|
+
proto=0,
|
|
225
|
+
flags=0,
|
|
226
|
+
)
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
# Not mocked — check if it's a passthrough host or we allow unmatched.
|
|
230
|
+
if host in self._passthrough_hosts or self.passthrough_unmatched:
|
|
231
|
+
original = self._originals_resolver[type(resolver_self)]
|
|
232
|
+
return await original(resolver_self, host, port, family)
|
|
233
|
+
|
|
234
|
+
# if no passthrough and no match, we return a redirection to localhost
|
|
235
|
+
# but will not be any handler registered for it, so it will raise a ClientConnectionError on the other side
|
|
236
|
+
return [
|
|
237
|
+
ResolveResult(
|
|
238
|
+
hostname=host,
|
|
239
|
+
host=self.server_host,
|
|
240
|
+
port=self.server_port,
|
|
241
|
+
family=family,
|
|
242
|
+
proto=0,
|
|
243
|
+
flags=0,
|
|
244
|
+
)
|
|
245
|
+
]
|
|
246
|
+
|
|
247
|
+
@staticmethod
|
|
248
|
+
def _clear_all_connector_caches() -> None:
|
|
249
|
+
"""
|
|
250
|
+
Walk every TCPConnector referenced by a live ClientSession and clear
|
|
251
|
+
its DNS cache. This ensures pre-patch resolutions are not reused.
|
|
252
|
+
"""
|
|
253
|
+
for obj in gc.get_objects():
|
|
254
|
+
if not issubclass(type(obj), aiohttp.TCPConnector):
|
|
255
|
+
continue
|
|
256
|
+
try:
|
|
257
|
+
obj.clear_dns_cache()
|
|
258
|
+
except Exception:
|
|
259
|
+
pass
|
|
260
|
+
|
|
261
|
+
async def _dispatch(self, request: web.Request) -> web.Response:
|
|
262
|
+
key = (request.method.upper(), normalize_url(request.url))
|
|
263
|
+
self.requests.setdefault(key, [])
|
|
264
|
+
request._captured_body = await request.read() if request.can_read_body else b""
|
|
265
|
+
try:
|
|
266
|
+
json = (
|
|
267
|
+
json_module.loads(request._captured_body) # type: ignore[attr-defined]
|
|
268
|
+
if request._captured_body # type: ignore[attr-defined]
|
|
269
|
+
else None
|
|
270
|
+
)
|
|
271
|
+
except Exception:
|
|
272
|
+
json = None
|
|
273
|
+
# this kwargs will be removed, should be deprecated in the future
|
|
274
|
+
request.kwargs = {
|
|
275
|
+
"headers": request.headers,
|
|
276
|
+
"query": dict(request.query),
|
|
277
|
+
"json": json,
|
|
278
|
+
}
|
|
279
|
+
# Read body eagerly before the handler runs, because aiohttp sets
|
|
280
|
+
# PayloadAccessError on the stream once the response cycle completes.
|
|
281
|
+
self.requests[key].append(request)
|
|
282
|
+
selected_handler = self.handlers.get((request.path, request.method))
|
|
283
|
+
if isinstance(selected_handler, list):
|
|
284
|
+
if not selected_handler:
|
|
285
|
+
handler: handler_type | None = None
|
|
286
|
+
else:
|
|
287
|
+
handler = typing.cast(handler_type, selected_handler.pop(0))
|
|
288
|
+
|
|
289
|
+
else:
|
|
290
|
+
handler = selected_handler
|
|
291
|
+
if handler is None:
|
|
292
|
+
# Reconstruct the original URL: aiohttp always sends the original Host header,
|
|
293
|
+
# but request.url reflects the local TestServer address. Try both schemes.
|
|
294
|
+
original_host = request.headers.get("Host", request.url.host)
|
|
295
|
+
original_urls = [
|
|
296
|
+
f"https://{original_host}{request.path_qs}",
|
|
297
|
+
f"http://{original_host}{request.path_qs}",
|
|
298
|
+
]
|
|
299
|
+
# Check if there's a pattern handler for this request
|
|
300
|
+
for (pattern, method), pattern_handler in self.patterns_handler.items():
|
|
301
|
+
if (
|
|
302
|
+
any(pattern.match(u) for u in original_urls)
|
|
303
|
+
and method == request.method
|
|
304
|
+
):
|
|
305
|
+
if isinstance(pattern_handler, list):
|
|
306
|
+
handler = pattern_handler[0]
|
|
307
|
+
remaining = pattern_handler[1:]
|
|
308
|
+
if remaining:
|
|
309
|
+
self.patterns_handler[pattern, request.method] = remaining
|
|
310
|
+
else:
|
|
311
|
+
del self.patterns_handler[pattern, request.method]
|
|
312
|
+
else:
|
|
313
|
+
handler = pattern_handler
|
|
314
|
+
break
|
|
315
|
+
|
|
316
|
+
if handler is None:
|
|
317
|
+
# this should raise ClientConnectionError on the other side
|
|
318
|
+
if request.transport:
|
|
319
|
+
request.transport.close()
|
|
320
|
+
return web.Response(
|
|
321
|
+
status=502, text="No handler registered for this request."
|
|
322
|
+
)
|
|
323
|
+
return await handler(request)
|
|
324
|
+
|
|
325
|
+
def add(
|
|
326
|
+
self,
|
|
327
|
+
url: URL | str | Pattern[str],
|
|
328
|
+
method: str = hdrs.METH_GET,
|
|
329
|
+
status: int = 200,
|
|
330
|
+
body: str | bytes = b"",
|
|
331
|
+
json: Any = None,
|
|
332
|
+
payload: dict | None = None,
|
|
333
|
+
headers: dict | None = None,
|
|
334
|
+
repeat: bool | int = False,
|
|
335
|
+
content_type: str | None = None,
|
|
336
|
+
callback: Callable[[URL | Pattern[str]], CallbackResult] | None = None,
|
|
337
|
+
reason: str | None = None,
|
|
338
|
+
exception: Exception | None = None,
|
|
339
|
+
**kwargs,
|
|
340
|
+
) -> None:
|
|
341
|
+
if exception is not None:
|
|
342
|
+
# if there is an excpetion, dont add handler, will return a clientDisconnectionError
|
|
343
|
+
# add some deprecation or similar
|
|
344
|
+
return
|
|
345
|
+
method = method.upper()
|
|
346
|
+
if isinstance(url, str):
|
|
347
|
+
url = URL(url)
|
|
348
|
+
|
|
349
|
+
if isinstance(url, Pattern):
|
|
350
|
+
self._patterns_list.append(url)
|
|
351
|
+
|
|
352
|
+
assert self.server is not None, (
|
|
353
|
+
"Server not started — use `async with aiointercept() as m:` first."
|
|
354
|
+
)
|
|
355
|
+
if isinstance(url, URL):
|
|
356
|
+
host = url.host
|
|
357
|
+
assert host, f"Cannot extract host from {url!r}"
|
|
358
|
+
|
|
359
|
+
# Map this host → our test server
|
|
360
|
+
self._host_list.append(host)
|
|
361
|
+
|
|
362
|
+
if json is not None:
|
|
363
|
+
body = json_module.dumps(json).encode()
|
|
364
|
+
elif payload is not None:
|
|
365
|
+
body = json_module.dumps(payload).encode()
|
|
366
|
+
elif isinstance(body, str):
|
|
367
|
+
body = body.encode()
|
|
368
|
+
|
|
369
|
+
resp_headers = dict(headers or {})
|
|
370
|
+
if not content_type and body and "Content-Type" not in resp_headers:
|
|
371
|
+
content_type = "application/json"
|
|
372
|
+
|
|
373
|
+
async def handler(request: web.Request) -> web.Response:
|
|
374
|
+
if callable(callback):
|
|
375
|
+
if inspect.iscoroutinefunction(callback):
|
|
376
|
+
result = await callback(url, **request.kwargs) # type: ignore[attr-defined]
|
|
377
|
+
else:
|
|
378
|
+
result = callback(url, **request.kwargs) # type: ignore[attr-defined]
|
|
379
|
+
_status = result.status
|
|
380
|
+
_body = result.body
|
|
381
|
+
_headers = result.headers or {}
|
|
382
|
+
if result.payload is not None:
|
|
383
|
+
_body = json_module.dumps(result.payload).encode()
|
|
384
|
+
_content_type = result.content_type
|
|
385
|
+
_reason = result.reason
|
|
386
|
+
else:
|
|
387
|
+
_status = status
|
|
388
|
+
_body = body
|
|
389
|
+
_headers = headers
|
|
390
|
+
_content_type = content_type
|
|
391
|
+
_reason = reason
|
|
392
|
+
|
|
393
|
+
return web.Response(
|
|
394
|
+
status=_status,
|
|
395
|
+
body=_body,
|
|
396
|
+
headers=_headers,
|
|
397
|
+
reason=_reason,
|
|
398
|
+
content_type=_content_type,
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
if repeat is True:
|
|
402
|
+
if isinstance(url, Pattern):
|
|
403
|
+
self.patterns_handler[url, method] = handler
|
|
404
|
+
return
|
|
405
|
+
path = url.path or "/"
|
|
406
|
+
self.handlers[path, method] = handler
|
|
407
|
+
else:
|
|
408
|
+
if repeat is False:
|
|
409
|
+
repeat = 1
|
|
410
|
+
handlers: list[handler_type] = [handler] * repeat
|
|
411
|
+
if isinstance(url, Pattern):
|
|
412
|
+
if (url, method) in self.patterns_handler:
|
|
413
|
+
list_pattern_handler = self.patterns_handler[(url, method)]
|
|
414
|
+
if isinstance(list_pattern_handler, list):
|
|
415
|
+
list_pattern_handler = typing.cast(
|
|
416
|
+
list[handler_type], list_pattern_handler
|
|
417
|
+
)
|
|
418
|
+
list_pattern_handler += handlers
|
|
419
|
+
else:
|
|
420
|
+
raise ValueError(
|
|
421
|
+
f"Existing handler for pattern {url} {method} has repeat=True, cannot add more handlers to it."
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
else:
|
|
425
|
+
self.patterns_handler[url, method] = handlers
|
|
426
|
+
return
|
|
427
|
+
path = url.path or "/"
|
|
428
|
+
if (path, method) in self.handlers:
|
|
429
|
+
handlers_list = self.handlers[(path, method)]
|
|
430
|
+
if isinstance(handlers_list, list):
|
|
431
|
+
handlers_list = typing.cast(list[handler_type], handlers_list)
|
|
432
|
+
handlers_list += handlers
|
|
433
|
+
else:
|
|
434
|
+
raise ValueError(
|
|
435
|
+
f"Existing handler for {path} {method} has repeat=True, cannot add more handlers to it."
|
|
436
|
+
)
|
|
437
|
+
else:
|
|
438
|
+
self.handlers[path, method] = handlers
|
|
439
|
+
|
|
440
|
+
def get(self, url, **kwargs):
|
|
441
|
+
self.add(url, method=hdrs.METH_GET, **kwargs)
|
|
442
|
+
|
|
443
|
+
def post(self, url, **kwargs):
|
|
444
|
+
self.add(url, method=hdrs.METH_POST, **kwargs)
|
|
445
|
+
|
|
446
|
+
def put(self, url, **kwargs):
|
|
447
|
+
self.add(url, method=hdrs.METH_PUT, **kwargs)
|
|
448
|
+
|
|
449
|
+
def patch(self, url, **kwargs):
|
|
450
|
+
self.add(url, method=hdrs.METH_PATCH, **kwargs)
|
|
451
|
+
|
|
452
|
+
def delete(self, url, **kwargs):
|
|
453
|
+
self.add(url, method=hdrs.METH_DELETE, **kwargs)
|
|
454
|
+
|
|
455
|
+
def head(self, url, **kwargs):
|
|
456
|
+
self.add(url, method=hdrs.METH_HEAD, **kwargs)
|
|
457
|
+
|
|
458
|
+
def options(self, url, **kwargs):
|
|
459
|
+
self.add(url, method=hdrs.METH_OPTIONS, **kwargs)
|
|
460
|
+
|
|
461
|
+
def clear(self):
|
|
462
|
+
self.requests.clear()
|
|
463
|
+
self.handlers.clear()
|
|
464
|
+
self.patterns_handler.clear()
|
|
465
|
+
|
|
466
|
+
def assert_called(self):
|
|
467
|
+
if not self.requests:
|
|
468
|
+
raise AssertionError("Expected at least one call, got none.")
|
|
469
|
+
|
|
470
|
+
def assert_not_called(self):
|
|
471
|
+
if self.requests:
|
|
472
|
+
raise AssertionError(
|
|
473
|
+
f"Expected no calls, got {sum(len(v) for v in self.requests.values())}."
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
def assert_called_once(self):
|
|
477
|
+
count = sum(len(v) for v in self.requests.values())
|
|
478
|
+
if count != 1:
|
|
479
|
+
raise AssertionError(f"Expected exactly 1 call, got {count}.")
|
|
480
|
+
|
|
481
|
+
def assert_any_call(
|
|
482
|
+
self,
|
|
483
|
+
url: URL | str,
|
|
484
|
+
method: str = hdrs.METH_GET,
|
|
485
|
+
params: dict[str, str] | None = None,
|
|
486
|
+
):
|
|
487
|
+
url = normalize_url(merge_params(url, params))
|
|
488
|
+
key = (method.upper(), url)
|
|
489
|
+
if key not in self.requests:
|
|
490
|
+
raise AssertionError(f"No calls to {method.upper()} {url}")
|
|
491
|
+
|
|
492
|
+
def assert_called_with(
|
|
493
|
+
self,
|
|
494
|
+
url: URL | str,
|
|
495
|
+
method: str = hdrs.METH_GET,
|
|
496
|
+
params: dict[str, str] | None = None,
|
|
497
|
+
data: str | bytes | dict[str, Any] | None = None,
|
|
498
|
+
json: Any = None,
|
|
499
|
+
headers: dict[str, str] | None = None,
|
|
500
|
+
):
|
|
501
|
+
url = normalize_url(merge_params(url, params))
|
|
502
|
+
key = (method.upper(), url)
|
|
503
|
+
if key not in self.requests:
|
|
504
|
+
raise AssertionError(f"No calls to {method.upper()} {url}")
|
|
505
|
+
request = self.requests[key][0] # check the first call
|
|
506
|
+
actual_body = getattr(request, "_captured_body", b"")
|
|
507
|
+
if json is not None:
|
|
508
|
+
# aiohttp sends json= as JSON-encoded bytes with application/json
|
|
509
|
+
expected_body = json_module.dumps(json).encode()
|
|
510
|
+
assert actual_body == expected_body, (
|
|
511
|
+
f"Expected JSON body {json!r}, got {actual_body!r}"
|
|
512
|
+
)
|
|
513
|
+
elif data is not None:
|
|
514
|
+
if isinstance(data, dict):
|
|
515
|
+
# aiohttp sends data=dict via FormData as application/x-www-form-urlencoded.
|
|
516
|
+
# Use FormData to produce the exact same encoding aiohttp does.
|
|
517
|
+
form_encoded = FormData(data)()._value # type: ignore[attr-defined]
|
|
518
|
+
# Accept order-insensitive comparison via parse_qs
|
|
519
|
+
actual_qs = parse_qs(actual_body.decode(errors="replace"))
|
|
520
|
+
expected_qs = parse_qs(form_encoded.decode())
|
|
521
|
+
match = actual_body == form_encoded or actual_qs == expected_qs
|
|
522
|
+
assert match, (
|
|
523
|
+
f"Expected body {data!r} (form encoded), got {actual_body!r}"
|
|
524
|
+
)
|
|
525
|
+
else:
|
|
526
|
+
if isinstance(data, str):
|
|
527
|
+
expected_body = data.encode()
|
|
528
|
+
else:
|
|
529
|
+
expected_body = data
|
|
530
|
+
assert actual_body == expected_body, (
|
|
531
|
+
f"Expected body {expected_body!r}, got {actual_body!r}"
|
|
532
|
+
)
|
|
533
|
+
actual_headers = dict(request.headers)
|
|
534
|
+
# Strip headers that aiohttp adds automatically, unless the caller
|
|
535
|
+
# explicitly wants to assert them.
|
|
536
|
+
for header in (
|
|
537
|
+
"Content-Length",
|
|
538
|
+
"Content-Type",
|
|
539
|
+
"Transfer-Encoding",
|
|
540
|
+
"Host",
|
|
541
|
+
"Accept",
|
|
542
|
+
"Accept-Encoding",
|
|
543
|
+
"User-Agent",
|
|
544
|
+
):
|
|
545
|
+
if header not in (headers or {}):
|
|
546
|
+
# this should be deprecated in the future, but for now we want to avoid breaking existing tests that don't specify these headers
|
|
547
|
+
actual_headers.pop(header, None)
|
|
548
|
+
expected_headers = headers or {}
|
|
549
|
+
assert expected_headers == actual_headers, (
|
|
550
|
+
f"Expected headers {expected_headers!r}, got {actual_headers!r}"
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
def assert_called_once_with(
|
|
554
|
+
self,
|
|
555
|
+
url: URL | str,
|
|
556
|
+
method: str = hdrs.METH_GET,
|
|
557
|
+
params: dict[str, str] | None = None,
|
|
558
|
+
data: str | bytes | dict[str, Any] | None = None,
|
|
559
|
+
json: Any = None,
|
|
560
|
+
headers: dict[str, str] | None = None,
|
|
561
|
+
):
|
|
562
|
+
self.assert_called_once()
|
|
563
|
+
self.assert_called_with(url, method, params, data, json, headers)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: aiointercept
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Author-email: Pablo Estevez <pablo22estevez@gmail.com>
|
|
5
|
+
Requires-Python: >=3.10
|
|
6
|
+
Requires-Dist: aiohttp<4.0.0,>=3.13.3
|
|
7
|
+
Requires-Dist: pytest-asyncio<2.0.0,>=1.3.0
|
|
8
|
+
Requires-Dist: yarl<2.0.0,>=1.23.0
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
|
|
11
|
+
Proof of concept for an aioresponses-like library that uses real aiohttp server
|
|
12
|
+
related to this discussion: https://github.com/orgs/aio-libs/discussions/45
|
|
13
|
+
|
|
14
|
+
This is a proof of concept of how a linrary could expose aioresponses-like API
|
|
15
|
+
while using real aiohttp server under the hood. It works only with GET requests,
|
|
16
|
+
only compare the url, param and method, and only implements assert_called_with.
|
|
17
|
+
|
|
18
|
+
The breaking changes will be:
|
|
19
|
+
- The context manager will be async, as it will need to start the server
|
|
20
|
+
- the assert_called_* methods will not work with arbitrary kwargs, instead they will only work with specific args that the request was made with (e.g. url, method, headers, params, etc.)
|
|
21
|
+
- the fragments (after #) are not being passed on the request, so will not be compared
|
|
22
|
+
- on a connector exception, aiohttp could retry, so the number of request will not be the same
|
|
23
|
+
- pass timeout don't work
|
|
24
|
+
- you cant raise exceptions exception clienterror
|
|
25
|
+
- the decorathor need to decorate an async function
|
|
26
|
+
- as this mock DNS will not work to mock request to external IPs
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
Added a perf test to compare the performance of the POC with the original aioresponses implementation.
|
|
30
|
+
|
|
31
|
+
Results:
|
|
32
|
+
|
|
33
|
+
aioresponses (main) : 0.7223 seconds for 1000 iter (0.72 ms/iter)
|
|
34
|
+
aioresponses_2 (POC) : 0.8272 seconds for 1000 iter (0.83 ms/iter)
|
|
35
|
+
|
|
36
|
+
Probablly this overheard is tolerable, and could be optimized.
|
|
37
|
+
However, once we start adding more features, the performance will degrade.
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
Changing:
|
|
41
|
+
|
|
42
|
+
@pytest.fixture
|
|
43
|
+
def mock_aioresponse():
|
|
44
|
+
with aioresponses() as m:
|
|
45
|
+
|
|
46
|
+
to:
|
|
47
|
+
|
|
48
|
+
@pytest_asyncio.fixture
|
|
49
|
+
async def mock_aioresponse():
|
|
50
|
+
async with aioresponses() as m:
|
|
51
|
+
yield m
|
|
52
|
+
|
|
53
|
+
is passing the tests on:
|
|
54
|
+
|
|
55
|
+
https://github.com/cortega26/PDF-Text-Analyzer
|
|
56
|
+
https://github.com/mxr/reconciler-for-ynab
|
|
57
|
+
https://github.com/pratik-choudhari/Financial-news-scraper /with some unrelated changes because is broken
|
|
58
|
+
https://github.com/dorianrod/GitReviewLens
|
|
59
|
+
|
|
60
|
+
https://github.com/natekspencer/pylitterbot required some changes:
|
|
61
|
+
- there were raising a custom exception on request on test_litter_robot_5_dispatch_command_failure, and that is not supported.
|
|
62
|
+
- - However, they were try to mimick a 500 with certain json that was supported
|
|
63
|
+
- There were some calls to localhost, that should be to https://localhost. This was
|
|
64
|
+
also a fixable error on the tests
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
Broken on main
|
|
68
|
+
https://github.com/symphony-youri/symphony-api-client-python.git
|
|
69
|
+
|
|
70
|
+
too complex to work: https://github.com/mguidon/osparc-simcore
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
aiointercept/__init__.py,sha256=EmhlxjHUM7ewTEJQJn7sEt-GwsGzHpv_gk5s5mU2mqo,93
|
|
2
|
+
aiointercept/compat.py,sha256=dA9u1765GP8HtA-a8oapxRlNv3gcvw8cNp7Wnz9C1Ew,555
|
|
3
|
+
aiointercept/core.py,sha256=8gx0qAre3W8vOy58Af5RXrIMnLi0QCzFpPbEJd553Nk,21634
|
|
4
|
+
aiointercept-0.1.0.dist-info/METADATA,sha256=nCvhEt5jEvqNjOX9_d1i4U-pSk-ZRqUKa_oVTgQxd0g,2756
|
|
5
|
+
aiointercept-0.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
|
|
6
|
+
aiointercept-0.1.0.dist-info/RECORD,,
|